/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

/*
    [Credits]
        Jessie - help with the Heitz Multiple-Scattering BSDF model and providing hemispherical albedo function (https://github.com/Jessie-LC)

    [References]:
        Heitz, E. (2014). Understanding the Masking-Shadowing Function in Microfacet-Based BRDFs. https://jcgt.org/published/0003/02/03/paper.pdf
        Heitz et al. (2016). Multiple-Scattering Microfacet BSDFs with the Smith Model. https://eheitzresearch.wordpress.com/240-2/
        Heitz et al. (2016). Multiple-Scattering Microfacet BSDFs with the Smith Model. https://jo.dreggn.org/home/2016_microfacets.pdf
        Heitz, E. (2017). A Simpler and Exact Sampling Routine for the GGX Distribution of Visible Normals. https://hal.science/hal-01509746/document
        Hammon, A., Jr. (2017). PBR Diffuse Lighting for GGX+Smith Microsurfaces. https://ubm-twvideo01.s3.amazonaws.com/o1/vault/gdc2017/Presentations/Hammon_Earl_PBR_Diffuse_Lighting.pdf
        Reed, N. (2021). Slope Space in BRDF Theory. https://www.reedbeta.com/blog/slope-space-in-brdf-theory/

    [Notes]:
        The Smith BSDF majoritarily works in slope space, meaning that numerical instability may happen (division by zero, infinity,...).
        We can prevent it by handling cases of the slope going beyond the [-0.9999;0.9999] range.
*/

#include "/include/fragment/render/height/gaussian.glsl"

#include "/include/fragment/render/slope/ggx_trowbridge_reitz.glsl"
#include "/include/fragment/render/slope/distribution.glsl"

float hemisphericalAlbedo(float n) {
    float n2  = pow2(n);
    float T_1 = (4.0 * (2.0 * n + 1.0)) / (3.0 * pow2(n + 1.0));
    float T_2 = ((4.0 * pow3(n) * (n2 + 2.0 * n - 1.0)) / (pow2(n2 + 1.0) * (n2 - 1.0))) - 
            ((2.0 * n2 * (n2 + 1.0) * log(n)) / pow2(n2 - 1.0)) +
            ((2.0 * n2 * pow2(n2 - 1.0) * log((n * (n + 1.0)) / (n - 1.0))) / pow3(n2 + 1.0));
    return saturate(1.0 - 0.5 * (T_1 + T_2));
}

//////////////////////////////////////////////////////////
/*-------------------- MICROSURFACE --------------------*/
//////////////////////////////////////////////////////////

float G1(vec3 wi, vec2 alpha) {
    if(wi.z >  0.9999) return 1.0;
    if(wi.z <= 0.0)    return 0.0;

    return 1.0 / (1.0 + lambdaSmith(wi, alpha));
}

float G1(vec3 wi, float h0, vec2 alpha) {
    if(wi.z >  0.9999) return 1.0;
    if(wi.z <= 0.0)    return 0.0;

    return pow(C1(h0), lambdaSmith(wi, alpha));
}

// Finding next intersection on the microsurface
float sampleHeight(vec3 wr, float hr, float U, vec2 alpha) {
    if(wr.z >  0.9999)  return FLT_MAX;
    if(wr.z < -0.9999)  return invC1(U * C1(hr));
    if(abs(wr.z) < EPS) return hr;
    
    if(U > 1.0 - G1(wr, hr, alpha)) return FLT_MAX;

    return invC1(C1(hr) / pow(1.0 - U, 1.0 / lambdaSmith(wr, alpha)));
}

//////////////////////////////////////////////////////////
/*------------------------ BSDF ------------------------*/
//////////////////////////////////////////////////////////

void sampleMicrosurfaceOpaquePhase(inout float estimate, inout float throughput, inout vec3 wr, Material material, float wavelength) {
    complexFloat n1 = complexFloat(airIOR, 0.0);
    complexFloat n2 = complexFloat(material.N, material.K);

    vec3  wm        = sampleD_wi(wr, rand2F(), material.alpha);
    float fresnel_R = fresnelComplex_R(dot(wm, wr), n1, n2);

    /* Conductor Phase */
    if(material.F0 * maxFloat8 > 229.5) {
        wr          = reflect(-wr, wm);
        throughput *= fresnel_R;
        return;
    }

    float fresnel_T = fresnelComplex_T(dot(wm, wr), n1, n2);

    float specularBounceProbability = fresnel_R / (material.albedo * fresnel_T + fresnel_R);
 
    if(specularBounceProbability > randF()) {
        wr = reflect(-wr, wm);

        throughput /= specularBounceProbability;
        throughput *= fresnel_R;
    } else {
        wr = generateCosineVector(wm, rand2F());

        throughput /= 1.0 - specularBounceProbability;
        throughput *= fresnel_T;
        throughput /= 1.0 - hemisphericalAlbedo(n2.r / n1.r);
        throughput *= material.albedo * material.ao;
        estimate   += throughput * computeBlocklightEmission(material.id, material.albedo, wavelength) * material.emission;
        throughput *= fresnelComplex_T(dot(wm, wr), n1, n2);

        if(isnan(throughput)) throughput = 1.0;
    }
}

