/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

#include "/include/fragment/cache/common.glsl"

uint getFaceIndex(vec3 normal) {
    return uint(dot(max0(normal), vec3(0, 2, 4)) + dot(max0(-normal), vec3(1, 3, 5)));
}

void writeCurrentCacheEntry(ivec3 position, vec3 irradiance, uint faceIndex) {
	uint hash = moduloArraySize(tridimensionalHash(position), CACHE_ARRAY_SIZE);

    uint packedPosition = packPosition(position, 1);

	uint packedIrradiance = packUnormArb(logLuvEncode(irradiance), uvec4(8));

	for (uint attempt = 0; attempt < MAX_CACHE_PROBING_ATTEMPTS; attempt++, incrementHash(hash, CACHE_ARRAY_SIZE)) {
        uint previous = atomicCompSwap(cacheEntries[hash].packedPosition, 0u, packedPosition);
		if (previous == 0u || previous == packedPosition) {
            switch (faceIndex) {
                case 0: atomicExchange(cacheEntries[hash].facePosX_A, packedIrradiance); break; // Pos X
                case 1: atomicExchange(cacheEntries[hash].faceNegX_A, packedIrradiance); break; // Neg X
                case 2: atomicExchange(cacheEntries[hash].facePosY_A, packedIrradiance); break; // Pos Y
                case 3: atomicExchange(cacheEntries[hash].faceNegY_A, packedIrradiance); break; // Neg Y
                case 4: atomicExchange(cacheEntries[hash].facePosZ_A, packedIrradiance); break; // Pos Z
                case 5: atomicExchange(cacheEntries[hash].faceNegZ_A, packedIrradiance); break; // Neg Z
                default: break;
            }
            return;
        }
	}
}

vec3 readCurrentCacheEntry(ivec3 position, uint faceIndex) {
    uint hash = moduloArraySize(tridimensionalHash(position), CACHE_ARRAY_SIZE);

    uint packedPosition = packPosition(position, 1);

    for (uint attempt = 0; attempt < MAX_CACHE_PROBING_ATTEMPTS; attempt++, incrementHash(hash, CACHE_ARRAY_SIZE)) {
        CacheEntry entry = cacheEntries[hash];

        if (entry.packedPosition == packedPosition) {
            uint face = ~0u;

            switch (faceIndex) {
                case 0: face = entry.facePosX_A; break; // Pos X
                case 1: face = entry.faceNegX_A; break; // Neg X
                case 2: face = entry.facePosY_A; break; // Pos Y
                case 3: face = entry.faceNegY_A; break; // Neg Y
                case 4: face = entry.facePosZ_A; break; // Pos Z
                case 5: face = entry.faceNegZ_A; break; // Neg Z
                default: break;
            }

            if (face != ~0u) {
                return logLuvDecode(unpackUnormArb(face, uvec4(8))).rgb;
            }
        }

        if (entry.packedPosition == ~0u) {
            return vec3(0.0, 0.0, 0.0);
        }
    }
    return vec3(0.0, 0.0, 0.0);
}

void writePreviousCacheEntry(ivec3 position, vec3 irradiance, uint faceIndex) {
	uint hash = moduloArraySize(tridimensionalHash(position), CACHE_ARRAY_SIZE);

    uint packedPosition = packPosition(position, 1);

	uint packedIrradiance = packUnormArb(logLuvEncode(irradiance), uvec4(8));

	for (uint attempt = 0; attempt < MAX_CACHE_PROBING_ATTEMPTS; attempt++, incrementHash(hash, CACHE_ARRAY_SIZE)) {
        uint previous = atomicCompSwap(cacheEntries[hash].packedPosition, 0u, packedPosition);
		if (previous == 0u || previous == packedPosition) {
            switch (faceIndex) {
                case 0: atomicExchange(cacheEntries[hash].facePosX_B, packedIrradiance); break; // Pos X
                case 1: atomicExchange(cacheEntries[hash].faceNegX_B, packedIrradiance); break; // Neg X
                case 2: atomicExchange(cacheEntries[hash].facePosY_B, packedIrradiance); break; // Pos Y
                case 3: atomicExchange(cacheEntries[hash].faceNegY_B, packedIrradiance); break; // Neg Y
                case 4: atomicExchange(cacheEntries[hash].facePosZ_B, packedIrradiance); break; // Pos Z
                case 5: atomicExchange(cacheEntries[hash].faceNegZ_B, packedIrradiance); break; // Neg Z
                default: break;
            }
            return;
        }
	}
}

