/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

struct Reservoir {
	vec3 radiance;
	vec3 position;
	vec3 normal;
	
	float weight;
	uint age;
};

layout(std430, binding = 3) restrict buffer spatialReservoirBuffer {
	Reservoir spatialReservoir[];
};

layout(std430, binding = 4) restrict buffer temporalReservoirBuffer {
	Reservoir temporalReservoir[];
};

layout(std430, binding = 5) restrict buffer visiblePositionBuffer {
	vec3 visiblePosition[];
};

uint getPixelIndex(ivec2 coords) {
    return coords.x + coords.y * uint(viewWidth);
}

Reservoir readSpatialReservoir(ivec2 coords) {
    return spatialReservoir[getPixelIndex(coords)];
}

void writeSpatialReservoir(Reservoir reservoir, ivec2 coords) {
    spatialReservoir[getPixelIndex(coords)] = reservoir;
}

uint ssboSize        = uint(viewWidth * viewHeight);
uint currframeOffset = ( frameCounter      % 2) * ssboSize;
uint prevFrameOffset = ((frameCounter + 1) % 2) * ssboSize;

Reservoir readTemporalReservoirCurrent(ivec2 coords) {
    return temporalReservoir[getPixelIndex(coords) + currframeOffset];
}

void writeTemporalReservoirCurrent(Reservoir reservoir, ivec2 coords) {
    temporalReservoir[getPixelIndex(coords) + currframeOffset] = reservoir;
}

Reservoir readTemporalReservoirPrevious(ivec2 coords) {
    return temporalReservoir[getPixelIndex(coords) + prevFrameOffset];
}

vec3 readVisiblePositionCurrent(ivec2 coords) {
    return visiblePosition[getPixelIndex(coords) + currframeOffset];
}

void writeVisiblePositionCurrent(vec3 position, ivec2 coords) {
    visiblePosition[getPixelIndex(coords) + currframeOffset] = position;
}

vec3 readVisiblePositionPrevious(ivec2 coords) {
    return visiblePosition[getPixelIndex(coords) + prevFrameOffset];
}

bool isSampleInvalid(float weight) {
	return isnan(weight) || isinf(weight) || weight <= 0.0;
}

bool mergeReservoir(inout Reservoir reservoir, Reservoir candidate, float weight) {
	reservoir.age += candidate.age;

	if(isSampleInvalid(weight)) return false;

	reservoir.weight += weight;

	bool selected = randF() * reservoir.weight <= weight;

	if(selected) {
		reservoir.radiance = candidate.radiance;
		reservoir.position = candidate.position;
		reservoir.normal   = candidate.normal;
	}

	return selected;
}

bool isGeometricallySimilar(vec3 visibleNormal, vec3 sampleVisibleNormal, float depth, float sampleDepth) {
	depth       = linearizeDepth(depth);
	sampleDepth = linearizeDepth(sampleDepth);

	return (abs(sampleDepth - depth) / max(sampleDepth, depth)) < DEPTH_THRESHOLD
		&& dot(visibleNormal, sampleVisibleNormal) > radians(NORMAL_THRESHOLD);
}

void finalizeReservoir(inout Reservoir reservoir, float p_hat) {
	reservoir.weight = p_hat <= 0.0 ? 0.0 : clamp(reservoir.weight / (p_hat * reservoir.age), 0.0, WEIGHT_CLAMP);
}

void finalizeReservoirMIS(inout Reservoir reservoir, float p_hat, uint M) {
	reservoir.weight = p_hat <= 0.0 ? 0.0 : clamp(reservoir.weight / (p_hat * M), 0.0, WEIGHT_CLAMP);
}
