﻿using Implab;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Implab.Automaton {
    public class DFATransitionTable : IDFATableBuilder {
        DFAStateDescriptior[] m_dfaTable;

        int m_stateCount;
        int m_symbolCount;
        int m_initialState;

        readonly HashSet<int> m_finalStates = new HashSet<int>();
        readonly HashSet<AutomatonTransition> m_transitions = new HashSet<AutomatonTransition>();


        #region IDFADefinition implementation

        public DFAStateDescriptior[] GetTransitionTable() {
            if (m_dfaTable == null) {
                if (m_stateCount <= 0)
                    throw new InvalidOperationException("Invalid automaton definition: states count = {0}", m_stateCount);
                if (m_symbolCount <= 0)
                    throw new InvalidOperationException("Invalid automaton definition: symbols count = {0}", m_symbolCount);

                m_dfaTable = ConstructTransitionTable();
            }
            return m_dfaTable;
        }

        public bool IsFinalState(int s) {
            Safe.ArgumentInRange(s, 0, m_stateCount, "s");

            return m_dfaTable != null ? m_dfaTable[s].final :  m_finalStates.Contains(s);
        }

        public IEnumerable<int> FinalStates {
            get {
                return m_finalStates;
            }
        }

        public int StateCount {
            get { return m_stateCount; }
        }

        public int AlphabetSize {
            get { return m_symbolCount; }
        }

        public int InitialState {
            get { return m_initialState; }
        }

        #endregion

        protected virtual DFAStateDescriptior[] ConstructTransitionTable() {
            var dfaTable = new DFAStateDescriptior[m_stateCount];

            foreach (var pair in m_finalStates) {
                var idx = pair.Key;

                dfaTable[idx].final = true;
                dfaTable[idx].tag = pair.Value;
            }

            foreach (var t in m_transitions) {
                if (dfaTable[t.s1].transitions == null) {
                    dfaTable[t.s1].transitions = new int[m_symbolCount];
                    for (int i = 0; i < dfaTable[t.s1].transitions.Length; i++)
                        dfaTable[t.s1].transitions[i] = DFAConst.UNREACHABLE_STATE;
                }

                dfaTable[t.s1].transitions[t.edge] = t.s2;
            }
        }

        #region IDFADefinitionBuilder

        public void DefineTransition(int s1, int s2, int symbol) {
            if (m_dfaTable != null)
                throw new InvalidOperationException("The transition table is already built");
            
            Safe.ArgumentAssert(s1 > 0, "s1");
            Safe.ArgumentAssert(s2 > 0, "s2");
            Safe.ArgumentAssert(symbol >= 0, "symbol");

            m_stateCount = Math.Max(Math.Max(m_stateCount, s1 + 1), s2 + 1);
            m_symbolCount = Math.Max(m_symbolCount, symbol + 1);

            m_transitions.Add(new AutomatonTransition(s1, s2, symbol));
        }

        public void MarkFinalState(int state, params TTag[] tags) {
            if (m_dfaTable != null)
                throw new InvalidOperationException("The transition table is already built");
            
            m_finalStates[state] = tags;
        }

        public void SetInitialState(int s) {
            Safe.ArgumentAssert(s >= 0, "s");
            m_initialState = s;
        }


        #endregion

        protected void Optimize<TInput, TState>(
            IDFATableBuilder<TTag> optimalDFA,
            IAlphabet<TInput> inputAlphabet,
            IAlphabetBuilder<TInput> optimalInputAlphabet,
            IAlphabet<TState> stateAlphabet,
            IAlphabetBuilder<TState> optimalStateAlphabet
        ) {
            Safe.ArgumentNotNull(optimalDFA, "dfa");
            Safe.ArgumentNotNull(optimalInputAlphabet, "optimalInputAlphabet");
            Safe.ArgumentNotNull(optimalStateAlphabet, "optimalStateAlphabet");
            Safe.ArgumentNotNull(inputAlphabet, "inputAlphabet");
            Safe.ArgumentNotNull(stateAlphabet, "stateAlphabet");

            if (inputAlphabet.Count != m_symbolCount)
                throw new InvalidOperationException("The input symbols aphabet mismatch");
            if (stateAlphabet.Count != m_stateCount)
                throw new InvalidOperationException("The states alphabet mismatch");

            var setComparer = new CustomEqualityComparer<HashSet<int>>(
                (x, y) => x.SetEquals(y),
                s => s.Sum(x => x.GetHashCode())
            );

            var arrayComparer = new CustomEqualityComparer<TTag[]>(
                (x,y) => (new HashSet<int>(x)).SetEquals(new HashSet<int>(y)),
                a => a.Sum(x => x.GetHashCode())
            );

            var optimalStates = new HashSet<HashSet<int>>(setComparer);
            var queue = new HashSet<HashSet<int>>(setComparer);

            // получаем конечные состояния, сгруппированные по маркерам
            optimalStates.UnionWith(
                m_finalStates
                .GroupBy(pair => pair.Value, arrayComparer)
                .Select(
                    g => new HashSet<int>(
                        g.Select( pair => pair.Key)
                    )
                )
            );

            var state = new HashSet<int>(
                Enumerable
                .Range(0, m_stateCount - 1)
                .Where(i => !m_finalStates.ContainsKey(i))
            );
            optimalStates.Add(state);
            queue.Add(state);

            var rmap = m_transitions
                .GroupBy(t => t.s2)
                .ToLookup(
                    g => g.Key, // s2
                    g => g.ToLookup(t => t.edge, t => t.s1)
                );

            while (queue.Count > 0) {
                var stateA = queue.First();
                queue.Remove(stateA);

                for (int c = 0; c < m_symbolCount; c++) {
                    var stateX = new HashSet<int>();
                    foreach(var a in stateA)
                        stateX.UnionWith(rmap[a][c]); // all states from wich 'c' leads to 'a'

                    foreach (var stateY in optimalStates.ToArray()) {
                        if (stateX.Overlaps(stateY) && !stateY.IsSubsetOf(stateX)) {
                            var stateR1 = new HashSet<int>(stateY);
                            var stateR2 = new HashSet<int>(stateY);

                            stateR1.IntersectWith(stateX);
                            stateR2.ExceptWith(stateX);

                            optimalStates.Remove(stateY);
                            optimalStates.Add(stateR1);
                            optimalStates.Add(stateR2);

                            if (queue.Contains(stateY)) {
                                queue.Remove(stateY);
                                queue.Add(stateR1);
                                queue.Add(stateR2);
                            } else {
                                queue.Add(stateR1.Count <= stateR2.Count ? stateR1 : stateR2);
                            }
                        }
                    }
                }
            }

            // карта получения оптимального состояния по соотвествующему ему простому состоянию
            var statesMap = stateAlphabet.Reclassify(optimalStateAlphabet, optimalStates);

            // получаем минимальный алфавит
            // входные символы не различимы, если Move(s,a1) == Move(s,a2)
            var optimalAlphabet = m_transitions
                .GroupBy(t => Tuple.Create(statesMap[t.s1], statesMap[t.s2]), t => t.edge);

            var alphabetMap = inputAlphabet.Reclassify(optimalInputAlphabet, optimalAlphabet);

            var optimalTags = m_finalStates
                .GroupBy(pair => statesMap[pair.Key])
                .ToDictionary(
                    g => g.Key,
                    g => g.SelectMany(pair => pair.Value).ToArray()
                );

            // построение автомата
            optimalDFA.SetInitialState(statesMap[m_initialState]);

            foreach (var pair in optimalTags)
                optimalDFA.MarkFinalState(pair.Key, pair.Value);

            foreach (var t in m_transitions.Select(t => new AutomatonTransition(statesMap[t.s1],statesMap[t.s2],alphabetMap[t.edge])).Distinct())
                optimalDFA.DefineTransition(t.s1, t.s2, t.edge);
            
        }

        protected void PrintDFA<TInput, TState>(IAlphabet<TInput> inputAlphabet, IAlphabet<TState> stateAlphabet) {
            Safe.ArgumentNotNull(inputAlphabet, "inputAlphabet");
            Safe.ArgumentNotNull(stateAlphabet, "stateAlphabet");

            var inputMap = inputAlphabet.CreateReverseMap();
            var stateMap = stateAlphabet.CreateReverseMap();

            for (int i = 0; i < inputMap.Length; i++) 
                Console.WriteLine("C{0}: {1}", i, String.Join(",", inputMap[i]));
            

            foreach(var t in m_transitions)
                Console.WriteLine(
                    "[{0}] -{{{1}}}-> [{2}]{3}",
                    stateMap[t.s1],
                    String.Join(",", inputMap[t.edge]),
                    stateMap[t.s2],
                    m_finalStates.ContainsKey(t.s2) ? "$" : ""
                );

        }

    }
}
