From b8ccdca63393c1b83a3335a5f9fbefc40db282e1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Hugo=20H=C3=B6rnquist?= <hugo@lysator.liu.se>
Date: Sun, 14 May 2023 21:06:25 +0200
Subject: [PATCH] Fix types in main.py.

---
 Makefile  |   2 +-
 main.py   | 251 ++++++++++++++++++++++++++++--------------------------
 mypy.ini  |  10 +++
 reflow.py |   4 +-
 4 files changed, 144 insertions(+), 123 deletions(-)
 create mode 100644 mypy.ini

diff --git a/Makefile b/Makefile
index e867d37..317a4b4 100644
--- a/Makefile
+++ b/Makefile
@@ -12,7 +12,7 @@ $(CACHE_DIR)/output.json:
 $(CACHE_DIR)/%.json:
 	cd $< && puppet strings generate --format json --out $(CURDIR)/$@
 
-index.html: $(CACHE_DIR)/output.json
+index.html: $(CACHE_DIR)/output.json *.py
 	python3 main.py $< > $@
 
 check:
diff --git a/main.py b/main.py
index 8b24465..2ac5c79 100644
--- a/main.py
+++ b/main.py
@@ -13,7 +13,17 @@ from subprocess import CalledProcessError
 import html
 import json
 import sys
-
+from typing import (
+    Union,
+    Literal,
+    Any,
+    TypeAlias,
+    Tuple,
+)
+
+HashEntry: TypeAlias  = Union[Tuple[Literal['=>'], str, Any],
+                              Tuple[Literal['+>'], str, Any],
+                              Tuple[Literal['splat-hash'], Any]]
 
 match sys.argv:
     case [_, d, *_]:
@@ -28,8 +38,9 @@ data = info
 
 param_doc: dict[str, str] = {}
 
-
-def print_hash(hash, indent=0):
+def print_hash(hash: list[HashEntry],
+               indent: int,
+               context: list[str]) -> None:
     """Print the contents of a puppet hash literal."""
     if not hash:
         return
@@ -38,20 +49,20 @@ def print_hash(hash, indent=0):
         match item:
             case ['=>', key, value]:
                 print(' '*indent*2, end='')
-                parse(key, indent=indent)
+                parse(key, indent, context)
                 # print(' =&gt; ', end='')
                 print(' ⇒ ', end='')
-                parse(value, indent=indent)
+                parse(value, indent, context)
             case ['splat-hash', value]:
                 print(' '*indent*2, end='')
                 print('* =&gt; ', end='')
-                parse(value, indent=indent)
+                parse(value, indent, context)
             case _:
                 print(f'<span class="parse-error">[|[{item}]|]</span>')
         print(',')
 
 
-def ops_namelen(ops) -> int:
+def ops_namelen(ops: list[HashEntry]) -> int:
     """Calculate max key length a list of puppet operators."""
     namelen = 0
     for item in ops:
@@ -67,7 +78,7 @@ def ops_namelen(ops) -> int:
     return namelen
 
 
-def print_array(arr, indent=0):
+def print_array(arr: list[Any], indent: int, context: list[str]) -> None:
     """Print a puppet array literal."""
     if not arr:
         print('[]', end='')
@@ -75,12 +86,12 @@ def print_array(arr, indent=0):
     print('[')
     for item in arr:
         print(' '*(indent+1)*2, end='')
-        parse(item)
+        parse(item, indent, context)
         print(',')
     print(' '*indent*2 + ']', end='')
 
 
-def print_var(x, dollar=True):
+def print_var(x: str, dollar: bool = True) -> None:
     """
     Print the given variable.
 
@@ -118,7 +129,7 @@ symbols: dict[str, str] = {
 }
 
 
-def parse(form, indent=0):
+def parse(form: Any, indent: int, context: list[str]) -> None:
     """
     Print everything from a puppet parse tree.
 
