Implement lfor, sfor, gfor, dfor

This commit is contained in:
Kodi Arfer 2018-06-12 10:43:17 -07:00
parent 7a40561db8
commit ba1dc55e96
2 changed files with 217 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet, HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
HyDict, HySequence, wrap_value) HyDict, HySequence, wrap_value)
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr, from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
dolike, pexpr, times) dolike, pexpr, times, Tag, tag)
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
from hy.errors import HyCompileError, HyTypeError from hy.errors import HyCompileError, HyTypeError
@ -1055,6 +1055,116 @@ class HyASTCompiler(object):
value=value.force_expr, value=value.force_expr,
generators=gen) generators=gen)
_loopers = many(
tag('setv', sym(":setv") + FORM + FORM) |
tag('if', sym(":if") + FORM) |
tag('do', sym(":do") + FORM) |
tag('afor', sym(":async") + FORM + FORM) |
tag('for', FORM + FORM))
@special(["lfor", "sfor", "gfor"], [_loopers, FORM])
@special(["dfor"], [_loopers, brackets(FORM, FORM)])
def compile_new_comp(self, expr, root, parts, final):
node_class = dict(
lfor=asty.ListComp,
dfor=asty.DictComp,
sfor=asty.SetComp,
gfor=asty.GeneratorExp)[str(root)]
# Compile the final value (and for dictionary comprehensions, the final
# key).
if node_class is asty.DictComp:
key, elt = map(self.compile, final)
else:
key = None
elt = self.compile(final)
# Compile the parts.
parts = [
Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [
self._storeize(p.value[0], self.compile(p.value[0])),
self.compile(p.value[1])])
for p in parts]
# Produce a result.
if (elt.stmts or (key is not None and key.stmts) or
any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts)
for p in parts)):
# The desired comprehension can't be expressed as a
# real Python comprehension. We'll write it as a nested
# loop in a function instead.
def f(parts):
# This function is called recursively to construct
# the nested loop.
if not parts:
if node_class is asty.DictComp:
ret = key + elt
val = asty.Tuple(
key, ctx=ast.Load(),
elts=[key.force_expr, elt.force_expr])
else:
ret = elt
val = elt.force_expr
return ret + asty.Expr(
elt, value=asty.Yield(elt, value=val))
(tagname, v), parts = parts[0], parts[1:]
if tagname in ("for", "afor"):
node = asty.AsyncFor if tagname == "afor" else asty.For
return v[1] + node(
v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts,
orelse=[])
elif tagname == "setv":
return v[1] + asty.Assign(
v[1], targets=[v[0]], value=v[1].force_expr) + f(parts)
elif tagname == "if":
return v + asty.If(
v, test=v.force_expr, body=f(parts).stmts, orelse=[])
elif tagname == "do":
return v + v.expr_as_stmt() + f(parts)
else:
raise ValueError("can't happen")
fname = self.get_anon_var()
# Define the generator function.
ret = Result() + asty.FunctionDef(
expr,
name=fname,
args=ast.arguments(
args=[], vararg=None, kwarg=None,
kwonlyargs=[], kw_defaults=[], defaults=[]),
body=f(parts).stmts,
decorator_list=[])
# Immediately call the new function. Unless the user asked
# for a generator, wrap the call in `[].__class__(...)` or
# `{}.__class__(...)` or `{1}.__class__(...)` to get the
# right type. We don't want to just use e.g. `list(...)`
# because the name `list` might be rebound.
return ret + Result(expr=ast.parse(
"{}({}())".format(
{asty.ListComp: "[].__class__",
asty.DictComp: "{}.__class__",
asty.SetComp: "{1}.__class__",
asty.GeneratorExp: ""}[node_class],
fname)).body[0].value)
# We can produce a real comprehension.
generators = []
for tagname, v in parts:
if tagname in ("for", "afor"):
generators.append(ast.comprehension(
target=v[0], iter=v[1].expr, ifs=[],
is_async=int(tagname == "afor")))
elif tagname == "setv":
generators.append(ast.comprehension(
target=v[0],
iter=asty.Tuple(v[1], elts=[v[1].expr], ctx=ast.Load()),
ifs=[], is_async=0))
elif tagname == "if":
generators[-1].ifs.append(v.expr)
else:
raise ValueError("can't happen")
if node_class is asty.DictComp:
return asty.DictComp(expr, key=key.expr, value=elt.expr, generators=generators)
return node_class(expr, elt=elt.expr, generators=generators)
@special(["not", "~"], [FORM]) @special(["not", "~"], [FORM])
def compile_unary_operator(self, expr, root, arg): def compile_unary_operator(self, expr, root, arg):
ops = {"not": ast.Not, ops = {"not": ast.Not,

View File

@ -0,0 +1,106 @@
(import
types
pytest)
(defn test-comprehension-types []
; Forms that get compiled to real comprehensions
(assert (is (type (lfor x "abc" x)) list))
(assert (is (type (sfor x "abc" x)) set))
(assert (is (type (dfor x "abc" [x x])) dict))
(assert (is (type (gfor x "abc" x)) types.GeneratorType))
; Forms that get compiled to loops
(assert (is (type (lfor x "abc" :do (setv y 1) x)) list))
(assert (is (type (sfor x "abc" :do (setv y 1) x)) set))
(assert (is (type (dfor x "abc" :do (setv y 1) [x x])) dict))
(assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType)))
#@ ((pytest.mark.parametrize "specialop" ["lfor" "sfor" "gfor" "dfor"])
(defn test-comprehensions [specialop]
(setv cases [
['(f x [] x)
[]]
['(f j [1 2 3] j)
[1 2 3]]
['(f x (range 3) (* x 2))
[0 2 4]]
['(f x (range 2) y (range 2) (, x y))
[(, 0 0) (, 0 1) (, 1 0) (, 1 1)]]
['(f (, x y) (.items {"1" 1 "2" 2}) (* y 2))
[2 4]]
['(f x (do (setv s "x") "ab") y (do (+= s "y") "def") (+ x y s))
["adxy" "aexy" "afxy" "bdxyy" "bexyy" "bfxyy"]]
['(f x (range 4) :if (% x 2) (* x 2))
[2 6]]
['(f x "abc" :setv y (.upper x) (+ x y))
["aA" "bB" "cC"]]
['(f x "abc" :do (setv y (.upper x)) (+ x y))
["aA" "bB" "cC"]]
['(f
x (range 3)
y (range 3)
:if (> y x)
z [7 8 9]
:setv s (+ x y z)
:if (!= z 8)
(, x y z s))
[(, 0 1 7 8) (, 0 1 9 10) (, 0 2 7 9) (, 0 2 9 11)
(, 1 2 7 10) (, 1 2 9 12)]]
['(f
x [0 1]
:setv l []
y (range 4)
:do (.append l (, x y))
:if (>= y 2)
z [7 8 9]
:if (!= z 8)
(, x y (tuple l) z))
[(, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 7)
(, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 9)
(, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 7)
(, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 9)
(, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 7)
(, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 9)
(, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 7)
(, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 9)]]
['(f x (range 4) :do (unless (% x 2) (continue)) (* x 2))
[2 6]]
['(f x (range 4) :setv p 9 :do (unless (% x 2) (continue)) (* x 2))
[2 6]]
['(f x (range 20) :do (when (= x 3) (break)) (* x 2))
[0 2 4]]
['(f x (range 20) :setv p 9 :do (when (= x 3) (break)) (* x 2))
[0 2 4]]
['(f x [4 5] y (range 20) :do (when (> y 1) (break)) z [8 9] (, x y z))
[(, 4 0 8) (, 4 0 9) (, 4 1 8) (, 4 1 9)
(, 5 0 8) (, 5 0 9) (, 5 1 8) (, 5 1 9)]]])
(for [[expr answer] cases]
; Mutate the case as appropriate for the operator before
; evaluating it.
(setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1)))
(when (= specialop "dfor")
(setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1]))))
(setv result (eval expr))
(when (= specialop "dfor")
(setv result (.keys result)))
(assert (= (sorted result) answer) (str expr)))))
(defn test-raise-in-comp []
(defclass E [Exception] [])
(setv l [])
(import pytest)
(with [(pytest.raises E)]
(lfor
x (range 10)
:do (.append l x)
:do (when (= x 5)
(raise (E)))
x))
(assert (= l [0 1 2 3 4 5])))