"""
A basic JSON parser.

This parser shouldn't be used, but is instead here to demonstrate how
to use the parser combinator library.

Besides the obvious, note can be taken of
- handlers and their type transformations
- lambdas for lazy evaluation.
"""

from muppet.parser_combinator import (
    MatchCompound,
    MatchObject,
    ParseDirective,
    complement,
    const,
    count,
    digit,
    discard,
    hexdig,
    many,
    many1,
    name,
    optional,
    or_,
    s,
    space,
    tag,
)
from typing import Optional, TypeVar
import math


T = TypeVar('T')


def force(t: Optional[T]) -> T:
    """
    Discard the None part of an optional value.

    Only use this when you /know/ that the value exists.

    :raises AssertionError:
        If the value was ``None`` after all.
    """
    assert t
    return t


def handle_int(xs: list[MatchObject]) -> list[int]:
    """
    Convert matched to an integer.

    Apply when parsing integers, such as::

        (many digit) @ handle_int

    Note that this only works if adjacant joining is working.
    """
    assert type(xs[0]) is str
    return [int(xs[0])]


def _handle_exp(parts: list[MatchObject]) -> list[int]:
    """Convert the exponential part of a float to its integer value."""
    dig = __find('dig', parts)
    assert isinstance(dig, MatchCompound)
    total = dig.matched[0]
    if sign := __find('sign', parts):
        assert isinstance(sign, MatchCompound)
        if sign.matched[0][0] == '-':
            total *= -1
    return [total]


def _handle_number(parts: list[MatchObject]) -> list[float]:
    """
    Construct a float from its components.

    A float is structured as ``±{base}.{dec}e{exp}``.

    :param base:
        The integer part of the float. SHOULD be positive for this function.
    :param exp:
        The exponent part of the float.
    :param dec:
        The decimal part of the float.
    :param neg:
        Is the value negative?
    :returns:
        The constructed float.
    """
    total: float = 0
    print(parts)
    # string: str = ''
    if base := __find('base', parts):
        assert isinstance(base, MatchCompound)
        total += base.matched[0]

    if frac := __find('fractional', parts):
        assert isinstance(frac, MatchCompound)
        d = frac.matched[0]
        total += d / 10**(math.floor(math.log10(d)) + 1)

    if exp := __find('exp', parts):
        assert isinstance(exp, MatchCompound)
        total *= 10**exp.matched[0]

    if __find('minus', parts):
        total *= -1

    return [total]


ws = discard(name('ws', many(space)))
digit_19 = or_(*(chr(x + ord('0')) for x in range(1, 10)))

_hex_esc = (discard(r'\u') & count(hexdig, 3, 5)) @ (lambda x: chr(int(x[0], 16)))

_json_esc = (_hex_esc |
             s(r'\"') @ const('"') |
             s(r'\/') @ const("/") |
             s(r'\b') @ const("\b") |
             s(r'\n') @ const("\n") |
             s(r'\r') @ const("\r") |
             s(r'\t') @ const("\t") |
             s(r'\\') @ const("\\"))

_json_char = name('_json_char', _json_esc | complement(r'\"'))

json_string = name('json_string',
                   discard('"')
                   & many(_json_char)
                   & discard('"')) \
               @ (lambda x: x[0])

_fraction = discard(".") & tag('fractional', many1(digit) @ handle_int)

_exponent = tag('exp',
                (discard(s("e") | "E")
                 & optional(tag('sign', s("-") | "+"))
                 & tag('dig', many1(digit) @ handle_int)) @ _handle_exp)

json_number = (optional(tag('minus', "-"))
               & tag('base', (s("0") | digit_19 & many(digit)) @ handle_int)
               & optional(_fraction)
               & optional(_exponent)) @ _handle_number

_json_keyword = s("true")  @ (lambda _: [True]) \
              | s("false") @ (lambda _: [False]) \
              | s("null")  @ (lambda _: [None])


_json_kv = tag('kv',
               ws & tag('key', json_string) &
               ws & discard(":")
               & tag('value', lambda: json_value))

json_object = tag(
    'object', discard("{") & (_json_kv & many(discard(",") & _json_kv) | ws) & discard("}"))

json_array = tag('array', discard("[")
                 & (lambda: (json_value & many(discard(",") & json_value)) | ws)
                 & discard("]"))

json_value: ParseDirective \
        = (ws & (json_string |
                 json_number |
                 _json_keyword |
                 json_object |
                 json_array)
           & ws)


def __find(key: str, objs: list[MatchObject]) -> Optional[MatchObject]:
    """Locate the first matching object of type key."""
    for item in objs:
        match item:
            case MatchCompound(type=s) if s == key:
                return item
    return None