From 3e9a2178c52faa3f89a4b4af4ba10675515657cb Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Sat, 6 Apr 2013 10:37:21 +0200 Subject: [PATCH] Add support for unary operators (not, ~) Signed-off-by: Julien Danjou --- hy/compiler.py | 16 ++++++++++++++++ tests/compilers/test_ast.py | 23 +++++++++++++++++++++++ tests/native_tests/language.hy | 13 +++++++++++++ 3 files changed, 52 insertions(+) diff --git a/hy/compiler.py b/hy/compiler.py index 4a72b1a..d6168a5 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -415,6 +415,22 @@ class HyASTCompiler(object): return call + @builds("not") + @builds("~") + def compile_unary_operator(self, expression): + if len(expression) != 2: + raise TypeError("Unary operator expects only 1 argument, got %d" + % (len(expression) - 1)) + ops = {"not": ast.Not, + "~": ast.Invert} + operator = expression.pop(0) + operand = expression.pop(0) + return ast.UnaryOp(op=ops[operator](), + operand=self.compile(operand), + lineno=operator.start_line, + col_offset=operator.start_column) + + @builds("=") @builds("!=") @builds("<") diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index 88208d1..7db9b57 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -34,6 +34,14 @@ def _ast_spotcheck(arg, root, secondary): assert getattr(root, arg) == getattr(secondary, arg) +def cant_compile(expr): + try: + hy_compile(tokenize(expr)) + assert False + except TypeError: + pass + + def test_ast_bad_type(): "Make sure AST breakage can happen" try: @@ -66,6 +74,21 @@ def test_ast_valid_if(): hy_compile(tokenize("(if foo bar)")) +def test_ast_valid_unary_op(): + "Make sure AST can compile valid unary operator" + hy_compile(tokenize("(not 2)")) + hy_compile(tokenize("(~ 1)")) + + +def test_ast_invalid_unary_op(): + "Make sure AST can't compile invalid unary operator" + cant_compile("(not 2 3 4)") + cant_compile("(not)") + cant_compile("(not 2 3 4)") + cant_compile("(~ 2 2 3 4)") + cant_compile("(~)") + + def test_ast_bad_while_0_arg(): "Make sure AST can't compile invalid while" try: diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index 4366d06..3d320c6 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -43,6 +43,19 @@ (assert (= fact 120))) +(defn test-not [] + "NATIVE: test not" + (assert (not (= 1 2))) + (assert (= true (not false))) + (assert (= false (not 42))) ) + + +(defn test-inv [] + "NATIVE: test inv" + (assert (= (~ 1) -2)) + (assert (= (~ -2) 1))) + + (defn test-in [] "NATIVE: test in" (assert (in "a" ["a" "b" "c" "d"]))