diff --git a/hy/compiler.py b/hy/compiler.py index c55fbbc..2dd13c4 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1811,18 +1811,7 @@ class HyASTCompiler(object): values=[value.force_expr for value in values]) return ret - @builds("=") - @builds("!=") - @builds("<") - @builds("<=") - @builds(">") - @builds(">=") - @builds("is") - @builds("in") - @builds("is_not") - @builds("not_in") - @checkargs(min=2) - def compile_compare_op_expression(self, expression): + def _compile_compare_op_expression(self, expression): ops = {"=": ast.Eq, "!=": ast.NotEq, "<": ast.Lt, "<=": ast.LtE, ">": ast.Gt, ">=": ast.GtE, @@ -1842,6 +1831,32 @@ class HyASTCompiler(object): lineno=e.start_line, col_offset=e.start_column) + @builds("=") + @builds("!=") + @builds("<") + @builds("<=") + @builds(">") + @builds(">=") + @checkargs(min=1) + def compile_compare_op_expression(self, expression): + if len(expression) == 2: + rval = "True" + if expression[0] == "!=": + rval = "False" + return ast.Name(id=rval, + ctx=ast.Load(), + lineno=expression.start_line, + col_offset=expression.start_column) + return self._compile_compare_op_expression(expression) + + @builds("is") + @builds("in") + @builds("is_not") + @builds("not_in") + @checkargs(min=2) + def compile_compare_op_expression_coll(self, expression): + return self._compile_compare_op_expression(expression) + @builds("%") @builds("**") @builds("<<") diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index adb0fa6..3a1f92d 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -468,9 +468,9 @@ def test_ast_unicode_strings(): def test_compile_error(): """Ensure we get compile error in tricky cases""" try: - can_compile("(fn [] (= 1))") + can_compile("(fn [] (in [1 2 3]))") except HyTypeError as e: - assert(e.message == "`=' needs at least 2 arguments, got 1.") + assert(e.message == "`in' needs at least 2 arguments, got 1.") else: assert(False) diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index 7821e8d..9986ff7 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -208,15 +208,26 @@ (defn test-noteq [] "NATIVE: not eq" - (assert (!= 2 3))) + (assert (!= 2 3)) + (assert (not (!= 1)))) + + +(defn test-eq [] + "NATIVE: eq" + (assert (= 1 1)) + (assert (= 1))) (defn test-numops [] "NATIVE: test numpos" (assert (> 5 4 3 2 1)) + (assert (> 1)) (assert (< 1 2 3 4 5)) + (assert (< 1)) (assert (<= 5 5 5 5 )) - (assert (>= 5 5 5 5 ))) + (assert (<= 1)) + (assert (>= 5 5 5 5 )) + (assert (>= 1))) (defn test-is []