From 535fdcd95283c59771eaa805df5f6b16c1353aff Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Hugo=20H=C3=B6rnquist?= <hugo@lysator.liu.se>
Date: Tue, 19 Sep 2023 07:35:56 +0200
Subject: [PATCH] Fix parsing of nested if statements.

---
 muppet/puppet/format/parser.py    | 35 ++++++++++++++++++-------------
 muppet/puppet/strings/__init__.py |  2 ++
 static-src/highlight/muppet.yaml  |  2 ++
 tests/test_parser_combinator.py   |  2 +-
 4 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/muppet/puppet/format/parser.py b/muppet/puppet/format/parser.py
index 6c4a3e7..7802286 100644
--- a/muppet/puppet/format/parser.py
+++ b/muppet/puppet/format/parser.py
@@ -186,7 +186,7 @@ class ParserFormatter(Serializer[ParseDirective]):
                 ensure => present,
             }
         """
-        return (ws & param.k &
+        return (ws & tag('qn', param.k) &
                 ws & param.arrow &
                 ws & self.s(param.v) &
                 # Technically only optional for final entry
@@ -214,7 +214,7 @@ class ParserFormatter(Serializer[ParseDirective]):
                 for item in xs:
                     parser &= ws & ',' & self.declaration_parameter(item)
                 parser &= ws & optional(s(',')) & ws & delim[1]
-        return name('declaration-parameters', parser)
+        return tag('declaration-parameters', parser)
 
     def known_array(self, delim: str, in_items: list[Puppet]) -> ParseDirective:
         """
@@ -257,8 +257,6 @@ class ParserFormatter(Serializer[ParseDirective]):
                 raise ValueError(f'Unexpected extra forms after else: {rest!r}')
 
             case [(test, body), *rest]:
-                # logger.warning("elsif clause, test: %s, body: %s", test, body)
-
                 # Recursive calls wrapped in lambdas, since they NEED
                 # to be lazily evaluated, since they are only valid in
                 # their branch ('else'/'elsif')
@@ -268,12 +266,12 @@ class ParserFormatter(Serializer[ParseDirective]):
                                 ws & self.s(body) &
                                 ws & '}') & (lambda: self.if_chain(rest))
 
+                inner = PuppetIfChain([(test, body), *rest])
                 else_parser = (ws & tag('keyword', 'else') &
                                ws & '{' &
-                               ws & (lambda: self.s(PuppetIfChain(rest))) &
+                               ws & (lambda: self.s(inner)) &
                                ws & '}')
 
-                # return elsif_parser | else_parser
                 return else_parser | elsif_parser
 
         raise ValueError(f"Bad if-chain: {chain!r}")
@@ -323,7 +321,7 @@ class ParserFormatter(Serializer[ParseDirective]):
     def _puppet_case(self, it: PuppetCase) -> ParseDirective:
         parser = ws & tag('keyword', 'case') & ws & self.s(it.test) & ws & '{'
 
-        for ((x, *xs), body) in it.cases:
+        for ([x, *xs], body) in it.cases:
             parser &= ws & self.s(x)
             for x in xs:
                 parser &= ws & ',' & ws & self.s(x)
@@ -333,11 +331,10 @@ class ParserFormatter(Serializer[ParseDirective]):
 
     @override
     def _puppet_class(self, it: PuppetClass) -> ParseDirective:
-        parser = (ws & tag('keyword', 'class') & ws & it.name &
+        parser = (ws & tag('keyword', 'class') & ws & tag('name', it.name) &
                   optional(ws & self.declaration_parameters('()', it.params)))
         parser &= optional(ws & 'inherits' & ws & tag('inherits', it.parent))
         parser &= ws & '{' & ws & self.s(it.body) & ws & '}'
-        # logger.warning(parser)
         return parser
 
     @override
@@ -355,13 +352,14 @@ class ParserFormatter(Serializer[ParseDirective]):
                          & optional(s('{'))
                          & ws
                          & optional(s('$'))
-                         & x
+                         & tag('var', x)
                          & ws
                          & optional(s('}')))
                     parser &= f
                 case PuppetString(st):
                     try:
-                        parser &= st
+                        for c in st:
+                            parser &= rich_char(c)
                     except ParseError:
                         for c in st:
                             parser &= rich_char(c)
@@ -373,6 +371,7 @@ class ParserFormatter(Serializer[ParseDirective]):
 
     @override
     def _puppet_declaration(self, it: PuppetDeclaration) -> ParseDirective:
+        # TODO tag with declaration
         return ws & self.s(it.k) & ws & '=' & ws & self.s(it.v)
 
     @override
@@ -410,7 +409,8 @@ class ParserFormatter(Serializer[ParseDirective]):
         # logger.warning("clauses: %s", it.clauses)
         (test1, body1), *rest = it.clauses
         assert test1 != 'else', f"Unexpected else clause: {it.clauses}"
-        parser = (ws & 'if'
+        # assert test1 != 'elsif', f"Unexpected elsif clause: {it.clauses}"
+        parser = (ws & tag('keyword', 'if')
                   & ws & self.s(test1)
                   & ws & '{'
                   & ws & self.s(body1)
@@ -444,7 +444,7 @@ class ParserFormatter(Serializer[ParseDirective]):
     def _puppet_lambda(self, it: PuppetLambda) -> ParseDirective:
         return tag('lambda',
                    self.declaration_parameters('||', it.params) &
-                   '{' & self.s(it.body) & '}')
+                   ws & '{' & self.s(it.body) & ws & '}')
 
     @override
     def _puppet_literal(self, it: PuppetLiteral) -> ParseDirective:
@@ -574,7 +574,10 @@ class ParserFormatter(Serializer[ParseDirective]):
         # A string with " as delimiter
         double_quoted = s('"') & [rich_char(c) for c in it.s] & '"'
 
-        parser = ws & (raw_string | single_quoted | double_quoted)
+        if it.s == '':
+            parser = ws & (single_quoted | double_quoted)
+        else:
+            parser = ws & (raw_string | single_quoted | double_quoted)
         return tag('string', parser)
 
     @override
@@ -588,7 +591,9 @@ class ParserFormatter(Serializer[ParseDirective]):
 
     @override
     def _puppet_var(self, it: PuppetVar) -> ParseDirective:
-        return name(f'${it.name}', ws & '$' & it.name)
+        # TODO highlight entire decalaration
+        # TODO hyperlink?
+        return name(f'${it.name}', ws & '$' & tag('var', it.name))
 
     @override
     def _puppet_virtual_query(self, it: PuppetVirtualQuery) -> ParseDirective:
diff --git a/muppet/puppet/strings/__init__.py b/muppet/puppet/strings/__init__.py
index 42d558d..9826658 100644
--- a/muppet/puppet/strings/__init__.py
+++ b/muppet/puppet/strings/__init__.py
@@ -53,6 +53,8 @@ class DocStringTag(Deserializable):
 
     :param text:
         Freeform text content of the tag.
+
+    TODO types
     """
 
     tag_name:   str
diff --git a/static-src/highlight/muppet.yaml b/static-src/highlight/muppet.yaml
index 93491ad..b60e6ce 100644
--- a/static-src/highlight/muppet.yaml
+++ b/static-src/highlight/muppet.yaml
@@ -1,6 +1,7 @@
 # Muppet's built in output
 comment:
   - comment
+  - line-comment
 error:
   - parse-error
 interpolate:
@@ -23,3 +24,4 @@ string:
 variable:
   - qn
   - var
+  - name
diff --git a/tests/test_parser_combinator.py b/tests/test_parser_combinator.py
index 19a104b..ea45940 100644
--- a/tests/test_parser_combinator.py
+++ b/tests/test_parser_combinator.py
@@ -182,7 +182,7 @@ def test_complement():
 def test_stringifiers():
     assert "'a'" == str(s("a"))
     assert "~ 'a'" == repr(~ s("a"))
-    assert "x" == str(name("x", space & space))
+    assert "[x]" == str(name("x", space & space))
     assert "('a' & 'b')" == str(s('a') & s('b'))
     assert "('a' | 'b')" == str(s('a') | s('b'))
     assert "char" == str(char)
-- 
GitLab