"""
"Parser" which instead "parses" the source code, and annotates it.

This is basically a parser combinator, implementing the grammar for
the exact file we were created from. While this might seem worthless,
this allows *really* good syntax highlighting.
"""

import logging
from .base import Serializer
from muppet.puppet.ast import (
    Puppet,
    PuppetAccess,
    PuppetArray,
    PuppetBinaryOperator,
    PuppetBlock,
    PuppetCall,
    PuppetCallMethod,
    PuppetCase,
    PuppetClass,
    PuppetCollect,
    PuppetConcat,
    PuppetDeclaration,
    PuppetDefine,
    PuppetExportedQuery,
    PuppetFunction,
    PuppetHash,
    PuppetHeredoc,
    PuppetIfChain,
    PuppetInstanciationParameter,
    PuppetInvoke,
    PuppetKeyword,
    PuppetLambda,
    PuppetLiteral,
    PuppetLiteralHeredoc,
    PuppetNode,
    PuppetNop,
    PuppetNumber,
    PuppetParenthesis,
    PuppetParseError,
    PuppetQn,
    PuppetQr,
    PuppetRegex,
    PuppetResource,
    PuppetResourceDefaults,
    PuppetResourceOverride,
    PuppetSelector,
    PuppetString,
    PuppetUnaryOperator,
    PuppetUnless,
    PuppetVar,
    PuppetVirtualQuery,

    PuppetDeclarationParameter,
)

from muppet.parser_combinator import (
    MatchObject,
    ParseDirective,
    ParseError,
    ParserCombinator,
    char,
    complement,
    count,
    hexdig,
    line_comment,
    many,
    name,
    nop,
    not_,
    optional,
    s,
    tag,
    all_,
    space,
)
# from muppet.parser_combinator import ws as primitive_ws

from typing import (
    Callable,
    Literal,
    TypeVar,
    # Optional,
    Sequence,
)

from dataclasses import dataclass


ws = name('ws', tag('ws', many(line_comment('#') | space)))


F = TypeVar('F', bound=Callable[..., object])

# TODO replace this decorator with
# from typing import override
# once the target python version is changed to 3.12


def override(f: F) -> F:
    """
    Return function unchanged.

    Placeholder @override annotator if the actual annotation isn't
    implemented in the current python version.
    """
    return f


logger = logging.getLogger(__name__)


@dataclass
class rich_char(ParseDirective):
    """A single character character in a string with full escaping."""

    c: str

    def run(self, parser: 'ParserCombinator') -> list[MatchObject]:  # noqa: D102
        snapshot = parser.snapshot()
        try:
            return parser.get(s(rf"\u{ord(self.c):04X}") | [r'\u{', count(hexdig, 2, 6), '}'])
        except ParseError:
            parser.restore(snapshot)

        match self.c:
            case '\\':
                return parser.get(r'\\')
            case '\n':
                return parser.get(s('\n') | r'\n')
            case '\r':
                return parser.get(s('\r') | r'\r')
            case '\t':
                return parser.get(s('\t') | r'\t')
            case ' ':
                return parser.get(s(' ') | r'\s')
            case '"':
                return parser.get(r'\"')
            case "'":
                return parser.get(s("'") | r"\'")
            case '$':
                return parser.get(s('$') | r'\$')
            case _:
                # Literal linebreaks can apparently be escaped inside
                # "rich" strings (thankfully not inside 'plain' strings).
                try:
                    return parser.get(s("\\\n") & self.c)
                except ParseError:
                    return parser.get(self.c)


