/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

#if defined STAGE_VERTEX

	out vec2 textureCoords;

	void main() {
		textureCoords = gl_Vertex.xy;
		gl_Position   = vec4(gl_Vertex.xy * 2.0 - 1.0, 1.0, 1.0);
	}

#elif defined STAGE_FRAGMENT

	#extension GL_ARB_shader_texture_lod : enable

	#include "/settings.glsl"
	#include "/include/common.glsl"

	#if RENDER_MODE == 0
		#if RESTIR_GI == 1 || TEMPORAL_ACCUMULATION == 1
			/* RENDERTARGETS: 3,6,12 */

			layout (location = 0) out vec4 radianceOut;
			layout (location = 1) out uint normalOut;
			layout (location = 2) out vec4 momentsOut;
		#else
			/* RENDERTARGETS: 3,12 */

			layout (location = 0) out vec4 radianceOut;
			layout (location = 1) out vec4 momentsOut;
		#endif
	#else
		/* RENDERTARGETS: 3 */
		
		layout (location = 0) out vec4 radianceOut;
	#endif

	in vec2 textureCoords;

	#if RENDER_MODE == 0
		#if ATROUS_FILTER == 1
			float estimateSpatialVariance(sampler2D tex, vec2 moments) {
				float sum = moments.r, sqSum = moments.g, totalWeight = 1.0;

				const float waveletKernel[3] = float[3](1.0, 2.0 / 3.0, 1.0 / 6.0);
				
				const vec2 stepSize = 8.0 * texelSize;

				for(int x = -1; x <= 1; x++) {
					for(int y = -1; y <= 1; y++) {
						if(x == 0 && y == 0) continue;

						vec2 sampleCoords = textureCoords + vec2(x, y) * stepSize;
						if(saturate(sampleCoords) != sampleCoords) continue;

						float weight    = waveletKernel[abs(x)] * waveletKernel[abs(y)];
						float luminance = luminance(texture(tex, sampleCoords).rgb);
					
						sum   += luminance * weight;
						sqSum += luminance * luminance * weight;

						totalWeight += weight;
					}
				}
				sum   /= totalWeight;
				sqSum /= totalWeight;
				return max0(sqSum - sum * sum);
			}
		#endif
		
		#if TEMPORAL_ACCUMULATION == 1

			float cubic(float x) {
				x = abs(x);
				int segment = int(x);
				x = fract(x);

				switch(segment) {
					case 0:  return 1.0 + x * x * (3.0 * x - 5.0) * 0.5;
					case 1:  return x * (x * (2.0 - x) - 1.0) * 0.5;
					default: return 0.0;
				}
			}

			vec2 cubic(vec2 coords) {
				return vec2(cubic(coords.x), cubic(coords.y));
			}

			vec4 filterHistory(sampler2D tex, vec2 coords, vec3 normal, float depth, out bool rejectHistory) {
				vec2 resolution = floor(viewSize * RENDER_SCALE * 0.01);

				coords = coords * resolution - 0.5;

				ivec2 fragCoords = ivec2(floor(coords));

				coords = fract(coords);

				vec4  history     = vec4(0.0);
				float totalWeight = 0.0;

				vec4 minColor = vec4(1e10), maxColor = vec4(0.0);

				float centerLuminance = 0.0;
				float luminanceSum    = 0.0;

				for(int x = -1; x <= 2; x++) {
					for(int y = -1; y <= 2; y++) {
						ivec2 sampleCoords = fragCoords + ivec2(x, y);

						if(clamp(sampleCoords, ivec2(0), ivec2(resolution)) != sampleCoords) continue;

						float sampleDepth  = linearizeDepth(exp2(texelFetch(MOMENTS_BUFFER, sampleCoords, 0).a));
						vec3  sampleNormal = decodeUnitVector((uvec2(texelFetch(NORMAL_BUFFER, sampleCoords, 0).r) >> uvec2(0, 16) & 65535u) * rcpMaxFloat16);

						float depthWeight  = pow(exp(-abs(sampleDepth - depth)), TEMPORAL_DEPTH_SIGMA);
						float normalWeight = pow(max0(dot(sampleNormal, normal)), TEMPORAL_NORMAL_SIGMA);

						vec2 cubicWeights = cubic(abs(vec2(x, y) - coords));

						float weight = cubicWeights.x * cubicWeights.y * depthWeight * normalWeight;

						vec4 sampleColor = texelFetch(tex, sampleCoords, 0);

						float sampleLuminance = luminance(sampleColor.rgb);

						if(x == 0 && y == 0) centerLuminance = sampleLuminance;
						else                 luminanceSum   += sampleLuminance * weight;

						history     += sampleColor * weight;
						totalWeight += weight;

						minColor = min(minColor, sampleColor);
						maxColor = max(maxColor, sampleColor);
					}
				}
				history = clamp(history / totalWeight, minColor, maxColor);

				bool fireflyRejection = distance(centerLuminance, luminanceSum / totalWeight) > LUMINANCE_DIFFERENCE_THRESHOLD;

				rejectHistory = totalWeight <= 1e-3;
				
				return history;
			}

		#endif
	#endif

	void main() {
		#if DEBUG_ALBEDO == 1 || DEBUG_NORMALS == 1 || DEBUG_HIT_POSITION == 1
			radianceOut.rgb = texture(LIGHTING_BUFFER, textureCoords).rgb;
			return;
		#endif

		float depth = texture(depthtex0, textureCoords).r;

		#if RENDER_MODE == 0
			if(depth == 1.0) return;

			#if RESTIR_GI == 1 || TEMPORAL_ACCUMULATION == 1
				uvec4 dataTexture   = texelFetch(RASTER_DATA_BUFFER, ivec2(textureCoords * viewSize), 0);
				vec3  visibleNormal = decodeUnitVector((uvec2(dataTexture.w) >> uvec2(0, 16) & 65535u) * rcpMaxFloat16);

				normalOut = dataTexture.w;
			#endif
		#endif

		vec3 currPosition = vec3(textureCoords, depth);
		vec3 prevPosition = currPosition - getVelocity(getClosestFragment(currPosition));

		vec3 radiance = texture(LIGHTING_BUFFER, textureCoords).rgb;
	
		#if TEMPORAL_ACCUMULATION == 1
			if(saturate(prevPosition.xy) == prevPosition.xy) {
				#if RENDER_MODE == 0
					vec3 viewPosition   = screenToView(currPosition, false);
					vec3 sceneDirection = -normalize(mat3(gbufferModelViewInverse) * viewPosition);

					//float distanceFalloff = linearStep(0.0, 1.0, saturate((15.0 * far / length(viewPosition)) / far));

					bool rejectHistory;
					radianceOut = filterHistory(IRRADIANCE_BUFFER, prevPosition.xy, visibleNormal, linearizeDepth(prevPosition.z), rejectHistory);

					if(!rejectHistory) {
						vec3 geometricNormal = decodeUnitVector((uvec2(dataTexture.z) >> uvec2(2, 17) & 32767u) * rcpMaxFloat15);

						// Camera velocity heuristic from Snurf
						const float cameraVelocityHeuristic = smoothstep(0.0, 0.14, dot(cameraPosition - previousCameraPosition, geometricNormal));

						float NdotV    = step(0.1, dot(geometricNormal, sceneDirection));
						radianceOut.a *= max0(mix(1.0, NdotV, cameraVelocityHeuristic));

						radianceOut.a = min(radianceOut.a + 1.0, MAX_ACCUMULATED_FRAMES);
					} else {
						radianceOut = vec4(0.0);
					}
				#else
					radianceOut = texture(IRRADIANCE_BUFFER, prevPosition.xy);

					radianceOut *= float(hideGUI);

					radianceOut.a++;
				#endif
			} else {
				radianceOut.a = 1.0;
			}

			#if RENDER_MODE == 1 && CHECKERBOARD_RENDERING == 1
				ivec2 coords   = ivec2(gl_FragCoord.xy / int(viewResolution.x * 0.1));
				bool canRender = (coords.x + coords.y & 1) == (frameCounter & 1);
				
				float weight = canRender ? saturate(1.0 / max(radianceOut.a, 1.0)) : 0.0;
			#else
				float weight = saturate(1.0 / max(radianceOut.a, 1.0));
			#endif

			radianceOut.rgb = mix(radianceOut.rgb, radiance, weight);

			/*
			uvec4 dataTexture   = texture(RASTER_DATA_BUFFER, textureCoords);
			vec3  visibleNormal = decodeUnitVector(vec2(dataTexture.w & 65535u, (dataTexture.w >> 16u) & 65535u) * rcpMaxFloat16);
			float visibleDepth  = linearizeDepth(prevPosition.z);

			float sampleDepth    = linearizeDepth(exp2(texture(MOMENTS_BUFFER, prevPosition.xy).a));
			vec3  previousNormal = texture(NORMAL_BUFFER, prevPosition.xy).rgb * 2.0 - 1.0;

			//radianceOut.rgb = textureCoords.x > 0.5 ? vec3(abs(sampleDepth - visibleDepth)) : abs(previousNormal - visibleNormal);
			//radianceOut.rgb = vec3(abs(sampleDepth - visibleDepth));
			//radianceOut.rgb = abs(previousNormal - visibleNormal);
			*/
		#else
			radianceOut.rgb = radiance;
		#endif

		#if RENDER_MODE == 0
			#if ATROUS_FILTER == 1
				float luminance = luminance(radianceOut.rgb);
				vec2  moments   = vec2(luminance, luminance * luminance);

				#if TEMPORAL_ACCUMULATION == 1
					momentsOut.rg = mix(momentsOut.rg, moments, weight);
				#endif

				if(radianceOut.a < VARIANCE_STABILIZATION_THRESHOLD) {
					momentsOut.b = estimateSpatialVariance(IRRADIANCE_BUFFER, moments);
				} else { 
					momentsOut.b = max0(momentsOut.g - momentsOut.r * momentsOut.r);
				}
			#endif

			momentsOut.a = log2(prevPosition.z);
		#endif
	}
	
#endif