void sampleMicrosurfaceOpaque(inout float estimate, inout float throughput, inout vec3 wr, Material material, float wavelength) {
    float hr = 1.0 + invC1(0.999);
    
    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_BSDF_SCATTERING_ORDERS) {

        #if RUSSIAN_ROULETTE == 1
            if(throughput < randF()) { 
                throughput = 0.0;
                break;
            }
            throughput /= saturate(throughput);
        #endif

        float U = randF();
        hr = sampleHeight(wr, hr, U, material.alpha);

        if(hr == FLT_MAX) break;

        wr = -wr;
        sampleMicrosurfaceOpaquePhase(estimate, throughput, wr, material, wavelength);
        scatteringOrder++;
    }

    if(isnan(throughput)) throughput = 0.0;
}

float evaluateMicrosurfaceOpaquePhase(vec3 wi, vec3 wo, Material material) {
    complexFloat n1 = complexFloat(airIOR, 0.0);
    complexFloat n2 = complexFloat(material.N, material.K);

    vec3 wm            = sampleD_wi(wi, rand2F(), material.alpha);
    vec3 halfwayVector = normalize(wi + wo);

    float VdotH = dot(wi, halfwayVector);

    /* Conductor Phase */
    if(material.F0 * maxFloat8 > 229.5) {
        float distribution = D_wi(wi, halfwayVector, material.alpha);
        float fresnel      = fresnelComplex_R(VdotH, n1, n2);
        return fresnel * distribution / (4.0 * VdotH);
    }

    float NdotL = saturate(dot(wm, wo));
    float NdotV = dot(wm, wi);
    float NdotH = dot(wm, halfwayVector);

    float diffuse  = material.albedo * NdotL * RCP_PI;
          diffuse *= fresnelComplex_T(NdotV, n1, n2);
          diffuse *= fresnelComplex_T(NdotL, n1, n2);
          diffuse /= 1.0 - hemisphericalAlbedo(n1.r / n2.r);

    float distribution = D_wi(wi, halfwayVector, material.alpha);
    float fresnel      = fresnelComplex_R(VdotH, n1, n2);
    float specular     = fresnel * distribution / (4.0 * VdotH);

    return diffuse + specular;
}

float evaluateMicrosurfaceOpaque(vec3 wi, vec3 wo, Material material, float wavelength) {
    float hr = 1.0 + invC1(0.999);

    float throughput = 1.0;
    float estimate   = 0.0;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_BSDF_SCATTERING_ORDERS) {

        #if RUSSIAN_ROULETTE == 1
            if(throughput < randF()) { 
                throughput = 0.0;
                break;
            }
            throughput /= saturate(throughput);
        #endif

        float U = randF();
        hr = sampleHeight(wi, hr, U, material.alpha);

        if(hr == FLT_MAX) break;

        wi = -wi;

        float phase     = evaluateMicrosurfaceOpaquePhase(wi, wo, material);
        float shadowing = G1(wo, hr, material.alpha);

        estimate += throughput * phase * shadowing;

        if(isnan(throughput)) throughput = 0.0;

        sampleMicrosurfaceOpaquePhase(estimate, throughput, wi, material, wavelength);
        scatteringOrder++;
    }
    
    return isnan(estimate) ? 0.0 : estimate / wo.z;
}

void sampleMicrosurfaceTranslucentPhase(inout float throughput, inout vec3 wr, inout bool wi_outside, Material material) {
    complexFloat n1 = complexFloat(airIOR, 0.0);
    complexFloat n2 = complexFloat(material.N, material.K);

    vec3 wm = vec3(0.0);
        
    if(!wi_outside) {
        n1 = complexFloat(material.N, material.K);
        n2 = complexFloat(airIOR, 0.0);
        wm = -sampleD_wi(-wr, rand2F(), material.alpha);
    } else {
        wm = sampleD_wi(wr, rand2F(), material.alpha);
    }

    float fresnel_R = fresnelComplex_R(dot(wm, wr), n1, n2);
    float fresnel_T = fresnelComplex_T(dot(wm, wr), n1, n2);

    float specularBounceProbability = fresnel_R;
 
    if(specularBounceProbability > randF()) {
        throughput /= specularBounceProbability;
        throughput *= fresnel_R;

        wr = reflect(-wr, wm);
    } else {
        wi_outside = !wi_outside;

        throughput /= 1.0 - specularBounceProbability;
        throughput *= fresnel_T;
        throughput *= material.albedo;
        
        wr = refract(-wr, wm, n1.r / n2.r);
    }
}

