using System;
using System.Collections;
using NUnit.Framework;
using Unity.Mathematics;
using UnityEngine;
using UnityEngine.Rendering;
using UnityEngine.TestTools;
using Random = UnityEngine.Random;

public class VolumeBrickTest
{
    private ComputeShader testShader;
    private const uint bricks = 16;
    private const uint size = VolumeData.BrickSize * bricks;
    private Buf3D<float> rawTestData;
    private BrickDataResult expectedData;
    private VolumeData testVolume;

    [SetUp]
    public void SetUp()
    {
        testShader = Resources.Load<ComputeShader>("TestShader");
        Assert.IsNotNull(testShader, "Failed to load test shader!");
        
        rawTestData = new Buf3D<float>(new uint3(size));
        RandomData(rawTestData, 0f, 1f);
        testVolume = ScriptableObject.CreateInstance<VolumeData>();
        testVolume.Init(rawTestData, "test_data", Vector3.one);
        testVolume.BuildTextures();
        
        expectedData = testVolume.ComputeBrickData();
    }
    
    [Test]
    public void TestBrickGenerationSimple()
    {
        // lookup value
        for (uint bz = 0; bz < size; ++bz)
        {
            for (uint by = 0; by < size; ++by)
            {
                for (uint bx = 0; bx < size; ++bx)
                {
                    var index = new uint3(bx, by, bz);
                    var brick = index >> 3;

                    var range = new float2(expectedData.Range[brick].r, expectedData.Range[brick].g);
                    var ptr = new uint3(expectedData.Indirection[brick].r, expectedData.Indirection[brick].g,
                        expectedData.Indirection[brick].b);

                    var lookup = (ptr << 3) + (index & 7);
                    var value = expectedData.Atlas[lookup].r;
                    value = range.x + value * (range.y - range.x);
                    
                    // denormalize
                    //value = value * (brickData.MinMax.y - brickData.MinMax.x) + brickData.MinMax.x;

                    var expected = (rawTestData[index] - expectedData.MinMax.x) / (expectedData.MinMax.y - expectedData.MinMax.x);

                    Assert.AreEqual(expected, value, 1e-5);
                }
            }
        }
    }

    [Test]
    public void TestBuf3DIndex()
    {
        var buffer = new Buf3D<float>(new uint3(10, 10, 10));

        var index = new uint3((uint) Random.Range(0, 10), (uint) Random.Range(0, 10), (uint) Random.Range(0, 10));
        Assert.AreEqual(index, buffer.ToCoord((int) buffer.ToIndex(index)));
    }

    [Test]
    public void TestBuf3DSize()
    {
        var buffer = new Buf3D<float>(new uint3(size, size, size));
        Assert.AreEqual(size * size * size, buffer.GetSize());
        Assert.AreEqual(size * size * size, buffer.Data.Length);
    }

    [Test]
    public void TestRangeMips()
    {
        for (var i = 0; i < expectedData.RangeMipmaps.Length; ++i)
        {
            var rangeBuf = expectedData.RangeMipmaps[i];
            var dim = rangeBuf.Dimensions;
            for (uint bz = 0; bz < dim.z; ++bz)
            {
                for (uint by = 0; by < dim.y; ++by)
                {
                    for (uint bx = 0; bx < dim.x; ++bx)
                    {
                        var index = new uint3(bx, by, bz);
                        uint factor = (uint) 1 << (i + 1);
                        float rangeMin = float.MaxValue, rangeMax = float.MinValue;
                        for (uint z = 0; z < factor; ++z)
                        {
                            for (uint y = 0; y < factor; ++y)
                            {
                                for (uint x = 0; x < factor; ++x)
                                {
                                    var lookup = index * factor + new uint3(x, y, z);
                                    var val = expectedData.Range[lookup];
                                    rangeMin = Math.Min(rangeMin, val.r);
                                    rangeMax = Math.Max(rangeMax, val.g);
                                }
                            }
                        }
                        var actual = rangeBuf[index];
                        Assert.AreEqual(rangeMin, actual.r, 1e-3);
                        Assert.AreEqual(rangeMax, actual.g, 1e-3);
                    }
                }
            }
        }
    }

#pragma warning disable CS0219 // Variable is assigned but its value is never used
    [UnityTest]
    public IEnumerator TestBrickLookupShader()
    {
        Buf3D<float> result = null;
        yield return RunShader<float>(
            "test_lookup_density", 
            size,
            "result",
            RenderTextureFormat.RFloat,
            (index, value) =>
            {
                var expected = (rawTestData[index] - expectedData.MinMax.x) / (expectedData.MinMax.y - expectedData.MinMax.x);
                
                Assert.AreEqual(expected, value, 1e-3);
            });
    }
#pragma warning restore CS0219 // Variable is assigned but its value is never used
    