class ParserFormatter(Serializer[ParseDirective]):
    """
    Reserialize AST by highlighting the original source code.

    :param source:
        The original source code. *Must* be the exact same source as
        used to construct the corresponinding Puppet ast object.
    :param seek:
        Current parsing position in the string.

        TODO make this private.
    """

    # parser: ParserCombinator

    # def __init__(self, source: str, file: Optional[str] = None):
    def __init__(self, source: str, file: str):
        # self.parser = ParserCombinator(source=source, file=file)
        self.file = file
        pass

    def declaration_parameter(self, item: PuppetDeclarationParameter) -> ParseDirective:
        """Build parser for the given declaration parameter."""
        type: ParseDirective
        value: ParseDirective

        type = self.s(item.type)
        value = optional(ws & '=' & ws & self.s(item.v))
        return name(f'decl-${item.k}', ws & type & ws & '$' & item.k & value)

    def instanciation_parameter(self, param: PuppetInstanciationParameter) -> ParseDirective:
        """
        Parse a declaration parameter.

        In the example below, `ensure => present,` is the part parsed.

        .. code-block:: puppet

            file { '/':
                ensure => present,
            }
        """
        return (ws & tag('qn', param.k) &
                ws & param.arrow &
                ws & self.s(param.v) &
                # Technically only optional for final entry
                ws & optional(s(',')))

    def declaration_parameters(
            self,
            delim: str,
            in_items: list[PuppetDeclarationParameter] | None) -> ParseDirective:
        """
        Build parser for the given list of declaration parameters.

        :param delim:
            A string of length two, containing a start and end delimiter.
        :param in_items:
            None and empty lists are both treated as empty lists. It's
            laid out like thit due to how the puppet parser works.
        """
        parser = ws & delim[0]
        match in_items:
            case [] | None:
                parser &= ws & delim[1]
            case [x, *xs]:
                parser &= self.declaration_parameter(x)
                for item in xs:
                    parser &= ws & ',' & self.declaration_parameter(item)
                parser &= ws & optional(s(',')) & ws & delim[1]
        return tag('declaration-parameters', parser)

    def known_array(self, delim: str, in_items: list[Puppet]) -> ParseDirective:
        """
        Read a delimted, comma separated, array.

        Reads the starting delimiter, a comma separated list of Puppet
        items, an optional ending comma, and the ending delimiter.

        :param delim:
            A string of length two, containing the starting and ending delimiter.
        :param in_items:
        """
        assert len(delim) == 2, "Delimiter should be the start and end character used."
        parser = ws & delim[0]
        match in_items:
            case []:
                parser &= ws & delim[1]
            case [x, *xs]:
                parser &= ws & self.s(x)
                for item in xs:
                    parser &= ws & ',' & ws & self.s(item)
                parser &= ws & optional(s(',')) & ws & delim[1]
        return parser

    def if_chain(self,
                 chain: list[tuple[Puppet | Literal['else'], list[Puppet]]]
                 ) -> ParseDirective:
        """Handle all trailing clauses in an if chain."""
        # logger.warning("chain: %a", chain)
        match chain:
            case []:
                return nop

            case [('else', body)]:
                # logger.warning("else clause, body: %s", body)
                return (ws & tag('keyword', 'else')
                        & ws & '{' & ws & self.s(body) & ws & '}')

            case [('else', body), *rest]:
                raise ValueError(f'Unexpected extra forms after else: {rest!r}')

            case [(test, body), *rest]:
                # Recursive calls wrapped in lambdas, since they NEED
                # to be lazily evaluated, since they are only valid in
                # their branch ('else'/'elsif')
                elsif_parser = (ws & tag('keyword', 'elsif') &
                                ws & self.s(test) &  # type: ignore
                                ws & '{' &
                                ws & self.s(body) &
                                ws & '}') & (lambda: self.if_chain(rest))

                inner = PuppetIfChain([(test, body), *rest])
                else_parser = (ws & tag('keyword', 'else') &
                               ws & '{' &
                               ws & (lambda: self.s(inner)) &
                               ws & '}')

                return else_parser | elsif_parser

        raise ValueError(f"Bad if-chain: {chain!r}")

    def s(self, it: Puppet | Sequence[Puppet] | None) -> ParseDirective:
        """Shorthand for self.serialize, but also handles None and lists."""
        match it:
            case Puppet():
                return self.serialize(it)
            case [x, *xs]:
                parser = ws & self.s(x)
                for x in xs:
                    parser &= ws & self.s(x)
                return parser
            case _:
                return nop

    def concat_access(self, item: PuppetAccess) -> ParseDirective:
        """
        Parse an access inside an interpolated string.

        The following string

        .. code-block:: puppet

            "Hello ${people['name']['first']}!"

        serializes more or less ass

        .. code-block:: lisp

            (concat
              (str "Hello ")
              (access (access (var people) 'name') 'first'))

        And the regular PuppetAccess procedure can't be used, since
        the leading '$' is optional here.

        Note that the delimiting "${" and "}" should be handled
        outside this procedure.
        """
        parser = ws
        match item:
            case PuppetAccess(PuppetVar(name), _):
                parser &= optional(s('$')) & name
            case PuppetAccess(PuppetAccess() as next, _):
                parser &= self.concat_access(next)
        parser &= ws & self.known_array("[]", item.args)
        return parser

    # --------------------------------------------------

    @override
    def _puppet_access(self, it: PuppetAccess) -> ParseDirective:
        return tag('access', ws & self.s(it.how) & ws & self.known_array('[]', it.args))

    @override
    def _puppet_array(self, it: PuppetArray) -> ParseDirective:
        return tag('array', ws & self.known_array('[]', it.items))

    @override
    def _puppet_binary_operator(self, it: PuppetBinaryOperator) -> ParseDirective:
        return ws & self.s(it.lhs) & ws & it.op & ws & self.s(it.rhs)

    @override
    def _puppet_block(self, it: PuppetBlock) -> ParseDirective:
        return ws & self.s(it.entries)

    @override
    def _puppet_call(self, it: PuppetCall) -> ParseDirective:
        return ws & self.s(it.func) & ws & self.known_array('()', it.args)

    @override
    def _puppet_call_method(self, it: PuppetCallMethod) -> ParseDirective:
        return ws & self.s(it.func) & \
                optional(ws & self.known_array('()', it.args)) & \
                optional(ws & self.s(it.block))

    @override
    def _puppet_case(self, it: PuppetCase) -> ParseDirective:
        parser = ws & tag('keyword', 'case') & ws & self.s(it.test) & ws & '{'

        for ([x, *xs], body) in it.cases:
            parser &= ws & self.s(x)
            for x in xs:
                parser &= ws & ',' & ws & self.s(x)
            parser &= ws & ':' & ws & '{' & ws & self.s(body) & ws & '}'
        parser &= ws & '}'
        return parser

    @override
    def _puppet_class(self, it: PuppetClass) -> ParseDirective:
        parser = (ws & tag('keyword', 'class') & ws & tag('name', it.name) &
                  optional(ws & self.declaration_parameters('()', it.params)))
        parser &= optional(ws & 'inherits' & ws & tag('inherits', it.parent))
        parser &= ws & '{' & ws & self.s(it.body) & ws & '}'
        return parser

    @override
    def _puppet_collect(self, it: PuppetCollect) -> ParseDirective:
        parser = ws & self.s(it.type) & ws & self.s(it.query)

        sub = ws & "{"
        for param in it.ops:
            sub &= self.instanciation_parameter(param)
        sub &= ws & '}'

        parser &= optional(sub)
        return parser

    @override
    def _puppet_concat(self, it: PuppetConcat) -> ParseDirective:
        parser = ws & '"'
        for fragment in it.fragments:
            match fragment:
                case PuppetVar(x):
                    f = (ws
                         & '$'
                         & optional(s('{'))
                         & ws
                         & optional(s('$'))
                         & tag('var', x)
                         & ws
                         & optional(s('}')))
                    parser &= f
                case PuppetString(st):
                    try:
                        for c in st:
                            parser &= rich_char(c)
                    except ParseError:
                        for c in st:
                            parser &= rich_char(c)
                case PuppetAccess():
                    # Needs to be separate from the "regular"
                    # PuppetAccess rule, since these variables
                    parser &= ws & "${" & self.concat_access(fragment)
                    parser &= ws & "}"
                case _:
                    # TODO "${x[10][20]}"
                    parser &= ws & "${" & ws & self.s(fragment) & ws & "}"
        parser &= s('"') & ws
        return parser

    @override
    def _puppet_declaration(self, it: PuppetDeclaration) -> ParseDirective:
        # TODO tag with declaration
        return ws & self.s(it.k) & ws & '=' & ws & self.s(it.v)

    @override
    def _puppet_define(self, it: PuppetDefine) -> ParseDirective:
        return (ws & tag('keyword', 'define') & ws & it.name &
                optional(ws & self.declaration_parameters('()', it.params)) &
                ws & '{' & ws & self.s(it.body) & ws & '}')

    @override
    def _puppet_exported_query(self, it: PuppetExportedQuery) -> ParseDirective:
        return ws & '<<|' & ws & self.s(it.filter) & ws & '|>>'

    @override
    def _puppet_function(self, it: PuppetFunction) -> ParseDirective:
        return (ws & tag('keyword', 'function') & ws & it.name &
                optional(ws & self.declaration_parameters('()', it.params)) &
                optional(ws & '>>' & self.s(it.returns)) &
                ws & '{' & ws & self.s(it.body) & ws & '}')

    @override
    def _puppet_hash(self, it: PuppetHash) -> ParseDirective:
        parser = ws & '{'
        for entry in it.entries:
            parser &= (ws & self.s(entry.k) &
                       ws & '=>' &
                       ws & self.s(entry.v) &
                       optional(ws & ','))
        parser &= ws & '}'
        return parser

    @override
    def _puppet_if_chain(self, it: PuppetIfChain) -> ParseDirective:
        if not it.clauses:
            return nop
        # logger.warning("clauses: %s", it.clauses)
        (test1, body1), *rest = it.clauses
        assert test1 != 'else', f"Unexpected else clause: {it.clauses}"
        # assert test1 != 'elsif', f"Unexpected elsif clause: {it.clauses}"
        parser = (ws & tag('keyword', 'if')
                  & ws & self.s(test1)
                  & ws & '{'
                  & ws & self.s(body1)
                  & ws & '}')

        # logger.warning("rest: %s", it.clauses)
        if rest:
            parser &= self.if_chain(rest)
        return parser

    @override
    def _puppet_instanciation_parameter(self, it: PuppetInstanciationParameter) -> ParseDirective:
        return ws & it.k & ws & it.arrow & ws & self.s(it.v) & optional(ws & ' &')

    @override
    def _puppet_invoke(self, it: PuppetInvoke) -> ParseDirective:
        parser = ws & self.s(it.func) & optional(ws & '(')
        match it.args:
            case [x, *xs]:
                parser &= ws & self.s(x)
                for x in xs:
                    parser &= ws & ',' & ws & self.s(x)
        parser &= optional(ws & ',') & optional(ws & ')')
        parser &= optional(ws & self.s(it.block))
        return tag('invoke', parser)

    @override
    def _puppet_keyword(self, it: PuppetKeyword) -> ParseDirective:
        return tag('keyword', ws & it.name)

    @override
    def _puppet_lambda(self, it: PuppetLambda) -> ParseDirective:
        return tag('lambda',
                   self.declaration_parameters('||', it.params) &
                   ws & '{' & self.s(it.body) & ws & '}')

    @override
    def _puppet_literal(self, it: PuppetLiteral) -> ParseDirective:
        return tag('literal', ws & it.literal)

    @override
    def _puppet_heredoc(self, it: PuppetHeredoc) -> ParseDirective:
        return nop

    @override
    def _puppet_literal_heredoc(self, it: PuppetLiteralHeredoc) -> ParseDirective:
        escape_switches = s('/') & many(s('n') | 'r' | 't' | 's' | '$' | 'u' | 'L')
        switches = optional(ws & escape_switches)

        if it.syntax:
            syntax = ws & ':' & ws & it.syntax
        else:
            syntax = nop

        unquoted_heredoc = many(complement('):/')) & syntax & switches
        quoted_heredoc = s('"') & many(complement('"')) & '"' & syntax & switches & ws & ')'

        heredoc_declarator = ws & '@(' & ws & (quoted_heredoc | unquoted_heredoc)

        # delim = stringify_match(delim_parts)

        parser = heredoc_declarator

        for line in it.content.split('\n'):
            parser &= ws & line.lstrip() & '\n'
        parser &= ws & '|' & optional(s('-')) & ws & many(all_(not_(ws), char))

        return parser

    @override
    def _puppet_node(self, it: PuppetNode) -> ParseDirective:
        parser = ws & 'node' & ws
        for match in it.matches:
            parser &= ws & match & ws & ","
        parser &= ws & "{" & ws & self.s(it.body) & "}"
        return parser

    @override
    def _puppet_nop(self, it: PuppetNop) -> ParseDirective:
        return nop

    @override
    def _puppet_number(self, it: PuppetNumber) -> ParseDirective:
        parser: ParseDirective = ws
        match (it.x, it.radix):
            case int(x), 8:
                parser &= s('0') & oct(x)[2:]
            case int(x), 16:
                parser &= s('0') & 'x' & hex(x)[2:]
            case x, None:
                parser &= str(x)
            case _:
                raise ValueError(f"Unexpected radix: {it.radix}")

        return parser

    @override
    def _puppet_parenthesis(self, it: PuppetParenthesis) -> ParseDirective:
        return ws & '(' & ws & self.s(it.form) & ws & ')'

    @override
    def _puppet_qn(self, it: PuppetQn) -> ParseDirective:
        return tag('qn', ws & it.name)

    @override
    def _puppet_qr(self, it: PuppetQr) -> ParseDirective:
        return tag('qr', ws & it.name)

    @override
    def _puppet_regex(self, it: PuppetRegex) -> ParseDirective:
        return tag('rx', ws & '/' & it.s.replace('/', r'\/') & '/')

    @override
    def _puppet_resource(self, it: PuppetResource) -> ParseDirective:
        parser = ws & self.s(it.type) & ws & '{'
        for key, params in it.bodies:
            parser &= ws & self.s(key) & ws & ':'
            for param in params:
                parser &= self.instanciation_parameter(param)
            parser &= ws & optional(s(';'))
        parser &= ws & '}'
        return parser

    @override
    def _puppet_resource_defaults(self, it: PuppetResourceDefaults) -> ParseDirective:
        parser = ws & self.s(it.type) & ws & '{' & ws
        for param in it.ops:
            parser &= self.instanciation_parameter(param)
        parser &= ws & '}'
        return parser

    @override
    def _puppet_resource_override(self, it: PuppetResourceOverride) -> ParseDirective:
        parser = ws & self.s(it.resource) & ws & '{' & ws
        for param in it.ops:
            parser &= self.instanciation_parameter(param)
        parser &= ws & '}'
        return parser

    @override
    def _puppet_selector(self, it: PuppetSelector) -> ParseDirective:
        parser = ws & self.s(it.resource) & ws & '?' & ws & '{'
        for key, body in it.cases:
            parser &= ws & self.s(key) & ws & '=>' & ws & self.s(body) & ws & optional(s(','))
        parser &= ws & '}'
        return parser

    @override
    def _puppet_string(self, it: PuppetString) -> ParseDirective:
        # get one char to find delimiter
        # Then read chars until matching delimiter (or parse expected
        # string)

        # A string without any delimiters at all
        raw_string = s(it.s)

        # A string with ' as delimiter
        single_quoted = s("'") & [{"'": r"\'",
                                   '\\': s(r'\\') | '\\'
                                   }.get(c, c) for c in it.s] & "'"

        # A string with " as delimiter
        double_quoted = s('"') & [rich_char(c) for c in it.s] & '"'

        if it.s == '':
            parser = (single_quoted | double_quoted)
        else:
            parser = (raw_string | single_quoted | double_quoted)
        # TODO should the whitespace really be here?
        return ws & tag('string', name(it.s, parser))

    @override
    def _puppet_unary_operator(self, it: PuppetUnaryOperator) -> ParseDirective:
        return ws & it.op & ws & self.s(it.x)

    @override
    def _puppet_unless(self, it: PuppetUnless) -> ParseDirective:
        parser = (ws & 'unless' & ws & self.s(it.condition) & ws & '{' &
                  ws & self.s(it.consequent) & ws & '}')
        parser &= optional(ws & 'else' & ws & '{' & ws & self.s(it.alternative) &
                           ws & '}')
        return parser

    @override
    def _puppet_var(self, it: PuppetVar) -> ParseDirective:
        # TODO highlight entire decalaration
        # TODO hyperlink?
        return name(f'${it.name}', ws & '$' & tag('var', it.name))

    @override
    def _puppet_virtual_query(self, it: PuppetVirtualQuery) -> ParseDirective:
        return ws & '<|' & ws & self.s(it.q) & ws & '|>'

    @override
    def _puppet_parse_error(self, it: PuppetParseError) -> ParseDirective:
        logger.fatal(it)
        raise Exception(it)
        # return MatchObject('', self.parser.get())