#pragma kernel Clear
#pragma kernel ClearBuffer
#pragma kernel FillUVMap
#pragma kernel JumpFlooding
#pragma kernel FinalPass
#pragma kernel VoxelizeProbeVolumeData
#pragma kernel Subdivide

#pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch

// #pragma enable_d3d11_debug_symbols

#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Color.hlsl"
#include "Packages/com.unity.render-pipelines.core/Editor/Lighting/ProbeVolume/ProbePlacement.cs.hlsl"

Texture3D _Input;
Texture3D _ProbeVolumeData;
RWTexture3D<float4> _Output;
RWTexture3D<float4> _FinalOutput;

AppendStructuredBuffer<float3> _Bricks;
RWStructuredBuffer<float3> _BricksToClear;

StructuredBuffer<GPUProbeVolumeOBB> _ProbeVolumes;

float4 _Size;
float4 _MaxBrickSize;
float4 _VolumeSizeInBricks;
float4 _VolumeOffsetInBricks;
float4 _VolumeWorldOffset;
float4 _SDFSize;
float _Offset;
float _BrickSize;
float _SubdivisionLevel;
float _MaxSubdivisionLevel;
float _ProbeVolumeCount;
float _ClearValue;

[numthreads(8,8,1)]
void Clear(uint3 id : SV_DispatchThreadID)
{
    if (any(id >= uint3(_Size.xyz)))
        return;

    _Output[id] = _ClearValue;
}

[numthreads(64,1,1)]
void ClearBuffer(uint3 id : SV_DispatchThreadID)
{
    _BricksToClear[id.x] = 0;
}

[numthreads(8,8,1)]
void FillUVMap(uint3 id : SV_DispatchThreadID)
{
    float4 input = _Input[id];

    if (input.a > 0.5)
    {
        _Output[id] = float4(id / _Size.x, 1);
    }
    else
    {
        _Output[id] = 0; // mark UV as invalid with w = 0
    }
}

static float3 offset3D[27] =
{
    float3(-1, -1, -1),
    float3(-1,  0, -1),
    float3(-1,  1, -1),
    float3( 0, -1, -1),
    float3( 0,  0, -1),
    float3( 0,  1, -1),
    float3( 1, -1, -1),
    float3( 1,  0, -1),
    float3( 1,  1, -1),

    float3(-1, -1,  0),
    float3(-1,  0,  0),
    float3(-1,  1,  0),
    float3( 0, -1,  0),
    float3( 0,  0,  0),
    float3( 0,  1,  0),
    float3( 1, -1,  0),
    float3( 1,  0,  0),
    float3( 1,  1,  0),

    float3(-1, -1,  1),
    float3(-1,  0,  1),
    float3(-1,  1,  1),
    float3( 0, -1,  1),
    float3( 0,  0,  1),
    float3( 0,  1,  1),
    float3( 1, -1,  1),
    float3( 1,  0,  1),
    float3( 1,  1,  1),
};

float Distance(float3 pos)
{
    return length(pos);
}

[numthreads(8,8,1)]
void JumpFlooding(uint3 id : SV_DispatchThreadID)
{
    float4 nearest = _Input[id];

    if (nearest.w < 0.5)
        nearest = float4(-10, -10, -10, 0);

    int j = 0;
    for (int i = 0; i < 27; i++)
    {
        float3 o = id + offset3D[i] * _Offset;

        o %= _Size.x;
        o.x += (o.x < 0) ? _Size.x : 0;
        o.y += (o.y < 0) ? _Size.x : 0;
        o.z += (o.z < 0) ? _Size.x : 0;
        float4 n1 = _Output[o];

        // Discard invalid samples
        if (n1.w < 0.5)
            continue;

        float3 uv = id / _Size.x;
        if (Distance(uv - n1.xyz) < Distance(uv - nearest.xyz))
            nearest = n1;

    }
    _Output[id] = nearest;
}

[numthreads(8, 8, 1)]
void FinalPass(uint3 id : SV_DispatchThreadID)
{
    float3 defaultUV = id / _Size.x;
    float3 nearestUV = _Input[id].xyz;

    float dist = Distance(nearestUV - defaultUV);

    _Output[id] = float4(dist.xxx, 1);
}

