/***********************************************/
/*           Copyright (c) 2025 Belmu          */
/*             All Rights Reserved             */
/***********************************************/

/*
    [References]:
        Heitz et al. (2017). Microfacet-based Normal Mapping for Robust Monte Carlo Path Tracing. https://jo.dreggn.org/home/2017_normalmap.pdf
*/

//////////////////////////////////////////////////////////
/*-------------------- MICROSURFACE --------------------*/
//////////////////////////////////////////////////////////

vec3 normal_t(vec3 wp) {
    return normalize(-vec3(wp.xy, 0.0)); // normalize(wg * dot(wg, wp) / dot(wg, wg) - wp)
}

float projected_p(vec3 wi, vec3 wm) {
    return dot(wi, wm) / wm.z;
}

float projected_t(vec3 wi, vec3 wm, vec3 wt) {
    return (dot(wi, wt) * sqrt(1.0 - wm.z * wm.z)) / wm.z;
}

float lambda_p(vec3 wi, vec3 wm, vec3 wt) {
    return projected_p(wi, wm) / (projected_p(wi, wm) + projected_t(wi, wm, wt));
}

float lambda_t(vec3 wi, vec3 wm, vec3 wt) {
    return projected_t(wi, wm, wt) / (projected_p(wi, wm) + projected_t(wi, wm, wt));
}

float G1_opaque(vec3 wi, vec3 wm, vec3 wt) {
    return wi.z < 0.0 ? 0.0 : float(dot(wi, wm) >= 0.0) * min(1.0, wi.z / (projected_p(wi, wm) + projected_t(wi, wm, wt)));
}

float G1_translucent(vec3 wi, vec3 wm, vec3 wt) {
    return float(dot(wi, wm) >= 0.0) * min(1.0, wi.z / (projected_p(wi, wm) + projected_t(wi, wm, wt)));
}

//////////////////////////////////////////////////////////
/*------------------------ BRDF ------------------------*/
//////////////////////////////////////////////////////////

/*
    G is the surface normal (Z-up)
    P is the facet oriented by the perturbed normal
    T is the facet perpendicular to the surface
*/

void sampleMicrosurfaceOpaqueMBNM(inout float estimate, inout float throughput, inout vec3 wr, Material material, float wavelength) {
    const vec3 wg = vec3(0.0, 0.0, 1.0);
    vec3 wp = material.tangentNormal;
    vec3 wt = normal_t(wp);

    mat3 tbnMatrixP = getTBNMatrix(wp);
    mat3 tbnMatrixT = getTBNMatrix(wt);

    bool intersectP = lambda_p(-wr, wp, wt) > randF();
    vec3 wm = intersectP ? wp : wt;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_MBMN_SCATTERING_ORDERS) {
        scatteringOrder++;

        if(throughput < randF()) { 
            throughput = 0.0;
            break;
        }
        throughput /= saturate(throughput);

        mat3 tbnMatrixM = intersectP ? tbnMatrixP : tbnMatrixT;

        wr = wr * tbnMatrixM;
        sampleMicrosurfaceOpaque(estimate, throughput, wr, material, wavelength);
        wr = tbnMatrixM * wr;

        if(G1_opaque(wr, wm, normal_t(wm)) > randF()) {
            break;
        } else {
            wm = intersectP ? wt : wp;
            intersectP = !intersectP;
        }
    }
}

float evaluateMicrosurfaceOpaqueMBNM(vec3 wi, vec3 wo, Material material, float wavelength) {
    const vec3 wg = vec3(0.0, 0.0, 1.0);
    vec3 wp = material.tangentNormal;
    vec3 wt = normal_t(wp);

    mat3 tbnMatrixP = getTBNMatrix(wp);
    mat3 tbnMatrixT = getTBNMatrix(wt);

    bool intersectP = lambda_p(-wi, wp, wt) > randF();
    vec3 wm = intersectP ? wp : wt;

    float throughput = 1.0;
    float estimate   = 0.0;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_MBMN_SCATTERING_ORDERS) {
        scatteringOrder++;

        if(throughput < randF()) { 
            throughput = 0.0;
            break;
        }
        throughput /= saturate(throughput);

        mat3 tbnMatrixM = intersectP ? tbnMatrixP : tbnMatrixT;

        vec3 wtNew = normal_t(wm);

        float masking_shadowing_wo = G1_opaque(wo, wm, wtNew) * dot(wo, wm);

        estimate += throughput * evaluateMicrosurfaceOpaque(wi * tbnMatrixM, wo * tbnMatrixM, material, wavelength) * masking_shadowing_wo;

        wi = wi * tbnMatrixM;
        sampleMicrosurfaceOpaque(estimate, throughput, wi, material, wavelength);
        wi = tbnMatrixM * wi;

        if(G1_opaque(wi, wm, wtNew) > randF()) {
            break;
        } else {
            wm = intersectP ? wt : wp;
            intersectP = !intersectP;
        }
    }
    return estimate;
}