@@ -141,43 +152,43 @@ def parse(form, indent=0):
 
         case ['access', how, *args]:
             print('<span class="compound-type">', end='')
-            parse(how, indent=indent)
+            parse(how, indent, context)
             print('[', end='')
             first = True
             for arg in args:
                 if not first:
                     print(', ', end='')
                 # TODO newlines?
-                parse(arg, indent=indent)
+                parse(arg, indent, context)
                 first = False
             print(']', end='')
             print('</span>', end='')
 
         case ['and', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' and ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['array', *items]:
-            print_array(items, indent=indent+1)
+            print_array(items, indent+1, context)
 
         case ['call', {'functor': func,
                        'args': args}]:
             print('<span class="call">', end='')
-            parse(func, indent=indent)
+            parse(func, indent, context)
             print('(', end='')
             first = True
             for arg in args:
                 if not first:
                     print(', ', end='')
                 first = False
-                parse(arg, indent=indent)
+                parse(arg, indent, context)
             print(')', end='')
             print('</span>', end='')
 
         case ['call-method', func]:
             print('<span class="call-method">', end='')
-            parse(func['functor'], indent=indent)
+            parse(func['functor'], indent, context)
 
             first = True
             if not ('block' in func and func['args'] == []):
@@ -186,12 +197,12 @@ def parse(form, indent=0):
                     if not first:
                         print(', ', end='')
                     first = False
-                    parse(x, indent=indent)
-                # print(', '.join(parse(x, indent=indent) for x in func['args']), end='')
+                    parse(x, indent, context)
+                # print(', '.join(parse(x, indent, context) for x in func['args']), end='')
                 print(')', end='')
 
             if 'block' in func:
-                parse(func['block'], indent=indent+1)
+                parse(func['block'], indent+1, context)
 
             print('</span>', end='')
 
@@ -199,7 +210,7 @@ def parse(form, indent=0):
 
         case ['case', test, forms]:
             print('case ', end='')
-            parse(test, indent=indent)
+            parse(test, indent, context)
             print(' {')
             for form in forms:
                 when = form['when']
@@ -212,11 +223,11 @@ def parse(form, indent=0):
                     if not first:
                         print(', ', end='')
                     first = False
-                    parse(item, indent=indent+1)
+                    parse(item, indent+1, context)
                 print(': {')
                 for item in then:
                     print(' '*(indent+2)*2, end='')
-                    parse(item, indent=indent+2)
+                    parse(item, indent+2, context)
                     print()
                 print(' '*(indent+1)*2+'},')
             print(' '*indent*2+'}', end='')
@@ -232,12 +243,12 @@ def parse(form, indent=0):
                     print(' '*(indent+1)*2, end='')
                     if 'type' in data:
                         print('<span class="type">', end='')
-                        parse(data['type'], indent=indent)
+                        parse(data['type'], indent, context)
                         print('</span> ', end='')
                     print(f'<span class="var">${name}</span>', end='')
                     if 'value' in data:
                         print(' = ', end='')
-                        parse(data.get('value'), indent=indent)
+                        parse(data.get('value'), indent, context)
                         print(',', end='')
                     print()
                 print(' '*indent*2 + ') {')
@@ -246,7 +257,7 @@ def parse(form, indent=0):
 
             for entry in body:
                 print(' '*(indent+1)*2, end='')
-                parse(entry, indent=indent+1)
+                parse(entry, indent+1, context)
                 print()
 
             print(' '*indent*2+'}')
@@ -258,7 +269,7 @@ def parse(form, indent=0):
                     case ['str', thingy]:
                         print('<span class="str-var">${', end='')
                         # print_var(x, dollar=False)
-                        parse(thingy, indent=indent)
+                        parse(thingy, indent, context)
                         print('}</span>', end='')
                     case s:
                         # print(s, file=sys.stderr)
@@ -271,9 +282,9 @@ def parse(form, indent=0):
 
         case ['collect', {'type': t,
                           'query': q}]:
-            parse(t, indent=indent)
+            parse(t, indent, context)
             print(' ', end='')
-            parse(q, indent=indent)
+            parse(q, indent, context)
 
         case ['default']:
             print('<span class="default">default</span>', end='')
@@ -290,13 +301,13 @@ def parse(form, indent=0):
                     print(' '*(indent+1)*2, end='')
                     if 'type' in data:
                         print('<span class="type">', end='')
-                        parse(data['type'], indent=indent)
+                        parse(data['type'], indent, context)
                         print('</span> ', end='')
                     # print(f'<span class="var">${name}</span>', end='')
                     print_var(name)
                     if 'value' in data:
                         print(' = ', end='')
-                        parse(data.get('value'), indent=indent)
+                        parse(data.get('value'), indent, context)
                         print(',', end='')
                     print()
 
@@ -305,7 +316,7 @@ def parse(form, indent=0):
 
             for entry in body:
                 print(' '*(indent+1)*2, end='')
-                parse(entry, indent=indent+1)
+                parse(entry, indent+1, context)
                 print()
 
             print(' '*indent*2 + '}', end='')
@@ -315,7 +326,7 @@ def parse(form, indent=0):
 
         case ['exported-query', arg]:
             print('<<| ', end='')
-            parse(arg, indent=indent)
+            parse(arg, indent, context)
             print(' |>>', end='')
 
         case ['function', {'name': name,
@@ -326,21 +337,21 @@ def parse(form, indent=0):
                 for name, attributes in rest['params'].items():
                     print(' '*(indent+1)*2, end='')
                     if 'type' in attributes:
-                        parse(attributes['type'])
+                        parse(attributes['type'], indent, context)
                         print(' ', end='')
                     print(f'${name}', end='')
                     if 'value' in attributes:
                         print(' = ', end='')
-                        parse(attributes['value'])
+                        parse(attributes['value'], indent, context)
                     print(',')
             print(')', end='')
             if 'returns' in rest:
                 print(' >> ', end='')
-                parse(rest['returns'])
+                parse(rest['returns'], indent, context)
             print(' {')
             for item in body:
                 print(' '*(indent+1)*2, end='')
-                parse(item, indent=indent+1)
+                parse(item, indent+1, context)
                 print()
             print('}')
 
@@ -349,24 +360,24 @@ def parse(form, indent=0):
                 print('{}', end='')
             else:
                 print('{')
-                print_hash(hash, indent=indent+1)
+                print_hash(hash, indent+1, context)
                 print(' '*indent*2, end='')
                 print('}', end='')
 
         case ['heredoc', {'text': text}]:
             print('@("EOF")')
-            parse(text, indent=indent)
+            parse(text, indent, context)
             print(' '*indent*2 + '|')
 
         case ['if', {'test': test,
                      **rest}]:
             print('if ', end='')
-            parse(test, indent=indent)
+            parse(test, indent, context)
             print(' {')
             if 'then' in rest:
                 for item in rest['then']:
                     print(' '*(indent+1)*2, end='')
-                    parse(item, indent=indent+1)
+                    parse(item, indent+1, context)
                     print()
             print(' '*indent*2 + '} ', end='')
 
@@ -374,27 +385,27 @@ def parse(form, indent=0):
                 match rest['else']:
                     case [['if', *rest]]:
                         print('els', end='')
-                        parse(['if', *rest], indent=indent)
+                        parse(['if', *rest], indent, context)
                     case el:
                         print('else {')
                         for item in el:
                             print(' '*(indent+1)*2, end='')
-                            parse(item, indent=indent+1)
+                            parse(item, indent+1, context)
                             print()
                         print(' '*indent*2+'}', end='')
 
         case ['in', needle, stack]:
-            parse(needle, indent=indent)
+            parse(needle, indent, context)
             print(' in ', end='')
-            parse(stack, indent=indent)
+            parse(stack, indent, context)
 
         case ['invoke', {'functor': func,
                          'args': args}]:
             print('<span class="invoke">', end='')
-            parse(func)
+            parse(func, indent, context)
             print(' ', end='')
             if len(args) == 1:
-                parse(args[0], indent=indent+1)
+                parse(args[0], indent+1, context)
             else:
                 print(args)
                 print('(', end='')
@@ -403,7 +414,7 @@ def parse(form, indent=0):
                     if not first:
                         print(', ', end='')
                     first = False
-                    parse(arg, indent=indent+1)
+                    parse(arg, indent+1, context)
                 # print(' '*indent*2, end='')
                 print(')', end='')
             # print()
@@ -420,24 +431,24 @@ def parse(form, indent=0):
             print(f' |{args}| {{')
             for entry in body:
                 print(' '*indent*2, end='')
-                parse(entry, indent=indent)
+                parse(entry, indent, context)
                 print()
             print(' '*(indent-1)*2 + '}', end='')
             print('</span>', end='')
         case ['and', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' and ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['or', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' or ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['paren', *forms]:
             print('(', end='')
             for form in forms:
-                parse(form, indent=indent+1)
+                parse(form, indent+1, context)
             print(')', end='')
 
         # Qualified name?
@@ -454,9 +465,9 @@ def parse(form, indent=0):
 
         case ['resource', {'type': t,
                            'bodies': [body]}]:
-            parse(t, indent=indent)
+            parse(t, indent, context)
             print(' { ', end='')
-            parse(body['title'])
+            parse(body['title'], indent, context)
             print(':')
             ops = body['ops']
 
@@ -470,7 +481,7 @@ def parse(form, indent=0):
                         print(f'<span class="parameter">{key}</span>', end='')
                         # print(' '*pad + ' => ', end='')
                         print(' '*pad + ' ⇒ ', end='')
-                        parse(value, indent=indent+1)
+                        parse(value, indent+1, context)
                         print(',')
 
                     case ['splat-hash', value]:
@@ -478,7 +489,7 @@ def parse(form, indent=0):
                         print('<span class="parameter splat">*</span>', end='')
                         # print(' '*(namelen - 1) + ' => ', end='')
                         print(' '*(namelen - 1) + ' ⇒ ', end='')
-                        parse(value, indent=indent+1)
+                        parse(value, indent+1, context)
                         print(',')
 
                     case _:
@@ -487,11 +498,11 @@ def parse(form, indent=0):
 
         case ['resource', {'type': t,
                            'bodies': bodies}]:
-            parse(t, indent=indent)
+            parse(t, indent, context)
             print(' {')
             for body in bodies:
                 print(' '*(indent+1)*2, end='')
-                parse(body['title'])
+                parse(body['title'], indent, context)
                 print(':')
                 ops = body['ops']
 
@@ -505,7 +516,7 @@ def parse(form, indent=0):
                             print(f'<span class="parameter">{key}</span>', end='')
                             # print(' '*pad + ' => ', end='')
                             print(' '*pad + ' ⇒ ', end='')
-                            parse(value, indent=indent+2)
+                            parse(value, indent+2, context)
                             print(',')
 
                         case ['splat-hash', value]:
@@ -513,7 +524,7 @@ def parse(form, indent=0):
                             print('<span class="parameter splat">*</span>', end='')
                             # print(' '*(namelen - 1) + ' => ', end='')
                             print(' '*(namelen - 1) + ' ⇒ ', end='')
-                            parse(value, indent=indent+2)
+                            parse(value, indent+2, context)
                             print(',')
 
                         case _:
@@ -524,7 +535,7 @@ def parse(form, indent=0):
 
         case ['resource-defaults', {'type': t,
                                     'ops': ops}]:
-            parse(t, indent=indent)
+            parse(t, indent, context)
             print(' {')
             namelen = ops_namelen(ops)
             for op in ops:
@@ -535,7 +546,7 @@ def parse(form, indent=0):
                         print(f'<span class="parameter">{key}</span>', end='')
                         # print(' '*pad + ' => ', end='')
                         print(' '*pad + ' ⇒ ', end='')
-                        parse(value, indent=indent+3)
+                        parse(value, indent+3, context)
                         print(',')
 
                     case ['splat-hash', value]:
@@ -543,7 +554,7 @@ def parse(form, indent=0):
                         pad = namelen - 1
                         print('<span class="parameter splat">*</span>', end=' '*pad)
                         print(' '*(namelen - 1) + ' ⇒ ', end='')
-                        parse(value, indent=indent+2)
+                        parse(value, indent+2, context)
                         print(',')
 
                     case x:
@@ -552,7 +563,7 @@ def parse(form, indent=0):
 
         case ['resource-override', {'resources': resources,
                                     'ops': ops}]:
-            parse(resources, indent=indent)
+            parse(resources, indent, context)
             print(' {')
             namelen = ops_namelen(ops)
             for op in ops:
@@ -563,7 +574,7 @@ def parse(form, indent=0):
                         print(f'<span class="parameter">{key}</span>', end='')
                         # print(' '*pad + ' => ', end='')
                         print(' '*pad + ' ⇒ ', end='')
-                        parse(value, indent=indent+3)
+                        parse(value, indent+3, context)
                         print(',')
 
                     case ['+>', key, value]:
@@ -572,7 +583,7 @@ def parse(form, indent=0):
                         print(f'<span class="parameter">{key}</span>', end='')
                         # print(' '*pad + ' => ', end='')
                         print(' '*pad + ' +> ', end='')
-                        parse(value, indent=indent+2)
+                        parse(value, indent+2, context)
                         print(',')
 
                     case ['splat-hash', value]:
@@ -580,7 +591,7 @@ def parse(form, indent=0):
                         pad = namelen - 1
                         print('<span class="parameter splat">*</span>', end=' '*pad)
                         print(' '*(namelen - 1) + ' ⇒ ', end='')
-                        parse(value, indent=indent+2)
+                        parse(value, indent+2, context)
                         print(',')
 
                     case _:
@@ -591,11 +602,11 @@ def parse(form, indent=0):
         case ['unless', {'test': test,
                          'then': then}]:
             print('unless ', end='')
-            parse(test, indent=indent)
+            parse(test, indent, context)
             print(' {')
             for item in then:
                 print(' '*(indent+1)*2, end='')
-                parse(item, indent=indent+1)
+                parse(item, indent+1, context)
                 print()
 
             print(' '*indent*2 + '}', end='')
@@ -605,7 +616,7 @@ def parse(form, indent=0):
 
         case ['virtual-query', q]:
             print('<| ', end='')
-            parse(q)
+            parse(q, indent, context)
             print(' |>', end='')
 
         case ['virtual-query']:
@@ -616,118 +627,118 @@ def parse(form, indent=0):
         case ['!', x]:
             # print('! ', end='')
             print('¬ ', end='')
-            parse(x, indent=indent)
+            parse(x, indent, context)
 
         case ['!=', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             # print(' != ', end='')
             print(' ≠ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['+', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' + ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['-', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' - ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['-', a]:
             print('- ', end='')
             parse(a)
 
         case ['*', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' × ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['%', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' % ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['<<', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' << ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['>>', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' >> ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['>=', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' ≥ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['<=', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' ≤ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['>', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' > ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['<', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' < ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['~>', left, right]:
-            parse(left, indent=indent)
+            parse(left, indent, context)
             print(f'\n{" "*indent*2}⤳ ', end='')
             # print(f'\n{" "*indent*2}~&gt; ', end='')
-            parse(right, indent=indent)
+            parse(right, indent, context)
 
         case ['->', left, right]:
-            parse(left, indent=indent)
+            parse(left, indent, context)
             # print(f'\n{" "*indent*2}-&gt; ', end='')
             print(f'\n{" "*indent*2}→ ', end='')
-            parse(right, indent=indent)
+            parse(right, indent, context)
 
         case ['.', left, right]:
-            parse(left, indent=indent)
+            parse(left, indent, context)
             print()
             print(' '*indent*2, end='.')
-            parse(right, indent=indent+1)
+            parse(right, indent+1, context)
 
         case ['/', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' / ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['=', field, value]:
             # print('  ', end='')
-            parse(field, indent=indent)
+            parse(field, indent, context)
             print(' = ', end='')
-            parse(value, indent=indent)
+            parse(value, indent, context)
 
         case ['==', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             # print(' == ', end='')
             print(' ≡ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['=~', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' =~ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['!~', a, b]:
-            parse(a, indent=indent)
+            parse(a, indent, context)
             print(' ≁ ', end='')
-            parse(b, indent=indent)
+            parse(b, indent, context)
 
         case ['?', condition, cases]:
             print('<span class="case">', end='')
-            parse(condition, indent=indent)
+            parse(condition, indent, context)
             print(' ? {')
-            print_hash(cases, indent=indent+1)
+            print_hash(cases, indent+1, context)
             print(' '*indent*2 + '}', end='')
             print('</span>', end='')
 
@@ -741,7 +752,7 @@ def parse(form, indent=0):
                 print(f'<span class="parse-error">[|[{form}]|]</span>', end='')
 
 
-def print_docstring(docstring):
+def print_docstring(docstring: dict[str, Any]) -> None:
     """
     Format docstrings as they appear in some puppet types.
 
@@ -805,7 +816,7 @@ for d_type in data['puppet_classes']:
     print('<pre><code class="puppet">')
     tree = parse_puppet(d_type['source'])
     t = traverse(tree)
-    parse(t)
+    parse(t, 0, [])
     print('</code></pre>')
 
     print('<hr/>')
@@ -823,7 +834,7 @@ for d_type in data['data_type_aliases']:
     print('<pre><code class="puppet">')
     tree = parse_puppet(d_type['alias_of'])
     t = traverse(tree)
-    parse(t)
+    parse(t, 0, [])
     print('</code></pre>')
 
     print('<hr/>')
@@ -838,7 +849,7 @@ for d_type in data['defined_types']:
     print('<pre><code class="puppet">')
     tree = parse_puppet(d_type['source'])
     t = traverse(tree)
-    parse(t)
+    parse(t, 0, [])
     print('</code></pre>')
 
     print('<hr/>')
@@ -887,7 +898,7 @@ for function in data['puppet_functions']:
         try:
             tree = parse_puppet(function['source'])
             t = traverse(tree)
-            parse(t)
+            parse(t, 0, [])
         except CalledProcessError as e:
             print(e)
         print('</code></pre>')
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000..a47647c
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,10 @@
+[mypy]
+# Disabled since `match` breaks it.
+disable_error_code = used-before-def
+
+disallow_untyped_calls = True
+disallow_untyped_defs = True
+disallow_incomplete_defs = True
+check_untyped_defs = True
+warn_return_any = True
+# warn_unrearchable = True
diff --git a/reflow.py b/reflow.py
index 7d01ce0..2d39f23 100755
--- a/reflow.py
+++ b/reflow.py
@@ -28,7 +28,7 @@ def tagged_list_to_dict(lst: list[Any]) -> dict[Any, Any]:
             for i in range(0, len(lst), 2)}
 
 
-def traverse(tree):
+def traverse(tree: Any) -> Any:
     """
     Reflow a puppet parse output tree.
 
@@ -61,7 +61,7 @@ def traverse(tree):
         return tree
 
 
-def __main():
+def __main() -> None:
     import json
     import sys
 
-- 
GitLab