From 3d8a3d589c9e25d6c4dbd66b579eed2230b56015 Mon Sep 17 00:00:00 2001 From: Nicolas Dandrimont Date: Sat, 4 May 2013 09:16:01 +0200 Subject: [PATCH] Refactor compiler using a result carrying object This object allows to coerce statements to an expression, if we need to use them that way, which, with a lisp, is often. This was collaborative work that has been rebased to make it bisectable. Helped-by: Paul Tagliamonte Helped-by: Julien Danjou --- hy/compiler.py | 1161 ++++++++++++++++++++++------------- tests/compilers/test_ast.py | 3 +- 2 files changed, 751 insertions(+), 413 deletions(-) diff --git a/hy/compiler.py b/hy/compiler.py index 47e3eb8..c45a972 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -2,6 +2,7 @@ # # Copyright (c) 2013 Paul Tagliamonte # Copyright (c) 2013 Julien Danjou +# Copyright (c) 2013 Nicolas Dandrimont # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), @@ -34,14 +35,15 @@ from hy.models.float import HyFloat from hy.models.list import HyList from hy.models.dict import HyDict -from hy.util import flatten_literal_list, str_type, temporary_attribute_value +from hy.util import str_type -from collections import defaultdict -import traceback import codecs +import traceback import ast import sys +from collections import defaultdict + class HyCompileError(HyError): def __init__(self, exception, traceback=None): @@ -70,6 +72,7 @@ class HyTypeError(TypeError): % (self.expression.start_line, self.expression.start_column)) + _compile_table = {} @@ -94,6 +97,202 @@ def builds(_type): return _dec +class Result(object): + """ + Smart representation of the result of a hy->AST compilation + + This object tries to reconcile the hy world, where everything can be used + as an expression, with the Python world, where statements and expressions + need to coexist. + + To do so, we represent a compiler result as a list of statements `stmts`, + terminated by an expression context `expr`. The expression context is used + when the compiler needs to use the result as an expression. + + Results are chained by addition: adding two results together returns a + Result representing the succession of the two Results' statements, with + the second Result's expression context. + + We make sure that a non-empty expression context does not get clobbered by + adding more results, by checking accesses to the expression context. We + assume that the context has been used, or deliberately ignored, if it has + been accessed. + + The Result object is interoperable with python AST objects: when an AST + object gets added to a Result object, it gets converted on-the-fly. + """ + __slots__ = ("imports", "stmts", "temp_variables", "_expr", "__used_expr") + + def __init__(self, *args, **kwargs): + if args: + # emulate kw-only args for future bits. + raise TypeError("Yo: Hacker: don't pass me real args, dingus") + + self.imports = defaultdict(set) + self.stmts = [] + self.temp_variables = [] + self._expr = None + + self.__used_expr = False + + # XXX: Make sure we only have AST where we should. + for kwarg in kwargs: + if kwarg not in ["imports", "stmts", "expr", "temp_variables"]: + raise TypeError( + "%s() got an unexpected keyword argument '%s'" % ( + self.__class__.__name__, kwarg)) + setattr(self, kwarg, kwargs[kwarg]) + + @property + def expr(self): + self.__used_expr = True + return self._expr + + @expr.setter + def expr(self, value): + self.__used_expr = False + self._expr = value + + def add_imports(self, mod, imports): + """Autoimport `imports` from `mod`""" + self.imports[mod].update(imports) + + def is_expr(self): + """Check whether I am a pure expression""" + return self._expr and not (self.imports or self.stmts) + + @property + def force_expr(self): + """Force the expression context of the Result. + + If there is no expression context, we return a "None" expression. + """ + if not self.expr: + # Spoof the position of the last statement for our generated None + lineno = 0 + col_offset = 0 + if self.stmts: + lineno = self.stmts[-1].lineno + col_offset = self.stmts[-1].col_offset + + return ast.Name(id=ast_str("None"), + arg=ast_str("None"), + ctx=ast.Load(), + lineno=lineno, + col_offset=col_offset) + # XXX: Likely raise Exception here - this will assertionfail + # pypy since the ast will be out of numerical order. + else: + return self.expr + + def expr_as_stmt(self): + """Convert the Result's expression context to a statement + + This is useful when we want to use the stored expression in a + statement context (for instance in a code branch). + + We drop bare names because they can't have any side effect, and they + make the generated code ugly. + + If there is no expression context, return an empty result. + """ + if self.expr and not isinstance(self.expr, ast.Name): + return Result() + ast.Expr(lineno=self.expr.lineno, + col_offset=self.expr.col_offset, + value=self.expr) + else: + return Result() + + def rename(self, new_name): + """Rename the Result's temporary variables to a `new_name`. + + We know how to handle ast.Names and ast.FunctionDefs. + """ + new_name = ast_str(new_name) + for var in self.temp_variables: + if isinstance(var, ast.Name): + var.id = new_name + var.arg = new_name + elif isinstance(var, ast.FunctionDef): + var.name = new_name + else: + raise TypeError("Don't know how to rename a %s!" % ( + var.__class__.__name__)) + self.temp_variables = [] + + def __add__(self, other): + # If we add an ast statement, convert it first + if isinstance(other, ast.stmt): + return self + Result(stmts=[other]) + + # If we add an ast expression, clobber the expression context + if isinstance(other, ast.expr): + return self + Result(expr=other) + + if isinstance(other, ast.excepthandler): + return self + Result(stmts=[other]) + + if not isinstance(other, Result): + raise TypeError("Can't add %r with non-compiler result %r" % ( + self, other)) + + # Check for expression context clobbering + if self.expr and not self.__used_expr: + traceback.print_stack() + print("Bad boy clobbered expr %s with %s" % ( + ast.dump(self.expr), + ast.dump(other.expr))) + + # Fairly obvious addition + result = Result() + result.imports = other.imports + result.stmts = self.stmts + other.stmts + result.expr = other.expr + result.temp_variables = other.temp_variables + return result + + def __str__(self): + return "Result(imports=[%s], stmts=[%s], expr=%s)" % ( + ", ".join(ast.dump(x) for x in self.imports), + ", ".join(ast.dump(x) for x in self.stmts), + ast.dump(self.expr) if self.expr else None, + ) + + +def _collect(results): + """Collect the expression contexts from a list of results + + This returns a list of the expression contexts, and the sum of the Result + objects passed as arguments. + """ + compiled_exprs = [] + ret = Result() + for result in results: + ret += result + compiled_exprs.append(ret.force_expr) + return compiled_exprs, ret + + +def _branch(results): + """Make a branch out of a list of Result objects + + This generates a Result from the given sequence of Results, forcing each + expression context as a statement before the next result is used. + + We keep the expression context of the last argument for the returned Result + """ + results = list(results) + ret = Result() + for result in results[:-1]: + ret += result + ret += result.expr_as_stmt() + + for result in results[-1:]: + ret += result + + return ret + + def _raise_wrong_args_number(expression, error): raise HyTypeError(expression, error % (expression.pop(0), @@ -128,16 +327,54 @@ class HyASTCompiler(object): def __init__(self): self.returnable = False self.anon_fn_count = 0 - self.imports = defaultdict(list) + self.anon_var_count = 0 + self.imports = defaultdict(set) - def is_returnable(self, v): - return temporary_attribute_value(self, "returnable", v) + def get_anon_var(self): + self.anon_var_count += 1 + return "_hy_anon_var_%s" % self.anon_var_count + + def get_anon_fn(self): + self.anon_fn_count += 1 + return "_hy_anon_fn_%d" % self.anon_fn_count + + def update_imports(self, result): + """Retrieve the imports from the result object""" + for mod in result.imports: + self.imports[mod].update(result.imports[mod]) + + def imports_as_stmts(self, expr): + """Convert the Result's imports to statements""" + ret = Result() + for module, names in self.imports.items(): + ret += self.compile([ + HyExpression([ + HySymbol("import"), + HyList([ + HySymbol(module), + HyList([HySymbol(name) for name in sorted(names)]) + ]) + ]).replace(expr) + ]) + self.imports = defaultdict(set) + return ret.stmts + + def compile_atom(self, atom_type, atom): + if atom_type in _compile_table: + ret = _compile_table[atom_type](self, atom) + if not isinstance(ret, Result): + ret = Result() + ret + return ret + else: + return None def compile(self, tree): try: _type = type(tree) - if _type in _compile_table: - return _compile_table[_type](self, tree) + ret = self.compile_atom(_type, tree) + if ret: + self.update_imports(ret) + return ret except HyCompileError: # compile calls compile, so we're going to have multiple raise # nested; so let's re-raise this exception, let's not wrap it in @@ -146,50 +383,17 @@ class HyASTCompiler(object): except Exception as e: raise HyCompileError(e, sys.exc_info()[2]) - raise HyCompileError( - Exception("Unknown type: `%s'" % (str(type(tree))))) + raise HyCompileError(Exception("Unknown type: `%s'" % _type)) - def _mangle_branch(self, tree, start_line, start_column): - tree = list(flatten_literal_list(tree)) - tree = list(filter(bool, tree)) # Remove empty statements + def _compile_collect(self, exprs): + return _collect(self.compile(expr) for expr in exprs) - # If tree is empty, just return a pass statement - if tree == []: - return [ast.Pass(lineno=start_line, col_offset=start_column)] - - tree.reverse() - - ret = [] - - if self.returnable: - el = tree[0] - if not isinstance(el, ast.stmt): - el = tree.pop(0) - ret.append(ast.Return(value=el, - lineno=el.lineno, - col_offset=el.col_offset)) - if isinstance(el, ast.FunctionDef): - ret.append(ast.Return( - value=ast.Name( - arg=el.name, id=el.name, ctx=ast.Load(), - lineno=el.lineno, col_offset=el.col_offset), - lineno=el.lineno, col_offset=el.col_offset)) - - for el in tree: - if isinstance(el, ast.stmt): - ret.append(el) - continue - - ret.append(ast.Expr( - value=el, - lineno=el.lineno, - col_offset=el.col_offset)) - - ret.reverse() - return ret + def _compile_branch(self, exprs): + return _branch(self.compile(expr) for expr in exprs) def _parse_lambda_list(self, exprs): """ Return FunctionDef parameter values from lambda list.""" + ret = Result() args = [] defaults = [] varargs = None @@ -236,7 +440,8 @@ class HyASTCompiler(object): # defining keyword arguments. for k, v in expr.items(): args.append(k) - defaults.append(self.compile(v)) + ret += self.compile(v) + defaults.append(ret.force_expr) elif lambda_keyword == "&optional": # not implemented yet. pass @@ -246,58 +451,106 @@ class HyASTCompiler(object): "&kwargs argument") kwargs = str(expr) - return args, defaults, varargs, kwargs + return ret, args, defaults, varargs, kwargs + + def _storeize(self, name): + """Return a new `name` object with an ast.Store() context""" + if isinstance(name, Result): + if not name.is_expr(): + raise TypeError("Can't assign to a non-expr") + name = name.expr + + if isinstance(name, (ast.Tuple, ast.List)): + typ = type(name) + new_elts = [] + for x in name.elts: + new_elts.append(self._storeize(x)) + new_name = typ(elts=new_elts) + elif isinstance(name, ast.Name): + new_name = ast.Name(id=name.id, arg=name.arg) + elif isinstance(name, ast.Subscript): + new_name = ast.Subscript(value=name.value, slice=name.slice) + elif isinstance(name, ast.Attribute): + new_name = ast.Attribute(value=name.value, attr=name.attr) + else: + raise TypeError("Can't assign to a %s object" % type(name)) + + new_name.ctx = ast.Store() + ast.copy_location(new_name, name) + return new_name @builds(list) def compile_raw_list(self, entries): - return [self.compile(x) for x in entries] + ret = self._compile_branch(entries) + ret += ret.expr_as_stmt() + return ret def _render_quoted_form(self, form): name = form.__class__.__name__ - self.imports["hy"].append((name, form)) + imports = [name] if isinstance(form, HyList): - return HyExpression( + contents = [] + for x in form: + form_imports, form_contents = self._render_quoted_form(x) + imports += form_imports + contents.append(form_contents) + return imports, HyExpression( [HySymbol(name), - HyList([self._render_quoted_form(x) for x in form])] + HyList(contents)] ).replace(form) elif isinstance(form, HySymbol): - return HyExpression([HySymbol(name), HyString(form)]).replace(form) - return HyExpression([HySymbol(name), form]).replace(form) + return imports, HyExpression([HySymbol(name), + HyString(form)]).replace(form) + return imports, HyExpression([HySymbol(name), form]).replace(form) @builds("quote") @checkargs(exact=1) def compile_quote(self, entries): - return self.compile(self._render_quoted_form(entries[1])) + imports, stmts = self._render_quoted_form(entries[1]) + ret = self.compile(stmts) + ret.add_imports("hy", imports) + return ret @builds("eval") @checkargs(exact=1) def compile_eval(self, expr): expr.pop(0) - self.imports["hy.importer"].append(("hy_eval", expr)) - return self.compile(HyExpression([ + ret = self.compile(HyExpression([ HySymbol("hy_eval")] + expr + [ HyExpression([HySymbol("locals")])]).replace(expr)) + ret.add_imports("hy.importer", ["hy_eval"]) + + return ret + @builds("do") @builds("progn") - def compile_do_expression(self, expr): - return [self.compile(x) for x in expr[1:]] + def compile_progn(self, expression): + expression.pop(0) + return self._compile_branch(expression) @builds("throw") @builds("raise") @checkargs(max=1) def compile_throw_expression(self, expr): expr.pop(0) - exc = self.compile(expr.pop(0)) if expr else None - return ast.Raise( + ret = Result() + if expr: + ret += self.compile(expr.pop(0)) + + # Use ret.expr to get a literal `None` + ret += ast.Raise( lineno=expr.start_line, col_offset=expr.start_column, - type=exc, - exc=exc, + type=ret.expr, + exc=ret.expr, inst=None, - tback=None) + tback=None, + cause=None) + + return ret @builds("try") def compile_try_expression(self, expr): @@ -309,38 +562,47 @@ class HyASTCompiler(object): body = [] # (try something…) - body = self._code_branch(self.compile(body), - expr.start_line, - expr.start_column) + body = self.compile(body) + + # XXX we will likely want to make this a tempvar + body += body.expr_as_stmt() + body = body.stmts + if not body: + body = [ast.Pass(lineno=expr.start_line, + col_offset=expr.start_column)] orelse = [] finalbody = [] handlers = [] + handler_results = Result() for e in expr: if not len(e): raise HyTypeError(e, "Empty list not allowed in `try'") if e[0] in (HySymbol("except"), HySymbol("catch")): - handlers.append(self.compile(e)) + handler_results += self.compile(e) + handlers.append(handler_results.stmts.pop()) elif e[0] == HySymbol("else"): if orelse: raise HyTypeError( e, "`try' cannot have more than one `else'") else: - orelse = self._code_branch(self.compile(e[1:]), - e.start_line, - e.start_column) + orelse = self._compile_branch(e[1:]) + # XXX tempvar magic + orelse += orelse.expr_as_stmt() + orelse = orelse.stmts elif e[0] == HySymbol("finally"): if finalbody: raise HyTypeError( e, "`try' cannot have more than one `finally'") else: - finalbody = self._code_branch(self.compile(e[1:]), - e.start_line, - e.start_column) + finalbody = self._compile_branch(e[1:]) + # XXX tempvar magic + finalbody += finalbody.expr_as_stmt() + finalbody = finalbody.stmts else: raise HyTypeError(e, "Unknown expression in `try'") @@ -361,9 +623,11 @@ class HyASTCompiler(object): body=[ast.Pass(lineno=expr.start_line, col_offset=expr.start_column)])] + ret = handler_results + if sys.version_info[0] >= 3 and sys.version_info[1] >= 3: # Python 3.3 features a merge of TryExcept+TryFinally into Try. - return ast.Try( + return ret + ast.Try( lineno=expr.start_line, col_offset=expr.start_column, body=body, @@ -373,7 +637,7 @@ class HyASTCompiler(object): if finalbody: if handlers: - return ast.TryFinally( + return ret + ast.TryFinally( lineno=expr.start_line, col_offset=expr.start_column, body=[ast.TryExcept( @@ -384,13 +648,13 @@ class HyASTCompiler(object): orelse=orelse)], finalbody=finalbody) - return ast.TryFinally( + return ret + ast.TryFinally( lineno=expr.start_line, col_offset=expr.start_column, body=body, finalbody=finalbody) - return ast.TryExcept( + return ret + ast.TryExcept( lineno=expr.start_line, col_offset=expr.start_column, handlers=handlers, @@ -448,90 +712,133 @@ class HyASTCompiler(object): if isinstance(exceptions_list, list): if len(exceptions_list): # [FooBar BarFoo] → catch Foobar and BarFoo exceptions - _type = ast.Tuple(elts=[self.compile(x) - for x in exceptions_list], - lineno=expr.start_line, - col_offset=expr.start_column, - ctx=ast.Load()) + elts, _type = self._compile_collect(exceptions_list) + _type += ast.Tuple(elts=elts, + lineno=expr.start_line, + col_offset=expr.start_column, + ctx=ast.Load()) else: # [] → all exceptions catched - _type = None + _type = Result() elif isinstance(exceptions_list, HySymbol): _type = self.compile(exceptions_list) else: raise HyTypeError(exceptions, "`%s' needs a valid exception list" % catch) - body = self._code_branch([self.compile(x) for x in expr], - expr.start_line, - expr.start_column) + body = self._compile_branch(expr) + # XXX tempvar handling magic + body += body.expr_as_stmt() - return ast.ExceptHandler( + body = body.stmts + if not body: + body = [ast.Pass(lineno=expr.start_line, + col_offset=expr.start_column)] + + # use _type.expr to get a literal `None` + return _type + ast.ExceptHandler( lineno=expr.start_line, col_offset=expr.start_column, - type=_type, + type=_type.expr, name=name, body=body) - def _code_branch(self, branch, start_line, start_column): - return self._mangle_branch((branch - if isinstance(branch, list) - else [branch]), - start_line, - start_column) - @builds("if") @checkargs(min=2, max=3) - def compile_if_expression(self, expr): - expr.pop(0) # if - test = self.compile(expr.pop(0)) - body = self._code_branch(self.compile(expr.pop(0)), - expr.start_line, - expr.start_column) + def compile_if(self, expression): + expression.pop(0) + cond = self.compile(expression.pop(0)) - if len(expr) == 1: - orel = self._code_branch(self.compile(expr.pop(0)), - expr.start_line, - expr.start_column) + body = self.compile(expression.pop(0)) + orel = Result() + if expression: + orel = self.compile(expression.pop(0)) + + # We want to hoist the statements from the condition + ret = cond + + if body.stmts or orel.stmts: + # We have statements in our bodies + # Get a temporary variable for the result storage + var = self.get_anon_var() + name = ast.Name(id=ast_str(var), arg=ast_str(var), + ctx=ast.Store(), + lineno=expression.start_line, + col_offset=expression.start_column) + + # Store the result of the body + body += ast.Assign(targets=[name], + value=body.force_expr, + lineno=expression.start_line, + col_offset=expression.start_column) + + # and of the else clause + orel += ast.Assign(targets=[name], + value=orel.force_expr, + lineno=expression.start_line, + col_offset=expression.start_column) + + # Then build the if + ret += ast.If(test=ret.force_expr, + body=body.stmts, + orelse=orel.stmts, + lineno=expression.start_line, + col_offset=expression.start_column) + + # And make our expression context our temp variable + expr_name = ast.Name(id=ast_str(var), arg=ast_str(var), + ctx=ast.Load(), + lineno=expression.start_line, + col_offset=expression.start_column) + + ret += Result(expr=expr_name, temp_variables=[expr_name, name]) else: - orel = [] - - return ast.If(test=test, - body=body, - orelse=orel, - lineno=expr.start_line, - col_offset=expr.start_column) + # Just make that an if expression + ret += ast.IfExp(test=ret.force_expr, + body=body.force_expr, + orelse=orel.force_expr, + lineno=expression.start_line, + col_offset=expression.start_column) + return ret @builds("print") def compile_print_expression(self, expr): call = expr.pop(0) # print + values, ret = self._compile_collect(expr) + if sys.version_info[0] >= 3: call = self.compile(call) - # AST changed with Python 3, we now just call it. - return ast.Call( - keywords=[], - func=call, - args=[self.compile(x) for x in expr], + ret += call + ret += ast.Call(func=call.expr, + args=values, + keywords=[], + starargs=None, + kwargs=None, + lineno=expr.start_line, + col_offset=expr.start_column) + else: + ret += ast.Print( lineno=expr.start_line, - col_offset=expr.start_column) + col_offset=expr.start_column, + dest=None, + values=values, + nl=True) - return ast.Print( - lineno=expr.start_line, - col_offset=expr.start_column, - dest=None, - values=[self.compile(x) for x in expr], - nl=True) + return ret @builds("assert") @checkargs(1) def compile_assert_expression(self, expr): expr.pop(0) # assert e = expr.pop(0) - return ast.Assert(test=self.compile(e), + ret = self.compile(e) + ret += ast.Assert(test=ret.force_expr, msg=None, lineno=e.start_line, col_offset=e.start_column) + return ret + @builds("global") @checkargs(1) def compile_global_expression(self, expr): @@ -542,13 +849,13 @@ class HyASTCompiler(object): col_offset=e.start_column) @builds("lambda") - @checkargs(min=2) + @checkargs(2) def compile_lambda_expression(self, expr): expr.pop(0) sig = expr.pop(0) - body = expr.pop(0) - # assert expr is empty - return ast.Lambda( + body = self.compile(expr.pop(0)) + + body += ast.Lambda( lineno=expr.start_line, col_offset=expr.start_column, args=ast.arguments(args=[ @@ -562,36 +869,42 @@ class HyASTCompiler(object): defaults=[], kwonlyargs=[], kw_defaults=[]), - body=self.compile(body)) + body=body.force_expr) + + return body @builds("yield") @checkargs(max=1) def compile_yield_expression(self, expr): expr.pop(0) + ret = Result() + value = None if expr != []: - value = self.compile(expr.pop(0)) - return ast.Yield( + ret += self.compile(expr.pop(0)) + value = ret.force_expr + + ret += ast.Yield( value=value, lineno=expr.start_line, col_offset=expr.start_column) + return ret + @builds("import") def compile_import_expression(self, expr): def _compile_import(expr, module, names=None, importer=ast.Import): - return [ - importer( - lineno=expr.start_line, - col_offset=expr.start_column, - module=ast_str(module), - names=names or [ - ast.alias(name=ast_str(module), asname=None) - ], - level=0) - ] + if not names: + names = [ast.alias(name=ast_str(module), asname=None)] + ret = importer(lineno=expr.start_line, + col_offset=expr.start_column, + module=ast_str(module), + names=names, + level=0) + return Result() + ret expr.pop(0) # index - rimports = [] + rimports = Result() while len(expr) > 0: iexpr = expr.pop(0) @@ -607,17 +920,14 @@ class HyASTCompiler(object): module = iexpr.pop(0) entry = iexpr[0] if isinstance(entry, HyKeyword) and entry == HyKeyword(":as"): - assert len(iexpr) == 2, "garbage after aliased import" + if not len(iexpr) == 2: + raise HyTypeError(iexpr, + "garbage after aliased import") iexpr.pop(0) # :as alias = iexpr.pop(0) - rimports += _compile_import( - expr, - ast_str(module), - [ - ast.alias(name=ast_str(module), - asname=ast_str(alias)) - ] - ) + names = [ast.alias(name=ast_str(module), + asname=ast_str(alias))] + rimports += _compile_import(expr, ast_str(module), names) continue if isinstance(entry, HyList): @@ -629,21 +939,19 @@ class HyASTCompiler(object): alias = ast_str(entry.pop(0)) else: alias = None - names += [ - ast.alias(name=ast_str(sym), - asname=alias) - ] + names.append(ast.alias(name=ast_str(sym), + asname=alias)) rimports += _compile_import(expr, module, names, ast.ImportFrom) continue - raise TypeError("Unknown entry (`%s`) in the HyList" % (entry)) + raise HyTypeError( + entry, + "Unknown entry (`%s`) in the HyList" % (entry) + ) - if len(rimports) == 1: - return rimports[0] - else: - return rimports + return rimports @builds("get") @checkargs(2) @@ -652,34 +960,39 @@ class HyASTCompiler(object): val = self.compile(expr.pop(0)) # target sli = self.compile(expr.pop(0)) # slice - return ast.Subscript( + return val + sli + ast.Subscript( lineno=expr.start_line, col_offset=expr.start_column, - value=val, - slice=ast.Index(value=sli), + value=val.force_expr, + slice=ast.Index(value=sli.force_expr), ctx=ast.Load()) @builds("slice") - @checkargs(min=1, max=3) + @checkargs(min=1, max=4) def compile_slice_expression(self, expr): expr.pop(0) # index val = self.compile(expr.pop(0)) # target - low = None + low = Result() if expr != []: low = self.compile(expr.pop(0)) - high = None + high = Result() if expr != []: high = self.compile(expr.pop(0)) - return ast.Subscript( + step = Result() + if expr != []: + step = self.compile(expr.pop(0)) + + # use low.expr, high.expr and step.expr to use a literal `None`. + return val + low + high + step + ast.Subscript( lineno=expr.start_line, col_offset=expr.start_column, - value=val, - slice=ast.Slice(lower=low, - upper=high, - step=None), + value=val.force_expr, + slice=ast.Slice(lower=low.expr, + upper=high.expr, + step=step.expr), ctx=ast.Load()) @builds("assoc") @@ -687,31 +1000,32 @@ class HyASTCompiler(object): def compile_assoc_expression(self, expr): expr.pop(0) # assoc # (assoc foo bar baz) => foo[bar] = baz - target = expr.pop(0) - key = expr.pop(0) - val = expr.pop(0) + target = self.compile(expr.pop(0)) + key = self.compile(expr.pop(0)) + val = self.compile(expr.pop(0)) - return ast.Assign( + return target + key + val + ast.Assign( lineno=expr.start_line, col_offset=expr.start_column, targets=[ ast.Subscript( lineno=expr.start_line, col_offset=expr.start_column, - value=self.compile(target), - slice=ast.Index(value=self.compile(key)), + value=target.force_expr, + slice=ast.Index(value=key.force_expr), ctx=ast.Store())], - value=self.compile(val)) + value=val.force_expr) @builds("with_decorator") @checkargs(min=1) def compile_decorate_expression(self, expr): expr.pop(0) # with-decorator fn = self.compile(expr.pop(-1)) - if type(fn) != ast.FunctionDef: + if not fn.stmts or not isinstance(fn.stmts[-1], ast.FunctionDef): raise HyTypeError(expr, "Decorated a non-function") - fn.decorator_list = [self.compile(x) for x in expr] - return fn + decorators, ret = self._compile_collect(expr) + fn.stmts[-1].decorator_list = decorators + return ret + fn @builds("with") @checkargs(min=2) @@ -729,27 +1043,34 @@ class HyASTCompiler(object): if args != []: thing = self._storeize(self.compile(args.pop(0))) - ret = ast.With(context_expr=ctx, - lineno=expr.start_line, - col_offset=expr.start_column, - optional_vars=thing, - body=self._code_branch( - [self.compile(x) for x in expr], - expr.start_line, - expr.start_column)) + body = self._compile_branch(expr) + body += body.expr_as_stmt() + + if not body.stmts: + body += ast.Pass(lineno=expr.start_line, + col_offset=expr.start_column) + + the_with = ast.With(context_expr=ctx.force_expr, + lineno=expr.start_line, + col_offset=expr.start_column, + optional_vars=thing, + body=body.stmts) if sys.version_info[0] >= 3 and sys.version_info[1] >= 3: - ret.items = [ast.withitem(context_expr=ctx, optional_vars=thing)] + the_with.items = [ast.withitem(context_expr=ctx.force_expr, + optional_vars=thing)] - return ret + return ctx + the_with @builds(",") def compile_tuple(self, expr): expr.pop(0) - return ast.Tuple(elts=[self.compile(x) for x in expr], + elts, ret = self._compile_collect(expr) + ret += ast.Tuple(elts=elts, lineno=expr.start_line, col_offset=expr.start_column, ctx=ast.Load()) + return ret @builds("list_comp") @checkargs(min=2, max=3) @@ -760,32 +1081,32 @@ class HyASTCompiler(object): tar_it = iter(expr.pop(0)) targets = zip(tar_it, tar_it) - cond = self.compile(expr.pop(0)) if expr != [] else None - - ret = ast.ListComp( - lineno=expr.start_line, - col_offset=expr.start_column, - elt=self.compile(expression), - generators=[]) + cond = self.compile(expr.pop(0)) if expr != [] else Result() + generator_res = Result() + generators = [] for target, iterable in targets: - ret.generators.append(ast.comprehension( - target=self._storeize(self.compile(target)), - iter=self.compile(iterable), + comp_target = self.compile(target) + target = self._storeize(comp_target) + generator_res += self.compile(iterable) + generators.append(ast.comprehension( + target=target, + iter=generator_res.force_expr, ifs=[])) - if cond: - ret.generators[-1].ifs.append(cond) + if cond.expr: + generators[-1].ifs.append(cond.expr) + + compiled_expression = self.compile(expression) + ret = compiled_expression + generator_res + cond + ret += ast.ListComp( + lineno=expr.start_line, + col_offset=expr.start_column, + elt=compiled_expression.force_expr, + generators=generators) return ret - def _storeize(self, name): - if isinstance(name, ast.Tuple): - for x in name.elts: - x.ctx = ast.Store() - name.ctx = ast.Store() - return name - @builds("kwapply") @checkargs(2) def compile_kwapply_expression(self, expr): @@ -793,16 +1114,22 @@ class HyASTCompiler(object): call = self.compile(expr.pop(0)) kwargs = expr.pop(0) - if type(call) != ast.Call: + if type(call.expr) != ast.Call: raise HyTypeError(expr, "kwapplying a non-call") if type(kwargs) != HyDict: raise TypeError("kwapplying with a non-dict") - call.keywords = [ast.keyword(arg=ast_str(x), - value=self.compile(kwargs[x])) for x in kwargs] + keywords = [] + ret = Result() + for x in kwargs: + ret += self.compile(kwargs[x]) + keywords.append(ast.keyword(arg=ast_str(x), + value=ret.force_expr)) - return call + call.expr.keywords = keywords + + return ret + call @builds("not") @builds("~") @@ -811,11 +1138,13 @@ class HyASTCompiler(object): ops = {"not": ast.Not, "~": ast.Invert} operator = expression.pop(0) - operand = expression.pop(0) - return ast.UnaryOp(op=ops[operator](), - operand=self.compile(operand), - lineno=operator.start_line, - col_offset=operator.start_column) + operand = self.compile(expression.pop(0)) + + operand += ast.UnaryOp(op=ops[operator](), + operand=operand.expr, + lineno=operator.start_line, + col_offset=operator.start_column) + return operand @builds("and") @builds("or") @@ -824,13 +1153,13 @@ class HyASTCompiler(object): ops = {"and": ast.And, "or": ast.Or} operator = expression.pop(0) - values = [] - for child in expression: - values.append(self.compile(child)) - return ast.BoolOp(op=ops[operator](), + values, ret = self._compile_collect(expression) + + ret += ast.BoolOp(op=ops[operator](), lineno=operator.start_line, col_offset=operator.start_column, values=values) + return ret @builds("=") @builds("!=") @@ -853,13 +1182,15 @@ class HyASTCompiler(object): inv = expression.pop(0) op = ops[inv] ops = [op() for x in range(1, len(expression))] - e = expression.pop(0) - return ast.Compare(left=self.compile(e), - ops=ops, - comparators=[self.compile(x) for x in expression], - lineno=e.start_line, - col_offset=e.start_column) + e = expression[0] + exprs, ret = self._compile_collect(expression) + + return ret + ast.Compare(left=exprs[0], + ops=ops, + comparators=exprs[1:], + lineno=e.start_line, + col_offset=e.start_column) @builds("+") @builds("%") @@ -890,16 +1221,17 @@ class HyASTCompiler(object): inv = expression.pop(0) op = ops[inv] - left = self.compile(expression.pop(0)) - calc = None + ret = self.compile(expression.pop(0)) for child in expression: - calc = ast.BinOp(left=left, + left_expr = ret.force_expr + ret += self.compile(child) + right_expr = ret.force_expr + ret += ast.BinOp(left=left_expr, op=op(), - right=self.compile(child), + right=right_expr, lineno=child.start_line, col_offset=child.start_column) - left = calc - return calc + return ret @builds("-") @checkargs(min=1) @@ -908,10 +1240,12 @@ class HyASTCompiler(object): return self.compile_maths_expression(expression) else: arg = expression[1] - return ast.UnaryOp(op=ast.USub(), - operand=self.compile(arg), + ret = self.compile(arg) + ret += ast.UnaryOp(op=ast.USub(), + operand=ret.force_expr, lineno=arg.start_line, col_offset=arg.start_column) + return ret @builds("+=") @builds("/=") @@ -943,115 +1277,113 @@ class HyASTCompiler(object): op = ops[expression[0]] target = self._storeize(self.compile(expression[1])) - value = self.compile(expression[2]) + ret = self.compile(expression[2]) - return ast.AugAssign( + ret += ast.AugAssign( target=target, - value=value, + value=ret.force_expr, op=op(), lineno=expression.start_line, col_offset=expression.start_column) - def compile_dotted_expression(self, expr): - ofn = expr.pop(0) # .join - - fn = HySymbol(ofn[1:]) - fn.replace(ofn) - - obj = expr.pop(0) # [1 2 3 4] - - return ast.Call( - func=ast.Attribute( - lineno=expr.start_line, - col_offset=expr.start_column, - value=self.compile(obj), - attr=ast_str(fn), - ctx=ast.Load()), - args=[self.compile(x) for x in expr], - keywords=[], - lineno=expr.start_line, - col_offset=expr.start_column, - starargs=None, - kwargs=None) + return ret @builds(HyExpression) def compile_expression(self, expression): fn = expression[0] + func = None if isinstance(fn, HyString): - if fn in _compile_table: - return _compile_table[fn](self, expression) + ret = self.compile_atom(fn, expression) + if ret: + return ret + if fn.startswith("."): + # (.split "test test") -> "test test".split() - if expression[0].startswith("."): - return self.compile_dotted_expression(expression) - if isinstance(fn, HyKeyword): - new_expr = HyExpression(["get", expression[1], fn]) - new_expr.start_line = expression.start_line - new_expr.start_column = expression.start_column - return self.compile_index_expression(new_expr) + # Get the attribute name + ofn = fn + fn = HySymbol(ofn[1:]) + fn.replace(ofn) - return ast.Call(func=self.compile(fn), - args=[self.compile(x) for x in expression[1:]], + # Get the object we want to take an attribute from + func = self.compile(expression.pop(1)) + + # And get the attribute + func += ast.Attribute(lineno=fn.start_line, + col_offset=fn.start_column, + value=func.force_expr, + attr=ast_str(fn), + ctx=ast.Load()) + + if not func: + func = self.compile(fn) + args, ret = self._compile_collect(expression[1:]) + + ret += ast.Call(func=func.expr, + args=args, keywords=[], starargs=None, kwargs=None, lineno=expression.start_line, col_offset=expression.start_column) + return func + ret + @builds("def") @builds("setf") @builds("setv") @checkargs(2) def compile_def_expression(self, expression): - expression.pop(0) # "def" + expression.pop(0) name = expression.pop(0) + result = self.compile(expression.pop(0)) - what = self.compile(expression.pop(0)) + if result.temp_variables and isinstance(name, HyString): + result.rename(name) + return result - if type(what) == ast.FunctionDef: - # We special case a FunctionDef, since we can define by setting - # FunctionDef's .name attribute, rather then foo == anon_fn. This - # helps keep things clean. - what.name = ast_str(name) - return what + ld_name = self.compile(name) + st_name = self._storeize(ld_name) - name = self._storeize(self.compile(name)) - - return ast.Assign( + result += ast.Assign( lineno=expression.start_line, col_offset=expression.start_column, - targets=[name], value=what) + targets=[st_name], value=result.force_expr) + + result += ld_name + return result @builds("foreach") @checkargs(min=1) def compile_for_expression(self, expression): - with self.is_returnable(False): - expression.pop(0) # for - name, iterable = expression.pop(0) - target = self._storeize(self.compile_symbol(name)) + expression.pop(0) # for + target_name, iterable = expression.pop(0) + target = self._storeize(self.compile(target_name)) - orelse = [] - # (foreach [] body (else …)) - if expression and expression[-1][0] == HySymbol("else"): - else_expr = expression.pop() - if len(else_expr) > 2: - raise HyTypeError( - else_expr, - "`else' statement in `foreach' is too long") - elif len(else_expr) == 2: - orelse = self._code_branch( - self.compile(else_expr[1]), - else_expr[1].start_line, - else_expr[1].start_column) + ret = Result() - ret = ast.For(lineno=expression.start_line, - col_offset=expression.start_column, - target=target, - iter=self.compile(iterable), - body=self._code_branch( - [self.compile(x) for x in expression], - expression.start_line, - expression.start_column), - orelse=orelse) + orel = Result() + # (foreach [] body (else …)) + if expression and expression[-1][0] == HySymbol("else"): + else_expr = expression.pop() + if len(else_expr) > 2: + raise HyTypeError( + else_expr, + "`else' statement in `foreach' is too long") + elif len(else_expr) == 2: + orel += self.compile(else_expr[1]) + orel += orel.expr_as_stmt() + + ret += self.compile(iterable) + + body = self._compile_branch(expression) + body += body.expr_as_stmt() + + ret += ast.For(lineno=expression.start_line, + col_offset=expression.start_column, + target=target, + iter=ret.force_expr, + body=body.stmts, + orelse=orel.stmts) return ret @@ -1059,69 +1391,73 @@ class HyASTCompiler(object): @checkargs(min=2) def compile_while_expression(self, expr): expr.pop(0) # "while" - test = self.compile(expr.pop(0)) + ret = self.compile(expr.pop(0)) - return ast.While(test=test, - body=self._code_branch( - [self.compile(x) for x in expr], - expr.start_line, - expr.start_column), + body = self._compile_branch(expr) + body += body.expr_as_stmt() + + ret += ast.While(test=ret.force_expr, + body=body.stmts, orelse=[], lineno=expr.start_line, col_offset=expr.start_column) + return ret + @builds(HyList) - def compile_list(self, expr): - return ast.List( - elts=[self.compile(x) for x in expr], - ctx=ast.Load(), - lineno=expr.start_line, - col_offset=expr.start_column) + def compile_list(self, expression): + elts, ret = self._compile_collect(expression) + ret += ast.List(elts=elts, + ctx=ast.Load(), + lineno=expression.start_line, + col_offset=expression.start_column) + return ret @builds("fn") @checkargs(min=1) - def compile_fn_expression(self, expression): - expression.pop(0) # fn + def compile_function_def(self, expression): + expression.pop(0) - self.anon_fn_count += 1 - name = "_hy_anon_fn_%d" % (self.anon_fn_count) - sig = expression.pop(0) + name = self.get_anon_fn() - body = [] - if expression != []: - with self.is_returnable(True): - tailop = self.compile(expression.pop(-1)) - with self.is_returnable(False): - for el in expression: - body.append(self.compile(el)) - body.append(tailop) + arglist = expression.pop(0) + ret, args, defaults, stararg, kwargs = self._parse_lambda_list(arglist) + body = self._compile_branch(expression) + if body.expr: + body += ast.Return(value=body.expr, + lineno=body.expr.lineno, + col_offset=body.expr.col_offset) - with self.is_returnable(True): - body = self._code_branch(body, - expression.start_line, - expression.start_column) + if not body.stmts: + body += ast.Pass(lineno=expression.start_line, + col_offset=expression.start_column) - args, defaults, stararg, kwargs = self._parse_lambda_list(sig) + ret += ast.FunctionDef(name=name, + lineno=expression.start_line, + col_offset=expression.start_column, + args=ast.arguments( + args=[ + ast.Name( + arg=ast_str(x), id=ast_str(x), + ctx=ast.Param(), + lineno=x.start_line, + col_offset=x.start_column) + for x in args], + vararg=stararg, + kwarg=kwargs, + kwonlyargs=[], + kw_defaults=[], + defaults=defaults), + body=body.stmts, + decorator_list=[]) - ret = ast.FunctionDef( - name=name, - lineno=expression.start_line, - col_offset=expression.start_column, - args=ast.arguments( - args=[ - ast.Name( - arg=ast_str(x), id=ast_str(x), - ctx=ast.Param(), - lineno=x.start_line, - col_offset=x.start_column) - for x in args], - vararg=stararg, - kwarg=kwargs, - kwonlyargs=[], - kw_defaults=[], - defaults=defaults), - body=body, - decorator_list=[]) + ast_name = ast.Name(id=name, + arg=name, + ctx=ast.Load(), + lineno=expression.start_line, + col_offset=expression.start_column) + + ret += Result(expr=ast_name, temp_variables=[ast_name, ret.stmts[-1]]) return ret @@ -1147,16 +1483,17 @@ class HyASTCompiler(object): def compile_symbol(self, symbol): if "." in symbol: glob, local = symbol.rsplit(".", 1) - glob = HySymbol(glob) - glob.replace(symbol) + glob = HySymbol(glob).replace(symbol) + ret = self.compile_symbol(glob) - return ast.Attribute( + ret = ast.Attribute( lineno=symbol.start_line, col_offset=symbol.start_column, - value=self.compile_symbol(glob), + value=ret, attr=ast_str(local), ctx=ast.Load() ) + return ret return ast.Name(id=ast_str(symbol), arg=ast_str(symbol), @@ -1166,67 +1503,67 @@ class HyASTCompiler(object): @builds(HyString) def compile_string(self, string): - return ast.Str(s=str_type(string), lineno=string.start_line, + return ast.Str(s=str_type(string), + lineno=string.start_line, col_offset=string.start_column) @builds(HyKeyword) def compile_keyword(self, keyword): - return ast.Str(s=str_type(keyword), lineno=keyword.start_line, + return ast.Str(s=str_type(keyword), + lineno=keyword.start_line, col_offset=keyword.start_column) @builds(HyDict) def compile_dict(self, m): - keys = [] - vals = [] - for entry in m: - keys.append(self.compile(entry)) - vals.append(self.compile(m[entry])) + keyvalues, ret = self._compile_collect(sum(m.items(), ())) - return ast.Dict( - lineno=m.start_line, - col_offset=m.start_column, - keys=keys, - values=vals) + ret += ast.Dict(lineno=m.start_line, + col_offset=m.start_column, + keys=keyvalues[::2], + values=keyvalues[1::2]) + return ret -def hy_compile(tree, root=None): - " Compile a HyObject tree into a Python AST tree. " - compiler = HyASTCompiler() - tlo = root - if root is None: - tlo = ast.Module +def hy_compile(tree, root=ast.Module, get_expr=False): + """ + Compile a HyObject tree into a Python AST Module. - _ast = compiler.compile(tree) - if type(_ast) == list: - _ast = compiler._mangle_branch(_ast, 0, 0) + If `get_expr` is True, return a tuple (module, last_expression), where + `last_expression` is the. + """ - if hasattr(sys, "subversion"): - implementation = sys.subversion[0].lower() - elif hasattr(sys, "implementation"): - implementation = sys.implementation.name.lower() + if hasattr(sys, "subversion"): + implementation = sys.subversion[0].lower() + elif hasattr(sys, "implementation"): + implementation = sys.implementation.name.lower() - imports = [] - for package in compiler.imports: - imported = set() - syms = compiler.imports[package] - for entry, form in syms: - if entry in imported: - continue + body = [] + expr = None - replace = form - if implementation != "cpython": - # using form causes pypy to blow up; let's conditionally - # add this for cpython, since it won't go through and make - # sure the AST makes sense. Muhahaha. - PRT - replace = tree[0] + if tree: + compiler = HyASTCompiler() + result = compiler.compile(tree) + expr = result.force_expr - imported.add(entry) - imports.append(HyExpression([ - HySymbol("import"), - HyList([HySymbol(package), HyList([HySymbol(entry)])]) - ]).replace(replace)) + if not get_expr: + result += result.expr_as_stmt() - _ast = compiler.compile(imports) + _ast + if isinstance(tree, list): + spoof_tree = tree[0] + else: + spoof_tree = tree + body = compiler.imports_as_stmts(spoof_tree) + result.stmts + + ret = root(body=body) + + # PyPy _really_ doesn't like the ast going backwards... + if implementation != "cpython": + for node in ast.walk(ret): + node.lineno = 1 + node.col_offset = 1 + + if get_expr: + expr = ast.Expression(body=expr) + ret = (ret, expr) - ret = tlo(body=_ast) return ret diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index 3f1893b..e06613b 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -240,12 +240,13 @@ def test_ast_good_slice(): hy_compile(tokenize("(slice x)")) hy_compile(tokenize("(slice x y)")) hy_compile(tokenize("(slice x y z)")) + hy_compile(tokenize("(slice x y z t)")) def test_ast_bad_slice(): "Make sure AST can't compile invalid slice" cant_compile("(slice)") - cant_compile("(slice 1 2 3 4)") + cant_compile("(slice 1 2 3 4 5)") def test_ast_good_take():