diff --git a/hy/compiler.py b/hy/compiler.py index 5563233..20eb094 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -50,6 +50,38 @@ def builds(_type): return _dec +def _raise_wrong_args_number(expression, error): + err = TypeError(error % (expression.pop(0), + len(expression))) + err.start_line = expression.start_line + err.start_column = expression.start_column + raise err + + +def checkargs(exact=None, min=None, max=None): + def _dec(fn): + def checker(self, expression): + if exact is not None and (len(expression) - 1) != exact: + _raise_wrong_args_number(expression, + "`%%s' needs %d arguments, got %%d" % + exact) + + if min is not None and (len(expression) - 1) < min: + _raise_wrong_args_number(expression, + "`%%s' needs at least %d arguments, got %%d" % + min) + + if max is not None and (len(expression) - 1) > max: + _raise_wrong_args_number(expression, + "`%%s' needs at most %d arguments, got %%d" % + max) + + return fn(self, expression) + + return checker + return _dec + + class HyASTCompiler(object): def __init__(self): @@ -57,9 +89,16 @@ class HyASTCompiler(object): self.anon_fn_count = 0 def compile(self, tree): - for _type in _compile_table: - if type(tree) == _type: - return _compile_table[_type](self, tree) + try: + for _type in _compile_table: + if type(tree) == _type: + return _compile_table[_type](self, tree) + except Exception as e: + err = HyCompileError(str(e)) + err.exception = e + err.start_line = getattr(e, "start_line", None) + err.start_column = getattr(e, "start_column", None) + raise err raise HyCompileError("Unknown type - `%s'" % (str(type(tree)))) @@ -100,10 +139,12 @@ class HyASTCompiler(object): @builds("do") @builds("progn") + @checkargs(min=1) def compile_do_expression(self, expr): return [self.compile(x) for x in expr[1:]] @builds("throw") + @checkargs(min=1) def compile_throw_expression(self, expr): expr.pop(0) exc = self.compile(expr.pop(0)) @@ -116,6 +157,7 @@ class HyASTCompiler(object): tback=None) @builds("try") + @checkargs(min=1) def compile_try_expression(self, expr): expr.pop(0) # try @@ -134,6 +176,7 @@ class HyASTCompiler(object): orelse=[]) @builds("catch") + @checkargs(min=2) def compile_catch_expression(self, expr): expr.pop(0) # catch _type = self.compile(expr.pop(0)) @@ -163,25 +206,16 @@ class HyASTCompiler(object): return self._mangle_branch([branch]) @builds("if") + @checkargs(min=2, max=3) def compile_if_expression(self, expr): - expr.pop(0) - try: - test = expr.pop(0) - except IndexError: - raise TypeError("if expects at least 2 arguments, got 0") - test = self.compile(test) - try: - body = expr.pop(0) - except IndexError: - raise TypeError("if expects at least 2 arguments, got 1") - body = self._code_branch(self.compile(body)) - orel = [] + expr.pop(0) # if + test = self.compile(expr.pop(0)) + body = self._code_branch(self.compile(expr.pop(0))) if len(expr) == 1: orel = self._code_branch(self.compile(expr.pop(0))) - elif len(expr) > 1: - raise TypeError("if expects 2 or 3 arguments, got %d" % ( - len(expr) + 2)) + else: + orel = [] return ast.If(test=test, body=body, @@ -210,6 +244,7 @@ class HyASTCompiler(object): nl=True) @builds("assert") + @checkargs(1) def compile_assert_expression(self, expr): expr.pop(0) # assert e = expr.pop(0) @@ -219,6 +254,7 @@ class HyASTCompiler(object): col_offset=e.start_column) @builds("lambda") + @checkargs(min=2) def compile_lambda_expression(self, expr): expr.pop(0) sig = expr.pop(0) @@ -241,10 +277,12 @@ class HyASTCompiler(object): body=self.compile(body)) @builds("pass") + @checkargs(0) def compile_pass_expression(self, expr): return ast.Pass(lineno=expr.start_line, col_offset=expr.start_column) @builds("yield") + @checkargs(1) def compile_yield_expression(self, expr): expr.pop(0) return ast.Yield( @@ -272,6 +310,7 @@ class HyASTCompiler(object): asname=str(x[1])) for x in modlist]) @builds("import_from") + @checkargs(min=1) def compile_import_from_expression(self, expr): expr.pop(0) # index return ast.ImportFrom( @@ -282,6 +321,7 @@ class HyASTCompiler(object): level=0) @builds("get") + @checkargs(2) def compile_index_expression(self, expr): expr.pop(0) # index val = self.compile(expr.pop(0)) # target @@ -295,6 +335,7 @@ class HyASTCompiler(object): ctx=ast.Load()) @builds("slice") + @checkargs(min=1, max=3) def compile_slice_expression(self, expr): expr.pop(0) # index val = self.compile(expr.pop(0)) # target @@ -317,6 +358,7 @@ class HyASTCompiler(object): ctx=ast.Load()) @builds("assoc") + @checkargs(3) def compile_assoc_expression(self, expr): expr.pop(0) # assoc # (assoc foo bar baz) => foo[bar] = baz @@ -337,6 +379,7 @@ class HyASTCompiler(object): value=self.compile(val)) @builds("decorate_with") + @checkargs(min=1) def compile_decorate_expression(self, expr): expr.pop(0) # decorate-with fn = self.compile(expr.pop(-1)) @@ -346,6 +389,7 @@ class HyASTCompiler(object): return fn @builds("with_as") + @checkargs(min=2) def compile_with_as_expression(self, expr): expr.pop(0) # with-as ctx = self.compile(expr.pop(0)) @@ -372,6 +416,7 @@ class HyASTCompiler(object): ctx=ast.Load()) @builds("list_comp") + @checkargs(min=2, max=3) def compile_list_comprehension(self, expr): # (list-comp expr (target iter) cond?) expr.pop(0) @@ -406,6 +451,7 @@ class HyASTCompiler(object): return name @builds("kwapply") + @checkargs(2) def compile_kwapply_expression(self, expr): expr.pop(0) # kwapply call = self.compile(expr.pop(0)) @@ -421,10 +467,8 @@ class HyASTCompiler(object): @builds("not") @builds("~") + @checkargs(1) def compile_unary_operator(self, expression): - if len(expression) != 2: - raise TypeError("Unary operator expects only 1 argument, got %d" - % (len(expression) - 1)) ops = {"not": ast.Not, "~": ast.Invert} operator = expression.pop(0) @@ -444,6 +488,7 @@ class HyASTCompiler(object): @builds("in") @builds("is_not") @builds("not_in") + @checkargs(min=2) def compile_compare_op_expression(self, expression): ops = {"=": ast.Eq, "!=": ast.NotEq, "<": ast.Lt, "<=": ast.LtE, @@ -467,6 +512,7 @@ class HyASTCompiler(object): @builds("-") @builds("/") @builds("*") + @checkargs(min=2) def compile_maths_expression(self, expression): # operator = Mod | Pow | LShift | RShift | BitOr | # BitXor | BitAnd | FloorDiv @@ -535,6 +581,7 @@ class HyASTCompiler(object): @builds("def") @builds("setf") @builds("setv") + @checkargs(2) def compile_def_expression(self, expression): expression.pop(0) # "def" name = expression.pop(0) @@ -556,6 +603,7 @@ class HyASTCompiler(object): targets=[name], value=what) @builds("foreach") + @checkargs(min=1) def compile_for_expression(self, expression): ret_status = self.returnable self.returnable = False @@ -579,17 +627,10 @@ class HyASTCompiler(object): return ret @builds("while") + @checkargs(min=2) def compile_while_expression(self, expr): expr.pop(0) # "while" - - try: - test = expr.pop(0) - except IndexError: - raise TypeError("while expects at least 2 arguments, got 0") - test = self.compile(test) - - if not expr: - raise TypeError("while expects a body") + test = self.compile(expr.pop(0)) return ast.While(test=test, body=self._mangle_branch([ @@ -607,6 +648,7 @@ class HyASTCompiler(object): col_offset=expr.start_column) @builds("fn") + @checkargs(min=2) def compile_fn_expression(self, expression): expression.pop(0) # fn diff --git a/hy/importer.py b/hy/importer.py index 1569d38..21b6d74 100644 --- a/hy/importer.py +++ b/hy/importer.py @@ -48,7 +48,13 @@ def import_file_to_hst(fpath): def import_file_to_ast(fpath): tree = import_file_to_hst(fpath) - ast = hy_compile(tree) + try: + ast = hy_compile(tree) + except Exception as e: + print("Compilation error at %s:%d,%d" + % (fpath, e.start_line, e.start_column)) + print("Compilation error: " + e.message) + raise e.exception return ast diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index 5f03672..23368aa 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -38,7 +38,7 @@ def cant_compile(expr): try: hy_compile(tokenize(expr)) assert False - except TypeError: + except HyCompileError: pass @@ -56,7 +56,7 @@ def test_ast_bad_if_0_arg(): try: hy_compile(tokenize("(if)")) assert False - except TypeError: + except HyCompileError: pass @@ -65,7 +65,7 @@ def test_ast_bad_if_1_arg(): try: hy_compile(tokenize("(if foobar)")) assert False - except TypeError: + except HyCompileError: pass @@ -74,7 +74,7 @@ def test_ast_bad_if_too_much_arg(): try: hy_compile(tokenize("(if 1 2 3 4 5)")) assert False - except TypeError: + except HyCompileError: pass @@ -103,7 +103,7 @@ def test_ast_bad_while_0_arg(): try: hy_compile(tokenize("(while)")) assert False - except TypeError: + except HyCompileError: pass @@ -112,10 +112,143 @@ def test_ast_bad_while_1_arg(): try: hy_compile(tokenize("(while (true))")) assert False - except TypeError: + except HyCompileError: pass +def test_ast_good_do(): + "Make sure AST can compile valid do" + hy_compile(tokenize("(do 1)")) + + +def test_ast_bad_do(): + "Make sure AST can't compile invalid do" + cant_compile("(do)") + + +def test_ast_good_throw(): + "Make sure AST can compile valid throw" + hy_compile(tokenize("(throw 1)")) + + +def test_ast_bad_throw(): + "Make sure AST can't compile invalid throw" + cant_compile("(throw)") + + +def test_ast_good_try(): + "Make sure AST can compile valid try" + hy_compile(tokenize("(try 1)")) + + +def test_ast_bad_try(): + "Make sure AST can't compile invalid try" + cant_compile("(try)") + + +def test_ast_good_catch(): + "Make sure AST can compile valid catch" + hy_compile(tokenize("(catch 1 2)")) + + +def test_ast_bad_catch(): + "Make sure AST can't compile invalid catch" + cant_compile("(catch)") + cant_compile("(catch 1)") + + +def test_ast_good_assert(): + "Make sure AST can compile valid assert" + hy_compile(tokenize("(assert 1)")) + + +def test_ast_bad_assert(): + "Make sure AST can't compile invalid assert" + cant_compile("(assert)") + cant_compile("(assert 1 2)") + + +def test_ast_good_lambda(): + "Make sure AST can compile valid lambda" + hy_compile(tokenize("(lambda [] 1)")) + + +def test_ast_bad_lambda(): + "Make sure AST can't compile invalid lambda" + cant_compile("(lambda)") + cant_compile("(lambda [])") + + +def test_ast_good_pass(): + "Make sure AST can compile valid pass" + hy_compile(tokenize("(pass)")) + + +def test_ast_bad_pass(): + "Make sure AST can't compile invalid pass" + cant_compile("(pass 1)") + cant_compile("(pass 1 2)") + + +def test_ast_good_yield(): + "Make sure AST can compile valid yield" + hy_compile(tokenize("(yield 1)")) + + +def test_ast_bad_yield(): + "Make sure AST can't compile invalid yield" + cant_compile("(yield)") + cant_compile("(yield 1 2)") + + +def test_ast_good_import_from(): + "Make sure AST can compile valid import-from" + hy_compile(tokenize("(import-from x y)")) + + +def test_ast_bad_import_from(): + "Make sure AST can't compile invalid import-from" + cant_compile("(import-from)") + + +def test_ast_good_get(): + "Make sure AST can compile valid get" + hy_compile(tokenize("(get x y)")) + + +def test_ast_bad_get(): + "Make sure AST can't compile invalid get" + cant_compile("(get)") + cant_compile("(get 1)") + cant_compile("(get 1 2 3)") + + +def test_ast_good_slice(): + "Make sure AST can compile valid slice" + hy_compile(tokenize("(slice x)")) + hy_compile(tokenize("(slice x y)")) + hy_compile(tokenize("(slice x y z)")) + + +def test_ast_bad_slice(): + "Make sure AST can't compile invalid slice" + cant_compile("(slice)") + cant_compile("(slice 1 2 3 4)") + + +def test_ast_good_assoc(): + "Make sure AST can compile valid assoc" + hy_compile(tokenize("(assoc x y z)")) + + +def test_ast_bad_assoc(): + "Make sure AST can't compile invalid assoc" + cant_compile("(assoc)") + cant_compile("(assoc 1)") + cant_compile("(assoc 1 2)") + cant_compile("(assoc 1 2 3 4)") + + def test_ast_valid_while(): "Make sure AST can't compile invalid while" hy_compile(tokenize("(while foo bar)")) @@ -151,7 +284,7 @@ def test_ast_non_decoratable(): try: hy_compile(tokenize("(decorate-with (foo) (* x x))")) assert True is False - except TypeError: + except HyCompileError: pass @@ -162,7 +295,7 @@ def test_ast_non_kwapplyable(): try: hy_compile(code) assert True is False - except TypeError: + except HyCompileError: pass diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index 1db6def..2f539ec 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -119,7 +119,7 @@ (defn test-index [] "NATIVE: Test that dict access works" - (assert (get {"one" "two"} "one") "two") + (assert (= (get {"one" "two"} "one") "two")) (assert (= (get [1 2 3 4 5] 1) 2)))