diff --git a/hy/compiler.py b/hy/compiler.py index 82be63b..fb47f4b 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1061,24 +1061,39 @@ class HyASTCompiler(object): tag('do', sym(":do") + FORM) | tag('afor', sym(":async") + FORM + FORM) | tag('for', FORM + FORM)) + @special(["for"], [brackets(_loopers), + many(notpexpr("else")) + maybe(dolike("else"))]) @special(["lfor", "sfor", "gfor"], [_loopers, FORM]) @special(["dfor"], [_loopers, brackets(FORM, FORM)]) def compile_new_comp(self, expr, root, parts, final): - node_class = dict( - lfor=asty.ListComp, - dfor=asty.DictComp, - sfor=asty.SetComp, - gfor=asty.GeneratorExp)[str(root)] + root = unmangle(ast_str(root)) + node_class = { + "for": asty.For, + "lfor": asty.ListComp, + "dfor": asty.DictComp, + "sfor": asty.SetComp, + "gfor": asty.GeneratorExp}[root] + is_for = root == "for" - # Compile the final value (and for dictionary comprehensions, the final - # key). - if node_class is asty.DictComp: - key, elt = map(self.compile, final) + orel = [] + if is_for: + # Get the `else`. + body, else_expr = final + if else_expr is not None: + orel.append(self._compile_branch(else_expr)) + orel[0] += orel[0].expr_as_stmt() else: - key = None - elt = self.compile(final) + # Get the final value (and for dictionary + # comprehensions, the final key). + if node_class is asty.DictComp: + key, elt = map(self.compile, final) + else: + key = None + elt = self.compile(final) # Compile the parts. + if is_for: + parts = parts[0] parts = [ Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [ self._storeize(p.value[0], self.compile(p.value[0])), @@ -1086,16 +1101,24 @@ class HyASTCompiler(object): for p in parts] # Produce a result. - if (elt.stmts or (key is not None and key.stmts) or + if (is_for or elt.stmts or (key is not None and key.stmts) or any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts) for p in parts)): # The desired comprehension can't be expressed as a # real Python comprehension. We'll write it as a nested # loop in a function instead. + contains_yield = [] def f(parts): # This function is called recursively to construct # the nested loop. if not parts: + if is_for: + if body: + bd = self._compile_branch(body) + if bd.contains_yield: + contains_yield.append(True) + return bd + bd.expr_as_stmt() + return Result(stmts=[asty.Pass(expr)]) if node_class is asty.DictComp: ret = key + elt val = asty.Tuple( @@ -1108,10 +1131,11 @@ class HyASTCompiler(object): elt, value=asty.Yield(elt, value=val)) (tagname, v), parts = parts[0], parts[1:] if tagname in ("for", "afor"): + orelse = orel and orel.pop().stmts node = asty.AsyncFor if tagname == "afor" else asty.For return v[1] + node( v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts, - orelse=[]) + orelse=orelse) elif tagname == "setv": return v[1] + asty.Assign( v[1], targets=[v[0]], value=v[1].force_expr) + f(parts) @@ -1122,6 +1146,10 @@ class HyASTCompiler(object): return v + v.expr_as_stmt() + f(parts) else: raise ValueError("can't happen") + if is_for: + ret = f(parts) + ret.contains_yield = bool(contains_yield) + return ret fname = self.get_anon_var() # Define the generator function. ret = Result() + asty.FunctionDef( diff --git a/hy/core/macros.hy b/hy/core/macros.hy index a948fd8..b87f110 100644 --- a/hy/core/macros.hy +++ b/hy/core/macros.hy @@ -122,14 +122,6 @@ used as the result." `(~node [(, ~@alist) (genexpr (, ~@alist) [~@args])] (do ~@body) ~@belse)))) -(defmacro for [args &rest body] - "Build a for-loop with `args` as a [element coll] bracket pair and run `body`. - -Args may contain multiple pairs, in which case it executes a nested for-loop -in order of the given pairs." - (_for 'for* args body)) - - (defmacro for/a [args &rest body] "Build a for/a-loop with `args` as a [element coll] bracket pair and run `body`. @@ -163,6 +155,7 @@ the second form, the second result is inserted into the third form, and so on." ~@(map build-form expressions) ~f)) + (defmacro ->> [head &rest args] "Thread `head` last through the `rest` of the forms. diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index d232649..ad41b01 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -531,9 +531,7 @@ def test_for_compile_error(): can_compile("(fn [] (for)))") assert excinfo.value.message == "Ran into a RPAREN where it wasn't expected." - with pytest.raises(HyTypeError) as excinfo: - can_compile("(fn [] (for [x] x))") - assert excinfo.value.message == "`for' requires an even number of args." + cant_compile("(fn [] (for [x] x))") def test_attribute_access(): diff --git a/tests/native_tests/comprehensions.hy b/tests/native_tests/comprehensions.hy index 9c1a264..b450e2f 100644 --- a/tests/native_tests/comprehensions.hy +++ b/tests/native_tests/comprehensions.hy @@ -18,8 +18,8 @@ (assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType))) -#@ ((pytest.mark.parametrize "specialop" ["lfor" "sfor" "gfor" "dfor"]) -(defn test-comprehensions [specialop] +#@ ((pytest.mark.parametrize "specialop" ["for" "lfor" "sfor" "gfor" "dfor"]) +(defn test-fors [specialop] (setv cases [ ['(f x [] x) @@ -86,6 +86,12 @@ (setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1))) (when (= specialop "dfor") (setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1])))) + (when (= specialop "for") + (setv expr `(do + (setv out []) + (for [~@(cut expr 1 -1)] + (.append out ~(get expr -1))) + out))) (setv result (eval expr)) (when (= specialop "dfor") (setv result (.keys result))) @@ -104,3 +110,80 @@ (raise (E))) x)) (assert (= l [0 1 2 3 4 5]))) + + +(defn test-for-loop [] + "NATIVE: test for loops" + (setv count1 0 count2 0) + (for [x [1 2 3 4 5]] + (setv count1 (+ count1 x)) + (setv count2 (+ count2 x))) + (assert (= count1 15)) + (assert (= count2 15)) + (setv count 0) + (for [x [1 2 3 4 5] + y [1 2 3 4 5]] + (setv count (+ count x y)) + (else + (+= count 1))) + (assert (= count 151)) + + (setv count 0) + ; multiple statements in the else branch should work + (for [x [1 2 3 4 5] + y [1 2 3 4 5]] + (setv count (+ count x y)) + (else + (+= count 1) + (+= count 10))) + (assert (= count 161)) + + ; don't be fooled by constructs that look like else + (setv s "") + (setv else True) + (for [x "abcde"] + (+= s x) + [else (+= s "_")]) + (assert (= s "a_b_c_d_e_")) + + (setv s "") + (with [(pytest.raises TypeError)] + (for [x "abcde"] + (+= s x) + ("else" (+= s "z")))) + (assert (= s "az")) + + (assert (= (list ((fn [] (for [x [[1] [2 3]] y x] (yield y))))) + (list-comp y [x [[1] [2 3]] y x]))) + (assert (= (list ((fn [] (for [x [[1] [2 3]] y x z (range 5)] (yield z))))) + (list-comp z [x [[1] [2 3]] y x z (range 5)])))) + + +(defn test-nasty-for-nesting [] + "NATIVE: test nesting for loops harder" + ;; This test and feature is dedicated to @nedbat. + + ;; OK. This next test will ensure that we call the else branch exactly + ;; once. + (setv flag 0) + (for [x (range 2) + y (range 2)] + (+ 1 1) + (else (setv flag (+ flag 2)))) + (assert (= flag 2))) + + +(defn test-empty-for [] + + (setv l []) + (defn f [] + (for [x (range 3)] + (.append l "a") + (yield x))) + (for [x (f)]) + (assert (= l ["a" "a" "a"])) + + (setv l []) + (for [x (f)] + (else (.append l "z"))) + (assert (= l ["a" "a" "a" "z"]))) diff --git a/tests/native_tests/contrib/walk.hy b/tests/native_tests/contrib/walk.hy index e2d7a2a..7facd59 100644 --- a/tests/native_tests/contrib/walk.hy +++ b/tests/native_tests/contrib/walk.hy @@ -46,7 +46,7 @@ (assert (= (macroexpand-all '(with [a 1])) '(with* [a 1] (do)))) (assert (= (macroexpand-all '(with [a 1 b 2 c 3] (for [d c] foo))) - '(with* [a 1] (with* [b 2] (with* [c 3] (do (for* [d c] (do foo)))))))) + '(with* [a 1] (with* [b 2] (with* [c 3] (do (for [d c] foo))))))) (assert (= (macroexpand-all '(with [a 1] '(with [b 2]) `(with [c 3] diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index 1aa6bef..8dd978b 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -176,110 +176,6 @@ (with [(pytest.raises TypeError)] ("when" 1 2))) ; A macro -(defn test-for-loop [] - "NATIVE: test for loops" - (setv count1 0 count2 0) - (for [x [1 2 3 4 5]] - (setv count1 (+ count1 x)) - (setv count2 (+ count2 x))) - (assert (= count1 15)) - (assert (= count2 15)) - (setv count 0) - (for [x [1 2 3 4 5] - y [1 2 3 4 5]] - (setv count (+ count x y)) - (else - (+= count 1))) - (assert (= count 151)) - - (setv count 0) - ; multiple statements in the else branch should work - (for [x [1 2 3 4 5] - y [1 2 3 4 5]] - (setv count (+ count x y)) - (else - (+= count 1) - (+= count 10))) - (assert (= count 161)) - - ; don't be fooled by constructs that look like else - (setv s "") - (setv else True) - (for [x "abcde"] - (+= s x) - [else (+= s "_")]) - (assert (= s "a_b_c_d_e_")) - - (setv s "") - (setv else True) - (with [(pytest.raises TypeError)] - (for [x "abcde"] - (+= s x) - ("else" (+= s "z")))) - (assert (= s "az")) - - (assert (= (list ((fn [] (for [x [[1] [2 3]] y x] (yield y))))) - (list-comp y [x [[1] [2 3]] y x]))) - (assert (= (list ((fn [] (for [x [[1] [2 3]] y x z (range 5)] (yield z))))) - (list-comp z [x [[1] [2 3]] y x z (range 5)]))) - - (setv l []) - (defn f [] - (for [x [4 9 2]] - (.append l (* 10 x)) - (yield x))) - (for [_ (f)]) - (assert (= l [40 90 20]))) - - -(defn test-nasty-for-nesting [] - "NATIVE: test nesting for loops harder" - ;; This test and feature is dedicated to @nedbat. - - ;; let's ensure empty iterating is an implicit do - (setv t 0) - (for [] (setv t 1)) - (assert (= t 1)) - - ;; OK. This first test will ensure that the else is hooked up to the - ;; for when we break out of it. - (for [x (range 2) - y (range 2)] - (break) - (else (raise Exception))) - - ;; OK. This next test will ensure that the else is hooked up to the - ;; "inner" iteration - (for [x (range 2) - y (range 2)] - (if (= y 1) (break)) - (else (raise Exception))) - - ;; OK. This next test will ensure that the else is hooked up to the - ;; "outer" iteration - (for [x (range 2) - y (range 2)] - (if (= x 1) (break)) - (else (raise Exception))) - - ;; OK. This next test will ensure that we call the else branch exactly - ;; once. - (setv flag 0) - (for [x (range 2) - y (range 2)] - (+ 1 1) - (else (setv flag (+ flag 2)))) - (assert (= flag 2)) - - (setv l []) - (defn f [] - (for [x [4 9 2]] - (.append l (* 10 x)) - (yield x))) - (for [_ (f)]) - (assert (= l [40 90 20]))) - - (defn test-while-loop [] "NATIVE: test while loops?" (setv count 5)