diff --git a/NEWS.rst b/NEWS.rst index d56dce5..476177c 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -11,6 +11,11 @@ Removals * Literal keywords are no longer parsed differently in calls to functions with certain names. +New Features +------------------------------ +* All augmented assignment operators (except `%=` and `^=`) now allow + more than two arguments. + Bug Fixes ------------------------------ * Statements in the second argument of `assert` are now executed. diff --git a/hy/compiler.py b/hy/compiler.py index 4387366..6d09afa 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1265,19 +1265,21 @@ class HyASTCompiler(object): return ret + asty.Compare( expr, left=exprs[0], ops=ops, comparators=exprs[1:]) - m_ops = {"+": ast.Add, - "/": ast.Div, - "//": ast.FloorDiv, - "*": ast.Mult, - "-": ast.Sub, - "%": ast.Mod, - "**": ast.Pow, - "<<": ast.LShift, - ">>": ast.RShift, - "|": ast.BitOr, - "^": ast.BitXor, - "&": ast.BitAnd, - "@": ast.MatMult} + # The second element of each tuple below is an aggregation operator + # that's used for augmented assignment with three or more arguments. + m_ops = {"+": (ast.Add, "+"), + "/": (ast.Div, "*"), + "//": (ast.FloorDiv, "*"), + "*": (ast.Mult, "*"), + "-": (ast.Sub, "+"), + "%": (ast.Mod, None), + "**": (ast.Pow, "**"), + "<<": (ast.LShift, "+"), + ">>": (ast.RShift, "+"), + "|": (ast.BitOr, "|"), + "^": (ast.BitXor, None), + "&": (ast.BitAnd, "&"), + "@": (ast.MatMult, "@")} @special(["+", "*", "|"], [many(FORM)]) @special(["-", "/", "&", "@"], [oneplus(FORM)]) @@ -1302,7 +1304,7 @@ class HyASTCompiler(object): # Return the argument unchanged. return self.compile(args[0]) - op = self.m_ops[root] + op = self.m_ops[root][0] right_associative = root == "**" ret = self.compile(args[-1 if right_associative else 0]) for child in args[-2 if right_associative else 1 :: @@ -1318,11 +1320,16 @@ class HyASTCompiler(object): a_ops = {x + "=": v for x, v in m_ops.items()} - @special(list(a_ops.keys()), [FORM, FORM]) - def compile_augassign_expression(self, expr, root, target, value): - op = self.a_ops[root] + @special([x for x, (_, v) in a_ops.items() if v is not None], [FORM, oneplus(FORM)]) + @special([x for x, (_, v) in a_ops.items() if v is None], [FORM, times(1, 1, FORM)]) + def compile_augassign_expression(self, expr, root, target, values): + if len(values) > 1: + return self.compile(mkexpr(root, [target], + mkexpr(self.a_ops[root][1], rest=values)).replace(expr)) + + op = self.a_ops[root][0] target = self._storeize(target, self.compile(target)) - ret = self.compile(value) + ret = self.compile(values[0]) return ret + asty.AugAssign( expr, target=target, value=ret.force_expr, op=op()) diff --git a/tests/native_tests/mathematics.hy b/tests/native_tests/mathematics.hy deleted file mode 100644 index 5069644..0000000 --- a/tests/native_tests/mathematics.hy +++ /dev/null @@ -1,199 +0,0 @@ -;; Copyright 2019 the authors. -;; This file is part of Hy, which is free software licensed under the Expat -;; license. See the LICENSE. - -(setv square (fn [x] - (* x x))) - - -(setv test_basic_math (fn [] - "NATIVE: Test basic math." - (assert (= (+ 2 2) 4)))) - -(setv test_mult (fn [] - "NATIVE: Test multiplication." - (assert (= 4 (square 2))) - (assert (= 8 (* 8))) - (assert (= 1 (*))))) - - -(setv test_sub (fn [] - "NATIVE: Test subtraction" - (assert (= 4 (- 8 4))) - (assert (= -8 (- 8))))) - - -(setv test_add (fn [] - "NATIVE: Test addition" - (assert (= 4 (+ 1 1 1 1))) - (assert (= 8 (+ 8))) - (assert (= 0 (+))))) - - -(defn test-add-unary [] - "NATIVE: test that unary + calls __pos__" - - (defclass X [object] - (defn __pos__ [self] "called __pos__")) - (assert (= (+ (X)) "called __pos__")) - - ; Make sure the shadowed version works, too. - (setv f +) - (assert (= (f (X)) "called __pos__"))) - - -(setv test_div (fn [] - "NATIVE: Test division" - (assert (= 25 (/ 100 2 2))) - ; Commented out until float constants get implemented - ; (assert (= 0.5 (/ 1 2))) - (assert (= 1 (* 2 (/ 1 2)))))) - -(setv test_int_div (fn [] - "NATIVE: Test integer division" - (assert (= 25 (// 101 2 2))))) - -(defn test-modulo [] - "NATIVE: test mod" - (assert (= (% 10 2) 0))) - -(defn test-pow [] - "NATIVE: test pow" - (assert (= (** 10 2) 100))) - -(defn test-lshift [] - "NATIVE: test lshift" - (assert (= (<< 1 2) 4))) - -(defn test-rshift [] - "NATIVE: test lshift" - (assert (= (>> 8 1) 4))) - -(defn test-bitor [] - "NATIVE: test lshift" - (assert (= (| 1 2) 3))) - -(defn test-bitxor [] - "NATIVE: test xor" - (assert (= (^ 1 2) 3))) - -(defn test-bitand [] - "NATIVE: test lshift" - (assert (= (& 1 2) 0))) - -(defn test-augassign-add [] - "NATIVE: test augassign add" - (setv x 1) - (+= x 41) - (assert (= x 42))) - -(defn test-augassign-sub [] - "NATIVE: test augassign sub" - (setv x 1) - (-= x 41) - (assert (= x -40))) - -(defn test-augassign-mult [] - "NATIVE: test augassign mult" - (setv x 1) - (*= x 41) - (assert (= x 41))) - -(defn test-augassign-div [] - "NATIVE: test augassign div" - (setv x 42) - (/= x 2) - (assert (= x 21))) - -(defn test-augassign-floordiv [] - "NATIVE: test augassign floordiv" - (setv x 42) - (//= x 2) - (assert (= x 21))) - -(defn test-augassign-mod [] - "NATIVE: test augassign mod" - (setv x 42) - (%= x 2) - (assert (= x 0))) - -(defn test-augassign-pow [] - "NATIVE: test augassign pow" - (setv x 2) - (**= x 3) - (assert (= x 8))) - -(defn test-augassign-lshift [] - "NATIVE: test augassign lshift" - (setv x 2) - (<<= x 2) - (assert (= x 8))) - -(defn test-augassign-rshift [] - "NATIVE: test augassign rshift" - (setv x 8) - (>>= x 1) - (assert (= x 4))) - -(defn test-augassign-bitand [] - "NATIVE: test augassign bitand" - (setv x 8) - (&= x 1) - (assert (= x 0))) - -(defn test-augassign-bitor [] - "NATIVE: test augassign bitand" - (setv x 0) - (|= x 2) - (assert (= x 2))) - -(defn test-augassign-bitxor [] - "NATIVE: test augassign bitand" - (setv x 1) - (^= x 1) - (assert (= x 0))) - -(defn overflow-int-to-long [] - "NATIVE: test if int does not raise an overflow exception" - (assert (integer? (+ 1 1000000000000000000000000)))) - - -(defclass HyTestMatrix [list] - (defn --matmul-- [self other] - (setv n (len self) - m (len (. other [0])) - result []) - (for [i (range m)] - (setv result-row []) - (for [j (range n)] - (setv dot-product 0) - (for [k (range (len (. self [0])))] - (+= dot-product (* (. self [i] [k]) - (. other [k] [j])))) - (.append result-row dot-product)) - (.append result result-row)) - result)) - -(setv first-test-matrix (HyTestMatrix [[1 2 3] - [4 5 6] - [7 8 9]])) - -(setv second-test-matrix (HyTestMatrix [[2 0 0] - [0 2 0] - [0 0 2]])) - -(setv product-of-test-matrices (HyTestMatrix [[ 2 4 6] - [ 8 10 12] - [14 16 18]])) - -(defn test-matmul [] - "NATIVE: test matrix multiplication" - (assert (= (@ first-test-matrix second-test-matrix) - product-of-test-matrices))) - -(defn test-augassign-matmul [] - "NATIVE: test augmented-assignment matrix multiplication" - (setv matrix first-test-matrix - matmul-attempt (try (@= matrix second-test-matrix) - (except [e [Exception]] e))) - (assert (= product-of-test-matrices matrix))) diff --git a/tests/native_tests/operators.hy b/tests/native_tests/operators.hy index 7c9b18b..4445138 100644 --- a/tests/native_tests/operators.hy +++ b/tests/native_tests/operators.hy @@ -303,3 +303,43 @@ (assert (= (f "hello" 1) "e")) (assert (= (f [[1 2 3] [4 5 6] [7 8 9]] 1 2) 6)) (assert (= (f {"x" {"y" {"z" 12}}} "x" "y" "z") 12))) + + +(defn test-augassign [] + (setv b 2 c 3 d 4) + (defmacro same-as [expr1 expr2 expected-value] + `(do + (setv a 4) + ~expr1 + (setv expr1-value a) + (setv a 4) + ~expr2 + (assert (= expr1-value a ~expected-value)))) + (same-as (+= a b c d) (+= a (+ b c d)) 13) + (same-as (-= a b c d) (-= a (+ b c d)) -5) + (same-as (*= a b c d) (*= a (* b c d)) 96) + (same-as (**= a b c) (**= a (** b c)) 65,536) + (same-as (/= a b c d) (/= a (* b c d)) (/ 1 6)) + (same-as (//= a b c d) (//= a (* b c d)) 0) + (same-as (<<= a b c d) (<<= a (+ b c d)) 0b10_00000_00000) + (same-as (>>= a b c d) (>>= a (+ b c d)) 0) + (same-as (&= a b c d) (&= a (& b c d)) 0) + (same-as (|= a b c d) (|= a (| b c d)) 0b111) + + (defclass C [object] + (defn __init__ [self content] (setv self.content content)) + (defn __matmul__ [self other] (C (+ self.content other.content)))) + (setv a (C "a") b (C "b") c (C "c") d (C "d")) + (@= a b c d) + (assert (= a.content "abcd")) + (setv a (C "a")) + (@= a (@ b c d)) + (assert (= a.content "abcd")) + + (setv a 15) + (%= a 9) + (assert (= a 6)) + + (setv a 0b1100) + (^= a 0b1010) + (assert (= a 0b0110)))