Use model patterns for comparison and math ops
This commit is contained in:
parent
57b5fa49b1
commit
41d3f26001
170
hy/compiler.py
170
hy/compiler.py
@ -7,11 +7,11 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
|
|||||||
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
|
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
|
||||||
HyDict, HySequence, wrap_value)
|
HyDict, HySequence, wrap_value)
|
||||||
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
|
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
|
||||||
dolike, pexpr)
|
dolike, pexpr, times)
|
||||||
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
|
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
|
||||||
from hy.errors import HyCompileError, HyTypeError
|
from hy.errors import HyCompileError, HyTypeError
|
||||||
|
|
||||||
from hy.lex.parser import mangle
|
from hy.lex.parser import mangle, unmangle
|
||||||
|
|
||||||
import hy.macros
|
import hy.macros
|
||||||
from hy._compat import (
|
from hy._compat import (
|
||||||
@ -35,6 +35,7 @@ if PY3:
|
|||||||
else:
|
else:
|
||||||
import __builtin__ as builtins
|
import __builtin__ as builtins
|
||||||
|
|
||||||
|
Inf = float('inf')
|
||||||
|
|
||||||
_compile_time_ns = {}
|
_compile_time_ns = {}
|
||||||
|
|
||||||
@ -670,11 +671,7 @@ class HyASTCompiler(object):
|
|||||||
|
|
||||||
@special(["quote", "quasiquote"], [FORM])
|
@special(["quote", "quasiquote"], [FORM])
|
||||||
def compile_quote(self, expr, root, arg):
|
def compile_quote(self, expr, root, arg):
|
||||||
if root == "quote":
|
level = Inf if root == "quote" else 0 # Only quasiquotes can unquote
|
||||||
# Never allow unquoting
|
|
||||||
level = float("inf")
|
|
||||||
else:
|
|
||||||
level = 0
|
|
||||||
imports, stmts, splice = self._render_quoted_form(arg, level)
|
imports, stmts, splice = self._render_quoted_form(arg, level)
|
||||||
ret = self.compile(stmts)
|
ret = self.compile(stmts)
|
||||||
ret.add_imports("hy", imports)
|
ret.add_imports("hy", imports)
|
||||||
@ -1276,58 +1273,71 @@ class HyASTCompiler(object):
|
|||||||
values=[value.force_expr for value in values])
|
values=[value.force_expr for value in values])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
ops = {"=": ast.Eq, "!=": ast.NotEq,
|
c_ops = {"=": ast.Eq, "!=": ast.NotEq,
|
||||||
"<": ast.Lt, "<=": ast.LtE,
|
"<": ast.Lt, "<=": ast.LtE,
|
||||||
">": ast.Gt, ">=": ast.GtE,
|
">": ast.Gt, ">=": ast.GtE,
|
||||||
"is": ast.Is, "is-not": ast.IsNot,
|
"is": ast.Is, "is-not": ast.IsNot,
|
||||||
"in": ast.In, "not-in": ast.NotIn}
|
"in": ast.In, "not-in": ast.NotIn}
|
||||||
ops = {ast_str(k): v for k, v in ops.items()}
|
c_ops = {ast_str(k): v for k, v in c_ops.items()}
|
||||||
|
|
||||||
def _compile_compare_op_expression(self, expr, root, args):
|
|
||||||
inv = ast_str(root)
|
|
||||||
ops = [self.ops[inv]() for _ in args[1:]]
|
|
||||||
|
|
||||||
exprs, ret, _ = self._compile_collect(args)
|
|
||||||
|
|
||||||
return ret + asty.Compare(
|
|
||||||
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
|
|
||||||
|
|
||||||
@special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)])
|
@special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)])
|
||||||
|
@special(["!=", "is-not"], [times(2, Inf, FORM)])
|
||||||
|
@special(["in", "not-in"], [times(2, 2, FORM)])
|
||||||
def compile_compare_op_expression(self, expr, root, args):
|
def compile_compare_op_expression(self, expr, root, args):
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
return (self.compile(args[0]) +
|
return (self.compile(args[0]) +
|
||||||
asty.Name(expr, id="True", ctx=ast.Load()))
|
asty.Name(expr, id="True", ctx=ast.Load()))
|
||||||
return self._compile_compare_op_expression(expr, root, args)
|
|
||||||
|
|
||||||
@special(["!=", "is-not"], [FORM, oneplus(FORM)])
|
ops = [self.c_ops[ast_str(root)]() for _ in args[1:]]
|
||||||
def compile_compare_op_expression_coll(self, expr, root, arg1, args):
|
exprs, ret, _ = self._compile_collect(args)
|
||||||
return self._compile_compare_op_expression(expr, root, [arg1] + args)
|
return ret + asty.Compare(
|
||||||
|
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
|
||||||
|
|
||||||
@special(["in", "not-in"], [FORM, FORM])
|
m_ops = {"+": ast.Add,
|
||||||
def compile_compare_op_expression_binary(self, expr, root, needle, haystack):
|
"/": ast.Div,
|
||||||
return self._compile_compare_op_expression(expr, root, [needle, haystack])
|
"//": ast.FloorDiv,
|
||||||
|
"*": ast.Mult,
|
||||||
|
"-": ast.Sub,
|
||||||
|
"%": ast.Mod,
|
||||||
|
"**": ast.Pow,
|
||||||
|
"<<": ast.LShift,
|
||||||
|
">>": ast.RShift,
|
||||||
|
"|": ast.BitOr,
|
||||||
|
"^": ast.BitXor,
|
||||||
|
"&": ast.BitAnd}
|
||||||
|
if PY35:
|
||||||
|
m_ops["@"] = ast.MatMult
|
||||||
|
|
||||||
def _compile_maths_expression(self, expr):
|
@special(["+", "*", "|"], [many(FORM)])
|
||||||
ops = {"+": ast.Add,
|
@special(["-", "/", "&", (PY35, "@")], [oneplus(FORM)])
|
||||||
"/": ast.Div,
|
@special(["**", "//", "<<", ">>"], [times(2, Inf, FORM)])
|
||||||
"//": ast.FloorDiv,
|
@special(["%", "^"], [times(2, 2, FORM)])
|
||||||
"*": ast.Mult,
|
def compile_maths_expression(self, expr, root, args):
|
||||||
"-": ast.Sub,
|
root = unmangle(ast_str(root))
|
||||||
"%": ast.Mod,
|
|
||||||
"**": ast.Pow,
|
|
||||||
"<<": ast.LShift,
|
|
||||||
">>": ast.RShift,
|
|
||||||
"|": ast.BitOr,
|
|
||||||
"^": ast.BitXor,
|
|
||||||
"&": ast.BitAnd}
|
|
||||||
if PY35:
|
|
||||||
ops.update({"@": ast.MatMult})
|
|
||||||
|
|
||||||
op = ops[expr.pop(0)]
|
if len(args) == 0:
|
||||||
right_associative = op is ast.Pow
|
# Return the identity element for this operator.
|
||||||
|
return asty.Num(expr, n=long_type(
|
||||||
|
{"+": 0, "|": 0, "*": 1}[root]))
|
||||||
|
|
||||||
ret = self.compile(expr.pop(-1 if right_associative else 0))
|
if len(args) == 1:
|
||||||
for child in expr[:: -1 if right_associative else 1]:
|
if root == "/":
|
||||||
|
# Compute the reciprocal of the argument.
|
||||||
|
args = [HyInteger(1).replace(expr), args[0]]
|
||||||
|
elif root in ("+", "-"):
|
||||||
|
# Apply unary plus or unary minus to the argument.
|
||||||
|
op = {"+": ast.UAdd, "-": ast.USub}[root]()
|
||||||
|
ret = self.compile(args[0])
|
||||||
|
return ret + asty.UnaryOp(expr, op=op, operand=ret.force_expr)
|
||||||
|
else:
|
||||||
|
# Return the argument unchanged.
|
||||||
|
return self.compile(args[0])
|
||||||
|
|
||||||
|
op = self.m_ops[root]
|
||||||
|
right_associative = root == "**"
|
||||||
|
ret = self.compile(args[-1 if right_associative else 0])
|
||||||
|
for child in args[-2 if right_associative else 1 ::
|
||||||
|
-1 if right_associative else 1]:
|
||||||
left_expr = ret.force_expr
|
left_expr = ret.force_expr
|
||||||
ret += self.compile(child)
|
ret += self.compile(child)
|
||||||
right_expr = ret.force_expr
|
right_expr = ret.force_expr
|
||||||
@ -1337,68 +1347,6 @@ class HyASTCompiler(object):
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@builds("**", "//", "<<", ">>", "&")
|
|
||||||
@checkargs(min=2)
|
|
||||||
def compile_maths_expression_2_or_more(self, expression):
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
|
|
||||||
@builds("%", "^")
|
|
||||||
@checkargs(2)
|
|
||||||
def compile_maths_expression_exactly_2(self, expression):
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
|
|
||||||
@builds("*", "|")
|
|
||||||
def compile_maths_expression_mul(self, expression):
|
|
||||||
id_elem = {"*": 1, "|": 0}[expression[0]]
|
|
||||||
if len(expression) == 1:
|
|
||||||
return asty.Num(expression, n=long_type(id_elem))
|
|
||||||
elif len(expression) == 2:
|
|
||||||
return self.compile(expression[1])
|
|
||||||
else:
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
|
|
||||||
@builds("/")
|
|
||||||
@checkargs(min=1)
|
|
||||||
def compile_maths_expression_div(self, expression):
|
|
||||||
if len(expression) == 2:
|
|
||||||
expression = HyExpression([HySymbol("/"),
|
|
||||||
HyInteger(1),
|
|
||||||
expression[1]]).replace(expression)
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
|
|
||||||
def _compile_maths_expression_additive(self, expression):
|
|
||||||
if len(expression) > 2:
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
else:
|
|
||||||
op = {"+": ast.UAdd, "-": ast.USub}[expression.pop(0)]()
|
|
||||||
ret = self.compile(expression.pop(0))
|
|
||||||
return ret + asty.UnaryOp(
|
|
||||||
expression, op=op, operand=ret.force_expr)
|
|
||||||
|
|
||||||
@builds("&")
|
|
||||||
@builds("@", iff=PY35)
|
|
||||||
@checkargs(min=1)
|
|
||||||
def compile_maths_expression_unary_idempotent(self, expression):
|
|
||||||
if len(expression) == 2:
|
|
||||||
# Used as a unary operator, this operator simply
|
|
||||||
# returns its argument.
|
|
||||||
return self.compile(expression[1])
|
|
||||||
else:
|
|
||||||
return self._compile_maths_expression(expression)
|
|
||||||
|
|
||||||
@builds("+")
|
|
||||||
def compile_maths_expression_add(self, expression):
|
|
||||||
if len(expression) == 1:
|
|
||||||
# Nullary +
|
|
||||||
return asty.Num(expression, n=long_type(0))
|
|
||||||
else:
|
|
||||||
return self._compile_maths_expression_additive(expression)
|
|
||||||
|
|
||||||
@builds("-")
|
|
||||||
@checkargs(min=1)
|
|
||||||
def compile_maths_expression_sub(self, expression):
|
|
||||||
return self._compile_maths_expression_additive(expression)
|
|
||||||
|
|
||||||
@builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=",
|
@builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=",
|
||||||
"^=", "&=")
|
"^=", "&=")
|
||||||
@builds("@=", iff=PY35)
|
@builds("@=", iff=PY35)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user