#pragma multi_compile GROUP_SIZE_1 GROUP_SIZE_2 GROUP_SIZE_4 GROUP_SIZE_8 GROUP_SIZE_16 GROUP_SIZE_32 GROUP_SIZE_64 GROUP_SIZE_128 GROUP_SIZE_256 GROUP_SIZE_512 GROUP_SIZE_1024

// Warning: Keep this kernel order in sync with WaveKernel enum in WaveEmulationTests.cs
#pragma kernel kAllEqual
#pragma kernel kAllTrue
#pragma kernel kAnyTrue
#pragma kernel kBallot
#pragma kernel kBitAnd
#pragma kernel kBitOr
#pragma kernel kBitXor
#pragma kernel kCountBits
#pragma kernel kMax
#pragma kernel kMin
#pragma kernel kProduct
#pragma kernel kSum
#pragma kernel kGetThreadCount
#pragma kernel kGetThreadIndex
#pragma kernel kIsFirstThread
#pragma kernel kPrefixCountBits
#pragma kernel kPrefixProduct
#pragma kernel kPrefixSum
#pragma kernel kReadThreadAtBroadcast
#pragma kernel kReadThreadAtShuffle
#pragma kernel kReadThreadFirst


// Required to define before using the threading library.
#if defined(GROUP_SIZE_1)
    #define THREADING_BLOCK_SIZE 1
#elif defined(GROUP_SIZE_2)
    #define THREADING_BLOCK_SIZE 2
#elif defined(GROUP_SIZE_4)
    #define THREADING_BLOCK_SIZE 4
#elif defined(GROUP_SIZE_8)
    #define THREADING_BLOCK_SIZE 8
#elif defined(GROUP_SIZE_16)
    #define THREADING_BLOCK_SIZE 16
#elif defined(GROUP_SIZE_32)
    #define THREADING_BLOCK_SIZE 32
#elif defined(GROUP_SIZE_64)
    #define THREADING_BLOCK_SIZE 64
#elif defined(GROUP_SIZE_128)
    #define THREADING_BLOCK_SIZE 128
#elif defined(GROUP_SIZE_256)
    #define THREADING_BLOCK_SIZE 256
#elif defined(GROUP_SIZE_512)
    #define THREADING_BLOCK_SIZE 512
#elif defined(GROUP_SIZE_1024)
    #define THREADING_BLOCK_SIZE 1024
#endif

#define THREADING_FORCE_WAVE_EMULATION

#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Common.hlsl"
#include "Packages/com.unity.render-pipelines.core/ShaderLibrary/Threading.hlsl"

typedef Threading::Group Group;
typedef Threading::GroupBallot GroupBallot;

ByteAddressBuffer   _Input;
RWByteAddressBuffer _Output;

// Macros
#define DATA ((uint)_Input.Load(group.dispatchID.x << 2))
#define OUTPUT_ADDR ((uint)(group.groupIndex << 2))

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kAllEqual(Group group)
{
    const uint result = (uint)group.AllEqual(DATA - group.groupIndex);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kAllTrue(Group group)
{
    const uint result = (uint)group.AllTrue(DATA == group.groupIndex);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kAnyTrue(Group group)
{
    const uint result = (uint)group.AnyTrue((DATA & (group.GetThreadCount() - 1)) == (group.GetThreadCount() / 2));

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kBallot(Group group)
{
    GroupBallot ballot = group.Ballot(DATA == group.groupIndex);

    uint numBits = ballot.CountBits();

    const uint result = (numBits == group.GetThreadCount());

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kBitAnd(Group group)
{
    const uint result = group.And(DATA == group.groupIndex);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kBitOr(Group group)
{
    const uint result = (group.Or(DATA != group.groupIndex) == 0);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kBitXor(Group group)
{
    const uint result = (group.Xor(DATA != group.groupIndex) == 0);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kCountBits(Group group)
{
    const uint result = (group.CountBits(DATA == group.groupIndex) == group.GetThreadCount());

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kMax(Group group)
{
    const uint result = (group.Max(DATA & (group.GetThreadCount() - 1)) == (group.GetThreadCount() - 1));

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kMin(Group group)
{
    const uint result = (group.Min(DATA & (group.GetThreadCount() - 1)) == 0);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kProduct(Group group)
{
    const uint result = group.Product((uint)(DATA == group.groupIndex));

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kSum(Group group)
{
    const uint result = (group.Sum((uint)(DATA == group.groupIndex)) == group.GetThreadCount());

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kGetThreadCount(Group group)
{
#if defined(THREADING_WAVE_SIZE)
    const uint result = (THREADING_BLOCK_SIZE == group.GetThreadCount());
#else
    // If we don't know the group size at compile time, we don't really have any way of verifying this functionality.
    // Unconditionally return 1 so this case always passes.
    const uint result = 1;
#endif

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kGetThreadIndex(Group group)
{

    const uint result = ((DATA & (group.GetThreadCount() - 1)) == group.GetThreadIndex());

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kIsFirstThread(Group group)
{
    const uint result = (group.GetThreadIndex() != 0) ^ group.IsFirstThread();

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kPrefixCountBits(Group group)
{
    const uint result = (group.PrefixCountBits(DATA == group.groupIndex) == group.GetThreadIndex());

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kPrefixProduct(Group group)
{
    const uint value = ((group.GetThreadIndex() % 64) == 0) ? 2u : 1u;

    const uint expectedValue = (1u << ((group.GetThreadIndex() / 64u) + 1));

    const uint result = ((group.PrefixProduct(value) * value) == expectedValue);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kPrefixSum(Group group)
{
    const uint value = group.GetThreadIndex();
    const uint expectedValue = (group.GetThreadIndex() * (group.GetThreadIndex() + 1u)) / 2u;

    const uint result = ((group.PrefixSum(value) + value) == expectedValue);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kReadThreadAtBroadcast(Group group)
{
    const uint laneValue = DATA & (group.GetThreadCount() - 1);

    const uint result = (group.ReadThreadAt(laneValue, (group.GetThreadCount() - 1)) == (group.GetThreadCount() - 1));

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kReadThreadAtShuffle(Group group)
{
    const uint laneValue = DATA & (group.GetThreadCount() - 1);

    const uint result = (group.ReadThreadShuffle(laneValue, laneValue) == laneValue);

    _Output.Store(OUTPUT_ADDR, result);
}

[numthreads(THREADING_BLOCK_SIZE, 1, 1)]
void kReadThreadFirst(Group group)
{
    const uint result = (group.ReadThreadFirst(DATA & (group.GetThreadCount() - 1)) == 0);

    _Output.Store(OUTPUT_ADDR, result);
}