Implement lfor, sfor, gfor, dfor

This commit is contained in:
Kodi Arfer 2018-06-12 10:43:17 -07:00
parent 7a40561db8
commit ba1dc55e96
2 changed files with 217 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
HyDict, HySequence, wrap_value)
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
dolike, pexpr, times)
dolike, pexpr, times, Tag, tag)
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
from hy.errors import HyCompileError, HyTypeError
@ -1055,6 +1055,116 @@ class HyASTCompiler(object):
value=value.force_expr,
generators=gen)
_loopers = many(
tag('setv', sym(":setv") + FORM + FORM) |
tag('if', sym(":if") + FORM) |
tag('do', sym(":do") + FORM) |
tag('afor', sym(":async") + FORM + FORM) |
tag('for', FORM + FORM))
@special(["lfor", "sfor", "gfor"], [_loopers, FORM])
@special(["dfor"], [_loopers, brackets(FORM, FORM)])
def compile_new_comp(self, expr, root, parts, final):
node_class = dict(
lfor=asty.ListComp,
dfor=asty.DictComp,
sfor=asty.SetComp,
gfor=asty.GeneratorExp)[str(root)]
# Compile the final value (and for dictionary comprehensions, the final
# key).
if node_class is asty.DictComp:
key, elt = map(self.compile, final)
else:
key = None
elt = self.compile(final)
# Compile the parts.
parts = [
Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [
self._storeize(p.value[0], self.compile(p.value[0])),
self.compile(p.value[1])])
for p in parts]
# Produce a result.
if (elt.stmts or (key is not None and key.stmts) or
any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts)
for p in parts)):
# The desired comprehension can't be expressed as a
# real Python comprehension. We'll write it as a nested
# loop in a function instead.
def f(parts):
# This function is called recursively to construct
# the nested loop.
if not parts:
if node_class is asty.DictComp:
ret = key + elt
val = asty.Tuple(
key, ctx=ast.Load(),
elts=[key.force_expr, elt.force_expr])
else:
ret = elt
val = elt.force_expr
return ret + asty.Expr(
elt, value=asty.Yield(elt, value=val))
(tagname, v), parts = parts[0], parts[1:]
if tagname in ("for", "afor"):
node = asty.AsyncFor if tagname == "afor" else asty.For
return v[1] + node(
v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts,
orelse=[])
elif tagname == "setv":
return v[1] + asty.Assign(
v[1], targets=[v[0]], value=v[1].force_expr) + f(parts)
elif tagname == "if":
return v + asty.If(
v, test=v.force_expr, body=f(parts).stmts, orelse=[])
elif tagname == "do":
return v + v.expr_as_stmt() + f(parts)
else:
raise ValueError("can't happen")
fname = self.get_anon_var()
# Define the generator function.
ret = Result() + asty.FunctionDef(
expr,
name=fname,
args=ast.arguments(
args=[], vararg=None, kwarg=None,
kwonlyargs=[], kw_defaults=[], defaults=[]),
body=f(parts).stmts,
decorator_list=[])
# Immediately call the new function. Unless the user asked
# for a generator, wrap the call in `[].__class__(...)` or
# `{}.__class__(...)` or `{1}.__class__(...)` to get the
# right type. We don't want to just use e.g. `list(...)`
# because the name `list` might be rebound.
return ret + Result(expr=ast.parse(
"{}({}())".format(
{asty.ListComp: "[].__class__",
asty.DictComp: "{}.__class__",
asty.SetComp: "{1}.__class__",
asty.GeneratorExp: ""}[node_class],
fname)).body[0].value)
# We can produce a real comprehension.
generators = []
for tagname, v in parts:
if tagname in ("for", "afor"):
generators.append(ast.comprehension(
target=v[0], iter=v[1].expr, ifs=[],
is_async=int(tagname == "afor")))
elif tagname == "setv":
generators.append(ast.comprehension(
target=v[0],
iter=asty.Tuple(v[1], elts=[v[1].expr], ctx=ast.Load()),
ifs=[], is_async=0))
elif tagname == "if":
generators[-1].ifs.append(v.expr)
else:
raise ValueError("can't happen")
if node_class is asty.DictComp:
return asty.DictComp(expr, key=key.expr, value=elt.expr, generators=generators)
return node_class(expr, elt=elt.expr, generators=generators)
@special(["not", "~"], [FORM])
def compile_unary_operator(self, expr, root, arg):
ops = {"not": ast.Not,

View File

@ -0,0 +1,106 @@
(import
types
pytest)
(defn test-comprehension-types []
; Forms that get compiled to real comprehensions
(assert (is (type (lfor x "abc" x)) list))
(assert (is (type (sfor x "abc" x)) set))
(assert (is (type (dfor x "abc" [x x])) dict))
(assert (is (type (gfor x "abc" x)) types.GeneratorType))
; Forms that get compiled to loops
(assert (is (type (lfor x "abc" :do (setv y 1) x)) list))
(assert (is (type (sfor x "abc" :do (setv y 1) x)) set))
(assert (is (type (dfor x "abc" :do (setv y 1) [x x])) dict))
(assert (is (type (gfor x "abc" :do (setv y 1) x)) types.GeneratorType)))
#@ ((pytest.mark.parametrize "specialop" ["lfor" "sfor" "gfor" "dfor"])
(defn test-comprehensions [specialop]
(setv cases [
['(f x [] x)
[]]
['(f j [1 2 3] j)
[1 2 3]]
['(f x (range 3) (* x 2))
[0 2 4]]
['(f x (range 2) y (range 2) (, x y))
[(, 0 0) (, 0 1) (, 1 0) (, 1 1)]]
['(f (, x y) (.items {"1" 1 "2" 2}) (* y 2))
[2 4]]
['(f x (do (setv s "x") "ab") y (do (+= s "y") "def") (+ x y s))
["adxy" "aexy" "afxy" "bdxyy" "bexyy" "bfxyy"]]
['(f x (range 4) :if (% x 2) (* x 2))
[2 6]]
['(f x "abc" :setv y (.upper x) (+ x y))
["aA" "bB" "cC"]]
['(f x "abc" :do (setv y (.upper x)) (+ x y))
["aA" "bB" "cC"]]
['(f
x (range 3)
y (range 3)
:if (> y x)
z [7 8 9]
:setv s (+ x y z)
:if (!= z 8)
(, x y z s))
[(, 0 1 7 8) (, 0 1 9 10) (, 0 2 7 9) (, 0 2 9 11)
(, 1 2 7 10) (, 1 2 9 12)]]
['(f
x [0 1]
:setv l []
y (range 4)
:do (.append l (, x y))
:if (>= y 2)
z [7 8 9]
:if (!= z 8)
(, x y (tuple l) z))
[(, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 7)
(, 0 2 (, (, 0 0) (, 0 1) (, 0 2)) 9)
(, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 7)
(, 0 3 (, (, 0 0) (, 0 1) (, 0 2) (, 0 3)) 9)
(, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 7)
(, 1 2 (, (, 1 0) (, 1 1) (, 1 2)) 9)
(, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 7)
(, 1 3 (, (, 1 0) (, 1 1) (, 1 2) (, 1 3)) 9)]]
['(f x (range 4) :do (unless (% x 2) (continue)) (* x 2))
[2 6]]
['(f x (range 4) :setv p 9 :do (unless (% x 2) (continue)) (* x 2))
[2 6]]
['(f x (range 20) :do (when (= x 3) (break)) (* x 2))
[0 2 4]]
['(f x (range 20) :setv p 9 :do (when (= x 3) (break)) (* x 2))
[0 2 4]]
['(f x [4 5] y (range 20) :do (when (> y 1) (break)) z [8 9] (, x y z))
[(, 4 0 8) (, 4 0 9) (, 4 1 8) (, 4 1 9)
(, 5 0 8) (, 5 0 9) (, 5 1 8) (, 5 1 9)]]])
(for [[expr answer] cases]
; Mutate the case as appropriate for the operator before
; evaluating it.
(setv expr (+ (HyExpression [(HySymbol specialop)]) (cut expr 1)))
(when (= specialop "dfor")
(setv expr (+ (cut expr 0 -1) `([~(get expr -1) 1]))))
(setv result (eval expr))
(when (= specialop "dfor")
(setv result (.keys result)))
(assert (= (sorted result) answer) (str expr)))))
(defn test-raise-in-comp []
(defclass E [Exception] [])
(setv l [])
(import pytest)
(with [(pytest.raises E)]
(lfor
x (range 10)
:do (.append l x)
:do (when (= x 5)
(raise (E)))
x))
(assert (= l [0 1 2 3 4 5])))