diff --git a/hy/compiler.py b/hy/compiler.py index 348f80f..82be63b 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -7,7 +7,7 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex, HyString, HyBytes, HySymbol, HyFloat, HyList, HySet, HyDict, HySequence, wrap_value) from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr, - dolike, pexpr, times) + dolike, pexpr, times, Tag, tag) from funcparserlib.parser import some, many, oneplus, maybe, NoParseError from hy.errors import HyCompileError, HyTypeError @@ -1055,6 +1055,116 @@ class HyASTCompiler(object): value=value.force_expr, generators=gen) + _loopers = many( + tag('setv', sym(":setv") + FORM + FORM) | + tag('if', sym(":if") + FORM) | + tag('do', sym(":do") + FORM) | + tag('afor', sym(":async") + FORM + FORM) | + tag('for', FORM + FORM)) + @special(["lfor", "sfor", "gfor"], [_loopers, FORM]) + @special(["dfor"], [_loopers, brackets(FORM, FORM)]) + def compile_new_comp(self, expr, root, parts, final): + node_class = dict( + lfor=asty.ListComp, + dfor=asty.DictComp, + sfor=asty.SetComp, + gfor=asty.GeneratorExp)[str(root)] + + # Compile the final value (and for dictionary comprehensions, the final + # key). + if node_class is asty.DictComp: + key, elt = map(self.compile, final) + else: + key = None + elt = self.compile(final) + + # Compile the parts. + parts = [ + Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [ + self._storeize(p.value[0], self.compile(p.value[0])), + self.compile(p.value[1])]) + for p in parts] + + # Produce a result. + if (elt.stmts or (key is not None and key.stmts) or + any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts) + for p in parts)): + # The desired comprehension can't be expressed as a + # real Python comprehension. We'll write it as a nested + # loop in a function instead. + def f(parts): + # This function is called recursively to construct + # the nested loop. + if not parts: + if node_class is asty.DictComp: + ret = key + elt + val = asty.Tuple( + key, ctx=ast.Load(), + elts=[key.force_expr, elt.force_expr]) + else: + ret = elt + val = elt.force_expr + return ret + asty.Expr( + elt, value=asty.Yield(elt, value=val)) + (tagname, v), parts = parts[0], parts[1:] + if tagname in ("for", "afor"): + node = asty.AsyncFor if tagname == "afor" else asty.For + return v[1] + node( + v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts, + orelse=[]) + elif tagname == "setv": + return v[1] + asty.Assign( + v[1], targets=[v[0]], value=v[1].force_expr) + f(parts) + elif tagname == "if": + return v + asty.If( + v, test=v.force_expr, body=f(parts).stmts, orelse=[]) + elif tagname == "do": + return v + v.expr_as_stmt() + f(parts) + else: + raise ValueError("can't happen") + fname = self.get_anon_var() + # Define the generator function. + ret = Result() + asty.FunctionDef( + expr, + name=fname, + args=ast.arguments( + args=[], vararg=None, kwarg=None, + kwonlyargs=[], kw_defaults=[], defaults=[]), + body=f(parts).stmts, + decorator_list=[]) + # Immediately call the new function. Unless the user asked + # for a generator, wrap the call in `[].__class__(...)` or + # `{}.__class__(...)` or `{1}.__class__(...)` to get the + # right type. We don't want to just use e.g. `list(...)` + # because the name `list` might be rebound. + return ret + Result(expr=ast.parse( + "{}({}())".format( + {asty.ListComp: "[].__class__", + asty.DictComp: "{}.__class__", + asty.SetComp: "{1}.__class__", + asty.GeneratorExp: ""}[node_class], + fname)).body[0].value) + + # We can produce a real comprehension. + generators = [] + for tagname, v in parts: + if tagname in ("for", "afor"): + generators.append(ast.comprehension( + target=v[0], iter=v[1].expr, ifs=[], + is_async=int(tagname == "afor"))) + elif tagname == "setv": + generators.append(ast.comprehension( + target=v[0], + iter=asty.Tuple(v[1], elts=[v[1].expr], ctx=ast.Load()), + ifs=[], is_async=0)) + elif tagname == "if": + generators[-1].ifs.append(v.expr) + else: + raise ValueError("can't happen") + if node_class is asty.DictComp: + return asty.DictComp(expr, key=key.expr, value=elt.expr, generators=generators) + return node_class(expr, elt=elt.expr, generators=generators) + @special(["not", "~"], [FORM]) def compile_unary_operator(self, expr, root, arg): ops = {"not": ast.Not, diff --git a/tests/native_tests/comprehensions.hy b/tests/native_tests/comprehensions.hy new file mode 100644 index 0000000..9c1a264 --- /dev/null +++ b/tests/native_tests/comprehensions.hy @@ -0,0 +1,106 @@ +(import + types + pytest) + + +(defn test-comprehension-types [] + + ; Forms that get compiled to real comprehensions + (assert (is (type (lfor x "abc" x)) list)) + (assert (is (type (sfor x "abc" x)) set)) + (assert (is (type (dfor x "abc" [x x])) dict)) + (assert (is (type (gfor x "abc" x)) types.GeneratorType)) + + ; Forms that get compiled to loops + (assert (is (type (lfor x "abc" :do (setv y 1) x)) list)) + (assert (is (type (sfor x "abc" :do (setv y 1) x)) set)) + (assert (is (type (dfor x "abc" :do (setv y 1) [x x])) dict)) + (assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType))) + + +#@ ((pytest.mark.parametrize "specialop" ["lfor" "sfor" "gfor" "dfor"]) +(defn test-comprehensions [specialop] + + (setv cases [ + ['(f x [] x) + []] + ['(f j [1 2 3] j) + [1 2 3]] + ['(f x (range 3) (* x 2)) + [0 2 4]] + ['(f x (range 2) y (range 2) (, x y)) + [(, 0 0) (, 0 1) (, 1 0) (, 1 1)]] + ['(f (, x y) (.items {"1" 1 "2" 2}) (* y 2)) + [2 4]] + ['(f x (do (setv s "x") "ab") y (do (+= s "y") "def") (+ x y s)) + ["adxy" "aexy" "afxy" "bdxyy" "bexyy" "bfxyy"]] + ['(f x (range 4) :if (% x 2) (* x 2)) + [2 6]] + ['(f x "abc" :setv y (.upper x) (+ x y)) + ["aA" "bB" "cC"]] + ['(f x "abc" :do (setv y (.upper x)) (+ x y)) + ["aA" "bB" "cC"]] + ['(f + x (range 3) + y (range 3) + :if (> y x) + z [7 8 9] + :setv s (+ x y z) + :if (!= z 8) + (, x y z s)) + [(, 0 1 7 8) (, 0 1 9 10) (, 0 2 7 9) (, 0 2 9 11) + (, 1 2 7 10) (, 1 2 9 12)]] + ['(f + x [0 1] + :setv l [] + y (range 4) + :do (.append l (, x y)) + :if (>= y 2) + z [7 8 9] + :if (!= z 8) + (, x y (tuple l) z)) + [(, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 7) + (, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 9) + (, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 7) + (, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 9) + (, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 7) + (, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 9) + (, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 7) + (, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 9)]] + + ['(f x (range 4) :do (unless (% x 2) (continue)) (* x 2)) + [2 6]] + ['(f x (range 4) :setv p 9 :do (unless (% x 2) (continue)) (* x 2)) + [2 6]] + ['(f x (range 20) :do (when (= x 3) (break)) (* x 2)) + [0 2 4]] + ['(f x (range 20) :setv p 9 :do (when (= x 3) (break)) (* x 2)) + [0 2 4]] + ['(f x [4 5] y (range 20) :do (when (> y 1) (break)) z [8 9] (, x y z)) + [(, 4 0 8) (, 4 0 9) (, 4 1 8) (, 4 1 9) + (, 5 0 8) (, 5 0 9) (, 5 1 8) (, 5 1 9)]]]) + + (for [[expr answer] cases] + ; Mutate the case as appropriate for the operator before + ; evaluating it. + (setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1))) + (when (= specialop "dfor") + (setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1])))) + (setv result (eval expr)) + (when (= specialop "dfor") + (setv result (.keys result))) + (assert (= (sorted result) answer) (str expr))))) + + +(defn test-raise-in-comp [] + (defclass E [Exception] []) + (setv l []) + (import pytest) + (with [(pytest.raises E)] + (lfor + x (range 10) + :do (.append l x) + :do (when (= x 5) + (raise (E))) + x)) + (assert (= l [0 1 2 3 4 5])))