From 7c91913122232f9a6e02788fb54fdd7e326bfe98 Mon Sep 17 00:00:00 2001 From: Paul Tagliamonte Date: Sun, 14 Jul 2013 13:03:08 -0400 Subject: [PATCH] Fix yielding to not suck (#151) This adds a class to avoid returning when we have a Yieldable expression contained in the body of the function. This breaks Python 2.x, and ought to break Python 3.x, but doesn't. We need this fo' context managers, etc. This commit also has work from @rwtolbert adding new testcases and fixes for yielded entries behind a while / for. --- hy/compiler.py | 32 ++++++++++++++++++++++------- tests/native_tests/native_macros.hy | 30 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/hy/compiler.py b/hy/compiler.py index ad67c28..01d7e17 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -154,7 +154,8 @@ class Result(object): The Result object is interoperable with python AST objects: when an AST object gets added to a Result object, it gets converted on-the-fly. """ - __slots__ = ("imports", "stmts", "temp_variables", "_expr", "__used_expr") + __slots__ = ("imports", "stmts", "temp_variables", + "_expr", "__used_expr", "contains_yield") def __init__(self, *args, **kwargs): if args: @@ -165,12 +166,14 @@ class Result(object): self.stmts = [] self.temp_variables = [] self._expr = None + self.contains_yield = False self.__used_expr = False # XXX: Make sure we only have AST where we should. for kwarg in kwargs: - if kwarg not in ["imports", "stmts", "expr", "temp_variables"]: + if kwarg not in ["imports", "contains_yield", "stmts", "expr", + "temp_variables"]: raise TypeError( "%s() got an unexpected keyword argument '%s'" % ( self.__class__.__name__, kwarg)) @@ -282,13 +285,21 @@ class Result(object): result.stmts = self.stmts + other.stmts result.expr = other.expr result.temp_variables = other.temp_variables + result.contains_yield = False + if self.contains_yield or other.contains_yield: + result.contains_yield = True + return result def __str__(self): - return "Result(imports=[%s], stmts=[%s], expr=%s)" % ( + return ( + "Result(imports=[%s], stmts=[%s], " + "expr=%s, contains_yield=%s)" + ) % ( ", ".join(ast.dump(x) for x in self.imports), ", ".join(ast.dump(x) for x in self.stmts), ast.dump(self.expr) if self.expr else None, + self.contains_yield ) @@ -1011,7 +1022,7 @@ class HyASTCompiler(object): @checkargs(max=1) def compile_yield_expression(self, expr): expr.pop(0) - ret = Result() + ret = Result(contains_yield=True) value = None if expr != []: @@ -1541,6 +1552,8 @@ class HyASTCompiler(object): body=body.stmts, orelse=orel.stmts) + ret.contains_yield = body.contains_yield + return ret @builds("while") @@ -1558,6 +1571,8 @@ class HyASTCompiler(object): lineno=expr.start_line, col_offset=expr.start_column) + ret.contains_yield = body.contains_yield + return ret @builds(HyList) @@ -1601,9 +1616,12 @@ class HyASTCompiler(object): return ret if body.expr: - body += ast.Return(value=body.expr, - lineno=body.expr.lineno, - col_offset=body.expr.col_offset) + if body.contains_yield: + body += body.expr_as_stmt() + else: + body += ast.Return(value=body.expr, + lineno=body.expr.lineno, + col_offset=body.expr.col_offset) if not body.stmts: body += ast.Pass(lineno=expression.start_line, diff --git a/tests/native_tests/native_macros.hy b/tests/native_tests/native_macros.hy index 2fb23d3..439ee03 100644 --- a/tests/native_tests/native_macros.hy +++ b/tests/native_tests/native_macros.hy @@ -46,6 +46,36 @@ "NATIVE: test macro calling a plain function" (assert (= 3 (bar 1 2)))) +(defn test-midtree-yield [] + "NATIVE: test yielding with a returnable" + (defn kruft [] (yield) (+ 1 1))) + +(defn test-midtree-yield-in-for [] + "NATIVE: test yielding in a for with a return" + (defn kruft-in-for [] + (for [i (range 5)] + (yield i)) + (+ 1 2))) + +(defn test-midtree-yield-in-while [] + "NATIVE: test yielding in a while with a return" + (defn kruft-in-while [] + (setv i 0) + (while (< i 5) + (yield i) + (setv i (+ i 1))) + (+ 2 3))) + +(defn test-multi-yield [] + "NATIVE: testing multiple yields" + (defn multi-yield [] + (for [i (range 3)] + (yield i)) + (yield "a") + (yield "end")) + (assert (= (list (multi-yield)) [0 1 2 "a" "end"]))) + + ; Macro that checks a variable defined at compile or load time (setv phase "load") (eval-when-compile