/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

#include "/include/camera/lens/common.glsl"

bool traceLensSystem(vec2 coords, float wavelength, inout float throughput, out vec3 sensorPosition, out vec3 sceneDirection) {
	vec2 sensorExtent = sensorSize * vec2(1.0, gbufferProjection[0][0] / gbufferProjection[1][1]);

	vec3 sensorRayOrigin = vec3(((coords + rand2F() * texelSize) * 2.0 - 1.0) * -sensorExtent, 0.0);

	float focalLength = focal(computeRayTransferMatrix(wavelength));

	vec3 rearPosition = vec3(sampleDisk(rand2F()) * lensSystem[lensSystem.length() - 1].aperture * 0.5, -focalLength);

	vec3 position  = sensorRayOrigin;
	vec3 direction = normalize(rearPosition - sensorRayOrigin);

	float opticalAxisZ = -focalLength;

	int i = lensSystem.length() - 1;

	int scatteringOrder = 0;

	while (i >= 0 && i < lensSystem.length()) {

		#if RUSSIAN_ROULETTE == 1
			if(throughput < randF()) { 
				throughput = 0.0;
				break;
			}
			throughput /= saturate(throughput);
		#endif

		if(scatteringOrder >= MAX_LENS_SCATTERING_ORDERS) {
			return false;
		}

		LensElement element = lensSystem[i];

		bool isApertureStop = element.curvature == 0.0;

		opticalAxisZ -= element.thickness;

		float t;
		vec3 center;

		if (isApertureStop) {
			// Ray-Plane intersection
			t = (opticalAxisZ - position.z) / direction.z;
		}
		else {
			// Ray-Sphere intersection
			center = vec3(0.0, 0.0, opticalAxisZ + element.curvature);

			vec2 sphere = intersectSphere(position - center, direction, element.curvature);

			t = direction.z * element.curvature > 0.0 ? sphere.x : sphere.y;
		}

		if (t < 0.0)
			return false;

		position += direction * t;

		// Outside of aperture
		if (length(position.xy) > element.aperture * 0.5)
			return false;

		if (!isApertureStop) {
			vec3 normal  = normalize(position - center);
				 normal *= -sign(dot(direction, normal));

			complexFloat n1;
			n1.r = sellmeier(element.coefficients, wavelength);
			n1.i = 0.0;

			complexFloat n2;
			n2.r = i > 0 ? sellmeier(lensSystem[i - 1].coefficients, wavelength) : airIOR;
			n2.i = 0.0;

			float NdotV = dot(normal, -direction);

			float fresnel_R = fresnelComplex_R(NdotV, n1, n2);
    		float fresnel_T = fresnelComplex_T(NdotV, n1, n2);

			float specularBounceProbability = 0.0;

			if(specularBounceProbability > randF()) {
				throughput /= specularBounceProbability;
				throughput *= fresnel_R;

				direction = reflect(direction, normal);
			} else {
				throughput /= 1.0 - specularBounceProbability;
				throughput *= fresnel_T;
				
				direction = refract(direction, normal, n1.r / n2.r);
			}
		}

		i++;
		scatteringOrder++;
	}

	sensorPosition = viewToWorld(position);
	sceneDirection = normalize(mat3(gbufferModelViewInverse) * direction);

	return true;
}

/*
float transmittedEta = i > 0 ? (lensSystem[i - 1].coefficients == NULL_COEFFS ? airIOR : sellmeier(lensSystem[i - 1].coefficients, wavelength)) : airIOR;

float specularBounceProbability = fresnel_R;

if(specularBounceProbability > randF()) {
	throughput /= specularBounceProbability;
	throughput *= fresnel_R;

	direction = reflect(direction, normal);
} else {
	throughput /= 1.0 - specularBounceProbability;
	throughput *= fresnel_T;
	
	direction = refract(direction, normal, n1.r / n2.r);
}
*/
