using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Burst.Compiler.IL.Syntax;
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Cecil.Rocks;

namespace zzzUnity.Burst.CodeGen
{
    /// <summary>
    /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back.
    /// </summary>
    internal class FunctionPointerInvokeTransform
    {
        private struct CaptureInformation
        {
            public MethodReference Operand;

            public List<Instruction> Captured;
        }

        private Dictionary<TypeReference, (MethodDefinition method, Instruction instruction)> _needsNativeFunctionPointer;
        private Dictionary<MethodDefinition, TypeReference> _needsIl2cppInvoke;
        private Dictionary<MethodDefinition, List<CaptureInformation>> _capturedSets;
        private MethodDefinition _monoPInvokeAttributeCtorDef;
        private MethodDefinition _nativePInvokeAttributeCtorDef;
        private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef;
        private TypeReference _burstFunctionPointerType;
        private TypeReference _burstCompilerType;
        private TypeReference _systemType;
        private TypeReference _callingConventionType;

        private LogDelegate _debugLog;
        private int _logLevel;

        private AssemblyResolver _loader;

        private ErrorDiagnosticDelegate _errorReport;

        public readonly static bool enableInvokeAttribute = true;
#if UNITY_DOTSPLAYER
        public readonly static bool enableCalliOptimisation = true;
        public readonly static bool enableUnmangedFunctionPointerInject = false;
#else
        public readonly static bool enableCalliOptimisation = false;		// For now only run the pass on dots player/tiny
        public readonly static bool enableUnmangedFunctionPointerInject = true;
#endif

        public FunctionPointerInvokeTransform(AssemblyResolver loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0)
        {
            _loader = loader;

            _needsNativeFunctionPointer = new Dictionary<TypeReference, (MethodDefinition, Instruction)>();
            _needsIl2cppInvoke = new Dictionary<MethodDefinition, TypeReference>();
            _capturedSets = new Dictionary<MethodDefinition, List<CaptureInformation>>();
            _monoPInvokeAttributeCtorDef = null;
            _unmanagedFunctionPointerAttributeCtorDef = null;
            _nativePInvokeAttributeCtorDef = null;  // Only present on DOTS_PLAYER
            _burstFunctionPointerType = null;
            _burstCompilerType = null;
            _systemType = null;
            _callingConventionType = null;
            _debugLog = log;
            _logLevel = logLevel;
            _errorReport = error;
        }

        private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName)
        {
            if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result))
            {
                return result;
            }
            return null;
        }

        public void Initialize(AssemblyResolver loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem)
        {
            if (_monoPInvokeAttributeCtorDef == null)
            {
                var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst");

                _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1");
                _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler");

                var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference);
                _systemType = corLibrary.MainModule.Types.FirstOrDefault(x => x.FullName == "System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity

                if (enableUnmangedFunctionPointerInject)
                {
                    var unmanagedFunctionPointerAttribute = corLibrary.MainModule.GetType("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute");
                    _callingConventionType = corLibrary.MainModule.GetType("System.Runtime.InteropServices.CallingConvention");
                    _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType);
                }

#if UNITY_DOTSPLAYER
                var asmDef = assemblyDefinition;
                if (asmDef.Name.Name != "Unity.Runtime")
                    asmDef = GetAsmDefinitionFromFile(loader, "Unity.Runtime");

                if (asmDef == null)
                    return;

                var monoPInvokeAttribute = asmDef.MainModule.GetType("Unity.Jobs.MonoPInvokeCallbackAttribute");
                _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();

                var nativePInvokeAttribute = asmDef.MainModule.GetType("NativePInvokeCallbackAttribute");
                _nativePInvokeAttributeCtorDef = nativePInvokeAttribute.GetConstructors().First();
#else
                var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule");
                // bail if we can't find a reference, handled gracefully later
                if (asmDef == null)
                    return;

                var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute");
                _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();
#endif
            }

        }

        public bool Run(AssemblyDefinition assemblyDefinition)
        {
            Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem);

            var types = assemblyDefinition.MainModule.GetTypes().ToArray();
            foreach (var type in types)
            {
                CollectDelegateInvokesFromType(type);
            }

            return Finish();
        }

        public void CollectDelegateInvokesFromType(TypeDefinition type)
        {
            foreach (var m in type.Methods)
            {
                if (m.HasBody)
                {
                    CollectDelegateInvokes(m);
                }
            }
        }

        private bool ProcessUnmanagedAttributeFixups()
        {
            if (_unmanagedFunctionPointerAttributeCtorDef == null)
                return false;

            bool modified = false;

            foreach (var kp in _needsNativeFunctionPointer)
            {
                var delegateType = kp.Key;
                var instruction = kp.Value.instruction;
                var method = kp.Value.method;
                var delegateDef = delegateType.Resolve();

                var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName);

                // If there is already an an attribute present
                if (hasAttributeAlready!=null)
                {
                    if (hasAttributeAlready.ConstructorArguments.Count==1)
                    {
                        var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value;
                        if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl)
                        {
                            if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches");
                        }
                        else
                        {
                            // constructor with non cdecl calling convention
                            _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst.");
                        }
                    }
                    else
                    {
                        // Empty constructor which defaults to Winapi which is incompatable
                        _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst.");
                    }
                    continue;
                }

                var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef));
                attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl));
                delegateDef.CustomAttributes.Add(attribute);
                modified = true;
            }

            return modified;
        }

        private bool ProcessIl2cppInvokeFixups()
        {
            if (_monoPInvokeAttributeCtorDef == null)
                return false;

            bool modified = false;
            foreach (var invokeNeeded in _needsIl2cppInvoke)
            {
                var declaringType = invokeNeeded.Value;
                var implementationMethod = invokeNeeded.Key;

#if UNITY_DOTSPLAYER

                // At present always uses monoPInvokeAttribute because we don't currently know if the burst is enabled
                var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
                implementationMethod.CustomAttributes.Add(attribute);
                modified = true;

#else
                // Unity requires a type parameter for the attributecallback
                if (declaringType == null)
                {
                    _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}");
                    continue;
                }

                var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
                attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType)));
                implementationMethod.CustomAttributes.Add(attribute);
                modified = true;

                if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to  {implementationMethod}");