float2 ProjectOBB(GPUProbeVolumeOBB obb, float3 axis)
{
    float pMin = dot(axis, obb.corner);
    float pMax = pMin;

    for (int x = 0; x < 2; x++)
    {
        for (int y = 0; y < 2; y++)
        {
            for (int z = 0; z < 2; z++)
            {
                float3 vert = obb.corner + obb.X * x + obb.Y * y + obb.Z * z;

                float proj = dot(axis, vert);

                if (proj < pMin)
                    pMin = proj;
                else if (proj > pMax)
                    pMax = proj;
            }
        }
    }

    return float2(pMin, pMax);
}

bool PointInsideOBB(GPUProbeVolumeOBB obbA, GPUProbeVolumeOBB obbB)
{
    float3 axes[] =
    {
        normalize(obbA.X),
        normalize(obbA.Y),
        normalize(obbA.Z),
        normalize(obbB.X),
        normalize(obbB.Y),
        normalize(obbB.Z),
    };

    for (int i = 0; i < 3; i++)
    {
        float2 obbAProjection = ProjectOBB(obbA, axes[i]);
        float2 obbBProjection = ProjectOBB(obbB, axes[i]);

        if (obbAProjection.y < obbBProjection.x || obbBProjection.y < obbAProjection.x)
            return false;
    }

    return true;
}

[numthreads(8, 8, 1)]
void VoxelizeProbeVolumeData(uint3 id : SV_DispatchThreadID)
{
    float2 subdivisionLimits = float2(0, 0);
    float maxGeometryDistance = -1e20;
    float maxSubdivLevelInsideVolume = 0;

    if (any(id >= uint3(_MaxBrickSize.xyz)))
        return;

    // Calculate the brick
    GPUProbeVolumeOBB brickOBB;
    brickOBB.corner = id * _BrickSize + _VolumeWorldOffset.xyz;
    brickOBB.X = float3(_BrickSize, 0, 0);
    brickOBB.Y = float3(0, _BrickSize, 0);
    brickOBB.Z = float3(0, 0, _BrickSize);

    for (int i = 0; i < _ProbeVolumeCount; i++)
    {
        GPUProbeVolumeOBB obb = _ProbeVolumes[i];

        if (PointInsideOBB(obb, brickOBB))
        {

            subdivisionLimits.x = max(subdivisionLimits.x, obb.minControllerSubdivLevel);
            subdivisionLimits.y = max(subdivisionLimits.y, obb.maxControllerSubdivLevel);
            maxGeometryDistance = max(maxGeometryDistance, obb.geometryDistanceOffset);
            maxSubdivLevelInsideVolume = max(maxSubdivLevelInsideVolume, obb.maxSubdivLevelInsideVolume);
        }
    }

    _Output[id] = float4(subdivisionLimits, maxGeometryDistance, maxSubdivLevelInsideVolume);
}

[numthreads(8, 8, 1)]
void Subdivide(uint3 id : SV_DispatchThreadID)
{
    if (any(id >= uint3(_MaxBrickSize.xyz)))
        return;

    // Compute the position at the center of the voxel
    float3 position01 = (float3(id) + 0.5) / _MaxBrickSize.xyz;

    // Get the local min and max subdivision levels
    float4 subdivisionLevelData = _ProbeVolumeData.Load(uint4(id, _SubdivisionLevel));

    // Discard bricks that are below the max subdiv level of the probe volume
    int maxControllerSubdivLevel = _MaxSubdivisionLevel - subdivisionLevelData.y;
    int maxSubdivLevelInsideVolume = subdivisionLevelData.w;
    if (_SubdivisionLevel < maxControllerSubdivLevel || _SubdivisionLevel > maxSubdivLevelInsideVolume)
        return;


    uint3 sdfId = floor(position01 * _SDFSize.xyz);

    float dist = _Input[sdfId].r;

    float voxelDetectionDistance = rcp(_MaxBrickSize.x);
    voxelDetectionDistance *= voxelDetectionDistance;
    voxelDetectionDistance = sqrt(voxelDetectionDistance + voxelDetectionDistance + voxelDetectionDistance) / 2.0;

    int minControllerSubdivLevel = _MaxSubdivisionLevel - subdivisionLevelData.x - 1;

    // Add a small offset to control how the subdivision looks from the probe volume
    voxelDetectionDistance += subdivisionLevelData.z * rcp(_MaxBrickSize.x);

    if (dist <= voxelDetectionDistance || _SubdivisionLevel > minControllerSubdivLevel)
    {
        // transform id to world position
        float3 positionInCell = _VolumeOffsetInBricks.xyz + (id / _MaxBrickSize.xyz) * _VolumeSizeInBricks.xyz;
        _Bricks.Append(positionInCell);
    }
}