vec3 readPreviousCacheEntry(ivec3 position, uint faceIndex) {
    uint hash = moduloArraySize(tridimensionalHash(position), CACHE_ARRAY_SIZE);

    uint packedPosition = packPosition(position, 1);

    for (uint attempt = 0; attempt < MAX_CACHE_PROBING_ATTEMPTS; attempt++, incrementHash(hash, CACHE_ARRAY_SIZE)) {
        CacheEntry entry = cacheEntries[hash];

        if (entry.packedPosition == ~0u) {
            return vec3(0.0);
        }

        if (entry.packedPosition == packedPosition) {
            uint face = ~0u;

            switch (faceIndex) {
                case 0: face = entry.facePosX_B; break; // Pos X
                case 1: face = entry.faceNegX_B; break; // Neg X
                case 2: face = entry.facePosY_B; break; // Pos Y
                case 3: face = entry.faceNegY_B; break; // Neg Y
                case 4: face = entry.facePosZ_B; break; // Pos Z
                case 5: face = entry.faceNegZ_B; break; // Neg Z
                default: break;
            }

            if (face != ~0u) {
                return logLuvDecode(unpackUnormArb(face, uvec4(8))).rgb;
            }
        }
    }
    return vec3(0.0);
}

#define READ_ADD_CLAMP_COUNTER(FIELD) \
    count = atomicAdd(cacheEntries[hash].FIELD, 1u) + 1u; \
    if (count > MAX_CACHE_SAMPLES) atomicMin(cacheEntries[hash].FIELD, MAX_CACHE_SAMPLES);

uint readAddFaceCounter(ivec3 position, uint faceIndex) {
    uint hash = moduloArraySize(tridimensionalHash(position), CACHE_ARRAY_SIZE);

    uint count = 0u;

    switch (faceIndex) {
        case 0: READ_ADD_CLAMP_COUNTER(counterPosX); break; // Pos X
        case 1: READ_ADD_CLAMP_COUNTER(counterNegX); break; // Neg X
        case 2: READ_ADD_CLAMP_COUNTER(counterPosY); break; // Pos Y
        case 3: READ_ADD_CLAMP_COUNTER(counterNegY); break; // Neg Y
        case 4: READ_ADD_CLAMP_COUNTER(counterPosZ); break; // Pos Z
        case 5: READ_ADD_CLAMP_COUNTER(counterNegZ); break; // Neg Z
        default: break;
    }

    return min(count, MAX_CACHE_SAMPLES);
}

float randF(inout uint seed) { 
    pcg(seed);
    return float(seed) / float(0xffffffffu);
}

vec2 rand2F(uint seed) { return vec2(randF(seed), randF(seed)); }

#if defined STAGE_COMPUTE

    void traceCacheRay(vec3 origin, vec3 normal, uint seed) {
        ivec3 originPosition = ivec3(floor(origin));
        uint  originFace     = getFaceIndex(normal);

        vec3 rayPosition  = origin + normal * 0.501;
        vec3 rayDirection = generateCosineVector(normal, rand2F(seed));

        vec3 throughput     = vec3(1.0);
        vec3 currIrradiance = vec3(0.0);
        
        for (int i = 0; i < MAX_CACHE_TRACE_BOUNCES; i++) {
            if (dot(normal, rayDirection) <= 0.0) break;

            VoxelIntersection hit = raytraceVoxel(rayPosition, rayDirection, 6, true, true);

            if (!hit.intersect) {
                vec3 atmosphere = texture(ATMOSPHERE_BUFFER, projectSphere(rayDirection)).rgb;
                currIrradiance += throughput * atmosphere;
                break;
            }

            Material material = getVoxelMaterial(hit.voxel.packedData, hit.textureCoords, hit.position, hit.normal, 0.0);

            rayPosition = hit.position + hit.normal * 1e-3;

            sampleMicrosurfaceOpaqueDiffuse(throughput, rayDirection, material);

            float visibility = sampleVisibility(shadowtex0, hit.position - cameraPosition);

            vec3 brdf = evaluateMicrosurfaceOpaqueDiffuse(material, shadowLightVector);
        
            currIrradiance += throughput * brdf * directIlluminance * visibility;

            currIrradiance += throughput * material.emission * EMISSIVE_INTENSITY * EMISSIVE_INTENSITY_MULTIPLIER;

            ivec3 hitPosition = ivec3(floor(rayPosition));
            uint  hitFace     = getFaceIndex(hit.normal);
            vec3  currEntry   = readCurrentCacheEntry(hitPosition, hitFace);

            if (currEntry != vec3(0.0))
                currIrradiance += throughput * currEntry;
        }

        vec3 prevIrradiance = readPreviousCacheEntry(originPosition, originFace);

        uint count      = readAddFaceCounter(originPosition, originFace);
        vec3 irradiance = (prevIrradiance * float(count - 1) + currIrradiance) / float(count);

        writeCurrentCacheEntry(originPosition, irradiance, originFace);
    }
    
#endif
