diff --git a/hy/compiler.py b/hy/compiler.py index 7bc88f6..1db58bd 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -841,31 +841,6 @@ class HyASTCompiler(object): lineno=e.start_line, col_offset=e.start_column) - @builds("lambda") - @checkargs(2) - def compile_lambda_expression(self, expr): - expr.pop(0) - sig = expr.pop(0) - body = self.compile(expr.pop(0)) - - body += ast.Lambda( - lineno=expr.start_line, - col_offset=expr.start_column, - args=ast.arguments(args=[ - ast.Name(arg=ast_str(x), id=ast_str(x), - ctx=ast.Param(), - lineno=x.start_line, - col_offset=x.start_column) - for x in sig], - vararg=None, - kwarg=None, - defaults=[], - kwonlyargs=[], - kw_defaults=[]), - body=body.force_expr) - - return body - @builds("yield") @checkargs(max=1) def compile_yield_expression(self, expr): @@ -1406,16 +1381,37 @@ class HyASTCompiler(object): col_offset=expression.start_column) return ret + @builds("lambda") @builds("fn") @checkargs(min=1) def compile_function_def(self, expression): - expression.pop(0) - - name = self.get_anon_fn() + called_as = expression.pop(0) arglist = expression.pop(0) ret, args, defaults, stararg, kwargs = self._parse_lambda_list(arglist) + + args = ast.arguments( + args=[ast.Name(arg=ast_str(x), id=ast_str(x), + ctx=ast.Param(), + lineno=x.start_line, + col_offset=x.start_column) + for x in args], + vararg=stararg, + kwarg=kwargs, + kwonlyargs=[], + kw_defaults=[], + defaults=defaults) + body = self._compile_branch(expression) + if not body.stmts and called_as == "lambda": + ret += ast.Lambda( + lineno=expression.start_line, + col_offset=expression.start_column, + args=args, + body=body.force_expr) + + return ret + if body.expr: body += ast.Return(value=body.expr, lineno=body.expr.lineno, @@ -1425,22 +1421,12 @@ class HyASTCompiler(object): body += ast.Pass(lineno=expression.start_line, col_offset=expression.start_column) + name = self.get_anon_fn() + ret += ast.FunctionDef(name=name, lineno=expression.start_line, col_offset=expression.start_column, - args=ast.arguments( - args=[ - ast.Name( - arg=ast_str(x), id=ast_str(x), - ctx=ast.Param(), - lineno=x.start_line, - col_offset=x.start_column) - for x in args], - vararg=stararg, - kwarg=kwargs, - kwonlyargs=[], - kw_defaults=[], - defaults=defaults), + args=args, body=body.stmts, decorator_list=[]) diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index e06613b..318aa0e 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -199,13 +199,13 @@ def test_ast_bad_global(): def test_ast_good_lambda(): "Make sure AST can compile valid lambda" + hy_compile(tokenize("(lambda [])")) hy_compile(tokenize("(lambda [] 1)")) def test_ast_bad_lambda(): "Make sure AST can't compile invalid lambda" cant_compile("(lambda)") - cant_compile("(lambda [])") def test_ast_good_yield():