Use model patterns for comparison and math ops
This commit is contained in:
parent
57b5fa49b1
commit
41d3f26001
138
hy/compiler.py
138
hy/compiler.py
@ -7,11 +7,11 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
|
||||
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
|
||||
HyDict, HySequence, wrap_value)
|
||||
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
|
||||
dolike, pexpr)
|
||||
dolike, pexpr, times)
|
||||
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
|
||||
from hy.errors import HyCompileError, HyTypeError
|
||||
|
||||
from hy.lex.parser import mangle
|
||||
from hy.lex.parser import mangle, unmangle
|
||||
|
||||
import hy.macros
|
||||
from hy._compat import (
|
||||
@ -35,6 +35,7 @@ if PY3:
|
||||
else:
|
||||
import __builtin__ as builtins
|
||||
|
||||
Inf = float('inf')
|
||||
|
||||
_compile_time_ns = {}
|
||||
|
||||
@ -670,11 +671,7 @@ class HyASTCompiler(object):
|
||||
|
||||
@special(["quote", "quasiquote"], [FORM])
|
||||
def compile_quote(self, expr, root, arg):
|
||||
if root == "quote":
|
||||
# Never allow unquoting
|
||||
level = float("inf")
|
||||
else:
|
||||
level = 0
|
||||
level = Inf if root == "quote" else 0 # Only quasiquotes can unquote
|
||||
imports, stmts, splice = self._render_quoted_form(arg, level)
|
||||
ret = self.compile(stmts)
|
||||
ret.add_imports("hy", imports)
|
||||
@ -1276,39 +1273,27 @@ class HyASTCompiler(object):
|
||||
values=[value.force_expr for value in values])
|
||||
return ret
|
||||
|
||||
ops = {"=": ast.Eq, "!=": ast.NotEq,
|
||||
c_ops = {"=": ast.Eq, "!=": ast.NotEq,
|
||||
"<": ast.Lt, "<=": ast.LtE,
|
||||
">": ast.Gt, ">=": ast.GtE,
|
||||
"is": ast.Is, "is-not": ast.IsNot,
|
||||
"in": ast.In, "not-in": ast.NotIn}
|
||||
ops = {ast_str(k): v for k, v in ops.items()}
|
||||
|
||||
def _compile_compare_op_expression(self, expr, root, args):
|
||||
inv = ast_str(root)
|
||||
ops = [self.ops[inv]() for _ in args[1:]]
|
||||
|
||||
exprs, ret, _ = self._compile_collect(args)
|
||||
|
||||
return ret + asty.Compare(
|
||||
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
|
||||
c_ops = {ast_str(k): v for k, v in c_ops.items()}
|
||||
|
||||
@special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)])
|
||||
@special(["!=", "is-not"], [times(2, Inf, FORM)])
|
||||
@special(["in", "not-in"], [times(2, 2, FORM)])
|
||||
def compile_compare_op_expression(self, expr, root, args):
|
||||
if len(args) == 1:
|
||||
return (self.compile(args[0]) +
|
||||
asty.Name(expr, id="True", ctx=ast.Load()))
|
||||
return self._compile_compare_op_expression(expr, root, args)
|
||||
|
||||
@special(["!=", "is-not"], [FORM, oneplus(FORM)])
|
||||
def compile_compare_op_expression_coll(self, expr, root, arg1, args):
|
||||
return self._compile_compare_op_expression(expr, root, [arg1] + args)
|
||||
ops = [self.c_ops[ast_str(root)]() for _ in args[1:]]
|
||||
exprs, ret, _ = self._compile_collect(args)
|
||||
return ret + asty.Compare(
|
||||
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
|
||||
|
||||
@special(["in", "not-in"], [FORM, FORM])
|
||||
def compile_compare_op_expression_binary(self, expr, root, needle, haystack):
|
||||
return self._compile_compare_op_expression(expr, root, [needle, haystack])
|
||||
|
||||
def _compile_maths_expression(self, expr):
|
||||
ops = {"+": ast.Add,
|
||||
m_ops = {"+": ast.Add,
|
||||
"/": ast.Div,
|
||||
"//": ast.FloorDiv,
|
||||
"*": ast.Mult,
|
||||
@ -1321,13 +1306,38 @@ class HyASTCompiler(object):
|
||||
"^": ast.BitXor,
|
||||
"&": ast.BitAnd}
|
||||
if PY35:
|
||||
ops.update({"@": ast.MatMult})
|
||||
m_ops["@"] = ast.MatMult
|
||||
|
||||
op = ops[expr.pop(0)]
|
||||
right_associative = op is ast.Pow
|
||||
@special(["+", "*", "|"], [many(FORM)])
|
||||
@special(["-", "/", "&", (PY35, "@")], [oneplus(FORM)])
|
||||
@special(["**", "//", "<<", ">>"], [times(2, Inf, FORM)])
|
||||
@special(["%", "^"], [times(2, 2, FORM)])
|
||||
def compile_maths_expression(self, expr, root, args):
|
||||
root = unmangle(ast_str(root))
|
||||
|
||||
ret = self.compile(expr.pop(-1 if right_associative else 0))
|
||||
for child in expr[:: -1 if right_associative else 1]:
|
||||
if len(args) == 0:
|
||||
# Return the identity element for this operator.
|
||||
return asty.Num(expr, n=long_type(
|
||||
{"+": 0, "|": 0, "*": 1}[root]))
|
||||
|
||||
if len(args) == 1:
|
||||
if root == "/":
|
||||
# Compute the reciprocal of the argument.
|
||||
args = [HyInteger(1).replace(expr), args[0]]
|
||||
elif root in ("+", "-"):
|
||||
# Apply unary plus or unary minus to the argument.
|
||||
op = {"+": ast.UAdd, "-": ast.USub}[root]()
|
||||
ret = self.compile(args[0])
|
||||
return ret + asty.UnaryOp(expr, op=op, operand=ret.force_expr)
|
||||
else:
|
||||
# Return the argument unchanged.
|
||||
return self.compile(args[0])
|
||||
|
||||
op = self.m_ops[root]
|
||||
right_associative = root == "**"
|
||||
ret = self.compile(args[-1 if right_associative else 0])
|
||||
for child in args[-2 if right_associative else 1 ::
|
||||
-1 if right_associative else 1]:
|
||||
left_expr = ret.force_expr
|
||||
ret += self.compile(child)
|
||||
right_expr = ret.force_expr
|
||||
@ -1337,68 +1347,6 @@ class HyASTCompiler(object):
|
||||
|
||||
return ret
|
||||
|
||||
@builds("**", "//", "<<", ">>", "&")
|
||||
@checkargs(min=2)
|
||||
def compile_maths_expression_2_or_more(self, expression):
|
||||
return self._compile_maths_expression(expression)
|
||||
|
||||
@builds("%", "^")
|
||||
@checkargs(2)
|
||||
def compile_maths_expression_exactly_2(self, expression):
|
||||
return self._compile_maths_expression(expression)
|
||||
|
||||
@builds("*", "|")
|
||||
def compile_maths_expression_mul(self, expression):
|
||||
id_elem = {"*": 1, "|": 0}[expression[0]]
|
||||
if len(expression) == 1:
|
||||
return asty.Num(expression, n=long_type(id_elem))
|
||||
elif len(expression) == 2:
|
||||
return self.compile(expression[1])
|
||||
else:
|
||||
return self._compile_maths_expression(expression)
|
||||
|
||||
@builds("/")
|
||||
@checkargs(min=1)
|
||||
def compile_maths_expression_div(self, expression):
|
||||
if len(expression) == 2:
|
||||
expression = HyExpression([HySymbol("/"),
|
||||
HyInteger(1),
|
||||
expression[1]]).replace(expression)
|
||||
return self._compile_maths_expression(expression)
|
||||
|
||||
def _compile_maths_expression_additive(self, expression):
|
||||
if len(expression) > 2:
|
||||
return self._compile_maths_expression(expression)
|
||||
else:
|
||||
op = {"+": ast.UAdd, "-": ast.USub}[expression.pop(0)]()
|
||||
ret = self.compile(expression.pop(0))
|
||||
return ret + asty.UnaryOp(
|
||||
expression, op=op, operand=ret.force_expr)
|
||||
|
||||
@builds("&")
|
||||
@builds("@", iff=PY35)
|
||||
@checkargs(min=1)
|
||||
def compile_maths_expression_unary_idempotent(self, expression):
|
||||
if len(expression) == 2:
|
||||
# Used as a unary operator, this operator simply
|
||||
# returns its argument.
|
||||
return self.compile(expression[1])
|
||||
else:
|
||||
return self._compile_maths_expression(expression)
|
||||
|
||||
@builds("+")
|
||||
def compile_maths_expression_add(self, expression):
|
||||
if len(expression) == 1:
|
||||
# Nullary +
|
||||
return asty.Num(expression, n=long_type(0))
|
||||
else:
|
||||
return self._compile_maths_expression_additive(expression)
|
||||
|
||||
@builds("-")
|
||||
@checkargs(min=1)
|
||||
def compile_maths_expression_sub(self, expression):
|
||||
return self._compile_maths_expression_additive(expression)
|
||||
|
||||
@builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=",
|
||||
"^=", "&=")
|
||||
@builds("@=", iff=PY35)
|
||||
|
Loading…
Reference in New Issue
Block a user