Merge pull request #243 from paultag/paultag/bugfix/fix-yield

Fix yielding to not suck (#151)

yay team
This commit is contained in:
rwtolbert 2013-07-16 18:27:14 -07:00
commit 348eaaa0f4
2 changed files with 55 additions and 7 deletions

View File

@ -154,7 +154,8 @@ class Result(object):
The Result object is interoperable with python AST objects: when an AST
object gets added to a Result object, it gets converted on-the-fly.
"""
__slots__ = ("imports", "stmts", "temp_variables", "_expr", "__used_expr")
__slots__ = ("imports", "stmts", "temp_variables",
"_expr", "__used_expr", "contains_yield")
def __init__(self, *args, **kwargs):
if args:
@ -165,12 +166,14 @@ class Result(object):
self.stmts = []
self.temp_variables = []
self._expr = None
self.contains_yield = False
self.__used_expr = False
# XXX: Make sure we only have AST where we should.
for kwarg in kwargs:
if kwarg not in ["imports", "stmts", "expr", "temp_variables"]:
if kwarg not in ["imports", "contains_yield", "stmts", "expr",
"temp_variables"]:
raise TypeError(
"%s() got an unexpected keyword argument '%s'" % (
self.__class__.__name__, kwarg))
@ -282,13 +285,21 @@ class Result(object):
result.stmts = self.stmts + other.stmts
result.expr = other.expr
result.temp_variables = other.temp_variables
result.contains_yield = False
if self.contains_yield or other.contains_yield:
result.contains_yield = True
return result
def __str__(self):
return "Result(imports=[%s], stmts=[%s], expr=%s)" % (
return (
"Result(imports=[%s], stmts=[%s], "
"expr=%s, contains_yield=%s)"
) % (
", ".join(ast.dump(x) for x in self.imports),
", ".join(ast.dump(x) for x in self.stmts),
ast.dump(self.expr) if self.expr else None,
self.contains_yield
)
@ -1011,7 +1022,7 @@ class HyASTCompiler(object):
@checkargs(max=1)
def compile_yield_expression(self, expr):
expr.pop(0)
ret = Result()
ret = Result(contains_yield=True)
value = None
if expr != []:
@ -1540,6 +1551,8 @@ class HyASTCompiler(object):
body=body.stmts,
orelse=orel.stmts)
ret.contains_yield = body.contains_yield
return ret
@builds("while")
@ -1557,6 +1570,8 @@ class HyASTCompiler(object):
lineno=expr.start_line,
col_offset=expr.start_column)
ret.contains_yield = body.contains_yield
return ret
@builds(HyList)
@ -1600,9 +1615,12 @@ class HyASTCompiler(object):
return ret
if body.expr:
body += ast.Return(value=body.expr,
lineno=body.expr.lineno,
col_offset=body.expr.col_offset)
if body.contains_yield:
body += body.expr_as_stmt()
else:
body += ast.Return(value=body.expr,
lineno=body.expr.lineno,
col_offset=body.expr.col_offset)
if not body.stmts:
body += ast.Pass(lineno=expression.start_line,

View File

@ -46,6 +46,36 @@
"NATIVE: test macro calling a plain function"
(assert (= 3 (bar 1 2))))
(defn test-midtree-yield []
"NATIVE: test yielding with a returnable"
(defn kruft [] (yield) (+ 1 1)))
(defn test-midtree-yield-in-for []
"NATIVE: test yielding in a for with a return"
(defn kruft-in-for []
(for [i (range 5)]
(yield i))
(+ 1 2)))
(defn test-midtree-yield-in-while []
"NATIVE: test yielding in a while with a return"
(defn kruft-in-while []
(setv i 0)
(while (< i 5)
(yield i)
(setv i (+ i 1)))
(+ 2 3)))
(defn test-multi-yield []
"NATIVE: testing multiple yields"
(defn multi-yield []
(for [i (range 3)]
(yield i))
(yield "a")
(yield "end"))
(assert (= (list (multi-yield)) [0 1 2 "a" "end"])))
; Macro that checks a variable defined at compile or load time
(setv phase "load")
(eval-when-compile