#endif
            }

            return modified;
        }

        private bool ProcessFunctionPointerInvokes()
        {
            var madeChange = false;
            foreach (var capturedData in _capturedSets)
            {
                var latePatchMethod = capturedData.Key;
                var capturedList = capturedData.Value;

                latePatchMethod.Body.SimplifyMacros();  // De-optimise short branches, since we will end up inserting instructions

                foreach(var capturedInfo in capturedList)
                {
                    var captured = capturedInfo.Captured;
                    var operand = capturedInfo.Operand;

                    if (captured.Count!=2)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference");
                        continue;
                    }

                    if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine}  latePatchMethod:{latePatchMethod}{Environment.NewLine}  captureList:{capturedList}{Environment.NewLine}  capture0:{captured[0]}{Environment.NewLine}  operand:{operand}");

                    var processor = latePatchMethod.Body.GetILProcessor();

                    var genericContext = GenericContext.From(operand, operand.DeclaringType);
                    CallSite callsite;
                    try
                    {
                        callsite = new CallSite(genericContext.Resolve(operand.ReturnType))
                        {
                            CallingConvention = MethodCallingConvention.C
                        };

                        for (int oo = 0; oo < operand.Parameters.Count; oo++)
                        {
                            var param = operand.Parameters[oo];
                            var ty = genericContext.Resolve(param.ParameterType);
                            callsite.Parameters.Add(new ParameterDefinition(param.Name, param.Attributes, ty));
                        }
                    }
                    catch (NullReferenceException)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Failed to resolve the generic context of `{operand}`");
                        continue;
                    }

                    // Make sure everything is in order before we make a change

                    var originalGetInvoke = captured[0];

                    if (originalGetInvoke.Operand is MethodReference mmr)
                    {
                        var genericMethodDef = mmr.Resolve();

                        var genericInstanceType = mmr.DeclaringType as GenericInstanceType;
                        var genericInstanceDef = genericInstanceType.Resolve();

                        // Locate the correct instance method - we know already at this point we have an instance of Function
                        MethodReference mr = default;
                        bool failed = true;
                        foreach (var m in genericInstanceDef.Methods)
                        {
                            if (m.FullName.Contains("get_Value"))
                            {
                                mr = m;
                                failed = false;
                                break;
                            }
                        }
                        if (failed)
                        {
                            _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference");
                            continue;
                        }

                        var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType)
                        {
                            HasThis = mr.HasThis,
                            ExplicitThis = mr.ExplicitThis,
                            CallingConvention = mr.CallingConvention
                        };
                        foreach (var param in mr.Parameters)
                            newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType));
                        foreach (var gparam in mr.GenericParameters)
                            newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef));
                        var importRef = latePatchMethod.Module.ImportReference(newGenericRef);
                        var newMethodCall = processor.Create(OpCodes.Call, importRef);

                        // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch
                        //the branch doesn't get updated.
                        originalGetInvoke.OpCode = newMethodCall.OpCode;
                        originalGetInvoke.Operand = newMethodCall.Operand;

                        // Add local to capture result
                        var newLocal = new VariableDefinition(mr.ReturnType);
                        latePatchMethod.Body.Variables.Add(newLocal);

                        // Store result of get_Value
                        var storeInst = processor.Create(OpCodes.Stloc, newLocal);
                        processor.InsertAfter(originalGetInvoke, storeInst);

                        // Swap invoke with calli
                        var calli = processor.Create(OpCodes.Calli, callsite);
                        // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch
                        processor.Replace(captured[1], calli);

                        // Insert load local prior to calli
                        var loadValue = processor.Create(OpCodes.Ldloc, newLocal);
                        processor.InsertBefore(calli, loadValue);

                        if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}");

                        madeChange = true;
                    }
                }

                latePatchMethod.Body.OptimizeMacros();  // Re-optimise branches
            }
            return madeChange;
        }

        public bool Finish()
        {
            bool madeChange = false;

            if (enableInvokeAttribute)
            {
                madeChange |= ProcessIl2cppInvokeFixups();
            }

            if (enableUnmangedFunctionPointerInject)
            {
                madeChange |= ProcessUnmanagedAttributeFixups();
            }

            if (enableCalliOptimisation)
            {
                madeChange |= ProcessFunctionPointerInvokes();
            }

            return madeChange;
        }

        private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance)
        {
            methodInstance = methodRef?.DeclaringType as GenericInstanceType;
            return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method);
        }

        private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method)
        {
            var methodInstance = methodRef?.DeclaringType as TypeReference;
            return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method);
        }

        private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i)
        {
            if (i.OpCode == OpCodes.Call)
            {
                var genInstMethod = i.Operand as GenericInstanceMethod;

                var isBurstCompilerCompileFunctionPointer = IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer");
                var isBurstFunctionPointerGetInvoke = IsBurstFunctionPointerMethod(i.Operand as MethodReference, "get_Invoke", out var instanceType);
                if (!(isBurstCompilerCompileFunctionPointer || isBurstFunctionPointerGetInvoke)) return;

                if (enableUnmangedFunctionPointerInject)
                {
                    var delegateType = isBurstCompilerCompileFunctionPointer ? genInstMethod.GenericArguments[0].Resolve() : instanceType.GenericArguments[0].Resolve();
                    // We check for null, since unfortunately it is possible that the call is wrapped inside
                    //another open delegate and we cannot determine the delegate type
                    if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType))
                    {
                        _needsNativeFunctionPointer.Add(delegateType, (m, i));
                    }
                }

                // No need to process further if its not a CompileFunctionPointer method
                if (!isBurstCompilerCompileFunctionPointer) return;

                if (enableInvokeAttribute)
                {
                    // Currently only handles the following pre-pattern (which should cover most common uses)
                    // ldftn ...
                    // newobj ...

                    if (i.Previous?.OpCode != OpCodes.Newobj)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}");
                        return;
                    }

                    var newObj = i.Previous;
                    if (newObj.Previous?.OpCode != OpCodes.Ldftn)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}");
                        return;
                    }

                    var ldFtn = newObj.Previous;

                    // Determine the delegate type
                    var methodDefinition = newObj.Operand as MethodDefinition;
                    var declaringType = methodDefinition?.DeclaringType;

                    // Fetch the implementation method
                    var implementationMethod = ldFtn.Operand as MethodDefinition;

                    var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x =>
                         (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName)
                         || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName));

                    if (hasInvokeAlready != null)
                    {
                        if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}");
                        return;
                    }

                    if (implementationMethod == null)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}");
                        return;
                    }

                    if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null)
                    {
                        _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}");
                        return;
                    }

                    // Need to add the custom attribute
                    if (!_needsIl2cppInvoke.ContainsKey(implementationMethod))
                    {
                        _needsIl2cppInvoke.Add(implementationMethod, declaringType);
                    }
                }
            }
        }

        [Obsolete("Will be removed in a future Burst verison")]
        public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i)
        {
            throw new NotImplementedException();
        }

        private void CollectDelegateInvokes(MethodDefinition m)
        {
            if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject))
                return;

            bool hitGetInvoke = false;
            TypeDefinition delegateType = null;
            List<Instruction> captured = null;

            foreach (var inst in m.Body.Instructions)
            {
                if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}");

                // Check for a FunctionPointerT creation
                if (enableUnmangedFunctionPointerInject || enableInvokeAttribute)
                {
                    LocateFunctionPointerTCreation(m, inst);
                }

                if (enableCalliOptimisation)
                {
                    if (!hitGetInvoke)
                    {
                        if (inst.OpCode != OpCodes.Call) continue;
                        if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue;

                        // At this point we have a call to a FunctionPointer.Invoke
                        hitGetInvoke = true;

                        delegateType = methodInstance.GenericArguments[0].Resolve();

                        captured = new List<Instruction>();

                        captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local
                    }
                    else
                    {
                        if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call))
                        {
                            // Don't perform transform across blocks
                            hitGetInvoke = false;
                        }
                        else
                        {
                            if (inst.OpCode == OpCodes.Callvirt)
                            {
                                if (inst.Operand is MethodReference mref)
                                {
                                    var method = mref.Resolve();

                                    if (method.DeclaringType == delegateType)
                                    {
                                        hitGetInvoke = false;

                                        List<CaptureInformation> storage = null;
                                        if (!_capturedSets.TryGetValue(m, out storage))
                                        {
                                            storage = new List<CaptureInformation>();
                                            _capturedSets.Add(m, storage);
                                        }

                                        // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli
                                        captured.Add(inst);
                                        var captureInfo = new CaptureInformation { Captured = captured, Operand = mref };
                                        if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}");
                                        storage.Add(captureInfo);
                                    }
                                }
                                else
                                {
                                    hitGetInvoke = false;
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}