Use model patterns for comparison and math ops

This commit is contained in:
Kodi Arfer 2018-05-08 21:05:45 -07:00
parent 57b5fa49b1
commit 41d3f26001

View File

@ -7,11 +7,11 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
HyString, HyBytes, HySymbol, HyFloat, HyList, HySet, HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
HyDict, HySequence, wrap_value) HyDict, HySequence, wrap_value)
from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr, from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr,
dolike, pexpr) dolike, pexpr, times)
from funcparserlib.parser import some, many, oneplus, maybe, NoParseError from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
from hy.errors import HyCompileError, HyTypeError from hy.errors import HyCompileError, HyTypeError
from hy.lex.parser import mangle from hy.lex.parser import mangle, unmangle
import hy.macros import hy.macros
from hy._compat import ( from hy._compat import (
@ -35,6 +35,7 @@ if PY3:
else: else:
import __builtin__ as builtins import __builtin__ as builtins
Inf = float('inf')
_compile_time_ns = {} _compile_time_ns = {}
@ -670,11 +671,7 @@ class HyASTCompiler(object):
@special(["quote", "quasiquote"], [FORM]) @special(["quote", "quasiquote"], [FORM])
def compile_quote(self, expr, root, arg): def compile_quote(self, expr, root, arg):
if root == "quote": level = Inf if root == "quote" else 0 # Only quasiquotes can unquote
# Never allow unquoting
level = float("inf")
else:
level = 0
imports, stmts, splice = self._render_quoted_form(arg, level) imports, stmts, splice = self._render_quoted_form(arg, level)
ret = self.compile(stmts) ret = self.compile(stmts)
ret.add_imports("hy", imports) ret.add_imports("hy", imports)
@ -1276,39 +1273,27 @@ class HyASTCompiler(object):
values=[value.force_expr for value in values]) values=[value.force_expr for value in values])
return ret return ret
ops = {"=": ast.Eq, "!=": ast.NotEq, c_ops = {"=": ast.Eq, "!=": ast.NotEq,
"<": ast.Lt, "<=": ast.LtE, "<": ast.Lt, "<=": ast.LtE,
">": ast.Gt, ">=": ast.GtE, ">": ast.Gt, ">=": ast.GtE,
"is": ast.Is, "is-not": ast.IsNot, "is": ast.Is, "is-not": ast.IsNot,
"in": ast.In, "not-in": ast.NotIn} "in": ast.In, "not-in": ast.NotIn}
ops = {ast_str(k): v for k, v in ops.items()} c_ops = {ast_str(k): v for k, v in c_ops.items()}
def _compile_compare_op_expression(self, expr, root, args):
inv = ast_str(root)
ops = [self.ops[inv]() for _ in args[1:]]
exprs, ret, _ = self._compile_collect(args)
return ret + asty.Compare(
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
@special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)]) @special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)])
@special(["!=", "is-not"], [times(2, Inf, FORM)])
@special(["in", "not-in"], [times(2, 2, FORM)])
def compile_compare_op_expression(self, expr, root, args): def compile_compare_op_expression(self, expr, root, args):
if len(args) == 1: if len(args) == 1:
return (self.compile(args[0]) + return (self.compile(args[0]) +
asty.Name(expr, id="True", ctx=ast.Load())) asty.Name(expr, id="True", ctx=ast.Load()))
return self._compile_compare_op_expression(expr, root, args)
@special(["!=", "is-not"], [FORM, oneplus(FORM)]) ops = [self.c_ops[ast_str(root)]() for _ in args[1:]]
def compile_compare_op_expression_coll(self, expr, root, arg1, args): exprs, ret, _ = self._compile_collect(args)
return self._compile_compare_op_expression(expr, root, [arg1] + args) return ret + asty.Compare(
expr, left=exprs[0], ops=ops, comparators=exprs[1:])
@special(["in", "not-in"], [FORM, FORM]) m_ops = {"+": ast.Add,
def compile_compare_op_expression_binary(self, expr, root, needle, haystack):
return self._compile_compare_op_expression(expr, root, [needle, haystack])
def _compile_maths_expression(self, expr):
ops = {"+": ast.Add,
"/": ast.Div, "/": ast.Div,
"//": ast.FloorDiv, "//": ast.FloorDiv,
"*": ast.Mult, "*": ast.Mult,
@ -1321,13 +1306,38 @@ class HyASTCompiler(object):
"^": ast.BitXor, "^": ast.BitXor,
"&": ast.BitAnd} "&": ast.BitAnd}
if PY35: if PY35:
ops.update({"@": ast.MatMult}) m_ops["@"] = ast.MatMult
op = ops[expr.pop(0)] @special(["+", "*", "|"], [many(FORM)])
right_associative = op is ast.Pow @special(["-", "/", "&", (PY35, "@")], [oneplus(FORM)])
@special(["**", "//", "<<", ">>"], [times(2, Inf, FORM)])
@special(["%", "^"], [times(2, 2, FORM)])
def compile_maths_expression(self, expr, root, args):
root = unmangle(ast_str(root))
ret = self.compile(expr.pop(-1 if right_associative else 0)) if len(args) == 0:
for child in expr[:: -1 if right_associative else 1]: # Return the identity element for this operator.
return asty.Num(expr, n=long_type(
{"+": 0, "|": 0, "*": 1}[root]))
if len(args) == 1:
if root == "/":
# Compute the reciprocal of the argument.
args = [HyInteger(1).replace(expr), args[0]]
elif root in ("+", "-"):
# Apply unary plus or unary minus to the argument.
op = {"+": ast.UAdd, "-": ast.USub}[root]()
ret = self.compile(args[0])
return ret + asty.UnaryOp(expr, op=op, operand=ret.force_expr)
else:
# Return the argument unchanged.
return self.compile(args[0])
op = self.m_ops[root]
right_associative = root == "**"
ret = self.compile(args[-1 if right_associative else 0])
for child in args[-2 if right_associative else 1 ::
-1 if right_associative else 1]:
left_expr = ret.force_expr left_expr = ret.force_expr
ret += self.compile(child) ret += self.compile(child)
right_expr = ret.force_expr right_expr = ret.force_expr
@ -1337,68 +1347,6 @@ class HyASTCompiler(object):
return ret return ret
@builds("**", "//", "<<", ">>", "&")
@checkargs(min=2)
def compile_maths_expression_2_or_more(self, expression):
return self._compile_maths_expression(expression)
@builds("%", "^")
@checkargs(2)
def compile_maths_expression_exactly_2(self, expression):
return self._compile_maths_expression(expression)
@builds("*", "|")
def compile_maths_expression_mul(self, expression):
id_elem = {"*": 1, "|": 0}[expression[0]]
if len(expression) == 1:
return asty.Num(expression, n=long_type(id_elem))
elif len(expression) == 2:
return self.compile(expression[1])
else:
return self._compile_maths_expression(expression)
@builds("/")
@checkargs(min=1)
def compile_maths_expression_div(self, expression):
if len(expression) == 2:
expression = HyExpression([HySymbol("/"),
HyInteger(1),
expression[1]]).replace(expression)
return self._compile_maths_expression(expression)
def _compile_maths_expression_additive(self, expression):
if len(expression) > 2:
return self._compile_maths_expression(expression)
else:
op = {"+": ast.UAdd, "-": ast.USub}[expression.pop(0)]()
ret = self.compile(expression.pop(0))
return ret + asty.UnaryOp(
expression, op=op, operand=ret.force_expr)
@builds("&")
@builds("@", iff=PY35)
@checkargs(min=1)
def compile_maths_expression_unary_idempotent(self, expression):
if len(expression) == 2:
# Used as a unary operator, this operator simply
# returns its argument.
return self.compile(expression[1])
else:
return self._compile_maths_expression(expression)
@builds("+")
def compile_maths_expression_add(self, expression):
if len(expression) == 1:
# Nullary +
return asty.Num(expression, n=long_type(0))
else:
return self._compile_maths_expression_additive(expression)
@builds("-")
@checkargs(min=1)
def compile_maths_expression_sub(self, expression):
return self._compile_maths_expression_additive(expression)
@builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=", @builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=",
"^=", "&=") "^=", "&=")
@builds("@=", iff=PY35) @builds("@=", iff=PY35)