Add a version of for parallel to lfor etc.

This commit is contained in:
Kodi Arfer 2018-06-12 10:54:08 -07:00
parent ba1dc55e96
commit 3256932b13
6 changed files with 129 additions and 131 deletions

View File

@ -1061,24 +1061,39 @@ class HyASTCompiler(object):
tag('do', sym(":do") + FORM) | tag('do', sym(":do") + FORM) |
tag('afor', sym(":async") + FORM + FORM) | tag('afor', sym(":async") + FORM + FORM) |
tag('for', FORM + FORM)) tag('for', FORM + FORM))
@special(["for"], [brackets(_loopers),
many(notpexpr("else")) + maybe(dolike("else"))])
@special(["lfor", "sfor", "gfor"], [_loopers, FORM]) @special(["lfor", "sfor", "gfor"], [_loopers, FORM])
@special(["dfor"], [_loopers, brackets(FORM, FORM)]) @special(["dfor"], [_loopers, brackets(FORM, FORM)])
def compile_new_comp(self, expr, root, parts, final): def compile_new_comp(self, expr, root, parts, final):
node_class = dict( root = unmangle(ast_str(root))
lfor=asty.ListComp, node_class = {
dfor=asty.DictComp, "for": asty.For,
sfor=asty.SetComp, "lfor": asty.ListComp,
gfor=asty.GeneratorExp)[str(root)] "dfor": asty.DictComp,
"sfor": asty.SetComp,
"gfor": asty.GeneratorExp}[root]
is_for = root == "for"
# Compile the final value (and for dictionary comprehensions, the final orel = []
# key). if is_for:
if node_class is asty.DictComp: # Get the `else`.
key, elt = map(self.compile, final) body, else_expr = final
if else_expr is not None:
orel.append(self._compile_branch(else_expr))
orel[0] += orel[0].expr_as_stmt()
else: else:
key = None # Get the final value (and for dictionary
elt = self.compile(final) # comprehensions, the final key).
if node_class is asty.DictComp:
key, elt = map(self.compile, final)
else:
key = None
elt = self.compile(final)
# Compile the parts. # Compile the parts.
if is_for:
parts = parts[0]
parts = [ parts = [
Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [ Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [
self._storeize(p.value[0], self.compile(p.value[0])), self._storeize(p.value[0], self.compile(p.value[0])),
@ -1086,16 +1101,24 @@ class HyASTCompiler(object):
for p in parts] for p in parts]
# Produce a result. # Produce a result.
if (elt.stmts or (key is not None and key.stmts) or if (is_for or elt.stmts or (key is not None and key.stmts) or
any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts) any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts)
for p in parts)): for p in parts)):
# The desired comprehension can't be expressed as a # The desired comprehension can't be expressed as a
# real Python comprehension. We'll write it as a nested # real Python comprehension. We'll write it as a nested
# loop in a function instead. # loop in a function instead.
contains_yield = []
def f(parts): def f(parts):
# This function is called recursively to construct # This function is called recursively to construct
# the nested loop. # the nested loop.
if not parts: if not parts:
if is_for:
if body:
bd = self._compile_branch(body)
if bd.contains_yield:
contains_yield.append(True)
return bd + bd.expr_as_stmt()
return Result(stmts=[asty.Pass(expr)])
if node_class is asty.DictComp: if node_class is asty.DictComp:
ret = key + elt ret = key + elt
val = asty.Tuple( val = asty.Tuple(
@ -1108,10 +1131,11 @@ class HyASTCompiler(object):
elt, value=asty.Yield(elt, value=val)) elt, value=asty.Yield(elt, value=val))
(tagname, v), parts = parts[0], parts[1:] (tagname, v), parts = parts[0], parts[1:]
if tagname in ("for", "afor"): if tagname in ("for", "afor"):
orelse = orel and orel.pop().stmts
node = asty.AsyncFor if tagname == "afor" else asty.For node = asty.AsyncFor if tagname == "afor" else asty.For
return v[1] + node( return v[1] + node(
v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts, v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts,
orelse=[]) orelse=orelse)
elif tagname == "setv": elif tagname == "setv":
return v[1] + asty.Assign( return v[1] + asty.Assign(
v[1], targets=[v[0]], value=v[1].force_expr) + f(parts) v[1], targets=[v[0]], value=v[1].force_expr) + f(parts)
@ -1122,6 +1146,10 @@ class HyASTCompiler(object):
return v + v.expr_as_stmt() + f(parts) return v + v.expr_as_stmt() + f(parts)
else: else:
raise ValueError("can't happen") raise ValueError("can't happen")
if is_for:
ret = f(parts)
ret.contains_yield = bool(contains_yield)
return ret
fname = self.get_anon_var() fname = self.get_anon_var()
# Define the generator function. # Define the generator function.
ret = Result() + asty.FunctionDef( ret = Result() + asty.FunctionDef(

View File

@ -122,14 +122,6 @@ used as the result."
`(~node [(, ~@alist) (genexpr (, ~@alist) [~@args])] (do ~@body) ~@belse)))) `(~node [(, ~@alist) (genexpr (, ~@alist) [~@args])] (do ~@body) ~@belse))))
(defmacro for [args &rest body]
"Build a for-loop with `args` as a [element coll] bracket pair and run `body`.
Args may contain multiple pairs, in which case it executes a nested for-loop
in order of the given pairs."
(_for 'for* args body))
(defmacro for/a [args &rest body] (defmacro for/a [args &rest body]
"Build a for/a-loop with `args` as a [element coll] bracket pair and run `body`. "Build a for/a-loop with `args` as a [element coll] bracket pair and run `body`.
@ -163,6 +155,7 @@ the second form, the second result is inserted into the third form, and so on."
~@(map build-form expressions) ~@(map build-form expressions)
~f)) ~f))
(defmacro ->> [head &rest args] (defmacro ->> [head &rest args]
"Thread `head` last through the `rest` of the forms. "Thread `head` last through the `rest` of the forms.

View File

@ -531,9 +531,7 @@ def test_for_compile_error():
can_compile("(fn [] (for)))") can_compile("(fn [] (for)))")
assert excinfo.value.message == "Ran into a RPAREN where it wasn't expected." assert excinfo.value.message == "Ran into a RPAREN where it wasn't expected."
with pytest.raises(HyTypeError) as excinfo: cant_compile("(fn [] (for [x] x))")
can_compile("(fn [] (for [x] x))")
assert excinfo.value.message == "`for' requires an even number of args."
def test_attribute_access(): def test_attribute_access():

View File

@ -18,8 +18,8 @@
(assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType))) (assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType)))
#@ ((pytest.mark.parametrize "specialop" ["lfor" "sfor" "gfor" "dfor"]) #@ ((pytest.mark.parametrize "specialop" ["for" "lfor" "sfor" "gfor" "dfor"])
(defn test-comprehensions [specialop] (defn test-fors [specialop]
(setv cases [ (setv cases [
['(f x [] x) ['(f x [] x)
@ -86,6 +86,12 @@
(setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1))) (setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1)))
(when (= specialop "dfor") (when (= specialop "dfor")
(setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1])))) (setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1]))))
(when (= specialop "for")
(setv expr `(do
(setv out [])
(for [~@(cut expr 1 -1)]
(.append out ~(get expr -1)))
out)))
(setv result (eval expr)) (setv result (eval expr))
(when (= specialop "dfor") (when (= specialop "dfor")
(setv result (.keys result))) (setv result (.keys result)))
@ -104,3 +110,80 @@
(raise (E))) (raise (E)))
x)) x))
(assert (= l [0 1 2 3 4 5]))) (assert (= l [0 1 2 3 4 5])))
(defn test-for-loop []
"NATIVE: test for loops"
(setv count1 0 count2 0)
(for [x [1 2 3 4 5]]
(setv count1 (+ count1 x))
(setv count2 (+ count2 x)))
(assert (= count1 15))
(assert (= count2 15))
(setv count 0)
(for [x [1 2 3 4 5]
y [1 2 3 4 5]]
(setv count (+ count x y))
(else
(+= count 1)))
(assert (= count 151))
(setv count 0)
; multiple statements in the else branch should work
(for [x [1 2 3 4 5]
y [1 2 3 4 5]]
(setv count (+ count x y))
(else
(+= count 1)
(+= count 10)))
(assert (= count 161))
; don't be fooled by constructs that look like else
(setv s "")
(setv else True)
(for [x "abcde"]
(+= s x)
[else (+= s "_")])
(assert (= s "a_b_c_d_e_"))
(setv s "")
(with [(pytest.raises TypeError)]
(for [x "abcde"]
(+= s x)
("else" (+= s "z"))))
(assert (= s "az"))
(assert (= (list ((fn [] (for [x [[1] [2 3]] y x] (yield y)))))
(list-comp y [x [[1] [2 3]] y x])))
(assert (= (list ((fn [] (for [x [[1] [2 3]] y x z (range 5)] (yield z)))))
(list-comp z [x [[1] [2 3]] y x z (range 5)]))))
(defn test-nasty-for-nesting []
"NATIVE: test nesting for loops harder"
;; This test and feature is dedicated to @nedbat.
;; OK. This next test will ensure that we call the else branch exactly
;; once.
(setv flag 0)
(for [x (range 2)
y (range 2)]
(+ 1 1)
(else (setv flag (+ flag 2))))
(assert (= flag 2)))
(defn test-empty-for []
(setv l [])
(defn f []
(for [x (range 3)]
(.append l "a")
(yield x)))
(for [x (f)])
(assert (= l ["a" "a" "a"]))
(setv l [])
(for [x (f)]
(else (.append l "z")))
(assert (= l ["a" "a" "a" "z"])))

View File

@ -46,7 +46,7 @@
(assert (= (macroexpand-all '(with [a 1])) (assert (= (macroexpand-all '(with [a 1]))
'(with* [a 1] (do)))) '(with* [a 1] (do))))
(assert (= (macroexpand-all '(with [a 1 b 2 c 3] (for [d c] foo))) (assert (= (macroexpand-all '(with [a 1 b 2 c 3] (for [d c] foo)))
'(with* [a 1] (with* [b 2] (with* [c 3] (do (for* [d c] (do foo)))))))) '(with* [a 1] (with* [b 2] (with* [c 3] (do (for [d c] foo)))))))
(assert (= (macroexpand-all '(with [a 1] (assert (= (macroexpand-all '(with [a 1]
'(with [b 2]) '(with [b 2])
`(with [c 3] `(with [c 3]

View File

@ -176,110 +176,6 @@
(with [(pytest.raises TypeError)] ("when" 1 2))) ; A macro (with [(pytest.raises TypeError)] ("when" 1 2))) ; A macro
(defn test-for-loop []
"NATIVE: test for loops"
(setv count1 0 count2 0)
(for [x [1 2 3 4 5]]
(setv count1 (+ count1 x))
(setv count2 (+ count2 x)))
(assert (= count1 15))
(assert (= count2 15))
(setv count 0)
(for [x [1 2 3 4 5]
y [1 2 3 4 5]]
(setv count (+ count x y))
(else
(+= count 1)))
(assert (= count 151))
(setv count 0)
; multiple statements in the else branch should work
(for [x [1 2 3 4 5]
y [1 2 3 4 5]]
(setv count (+ count x y))
(else
(+= count 1)
(+= count 10)))
(assert (= count 161))
; don't be fooled by constructs that look like else
(setv s "")
(setv else True)
(for [x "abcde"]
(+= s x)
[else (+= s "_")])
(assert (= s "a_b_c_d_e_"))
(setv s "")
(setv else True)
(with [(pytest.raises TypeError)]
(for [x "abcde"]
(+= s x)
("else" (+= s "z"))))
(assert (= s "az"))
(assert (= (list ((fn [] (for [x [[1] [2 3]] y x] (yield y)))))
(list-comp y [x [[1] [2 3]] y x])))
(assert (= (list ((fn [] (for [x [[1] [2 3]] y x z (range 5)] (yield z)))))
(list-comp z [x [[1] [2 3]] y x z (range 5)])))
(setv l [])
(defn f []
(for [x [4 9 2]]
(.append l (* 10 x))
(yield x)))
(for [_ (f)])
(assert (= l [40 90 20])))
(defn test-nasty-for-nesting []
"NATIVE: test nesting for loops harder"
;; This test and feature is dedicated to @nedbat.
;; let's ensure empty iterating is an implicit do
(setv t 0)
(for [] (setv t 1))
(assert (= t 1))
;; OK. This first test will ensure that the else is hooked up to the
;; for when we break out of it.
(for [x (range 2)
y (range 2)]
(break)
(else (raise Exception)))
;; OK. This next test will ensure that the else is hooked up to the
;; "inner" iteration
(for [x (range 2)
y (range 2)]
(if (= y 1) (break))
(else (raise Exception)))
;; OK. This next test will ensure that the else is hooked up to the
;; "outer" iteration
(for [x (range 2)
y (range 2)]
(if (= x 1) (break))
(else (raise Exception)))
;; OK. This next test will ensure that we call the else branch exactly
;; once.
(setv flag 0)
(for [x (range 2)
y (range 2)]
(+ 1 1)
(else (setv flag (+ flag 2))))
(assert (= flag 2))
(setv l [])
(defn f []
(for [x [4 9 2]]
(.append l (* 10 x))
(yield x)))
(for [_ (f)])
(assert (= l [40 90 20])))
(defn test-while-loop [] (defn test-while-loop []
"NATIVE: test while loops?" "NATIVE: test while loops?"
(setv count 5) (setv count 5)