From d9012b0784b1f6e1c2957fcecf880f60a7ae4dcf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Hugo=20H=C3=B6rnquist?= <hugo@lysator.liu.se>
Date: Wed, 20 Sep 2023 23:02:21 +0200
Subject: [PATCH] Fix access and funcalls in string interpolation.

---
 muppet/parser_combinator.py    |   2 +-
 muppet/puppet/format/parser.py |  99 +++++++++++++------------
 tests/test_parse_elsif.py      | 132 ++++++++++++++++++++++++++++++++-
 3 files changed, 181 insertions(+), 52 deletions(-)

diff --git a/muppet/parser_combinator.py b/muppet/parser_combinator.py
index 1191088..0d99c04 100644
--- a/muppet/parser_combinator.py
+++ b/muppet/parser_combinator.py
@@ -723,7 +723,7 @@ Parser for a single hexadecimal digit.
 Both upper and lower case are supported
 """
 
-space = s(' ') | '\t' | '\n' | '\r'
+space = name('space', s(' ') | '\t' | '\n' | '\r')
 """
 Parses a single whitespace token.
 
diff --git a/muppet/puppet/format/parser.py b/muppet/puppet/format/parser.py
index ee270d2..c942ee8 100644
--- a/muppet/puppet/format/parser.py
+++ b/muppet/puppet/format/parser.py
@@ -65,9 +65,9 @@ from muppet.parser_combinator import (
     hexdig,
     line_comment,
     many,
+    many1,
     name,
     nop,
-    not_,
     optional,
     s,
     tag,
@@ -135,7 +135,7 @@ class rich_char(ParseDirective):
             case ' ':
                 return parser.get(s(' ') | r'\s')
             case '"':
-                return parser.get(r'\"')
+                return parser.get(s(r'\"') | '"')
             case "'":
                 return parser.get(s("'") | r"\'")
             case '$':
@@ -294,39 +294,6 @@ class ParserFormatter(Serializer[ParseDirective]):
             case _:
                 return nop
 
-    def concat_access(self, item: PuppetAccess) -> ParseDirective:
-        """
-        Parse an access inside an interpolated string.
-
-        The following string
-
-        .. code-block:: puppet
-
-            "Hello ${people['name']['first']}!"
-
-        serializes more or less ass
-
-        .. code-block:: lisp
-
-            (concat
-              (str "Hello ")
-              (access (access (var people) 'name') 'first'))
-
-        And the regular PuppetAccess procedure can't be used, since
-        the leading '$' is optional here.
-
-        Note that the delimiting "${" and "}" should be handled
-        outside this procedure.
-        """
-        parser: ParseDirective = ws
-        match item:
-            case PuppetAccess(PuppetVar(name), _):
-                parser &= optional(s('$')) & name
-            case PuppetAccess(PuppetAccess() as next, _):
-                parser &= self.concat_access(next)
-        parser &= ws & self.known_array("[]", item.args)
-        return parser
-
     # --------------------------------------------------
 
     @override
@@ -403,20 +370,18 @@ class ParserFormatter(Serializer[ParseDirective]):
                          & optional(s('}')))
                     parser &= f
                 case PuppetString(st):
+                    # TODO what am I doing here? ParseError should
+                    # never be throw, since we aren't running the
+                    # parser, simply building it. And why is the
+                    # except block the exact same as the try block?
                     try:
                         for c in st:
                             parser &= rich_char(c)
                     except ParseError:
                         for c in st:
                             parser &= rich_char(c)
-                case PuppetAccess():
-                    # Needs to be separate from the "regular"
-                    # PuppetAccess rule, since these variables
-                    parser &= ws & "${" & self.concat_access(fragment)
-                    parser &= ws & "}"
                 case _:
-                    # TODO "${x[10][20]}"
-                    parser &= ws & "${" & ws & self.s(fragment) & ws & "}"
+                    parser &= many(space) & "${" & ws & self.s(fragment) & ws & "}"
         parser &= s('"') & ws
         return parser
 
