/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

#if SPATIAL_REUSE_PASS_INDEX > 0
    /* RENDERTARGETS: 1 */
    
    layout (location = 0) out vec3 irradianceOut;
#endif

in vec2 textureCoords;

#include "/include/fragment/restir/common.glsl"

#if SPATIAL_REUSE_PASS_INDEX == 1

	#include "/include/voxels/raytracer.glsl"

	bool visibilityValidation(vec3 visiblePosition, vec3 visibleNormal, vec3 reservoirPosition) {
		visiblePosition += visibleNormal * 1e-3;
		VoxelIntersection visibility = raytraceVoxel2(visiblePosition, normalize(reservoirPosition - visiblePosition), VISIBILITY_VALIDATION_STEPS, true, true);
		return visibility.intersect && visibility.dists.x + SPATIAL_VISIBILITY_LENIENCE <= distance(visiblePosition, reservoirPosition);
	}
	
#endif

void main() {
	float depth = texture(depthtex0, textureCoords).r;
	if(depth == 1.0) { discard; return; }

	vec3 visiblePosition = readVisiblePositionCurrent(ivec2(gl_FragCoord.xy));
	bool inVoxelBounds   = isInVoxelBounds(ivec3(floor(visiblePosition)));

	if(!inVoxelBounds) {
		#if SPATIAL_REUSE_PASS_INDEX == 1
			irradianceOut = readTemporalReservoirCurrent(ivec2(gl_FragCoord.xy)).radiance;
		#endif
		return;
	}

	Material material = getRasterMaterial(textureCoords); 

	/*
	Reservoir r = readSpatialReservoir(ivec2(gl_FragCoord.xy));

	#if SPATIAL_REUSE_PASS_INDEX == 1
		vec3 rayDirection   = normalize(-(visiblePosition - cameraPosition));
		vec3 lightDirection = normalize(r.position - visiblePosition);

		float NdotL = saturate(dot(material.normal, lightDirection));

		float diffuse  = NdotL * RCP_PI;
		vec3  specular = computeSpecularBRDF(material, rayDirection, lightDirection);

		irradianceOut = r.radiance * r.weight * (diffuse + specular);
	#endif

	#if SPATIAL_REUSE_PASS_INDEX == 0
		writeSpatialReservoir(r, ivec2(gl_FragCoord.xy));
	#endif
	return;
	*/

	Reservoir centerReservoir = readSpatialReservoir(ivec2(gl_FragCoord.xy));

	float p_hat_center  = computeIntegrand(centerReservoir.radiance, centerReservoir.position, material, visiblePosition, material.normal);
	float weight_center = 1.0;

	float radiusOffset = temporalBlueNoise(gl_FragCoord.xy).r;
	float angleOffset  = temporalBlueNoise(gl_FragCoord.xy / SPATIAL_NOISE_TILES_SCALE).g;

	Reservoir reservoir = Reservoir(vec3(0.0), vec3(0.0), vec3(0.0), 0.0, 0);
	float p_hat;
	int validSampleCount = 1;

	for(int sampleCount = 0; sampleCount < SPATIAL_REUSE_ITERATIONS; sampleCount++) {
		vec2 offset          = sampleDisk(sampleCount, SPATIAL_REUSE_ITERATIONS, radiusOffset, angleOffset) * SPATIAL_NEIGHBORHOOD_SIZE * texelSize;
		vec2 neighbourCoords = textureCoords + offset;

		if(saturate(neighbourCoords) != neighbourCoords || all(lessThanEqual(offset, vec2(EPS)))) continue;

		Reservoir neighbourReservoir = readSpatialReservoir(ivec2(neighbourCoords * viewSize));

		if(neighbourReservoir.age == 0) continue;

		float neighbourDepth = texture(depthtex0, neighbourCoords).r;

		vec3 neighbourVisiblePosition = readVisiblePositionCurrent(ivec2(neighbourCoords * viewSize));

		Material neighbourMaterial = getRasterMaterial(neighbourCoords); 

		if(!isGeometricallySimilar(material.normal, neighbourMaterial.normal, depth, neighbourDepth)) continue;

		neighbourReservoir.age = min(neighbourReservoir.age, SPATIAL_M_CLAMP);

		/* Shift Mappings */

		float jacobian_center    = computeJacobianDeterminant(centerReservoir.position   , centerReservoir.normal   , neighbourVisiblePosition, visiblePosition);
		float jacobian_neighbour = computeJacobianDeterminant(neighbourReservoir.position, neighbourReservoir.normal, visiblePosition, neighbourVisiblePosition);

		float p_hat_neighbour = computeIntegrand(neighbourReservoir.radiance, neighbourReservoir.position, neighbourMaterial, neighbourVisiblePosition, neighbourMaterial.normal);

		float shifted_center    = computeIntegrand(centerReservoir.radiance   , centerReservoir.position   , neighbourMaterial, neighbourVisiblePosition, neighbourMaterial.normal);
		float shifted_neighbour = computeIntegrand(neighbourReservoir.radiance, neighbourReservoir.position, material         , visiblePosition         , material.normal);

		float p_hat_center_to_neighbour = shifted_center    * jacobian_center;
		float p_hat_neighbour_to_center = shifted_neighbour * jacobian_neighbour;

		/* Pairwise MIS */

		float n_k = float(neighbourReservoir.age * SPATIAL_REUSE_ITERATIONS);

		float weight_neighbour = balanceHeuristic(p_hat_neighbour, n_k, p_hat_neighbour_to_center, centerReservoir.age);
		float center_heuristic = balanceHeuristic(p_hat_center_to_neighbour, n_k, p_hat_center, centerReservoir.age);

		weight_center += 1.0 - center_heuristic;

		weight_neighbour = p_hat_neighbour_to_center * neighbourReservoir.weight * weight_neighbour;

		bool selected_neighbour = mergeReservoir(reservoir, neighbourReservoir, weight_neighbour);
		validSampleCount++;

		if(selected_neighbour) p_hat = shifted_neighbour;
	}

	bool selected_center = mergeReservoir(reservoir, centerReservoir, p_hat_center * centerReservoir.weight * weight_center);

	if(selected_center) p_hat = p_hat_center;
	
	finalizeReservoirMIS(reservoir, p_hat, validSampleCount);

	#if SPATIAL_REUSE_PASS_INDEX == 0
		writeSpatialReservoir(reservoir, ivec2(gl_FragCoord.xy));
	#else
		bool visibility = visibilityValidation(visiblePosition, material.normal, reservoir.position);

		if(!visibility) {
			vec3 rayDirection   = normalize(-(visiblePosition - cameraPosition));
			vec3 lightDirection = normalize(reservoir.position - visiblePosition);

			float NdotL = saturate(dot(material.normal, lightDirection));

			float isMetal = float(material.F0 * maxFloat8 > 229.5);

			float diffuse  = NdotL * RCP_PI * (1.0 - isMetal);
			vec3  specular = computeSpecularBRDF(material, rayDirection, lightDirection);

			irradianceOut = reservoir.radiance * reservoir.weight * (diffuse + specular);
		}
	#endif
}
