diff --git a/hy/compiler.py b/hy/compiler.py index f4692fb..c04969c 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -154,7 +154,8 @@ class Result(object): The Result object is interoperable with python AST objects: when an AST object gets added to a Result object, it gets converted on-the-fly. """ - __slots__ = ("imports", "stmts", "temp_variables", "_expr", "__used_expr") + __slots__ = ("imports", "stmts", "temp_variables", + "_expr", "__used_expr", "contains_yield") def __init__(self, *args, **kwargs): if args: @@ -165,12 +166,14 @@ class Result(object): self.stmts = [] self.temp_variables = [] self._expr = None + self.contains_yield = False self.__used_expr = False # XXX: Make sure we only have AST where we should. for kwarg in kwargs: - if kwarg not in ["imports", "stmts", "expr", "temp_variables"]: + if kwarg not in ["imports", "contains_yield", "stmts", "expr", + "temp_variables"]: raise TypeError( "%s() got an unexpected keyword argument '%s'" % ( self.__class__.__name__, kwarg)) @@ -282,13 +285,21 @@ class Result(object): result.stmts = self.stmts + other.stmts result.expr = other.expr result.temp_variables = other.temp_variables + result.contains_yield = False + if self.contains_yield or other.contains_yield: + result.contains_yield = True + return result def __str__(self): - return "Result(imports=[%s], stmts=[%s], expr=%s)" % ( + return ( + "Result(imports=[%s], stmts=[%s], " + "expr=%s, contains_yield=%s)" + ) % ( ", ".join(ast.dump(x) for x in self.imports), ", ".join(ast.dump(x) for x in self.stmts), ast.dump(self.expr) if self.expr else None, + self.contains_yield ) @@ -1011,7 +1022,7 @@ class HyASTCompiler(object): @checkargs(max=1) def compile_yield_expression(self, expr): expr.pop(0) - ret = Result() + ret = Result(contains_yield=True) value = None if expr != []: @@ -1540,6 +1551,8 @@ class HyASTCompiler(object): body=body.stmts, orelse=orel.stmts) + ret.contains_yield = body.contains_yield + return ret @builds("while") @@ -1557,6 +1570,8 @@ class HyASTCompiler(object): lineno=expr.start_line, col_offset=expr.start_column) + ret.contains_yield = body.contains_yield + return ret @builds(HyList) @@ -1600,9 +1615,12 @@ class HyASTCompiler(object): return ret if body.expr: - body += ast.Return(value=body.expr, - lineno=body.expr.lineno, - col_offset=body.expr.col_offset) + if body.contains_yield: + body += body.expr_as_stmt() + else: + body += ast.Return(value=body.expr, + lineno=body.expr.lineno, + col_offset=body.expr.col_offset) if not body.stmts: body += ast.Pass(lineno=expression.start_line, diff --git a/tests/native_tests/native_macros.hy b/tests/native_tests/native_macros.hy index 2fb23d3..439ee03 100644 --- a/tests/native_tests/native_macros.hy +++ b/tests/native_tests/native_macros.hy @@ -46,6 +46,36 @@ "NATIVE: test macro calling a plain function" (assert (= 3 (bar 1 2)))) +(defn test-midtree-yield [] + "NATIVE: test yielding with a returnable" + (defn kruft [] (yield) (+ 1 1))) + +(defn test-midtree-yield-in-for [] + "NATIVE: test yielding in a for with a return" + (defn kruft-in-for [] + (for [i (range 5)] + (yield i)) + (+ 1 2))) + +(defn test-midtree-yield-in-while [] + "NATIVE: test yielding in a while with a return" + (defn kruft-in-while [] + (setv i 0) + (while (< i 5) + (yield i) + (setv i (+ i 1))) + (+ 2 3))) + +(defn test-multi-yield [] + "NATIVE: testing multiple yields" + (defn multi-yield [] + (for [i (range 3)] + (yield i)) + (yield "a") + (yield "end")) + (assert (= (list (multi-yield)) [0 1 2 "a" "end"]))) + + ; Macro that checks a variable defined at compile or load time (setv phase "load") (eval-when-compile