void sampleMicrosurfaceTranslucentMBNM(inout float throughput, inout vec3 wr, bool wi_outside, Material material) {
    const vec3 wg = vec3(0.0, 0.0, 1.0);
    vec3 wp = material.tangentNormal;
    vec3 wt = normal_t(wp);

    mat3 tbnMatrixP = getTBNMatrix(wp);
    mat3 tbnMatrixT = getTBNMatrix(wt);

    float probabilityP = wi_outside ? lambda_p(-wr, wp, wt) : lambda_p(-wr, -wp, -wt);
    bool  intersectP   = probabilityP > randF();

    vec3 wm = intersectP ? wp : wt;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_MBMN_SCATTERING_ORDERS) {
        scatteringOrder++;

        if(throughput < randF()) { 
            throughput = 0.0;
            break;
        }
        throughput /= saturate(throughput);

        mat3 tbnMatrixM = intersectP ? tbnMatrixP : tbnMatrixT;

        wr = wr * tbnMatrixM;
        sampleMicrosurfaceTranslucent(throughput, wr, wi_outside, material);
        wr = tbnMatrixM * wr;

        vec3  wtNew      = wi_outside ? normal_t(wm) : -normal_t(wm);
        float masking_wr = G1_translucent(wr, wm, wtNew);

        if(masking_wr > randF()) {
            break;
        } else {
            wm = intersectP ? wt : wp;
            intersectP = !intersectP;
        }
    }
}

float evaluateMicrosurfaceTranslucentMBNM(vec3 wi, vec3 wo, bool wi_outside, bool wo_outside, Material material) {
    const vec3 wg = vec3(0.0, 0.0, 1.0);
    vec3 wp = material.tangentNormal;
    vec3 wt = normal_t(wp);

    mat3 tbnMatrixP = getTBNMatrix(wp);
    mat3 tbnMatrixT = getTBNMatrix(wt);

    float probabilityP = wi_outside ? lambda_p(-wi, wp, wt) : lambda_p(-wi, -wp, -wt);
    bool  intersectP   = probabilityP > randF();

    vec3 wm = intersectP ? wp : wt;

    float throughput = 1.0;
    float estimate   = 0.0;

    int scatteringOrder = 0;
    while(scatteringOrder <= MAX_MBMN_SCATTERING_ORDERS) {
        scatteringOrder++;

        if(throughput < randF()) { 
            throughput = 0.0;
            break;
        }
        throughput /= saturate(throughput);

        mat3 tbnMatrixM = intersectP ? tbnMatrixP : tbnMatrixT;

        vec3 wmNew = wi_outside ? wm : -wm;
        vec3 wtNew = normal_t(wmNew);

        float masking_shadowing_wo = G1_translucent(wo, wmNew, wtNew) * dot(wo, wmNew);

        wi = wi * tbnMatrixM;

        estimate += throughput * evaluateMicrosurfaceTranslucent(wi, wo * tbnMatrixM, wi_outside, wo_outside, material) * masking_shadowing_wo;

        sampleMicrosurfaceTranslucent(throughput, wi, wi_outside, material);
        wi = tbnMatrixM * wi;

        float masking_wi = G1_translucent(wi, wm, wtNew);

        if(dot(wm, wo) < 0.0 || masking_wi > randF()) {
            break;
        } else {
            wm = intersectP ? wt : wp;
            intersectP = !intersectP;
        }
    }
    return estimate;
}
