diff --git a/hy/compiler.py b/hy/compiler.py index ef5268b..e7119b4 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1132,20 +1132,20 @@ class HyASTCompiler(object): return node(expr, names=names) @builds("yield") - @builds("yield_from", iff=PY3) @checkargs(max=1) def compile_yield_expression(self, expr): ret = Result(contains_yield=(not PY3)) if len(expr) > 1: ret += self.compile(expr[1]) - node = asty.Yield if expr[0] == "yield" else asty.YieldFrom - return ret + node(expr, value=ret.force_expr) + return ret + asty.Yield(expr, value=ret.force_expr) + @builds("yield_from", iff=PY3) @builds("await", iff=PY35) @checkargs(1) - def compile_await_expression(self, expr): + def compile_yield_from_or_await_expression(self, expr): ret = Result() + self.compile(expr[1]) - return ret + asty.Await(expr, value=ret.force_expr) + node = asty.YieldFrom if expr[0] == "yield_from" else asty.Await + return ret + node(expr, value=ret.force_expr) @builds("import") def compile_import_expression(self, expr): diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index a82cacc..e85f6be 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -14,6 +14,7 @@ from hy.lex.exceptions import LexException from hy._compat import PY3 import ast +import pytest def _ast_spotcheck(arg, root, secondary): @@ -651,3 +652,15 @@ def test_compiler_macro_tag_try(): # https://github.com/hylang/hy/issues/1350 can_compile("(defmacro foo [] (try None (except [] None)) `())") can_compile("(deftag foo [] (try None (except [] None)) `())") + + +@pytest.mark.skipif(not PY3, reason="Python 3 required") +def test_ast_good_yield_from(): + "Make sure AST can compile valid yield-from" + can_compile("(yield-from [1 2])") + + +@pytest.mark.skipif(not PY3, reason="Python 3 required") +def test_ast_bad_yield_from(): + "Make sure AST can't compile invalid yield-from" + cant_compile("(yield-from)")