#pragma kernel DilateCell

#pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch webgpu

#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Common.hlsl"
#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/CommonLighting.hlsl"
#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Color.hlsl"
#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/EntityLighting.hlsl"

#include "Packages/com.unity.render-pipelines.core/Editor/Lighting/ProbeVolume/ProbeGIBaking.Dilate.cs.hlsl"
#include "Packages/com.unity.render-pipelines.core/Runtime/Lighting/ProbeVolume/ProbeVolume.hlsl"

CBUFFER_START(APVDilation)
float4 _DilationParameters;
float4 _DilationParameters2;
CBUFFER_END

#define _ProbeCount (uint)_DilationParameters.x
#define _ValidityThreshold _DilationParameters.y
#define _SearchRadius _DilationParameters.z
#define _SquaredDistanceWeight (_DilationParameters.w > 0)

StructuredBuffer<int> _NeedDilating;
StructuredBuffer<float3> _ProbePositionsBuffer;

RWStructuredBuffer<DilatedProbe> _OutputProbes;


// Sampling is copied from SampleAPV()
void AddProbeSample(APVResources apvRes, float3 uvw, inout DilatedProbe probe, float weight, inout float shWeight, inout float soWeight, inout float poWeight)
{
    // SH data
    float4 L0_L1Rx = SAMPLE_TEXTURE3D_LOD(apvRes.L0_L1Rx, s_point_clamp_sampler, uvw, 0).rgba;
    float4 L1G_L1Ry = SAMPLE_TEXTURE3D_LOD(apvRes.L1G_L1Ry, s_point_clamp_sampler, uvw, 0).rgba;
    float4 L1B_L1Rz = SAMPLE_TEXTURE3D_LOD(apvRes.L1B_L1Rz, s_point_clamp_sampler, uvw, 0).rgba;

    float4 l2_R = SAMPLE_TEXTURE3D_LOD(apvRes.L2_0, s_point_clamp_sampler, uvw, 0).rgba;
    float4 l2_G = SAMPLE_TEXTURE3D_LOD(apvRes.L2_1, s_point_clamp_sampler, uvw, 0).rgba;
    float4 l2_B = SAMPLE_TEXTURE3D_LOD(apvRes.L2_2, s_point_clamp_sampler, uvw, 0).rgba;
    float4 l2_C = SAMPLE_TEXTURE3D_LOD(apvRes.L2_3, s_point_clamp_sampler, uvw, 0).rgba;

    if (!all(L0_L1Rx.xyz < 0.0001f))
    {
        probe.L0 += L0_L1Rx.xyz * weight;
        probe.L1_0 += float3(L0_L1Rx.w, L1G_L1Ry.x, L1B_L1Rz.x) * weight;
        probe.L1_1 += float3(L1G_L1Ry.w, L1G_L1Ry.y, L1B_L1Rz.y) * weight;
        probe.L1_2 += float3(L1B_L1Rz.w, L1G_L1Ry.z, L1B_L1Rz.z) * weight;

        probe.L2_0 += float3(l2_R.x, l2_G.x, l2_B.x) * weight;
        probe.L2_1 += float3(l2_R.y, l2_G.y, l2_B.y) * weight;
        probe.L2_2 += float3(l2_R.z, l2_G.z, l2_B.z) * weight;
        probe.L2_3 += float3(l2_R.w, l2_G.w, l2_B.w) * weight;
        probe.L2_4 += float3(l2_C.x, l2_C.y, l2_C.z) * weight;

        shWeight += weight;
    }

    float4 probeOcclusion = SAMPLE_TEXTURE3D_LOD(apvRes.ProbeOcclusion, s_linear_clamp_sampler, uvw, 0).rgba;
    if (!all(probeOcclusion < 0.0001f))
    {
        probe.ProbeOcclusion = probeOcclusion * weight;
        poWeight += weight;
    }

    // Sky occlusion data
    if (_APVSkyOcclusionWeight > 0)
    {
        float4 SO_L0L1 = SAMPLE_TEXTURE3D_LOD(apvRes.SkyOcclusionL0L1, s_linear_clamp_sampler, uvw, 0).rgba * weight;

        if (SO_L0L1.x >= 0.0001f)
        {
            probe.SO_L0L1 += SO_L0L1 * weight;

            if (_APVSkyDirectionWeight > 0)
            {
                int3 texCoord = uvw * _APVPoolDim - 0.5f; // No interpolation for sky shading indices
                uint index = LOAD_TEXTURE3D(apvRes.SkyShadingDirectionIndices, texCoord).x * 255.0;

                probe.SO_Direction = index == 255 ? float3(0, 0, 0) : apvRes.SkyPrecomputedDirections[index].rgb * weight;
            }

            soWeight += weight;
        }
    }
}

void ScaleProbe(inout DilatedProbe probe, float weight)
{
    probe.L0   *= weight;
    probe.L1_0 *= weight;
    probe.L1_1 *= weight;
    probe.L1_2 *= weight;
    probe.L2_0 *= weight;
    probe.L2_1 *= weight;
    probe.L2_2 *= weight;
    probe.L2_3 *= weight;
    probe.L2_4 *= weight;
}

[numthreads(64, 1, 1)]
void DilateCell(uint3 id : SV_DispatchThreadID)
{
    if (id.x < _ProbeCount)
    {
        uint probeIdx = id.x;

        if (_NeedDilating[probeIdx] > 0)
        {
            float3 centralPosition = _ProbePositionsBuffer[probeIdx] - _APVWorldOffset;
            DilatedProbe probe = (DilatedProbe)0;

            float3 uvw;
            uint subdiv;
            float shWeight = 0, soWeight = 0, poWeight = 0;
            float3 uvwDelta = rcp(_APVPoolDim);

            float3 biasedPosWS;
            if (TryToGetPoolUVWAndSubdiv(FillAPVResources(), centralPosition, 0, 0, uvw, subdiv, biasedPosWS))
            {
                float stepSize = _APVMinBrickSize / 3.0f;

                // Inflate search radius a bit.
                float radius = 1.5f * _SearchRadius;

                int iterationCount = ceil(radius / stepSize);

                for (int x = -1; x <= 1; ++x)
                {
                    for (int y = -1; y <= 1; ++y)
                    {
                        for (int z = -1; z <= 1; ++z)
                        {
                            float3 dir = normalize(float3(x, y, z));
                            for (int s = 0; s < iterationCount; ++s)
                            {
                                float3 samplePos = centralPosition + dir * (stepSize * s);
                                float3 sampleUVW;
                                if (TryToGetPoolUVW(FillAPVResources(), samplePos, 0, 0, sampleUVW))
                                {
                                    float d = distance(samplePos, centralPosition);
                                    if (d > 0 && d <= _SearchRadius)
                                    {
                                        if (_SquaredDistanceWeight)
                                        {
                                            d *= d;
                                        }
                                        AddProbeSample(FillAPVResources(), sampleUVW, probe, rcp(d), shWeight, soWeight, poWeight);
                                    }
                                }
                            }
                        }
                    }
                }
            }

            if (shWeight > 0)
                ScaleProbe(probe, rcp(shWeight));

            if (soWeight > 0)
                probe.SO_L0L1 *= rcp(soWeight);
            probe.SO_Direction = SafeNormalize(probe.SO_Direction);

            if (poWeight > 0)
            {
                probe.ProbeOcclusion *= rcp(poWeight);
                probe.ProbeOcclusion = saturate(probe.ProbeOcclusion);
            }

            _OutputProbes[probeIdx] = probe;
        }
    }
}