void sampleMicrosurfaceTranslucent(inout float throughput, inout vec3 wr, inout bool wi_outside, Material material) {
    float hr = 1.0 + invC1(0.999);

    if(!wi_outside) hr = -hr;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_BSDF_SCATTERING_ORDERS) {

        #if RUSSIAN_ROULETTE == 1
            if(throughput < randF()) { 
                throughput = 0.0;
                break;
            }
            throughput /= saturate(throughput);
        #endif

        float U = randF();
        hr = wi_outside ? sampleHeight(wr, hr, U, material.alpha) : -sampleHeight(-wr, -hr, U, material.alpha);

        if(hr == FLT_MAX || hr == -FLT_MAX) break;

        wr = -wr;
        sampleMicrosurfaceTranslucentPhase(throughput, wr, wi_outside, material);
        scatteringOrder++;
    }
}

float evaluateMicrosurfaceTranslucentPhase(vec3 wi, vec3 wo, bool wi_outside, bool wo_outside, Material material) {
    complexFloat n1 = complexFloat(airIOR, 0.0);
    complexFloat n2 = complexFloat(material.N, material.K);

    if(!wi_outside) {
        n1 = complexFloat(material.N, material.K);
        n2 = complexFloat(airIOR, 0.0);
    }

    if(wi_outside == wo_outside) {
        vec3 halfwayVector = normalize(wi + wo);

        float VdotH;
        float distribution;

        if(wi_outside) {
            VdotH        = dot(wi, halfwayVector);
            distribution = D_wi(wi, halfwayVector, material.alpha);
        } else {
            VdotH        = dot(-wi, -halfwayVector);
            distribution = D_wi(-wi, -halfwayVector, material.alpha);
        }
        
        float fresnel = fresnelComplex_R(VdotH, n1, n2);
        
        return fresnel * distribution / (4.0 * VdotH);
    } else {
        vec3 tmp = wi;
        wi = -wo;
        wo = -tmp;

        float eta = n1.r / n2.r;

        vec3  halfwayVector = -normalize(wi + wo * eta);

        float VdotH;
        float LdotH;
        float distribution;

        if(wi_outside) {
            halfwayVector *= sign(halfwayVector.z);

            VdotH = dot(wi, halfwayVector);
            LdotH = dot(wo, halfwayVector);

            distribution = D_wi(wi, halfwayVector, material.alpha);
        } else {
            halfwayVector *= -sign(halfwayVector.z);

            VdotH = dot(-wi, -halfwayVector);
            LdotH = dot(-wo, -halfwayVector);

            distribution = D_wi(-wi, -halfwayVector, material.alpha);
        }

        float numerator   = eta * eta * fresnelComplex_T(VdotH, n1, n2) * distribution;
        float denominator = pow2(VdotH + eta * LdotH);

        return material.albedo * (numerator / denominator) * max0(-LdotH);
    }
}

float evaluateMicrosurfaceTranslucent(vec3 wi, vec3 wo, bool wi_outside, bool wo_outside, Material material) {
    float hr = 1.0 + invC1(0.999);

    if(!wi_outside) hr = -hr;

    float throughput = 1.0;
    float estimate   = 0.0;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_BSDF_SCATTERING_ORDERS) {

        #if RUSSIAN_ROULETTE == 1
            if(throughput < randF()) { 
                throughput = 0.0;
                break;
            }
            throughput /= saturate(throughput);
        #endif

        float U = randF();
        hr = wi_outside ? sampleHeight(wi, hr, U, material.alpha) : -sampleHeight(-wi, -hr, U, material.alpha);

        if(hr == FLT_MAX || hr == -FLT_MAX) break;

        wi = -wi;

        float phase     = evaluateMicrosurfaceTranslucentPhase(wi, wo, wi_outside, wo_outside, material);
        float shadowing = wo_outside ? G1(wo, hr, material.alpha) : G1(-wo, -hr, material.alpha);

        estimate += throughput * phase * shadowing;

        sampleMicrosurfaceTranslucentPhase(throughput, wi, wi_outside, material);
        scatteringOrder++;
    }
    return estimate;
}