@@ -529,7 +494,41 @@ class ParserFormatter(Serializer[ParseDirective]):
 
     @override
     def _puppet_heredoc(self, it: PuppetHeredoc) -> ParseDirective:
-        return nop
+        # TODO The header and footer is mostly shared wiht literal
+        # heredoc. Merge these
+        parser = ws & '@(' & ws & '"' & ws & many(complement('"')) & '"'
+        parser &= optional(ws & ':' & ws & it.syntax)
+        switches = name('switches', many(s('n') | 'r' | 't' | 's' | '$' | 'u' | 'L'))
+        parser &= optional(ws & '/' & switches)
+        parser &= ws & ')'
+
+        # TODO the contents shares muttch with concat, merge these
+        for fragment in it.fragments:
+            match fragment:
+                case PuppetVar(x):
+                    f = (many(space)
+                         & '$'
+                         & optional(s('{'))
+                         & ws
+                         & optional(s('$'))
+                         & tag('var', x)
+                         & ws
+                         & optional(s('}')))
+                    parser &= f
+                case PuppetString(st):
+                    word = many(space)
+                    for line in st.split('\n'):
+                        word &= many(space)
+                        for c in line.strip():
+                            word &= rich_char(c)
+                        word &= optional(s('\n'))
+                    parser &= name(repr(st), word)
+                    # parser &= many(space) & st.strip()
+                case _:
+                    parser &= many(space) & "${" & ws & self.s(fragment) & ws & "}"
+        parser &= many(space) & optional(s('|')) & ws & optional(s('-'))
+        parser &= ws & many1(all_(~ space, char))
+        return parser
 
     @override
     def _puppet_literal_heredoc(self, it: PuppetLiteralHeredoc) -> ParseDirective:
@@ -542,17 +541,18 @@ class ParserFormatter(Serializer[ParseDirective]):
             syntax = nop
 
         unquoted_heredoc = many(complement('):/')) & syntax & switches
-        quoted_heredoc = s('"') & many(complement('"')) & '"' & syntax & switches & ws & ')'
+        quoted_heredoc = s('"') & many(complement('"')) & '"' & syntax & switches
 
-        heredoc_declarator = ws & '@(' & ws & (quoted_heredoc | unquoted_heredoc)
+        heredoc_declarator = ws & '@(' & ws & (quoted_heredoc | unquoted_heredoc) & ws & ')'
 
         # delim = stringify_match(delim_parts)
 
         parser = heredoc_declarator
 
         for line in it.content.split('\n'):
-            parser &= ws & line.lstrip() & '\n'
-        parser &= ws & '|' & optional(s('-')) & ws & many(all_(not_(ws), char))
+            parser &= many(space) & line.strip()
+        parser &= many(space) & optional(s('|')) & ws & optional(s('-'))
+        parser &= ws & many1(all_(~ space, char))
 
         return parser
 
@@ -674,7 +674,10 @@ class ParserFormatter(Serializer[ParseDirective]):
     def _puppet_var(self, it: PuppetVar) -> ParseDirective:
         # TODO highlight entire decalaration
         # TODO hyperlink?
-        return name(f'${it.name}', ws & '$' & tag('var', it.name))
+
+        # The leading '$' is optional, since it's optional for
+        # variables in string interpolations, e.g. "${x}".
+        return name(f'${it.name}', ws & optional(s('$')) & tag('var', it.name))
 
     @override
     def _puppet_virtual_query(self, it: PuppetVirtualQuery) -> ParseDirective:
diff --git a/tests/test_parse_elsif.py b/tests/test_parse_elsif.py
index dcd629c..db92640 100644
--- a/tests/test_parse_elsif.py
+++ b/tests/test_parse_elsif.py
@@ -31,20 +31,30 @@ import pytest
 from muppet.puppet.format.parser import ParserFormatter
 from muppet.puppet.ast import (
     build_ast,
-    PuppetAST,
     HashEntry,
+    PuppetAST,
+    PuppetAccess,
+    PuppetBinaryOperator,
+    PuppetCallMethod,
+    PuppetConcat,
     PuppetHash,
+    PuppetInvoke,
+    PuppetLiteralHeredoc,
     PuppetNumber,
     PuppetQn,
     PuppetString,
-    PuppetInvoke,
+    PuppetVar,
 )
 from muppet.puppet.parser import puppet_parser
