/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

/*
    [Credits]:
        Jakemichie97 (https://twitter.com/jakemichie97)
        Samuel (https://github.com/swr06)
        L4mbads
        sixthsurge (https://github.com/sixthsurge)
        
    [References]:
        Schied et al. (2017). Spatiotemporal Variance-Guided Filtering: Real-Time Reconstruction for Path-Traced Global Illumination. https://research.nvidia.com/publication/2017-07_spatiotemporal-variance-guided-filtering-real-time-reconstruction-path-traced
        Dundr, J. (2018). Progressive Spatiotemporal Variance-Guided Filtering. https://cescg.org/wp-content/uploads/2018/04/Dundr-Progressive-Spatiotemporal-Variance-Guided-Filtering-2.pdf
        Galvan, A. (2020). Ray Tracing Denoising. https://alain.xyz/blog/ray-tracing-denoising
*/

#include "/settings.glsl"

#if RENDER_MODE == 0 && ATROUS_FILTER == 1 && DEBUG_ALBEDO == 0 && DEBUG_NORMALS == 0 && DEBUG_HIT_POSITION == 0
    #if defined STAGE_VERTEX

	    out vec2 textureCoords;

	    void main() {
		    textureCoords = gl_Vertex.xy;
		    gl_Position   = vec4(gl_Vertex.xy * 2.0 - 1.0, 1.0, 1.0);
	    }

    #elif defined STAGE_FRAGMENT

        #if ATROUS_PASS_INDEX == 0
            #define INPUT_BUFFER IRRADIANCE_BUFFER
        #else
            #define INPUT_BUFFER LIGHTING_BUFFER
        #endif

        /* RENDERTARGETS: 1,12 */

        layout (location = 0) out vec4 irradiance;
        layout (location = 1) out vec4 moments;

        in vec2 textureCoords;

        #include "/include/common.glsl"

        const float waveletKernel[3] = float[3](1.0, 2.0 / 3.0, 1.0 / 6.0);

        const float stepSize = ATROUS_STEP_SIZE * pow(0.5, 4 - ATROUS_PASS_INDEX);

        float calculateATrousNormalWeight(vec3 normal, vec3 sampleNormal) {   
            return pow(max0(dot(normal, sampleNormal)), ATROUS_NORMAL_WEIGHT_SIGMA);
        }

        float calculateATrousDepthWeight(float depth, float sampleDepth, vec2 depthGradient, vec2 offset) {
            return exp(-abs(linearizeDepth(depth) - linearizeDepth(sampleDepth)) / (abs(ATROUS_DEPTH_WEIGHT_SIGMA * dot(depthGradient, offset)) + 0.8));
        }

        float calculateATrousLuminanceWeight(float luminance, float sampleLuminance, float variance) {
            return exp(-abs(luminance - sampleLuminance) / (ATROUS_LUMINANCE_WEIGHT_SIGMA * sqrt(variance) + 0.01));
        }

        float gaussianVariance() {
            float varianceSum = 0.0, totalWeight = EPS;

            const float gaussianKernel[3] = float[3](0.25, 0.125, 0.0625);
            
            for(int x = -1; x <= 1; x++) {
                for(int y = -1; y <= 1; y++) {
                    vec2  offset       = vec2(x, y) * texelSize;
                    vec2  sampleCoords = textureCoords + offset;

                    if(saturate(sampleCoords) == sampleCoords) {
                        float weight   = gaussianKernel[abs(x)] * gaussianKernel[abs(y)];
                        float variance = texture(MOMENTS_BUFFER, sampleCoords).b;

                        varianceSum += variance * weight;
                        totalWeight += weight;
                    }
                }
            }
            return (varianceSum / totalWeight) * 3.0;
        }

        void aTrousFilter(inout vec3 irradiance, inout vec3 moments) {
            float depth = texture(depthtex0, textureCoords).r;
            if(depth == 1.0) return;

            uvec4 dataTexture = texelFetch(RASTER_DATA_BUFFER, ivec2(textureCoords * viewSize), 0);
		    vec3  normal      = decodeUnitVector((uvec2(dataTexture.w) >> uvec2(0, 16) & 65535u) * rcpMaxFloat16);

            float accumulatedSamples = texture(IRRADIANCE_BUFFER, textureCoords).a;
            float frameWeight        = float(accumulatedSamples > MIN_FRAMES_LUMINANCE_WEIGHT);

            float linearDepth   = linearizeDepth(depth);
            vec2  depthGradient = vec2(dFdx(linearDepth), dFdy(linearDepth));

            float centerLuminance  = luminance(irradiance);
            float filteredVariance = gaussianVariance();
            
            float totalWeight = 1.0;

            for(int x = -1; x <= 1; x++) {
                for(int y = -1; y <= 1; y++) {
                    if(x == 0 && y == 0) continue;

                    vec2 offset       = vec2(x, y) * stepSize * texelSize;
                    vec2 sampleCoords = textureCoords + offset;

                    if(saturate(sampleCoords) != sampleCoords) continue;

                    uvec4 sampleDataTexture = texelFetch(RASTER_DATA_BUFFER, ivec2(sampleCoords * viewSize), 0);
		            vec3  sampleNormal      = decodeUnitVector((uvec2(sampleDataTexture.w) >> uvec2(0, 16) & 65535u) * rcpMaxFloat16);

                    float sampleDepth = texture(depthtex0, sampleCoords).r;

                    vec3  sampleIrradiance = texture(INPUT_BUFFER  , sampleCoords).rgb;
                    float sampleVariance   = texture(MOMENTS_BUFFER, sampleCoords).b;

                    float normalWeight    = calculateATrousNormalWeight(normal, sampleNormal);
                    float depthWeight     = calculateATrousDepthWeight(depth, sampleDepth, depthGradient, offset);
                    float luminanceWeight = calculateATrousLuminanceWeight(centerLuminance, luminance(sampleIrradiance), filteredVariance);

                    float weight  = normalWeight * depthWeight * mix(1.0, luminanceWeight, 1.0);
                          weight *= waveletKernel[abs(x)] * waveletKernel[abs(y)];

                    irradiance  += sampleIrradiance * weight;
                    moments.b   += sampleVariance   * weight * weight;
                    totalWeight += weight;
                }
            }
            irradiance /= totalWeight;
            moments.b  /= (totalWeight * totalWeight);
        }

        void main() {
            ivec2 coords = ivec2(gl_FragCoord.xy);

            irradiance = texelFetch(INPUT_BUFFER  , coords, 0);
            moments    = texelFetch(MOMENTS_BUFFER, coords, 0);

            aTrousFilter(irradiance.rgb, moments.rgb);
        }
        
    #endif
#else
    #include "/programs/discard.glsl"
#endif
