diff --git a/NEWS.rst b/NEWS.rst index 71da3c9..be8eb70 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -17,6 +17,7 @@ New Features ------------------------------ * Added special forms ``py`` to ``pys`` that allow Hy programs to include inline Python code. +* Added a special form ``cmp`` for chained comparisons. * All augmented assignment operators (except `%=` and `^=`) now allow more than two arguments. * PEP 3107 and PEP 526 function and variable annotations are now supported. diff --git a/docs/language/api.rst b/docs/language/api.rst index e675126..cb1d70d 100644 --- a/docs/language/api.rst +++ b/docs/language/api.rst @@ -299,6 +299,38 @@ as the user enters *k*. (print "Try again"))) +cmp +--- + +``cmp`` creates a :ref:`comparison expression `. It isn't +required for unchained comparisons, which have only one comparison operator, +nor for chains of the same operator. For those cases, you can use the +comparison operators directly with Hy's usual prefix syntax, as in ``(= x 1)`` +or ``(< 1 2 3)``. The use of ``cmp`` is to construct chains of heterogeneous +operators, such as ``x <= y < z``. It uses an infix syntax with the general +form + +:: + + (cmp ARG OP ARG OP ARG…) + +Hence, ``(cmp x <= y < z)`` is equivalent to ``(and (<= x y) (< y z))``, +including short-circuiting, except that ``y`` is only evaluated once. + +Each ``ARG`` is an arbitrary form, which does not itself use infix syntax. Use +:ref:`py-specialform` if you want fully Python-style operator syntax. You can +also nest ``cmp`` forms, although this is rarely useful. Each ``OP`` is a +literal comparison operator; other forms that resolve to a comparison operator +are not allowed. + +At least two ``ARG``\ s and one ``OP`` are required, and every ``OP`` must be +followed by an ``ARG``. + +As elsewhere in Hy, the equality operator is spelled ``=``, not ``==`` as in +Python. + + + comment ------- diff --git a/hy/compiler.py b/hy/compiler.py index 2b8bb6e..840f9f7 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1256,12 +1256,18 @@ class HyASTCompiler(object): values=[value.force_expr for value in values]) return ret - c_ops = {"=": ast.Eq, "!=": ast.NotEq, + _c_ops = {"=": ast.Eq, "!=": ast.NotEq, "<": ast.Lt, "<=": ast.LtE, ">": ast.Gt, ">=": ast.GtE, "is": ast.Is, "is-not": ast.IsNot, "in": ast.In, "not-in": ast.NotIn} - c_ops = {ast_str(k): v for k, v in c_ops.items()} + _c_ops = {ast_str(k): v for k, v in _c_ops.items()} + def _get_c_op(self, sym): + k = ast_str(sym) + if k not in self._c_ops: + raise self._syntax_error(sym, + "Illegal comparison operator: " + str(sym)) + return self._c_ops[k]() @special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)]) @special(["!=", "is-not"], [times(2, Inf, FORM)]) @@ -1271,11 +1277,23 @@ class HyASTCompiler(object): return (self.compile(args[0]) + asty.Name(expr, id="True", ctx=ast.Load())) - ops = [self.c_ops[ast_str(root)]() for _ in args[1:]] + ops = [self._get_c_op(root) for _ in args[1:]] exprs, ret, _ = self._compile_collect(args) return ret + asty.Compare( expr, left=exprs[0], ops=ops, comparators=exprs[1:]) + @special("cmp", [FORM, many(SYM + FORM)]) + def compile_chained_comparison(self, expr, root, arg1, args): + ret = self.compile(arg1) + arg1 = ret.force_expr + + ops = [self._get_c_op(op) for op, _ in args] + args, ret2, _ = self._compile_collect( + [x for _, x in args]) + + return ret + ret2 + asty.Compare(expr, + left=arg1, ops=ops, comparators=args) + # The second element of each tuple below is an aggregation operator # that's used for augmented assignment with three or more arguments. m_ops = {"+": (ast.Add, "+"), diff --git a/tests/native_tests/operators.hy b/tests/native_tests/operators.hy index e8c858d..e62fa57 100644 --- a/tests/native_tests/operators.hy +++ b/tests/native_tests/operators.hy @@ -305,6 +305,21 @@ (assert (= (f {"x" {"y" {"z" 12}}} "x" "y" "z") 12))) +(defn test-chained-comparison [] + (assert (cmp 2 = (+ 1 1) = (- 3 1))) + (assert (not (cmp 2 = (+ 1 1) = (+ 3 1)))) + + (assert (cmp 2 = 2 > 1)) + (assert (cmp 2 = (+ 1 1) > 1)) + (setv x 2) + (assert (cmp 2 = x > 1)) + (assert (cmp 2 = x > (> 4 3))) + (assert (not (cmp (> 4 3) = x > 1))) + + (assert (cmp 1 in [1] in [[1] [2 3]] not-in [5])) + (assert (not (cmp 1 in [1] not-in [[1] [2 3]] not-in [5])))) + + (defn test-augassign [] (setv b 2 c 3 d 4) (defmacro same-as [expr1 expr2 expected-value]