-from muppet.parser_combinator import ParserCombinator
+from muppet.parser_combinator import ParserCombinator, MatchCompound
 from pprint import pprint
 from typing import Any, Optional
 
 
+def ws(x):
+    return MatchCompound(type='ws', matched=x)
+
+
 def parse_string(s: str, *,
                  ast: Optional[PuppetAST] = None,
                  matched: Optional[list[Any]] = None):
@@ -235,6 +245,83 @@ def test_string_interpolation_deep_access():
     """)
 
 
+def test_string_interpolation_call_method():
+    # var, but looks like qn
+    s = """
+    "${x.y}"
+    """
+    parse_string(
+        s,
+        ast=PuppetConcat([
+            PuppetCallMethod(
+                func=PuppetBinaryOperator(
+                    op='.',
+                    lhs=PuppetVar(name='x'),
+                    rhs=PuppetQn(name='y')),
+                args=[],
+                block=None)]))
+
+
+def test_string_interpolation_deep_call_method():
+    # var, but looks like qn
+    s = """
+    "${x.y.z}"
+    """
+    parse_string(
+        s,
+        ast=PuppetConcat([
+            PuppetCallMethod(
+                func=PuppetBinaryOperator(
+                    op='.',
+                    lhs=PuppetCallMethod(
+                        func=PuppetBinaryOperator(
+                            op='.',
+                            lhs=PuppetVar(name='x'),
+                            rhs=PuppetQn(name='y')),
+                        args=[],
+                        block=None),
+                    rhs=PuppetQn(name='z')),
+                args=[],
+                block=None)]))
+
+
+def test_string_interpolation_access_call():
+    s = """
+    "${x.y('Hello')[1]}"
+    """
+
+    parse_string(
+        s,
+        ast=PuppetConcat([
+            PuppetAccess(
+                how=PuppetCallMethod(
+                    func=PuppetBinaryOperator(
+                        op='.',
+                        lhs=PuppetVar('x'),
+                        rhs=PuppetQn('y')),
+                    args=[PuppetString('Hello')],
+                    block=None),
+                args=[PuppetNumber(1)])]))
+
+
+def test_string_interpolation_call_access():
+    s = """
+    "${x[1].y('Hello')}"
+    """
+
+    parse_string(
+        s,
+        ast=PuppetConcat([
+            PuppetCallMethod(
+                func=PuppetBinaryOperator(
+                    op='.',
+                    lhs=PuppetAccess(how=PuppetVar('x'),
+                                     args=[PuppetNumber(1)]),
+                    rhs=PuppetQn('y')),
+                args=[PuppetString('Hello')],
+                block=None)]))
+
+
 def test_collect():
     parse_string("""
     Exec <| title=='apt_update' |> {
@@ -285,3 +372,42 @@ def test_bare_hash():
                                            PuppetNumber(1)),
                                  HashEntry(PuppetQn('y'),
                                            PuppetNumber(2))])]))
+
+
+def test_literal_heredoc():
+    s = """
+    @(EOF)
+    load_module Heredoc
+    | EOF
+    """
+    parse_string(
+        s,
+        ast=PuppetLiteralHeredoc(content='load_module Heredoc\n')
+    )
+
+
+def test_heredoc():
+    s = """
+    @("EOF")
+    load_module /usr/lib/nginx/modules/${modname}.so;
+    | EOF
+    """
+
+    parse_string(s)
+
+
+def test_heredoc_2():
+    """
+    Tests.
+
+    - a variable alone on a line
+    - A line starting with '#' (previously ignored as a comment ...)
+    """
+    s = """
+    @("EOF")
+       # File managed by puppet
+       ${lines}
+       | EOF
+    """
+
+    parse_string(s)
-- 
GitLab