    [UnityTest]
    public IEnumerator TestBrickLookupRangeShader()
    {
        yield return RunShader<float2>(
            "test_lookup_range_direct",
            bricks,
            "result_range",
            RenderTextureFormat.RGFloat,
            (index, value) =>
            {
                Assert.AreEqual(expectedData.Range[index].r, value.x, 1e-3);
                Assert.AreEqual(expectedData.Range[index].g, value.y, 1e-3);
            });
    }

    [UnityTest] 
    public IEnumerator TestBrickLookupRangeMipShader()
    {
        for (var i = 0; i <= expectedData.RangeMipmaps.Length; ++i)
        {
            testShader.SetInt("range_mip", i);
            var itmp = i;
            yield return RunShader<float2>(
                "test_lookup_range_mip",
                bricks >> i,
                "result_range",
                RenderTextureFormat.RGFloat,
                (index, value) =>
                {
                    if (itmp == 0)
                    {
                        Assert.AreEqual(expectedData.Range[index].r, value.x, 1e-3);
                        Assert.AreEqual(expectedData.Range[index].g, value.y, 1e-3);
                    } else
                    {
                        Assert.AreEqual(expectedData.RangeMipmaps[itmp - 1][index].r, value.x, 1e-3);
                        Assert.AreEqual(expectedData.RangeMipmaps[itmp - 1][index].g, value.y, 1e-3);
                    }
                });
        }
    }
    
    [UnityTest]
    public IEnumerator TestBrickLookupIndexShader()
    {
        yield return RunShader<Color32>(
            "test_lookup_index", 
            bricks,
            "result_index",
            RenderTextureFormat.ARGB32,
            (index, value) =>
            {
                Assert.AreEqual(expectedData.Indirection[index], value);
            });
    }

    private IEnumerator RunShader<T>(string kernelName, uint usedSize, String resultName, RenderTextureFormat resultFormat, Action<uint3, T> callback)
    where T: struct
    {
        var resultTexture = new RenderTexture((int)usedSize, (int)usedSize, 0, resultFormat)
        {
            dimension = UnityEngine.Rendering.TextureDimension.Tex3D,
            volumeDepth = (int)usedSize,
            enableRandomWrite = true
        };
        resultTexture.Create();

        var kernelHandle = testShader.FindKernel(kernelName);
        var indices = new ShaderIndices(testShader, false);
        
        testVolume.Bind(testShader, kernelHandle, indices);
        
        testShader.SetTexture(kernelHandle, resultName, resultTexture);
        
        testShader.Dispatch(kernelHandle, (int) usedSize, (int) usedSize, (int) usedSize);

        var finishedComputing = false;
        T[] rawResult = null;

        AsyncGPUReadback.Request(resultTexture, 0, request =>
        {
            if (request.hasError)
            {
                Assert.Fail("Failed to run compute shader");
            }

            finishedComputing = true;
            rawResult = request.GetData<T>().ToArray();
        });

        yield return new WaitUntil(() => finishedComputing);
        
        var result = new Buf3D<T>(new uint3(usedSize, usedSize, 1));
        result.Data = rawResult;

        for (uint bz = 0; bz < 1; ++bz)
        {
            for (uint by = 0; by < usedSize; ++by)
            {
                for (uint bx = 0; bx < usedSize; ++bx)
                {
                    var index = new uint3(bx, by, bz);
                    callback(index, result[index]);
                }
            }
        }

        resultTexture.Release();
    }
    
    private void RandomData(Buf3D<float> data, float min, float max)
    {
        for (uint bz = 0; bz < data.Dimensions.z; ++bz)
        {
            for (uint by = 0; by < data.Dimensions.y; ++by)
            {
                for (uint bx = 0; bx < data.Dimensions.x; ++bx)
                {
                    data[new uint3(bx, by, bz)] = Random.Range(min, max);
                }
            }
        }
    }
}
