From 210086c7caf95b064f0190c79bc636947e12282f Mon Sep 17 00:00:00 2001 From: Kodi Arfer Date: Sun, 22 Apr 2018 12:48:37 -0700 Subject: [PATCH] Clean up the decorators used in the compiler --- hy/compiler.py | 189 +++++++++++-------------------- hy/completer.py | 2 +- hy/contrib/walk.hy | 4 +- hy/extra/reserved.hy | 4 +- tests/compilers/test_compiler.py | 2 +- 5 files changed, 73 insertions(+), 128 deletions(-) diff --git a/hy/compiler.py b/hy/compiler.py index 389ec91..70127df 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -19,7 +19,6 @@ from hy._compat import ( raise_empty) from hy.macros import require, macroexpand, tag_macroexpand import hy.importer -import hy.inspect import traceback import importlib @@ -70,7 +69,8 @@ def ast_str(x, piecewise=False): return x if PY3 else x.encode('UTF8') -_compile_table = {} +_special_form_compilers = {} +_model_compilers = {} _decoratables = (ast.FunctionDef, ast.ClassDef) if PY35: _decoratables += (ast.AsyncFunctionDef,) @@ -81,23 +81,9 @@ _bad_roots = tuple(ast_str(x) for x in ( "unquote", "unquote-splice", "unpack-mapping", "except")) -def builds(*types, **kwargs): - # A decorator that adds the decorated method to _compile_table for - # compiling `types`, but only if kwargs['iff'] (if provided) is - # true. - if not kwargs.get('iff', True): - return lambda fn: fn - - def _dec(fn): - for t in types: - if isinstance(t, string_types): - t = ast_str(t) - _compile_table[t] = fn - return fn - return _dec - - def special(names, pattern): + """Declare special operators. The decorated method and the given pattern + is assigned to _special_form_compilers for each of the listed names.""" pattern = whole(pattern) def dec(fn): for name in names if isinstance(names, list) else [names]: @@ -105,11 +91,20 @@ def special(names, pattern): condition, name = name if not condition: continue - _compile_table[ast_str(name)] = (fn, pattern) + _special_form_compilers[ast_str(name)] = (fn, pattern) return fn return dec +def builds_model(*model_types): + "Assign the decorated method to _model_compilers for the given types." + def _dec(fn): + for t in model_types: + _model_compilers[t] = fn + return fn + return _dec + + def spoof_positions(obj): if not isinstance(obj, HyObject): return @@ -338,44 +333,6 @@ def _nargs(n): return "%d argument%s" % (n, ("" if n == 1 else "s")) -def checkargs(exact=None, min=None, max=None, even=None, multiple=None): - def _dec(fn): - def checker(self, expression): - if exact is not None and (len(expression) - 1) != exact: - _raise_wrong_args_number( - expression, "`%%s' needs %s, got %%d" % _nargs(exact)) - if min is not None and (len(expression) - 1) < min: - _raise_wrong_args_number( - expression, - "`%%s' needs at least %s, got %%d." % _nargs(min)) - - if max is not None and (len(expression) - 1) > max: - _raise_wrong_args_number( - expression, - "`%%s' needs at most %s, got %%d" % _nargs(max)) - - is_even = not((len(expression) - 1) % 2) - if even is not None and is_even != even: - even_str = "even" if even else "odd" - _raise_wrong_args_number( - expression, - "`%%s' needs an %s number of arguments, got %%d" - % (even_str)) - - if multiple is not None: - if not (len(expression) - 1) in multiple: - choices = ", ".join([str(val) for val in multiple[:-1]]) - choices += " or %s" % multiple[-1] - _raise_wrong_args_number( - expression, - "`%%s' needs %s arguments, got %%d" % choices) - - return fn(self, expression) - - return checker - return _dec - - def is_unpack(kind, x): return (isinstance(x, HyExpression) and len(x) > 0 @@ -436,47 +393,21 @@ class HyASTCompiler(object): self.imports = defaultdict(set) return ret.stmts - def compile_atom(self, atom_type, atom): - if isinstance(atom_type, string_types): - atom_type = ast_str(atom_type) - if atom_type in _bad_roots: - raise HyTypeError(atom, "The special form '{}' " - "is not allowed here".format(atom_type)) - if atom_type in _compile_table: - # _compile_table[atom_type] is a method for compiling this - # type of atom, so call it. If it has an extra parameter, - # pass in `atom_type`. - atom_compiler = _compile_table[atom_type] - if isinstance(atom_compiler, tuple): - # This build method has a pattern. - build_method, pattern = atom_compiler - try: - parse_tree = pattern.parse(atom[1:]) - except NoParseError as e: - raise HyTypeError(atom, - "parse error for special form '{}': {}'".format( - atom[0], str(e))) - ret = build_method(self, atom, atom[0], *parse_tree) - else: - arity = hy.inspect.get_arity(atom_compiler) - # Compliation methods may mutate the atom, so copy it first. - atom = copy.copy(atom) - ret = (atom_compiler(self, atom, atom_type) - if arity == 3 - else atom_compiler(self, atom)) - if not isinstance(ret, Result): - ret = Result() + ret - return ret + def compile_atom(self, atom): if not isinstance(atom, HyObject): atom = wrap_value(atom) - if isinstance(atom, HyObject): - spoof_positions(atom) - return self.compile_atom(type(atom), atom) + if not isinstance(atom, HyObject): + return + spoof_positions(atom) + if type(atom) not in _model_compilers: + return + # Compilation methods may mutate the atom, so copy it first. + atom = copy.copy(atom) + return Result() + _model_compilers[type(atom)](self, atom) def compile(self, tree): try: - _type = type(tree) - ret = self.compile_atom(_type, tree) + ret = self.compile_atom(tree) if ret: self.update_imports(ret) return ret @@ -1633,13 +1564,7 @@ class HyASTCompiler(object): if ast_str(root) == "eval_and_compile" else Result()) - @checkargs(1) - def _compile_keyword_call(self, expression): - expression.append(expression.pop(0)) - expression.insert(0, HySymbol("get")) - return self.compile(expression) - - @builds(HyExpression) + @builds_model(HyExpression) def compile_expression(self, expression): # Perform macro expansions expression = macroexpand(expression, self) @@ -1648,23 +1573,45 @@ class HyASTCompiler(object): return self.compile(expression) if expression == []: - return self.compile_list(expression, HyList) + return self.compile_atom(HyList().replace(expression)) fn = expression[0] func = None if isinstance(fn, HyKeyword): - return self._compile_keyword_call(expression) + if len(expression) > 2: + raise HyTypeError( + expression, "keyword calls take only 1 argument") + expression.append(expression.pop(0)) + expression.insert(0, HySymbol("get")) + return self.compile(expression) if isinstance(fn, HySymbol): - # First check if `fn` is a special form, unless it has an + + # First check if `fn` is a special operator, unless it has an # `unpack-iterable` in it, since Python's operators (`+`, # etc.) can't unpack. An exception to this exception is that # tuple literals (`,`) can unpack. - if fn == "," or not ( - any(is_unpack("iterable", x) for x in expression[1:])): - ret = self.compile_atom(fn, expression) - if ret: - return ret + sfn = ast_str(fn) + if (sfn in _special_form_compilers or sfn in _bad_roots) and ( + sfn == mangle(",") or + not any(is_unpack("iterable", x) for x in expression[1:])): + if sfn in _bad_roots: + raise HyTypeError( + expression, + "The special form '{}' is not allowed here".format(fn)) + # `sfn` is a special operator. Get the build method and + # pattern-match the arguments. + build_method, pattern = _special_form_compilers[sfn] + try: + parse_tree = pattern.parse(expression[1:]) + except NoParseError as e: + raise HyTypeError( + expression[min(e.state.pos + 1, len(expression) - 1)], + "parse error for special form '{}': {}".format( + expression[0], + e.msg.replace("", "end of form"))) + return Result() + build_method( + self, expression, expression[0], *parse_tree) if fn.startswith("."): # (.split "test test") -> "test test".split() @@ -1718,11 +1665,11 @@ class HyASTCompiler(object): expression, func=func.expr, args=args, keywords=keywords, starargs=oldpy_star, kwargs=oldpy_kw) - @builds(HyInteger, HyFloat, HyComplex) - def compile_numeric_literal(self, x, building): + @builds_model(HyInteger, HyFloat, HyComplex) + def compile_numeric_literal(self, x): f = {HyInteger: long_type, HyFloat: float, - HyComplex: complex}[building] + HyComplex: complex}[type(x)] # Work around https://github.com/berkerpeksag/astor/issues/85 : # astor can't generate Num nodes with NaN, so we have # to build an expression that evaluates to NaN. @@ -1741,7 +1688,7 @@ class HyASTCompiler(object): return nan() return nn(f(x)) - @builds(HySymbol) + @builds_model(HySymbol) def compile_symbol(self, symbol): if "." in symbol: glob, local = symbol.rsplit(".", 1) @@ -1772,7 +1719,7 @@ class HyASTCompiler(object): return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load()) - @builds(HyKeyword) + @builds_model(HyKeyword) def compile_keyword(self, obj): ret = Result() ret += asty.Call( @@ -1783,19 +1730,19 @@ class HyASTCompiler(object): ret.add_imports("hy", {"HyKeyword"}) return ret - @builds(HyString, HyBytes) - def compile_string(self, string, building): - node = asty.Bytes if PY3 and building is HyBytes else asty.Str - f = bytes_type if building is HyBytes else str_type + @builds_model(HyString, HyBytes) + def compile_string(self, string): + node = asty.Bytes if PY3 and type(string) is HyBytes else asty.Str + f = bytes_type if type(string) is HyBytes else str_type return node(string, s=f(string)) - @builds(HyList, HySet) - def compile_list(self, expression, building): + @builds_model(HyList, HySet) + def compile_list(self, expression): elts, ret, _ = self._compile_collect(expression) - node = {HyList: asty.List, HySet: asty.Set}[building] + node = {HyList: asty.List, HySet: asty.Set}[type(expression)] return ret + node(expression, elts=elts, ctx=ast.Load()) - @builds(HyDict) + @builds_model(HyDict) def compile_dict(self, m): keyvalues, ret, _ = self._compile_collect(m, dict_display=True) return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2]) diff --git a/hy/completer.py b/hy/completer.py index 9cfc668..7748c3d 100644 --- a/hy/completer.py +++ b/hy/completer.py @@ -37,7 +37,7 @@ class Completer(object): if not isinstance(namespace, dict): raise TypeError('namespace must be a dictionary') self.namespace = namespace - self.path = [hy.compiler._compile_table, + self.path = [hy.compiler._special_form_compilers, builtins.__dict__, hy.macros._hy_macros[None], namespace] diff --git a/hy/contrib/walk.hy b/hy/contrib/walk.hy index b50ace8..8e9b3c3 100644 --- a/hy/contrib/walk.hy +++ b/hy/contrib/walk.hy @@ -72,9 +72,7 @@ ;; TODO: move to hy.extra.reserved? (import hy) -(setv special-forms (list-comp k - [k (.keys hy.compiler._compile-table)] - (isinstance k hy._compat.string-types))) +(setv special-forms (list (.keys hy.compiler._special-form-compilers))) (defn lambda-list [form] diff --git a/hy/extra/reserved.hy b/hy/extra/reserved.hy index d7ae23c..90ffee5 100644 --- a/hy/extra/reserved.hy +++ b/hy/extra/reserved.hy @@ -18,6 +18,6 @@ hy.core.shadow.EXPORTS (list (.keys (get hy.macros._hy_macros None))) keyword.kwlist - (list-comp k [k (.keys hy.compiler.-compile-table)] - (isinstance k hy._compat.string-types)))))))) + (list (.keys hy.compiler._special_form_compilers)) + (list hy.compiler._bad_roots))))))) _cache) diff --git a/tests/compilers/test_compiler.py b/tests/compilers/test_compiler.py index 64082a7..cbce371 100644 --- a/tests/compilers/test_compiler.py +++ b/tests/compilers/test_compiler.py @@ -55,7 +55,7 @@ def test_compiler_yield_return(): HyExpression([HySymbol("+"), HyInteger(1), HyInteger(1)])) - ret = compiler.HyASTCompiler('test').compile_atom("fn", e) + ret = compiler.HyASTCompiler('test').compile_atom(e) assert len(ret.stmts) == 1 stmt, = ret.stmts