From 41d3f2600120e13dbea62355d120c5cd5d5e71d6 Mon Sep 17 00:00:00 2001 From: Kodi Arfer Date: Tue, 8 May 2018 21:05:45 -0700 Subject: [PATCH] Use model patterns for comparison and math ops --- hy/compiler.py | 170 +++++++++++++++++-------------------------------- 1 file changed, 59 insertions(+), 111 deletions(-) diff --git a/hy/compiler.py b/hy/compiler.py index d5d085b..2675c77 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -7,11 +7,11 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex, HyString, HyBytes, HySymbol, HyFloat, HyList, HySet, HyDict, HySequence, wrap_value) from hy.model_patterns import (FORM, SYM, STR, sym, brackets, whole, notpexpr, - dolike, pexpr) + dolike, pexpr, times) from funcparserlib.parser import some, many, oneplus, maybe, NoParseError from hy.errors import HyCompileError, HyTypeError -from hy.lex.parser import mangle +from hy.lex.parser import mangle, unmangle import hy.macros from hy._compat import ( @@ -35,6 +35,7 @@ if PY3: else: import __builtin__ as builtins +Inf = float('inf') _compile_time_ns = {} @@ -670,11 +671,7 @@ class HyASTCompiler(object): @special(["quote", "quasiquote"], [FORM]) def compile_quote(self, expr, root, arg): - if root == "quote": - # Never allow unquoting - level = float("inf") - else: - level = 0 + level = Inf if root == "quote" else 0 # Only quasiquotes can unquote imports, stmts, splice = self._render_quoted_form(arg, level) ret = self.compile(stmts) ret.add_imports("hy", imports) @@ -1276,58 +1273,71 @@ class HyASTCompiler(object): values=[value.force_expr for value in values]) return ret - ops = {"=": ast.Eq, "!=": ast.NotEq, - "<": ast.Lt, "<=": ast.LtE, - ">": ast.Gt, ">=": ast.GtE, - "is": ast.Is, "is-not": ast.IsNot, - "in": ast.In, "not-in": ast.NotIn} - ops = {ast_str(k): v for k, v in ops.items()} - - def _compile_compare_op_expression(self, expr, root, args): - inv = ast_str(root) - ops = [self.ops[inv]() for _ in args[1:]] - - exprs, ret, _ = self._compile_collect(args) - - return ret + asty.Compare( - expr, left=exprs[0], ops=ops, comparators=exprs[1:]) + c_ops = {"=": ast.Eq, "!=": ast.NotEq, + "<": ast.Lt, "<=": ast.LtE, + ">": ast.Gt, ">=": ast.GtE, + "is": ast.Is, "is-not": ast.IsNot, + "in": ast.In, "not-in": ast.NotIn} + c_ops = {ast_str(k): v for k, v in c_ops.items()} @special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)]) + @special(["!=", "is-not"], [times(2, Inf, FORM)]) + @special(["in", "not-in"], [times(2, 2, FORM)]) def compile_compare_op_expression(self, expr, root, args): if len(args) == 1: return (self.compile(args[0]) + - asty.Name(expr, id="True", ctx=ast.Load())) - return self._compile_compare_op_expression(expr, root, args) + asty.Name(expr, id="True", ctx=ast.Load())) - @special(["!=", "is-not"], [FORM, oneplus(FORM)]) - def compile_compare_op_expression_coll(self, expr, root, arg1, args): - return self._compile_compare_op_expression(expr, root, [arg1] + args) + ops = [self.c_ops[ast_str(root)]() for _ in args[1:]] + exprs, ret, _ = self._compile_collect(args) + return ret + asty.Compare( + expr, left=exprs[0], ops=ops, comparators=exprs[1:]) - @special(["in", "not-in"], [FORM, FORM]) - def compile_compare_op_expression_binary(self, expr, root, needle, haystack): - return self._compile_compare_op_expression(expr, root, [needle, haystack]) + m_ops = {"+": ast.Add, + "/": ast.Div, + "//": ast.FloorDiv, + "*": ast.Mult, + "-": ast.Sub, + "%": ast.Mod, + "**": ast.Pow, + "<<": ast.LShift, + ">>": ast.RShift, + "|": ast.BitOr, + "^": ast.BitXor, + "&": ast.BitAnd} + if PY35: + m_ops["@"] = ast.MatMult - def _compile_maths_expression(self, expr): - ops = {"+": ast.Add, - "/": ast.Div, - "//": ast.FloorDiv, - "*": ast.Mult, - "-": ast.Sub, - "%": ast.Mod, - "**": ast.Pow, - "<<": ast.LShift, - ">>": ast.RShift, - "|": ast.BitOr, - "^": ast.BitXor, - "&": ast.BitAnd} - if PY35: - ops.update({"@": ast.MatMult}) + @special(["+", "*", "|"], [many(FORM)]) + @special(["-", "/", "&", (PY35, "@")], [oneplus(FORM)]) + @special(["**", "//", "<<", ">>"], [times(2, Inf, FORM)]) + @special(["%", "^"], [times(2, 2, FORM)]) + def compile_maths_expression(self, expr, root, args): + root = unmangle(ast_str(root)) - op = ops[expr.pop(0)] - right_associative = op is ast.Pow + if len(args) == 0: + # Return the identity element for this operator. + return asty.Num(expr, n=long_type( + {"+": 0, "|": 0, "*": 1}[root])) - ret = self.compile(expr.pop(-1 if right_associative else 0)) - for child in expr[:: -1 if right_associative else 1]: + if len(args) == 1: + if root == "/": + # Compute the reciprocal of the argument. + args = [HyInteger(1).replace(expr), args[0]] + elif root in ("+", "-"): + # Apply unary plus or unary minus to the argument. + op = {"+": ast.UAdd, "-": ast.USub}[root]() + ret = self.compile(args[0]) + return ret + asty.UnaryOp(expr, op=op, operand=ret.force_expr) + else: + # Return the argument unchanged. + return self.compile(args[0]) + + op = self.m_ops[root] + right_associative = root == "**" + ret = self.compile(args[-1 if right_associative else 0]) + for child in args[-2 if right_associative else 1 :: + -1 if right_associative else 1]: left_expr = ret.force_expr ret += self.compile(child) right_expr = ret.force_expr @@ -1337,68 +1347,6 @@ class HyASTCompiler(object): return ret - @builds("**", "//", "<<", ">>", "&") - @checkargs(min=2) - def compile_maths_expression_2_or_more(self, expression): - return self._compile_maths_expression(expression) - - @builds("%", "^") - @checkargs(2) - def compile_maths_expression_exactly_2(self, expression): - return self._compile_maths_expression(expression) - - @builds("*", "|") - def compile_maths_expression_mul(self, expression): - id_elem = {"*": 1, "|": 0}[expression[0]] - if len(expression) == 1: - return asty.Num(expression, n=long_type(id_elem)) - elif len(expression) == 2: - return self.compile(expression[1]) - else: - return self._compile_maths_expression(expression) - - @builds("/") - @checkargs(min=1) - def compile_maths_expression_div(self, expression): - if len(expression) == 2: - expression = HyExpression([HySymbol("/"), - HyInteger(1), - expression[1]]).replace(expression) - return self._compile_maths_expression(expression) - - def _compile_maths_expression_additive(self, expression): - if len(expression) > 2: - return self._compile_maths_expression(expression) - else: - op = {"+": ast.UAdd, "-": ast.USub}[expression.pop(0)]() - ret = self.compile(expression.pop(0)) - return ret + asty.UnaryOp( - expression, op=op, operand=ret.force_expr) - - @builds("&") - @builds("@", iff=PY35) - @checkargs(min=1) - def compile_maths_expression_unary_idempotent(self, expression): - if len(expression) == 2: - # Used as a unary operator, this operator simply - # returns its argument. - return self.compile(expression[1]) - else: - return self._compile_maths_expression(expression) - - @builds("+") - def compile_maths_expression_add(self, expression): - if len(expression) == 1: - # Nullary + - return asty.Num(expression, n=long_type(0)) - else: - return self._compile_maths_expression_additive(expression) - - @builds("-") - @checkargs(min=1) - def compile_maths_expression_sub(self, expression): - return self._compile_maths_expression_additive(expression) - @builds("+=", "/=", "//=", "*=", "-=", "%=", "**=", "<<=", ">>=", "|=", "^=", "&=") @builds("@=", iff=PY35)