Check the number of arguments for each function

Signed-off-by: Julien Danjou <julien@danjou.info>
This commit is contained in:
Julien Danjou 2013-04-06 16:33:06 +02:00
parent 4b57fd0a51
commit 0eb795b4a5
4 changed files with 221 additions and 40 deletions

View File

@ -50,6 +50,38 @@ def builds(_type):
return _dec return _dec
def _raise_wrong_args_number(expression, error):
err = TypeError(error % (expression.pop(0),
len(expression)))
err.start_line = expression.start_line
err.start_column = expression.start_column
raise err
def checkargs(exact=None, min=None, max=None):
def _dec(fn):
def checker(self, expression):
if exact is not None and (len(expression) - 1) != exact:
_raise_wrong_args_number(expression,
"`%%s' needs %d arguments, got %%d" %
exact)
if min is not None and (len(expression) - 1) < min:
_raise_wrong_args_number(expression,
"`%%s' needs at least %d arguments, got %%d" %
min)
if max is not None and (len(expression) - 1) > max:
_raise_wrong_args_number(expression,
"`%%s' needs at most %d arguments, got %%d" %
max)
return fn(self, expression)
return checker
return _dec
class HyASTCompiler(object): class HyASTCompiler(object):
def __init__(self): def __init__(self):
@ -57,9 +89,16 @@ class HyASTCompiler(object):
self.anon_fn_count = 0 self.anon_fn_count = 0
def compile(self, tree): def compile(self, tree):
for _type in _compile_table: try:
if type(tree) == _type: for _type in _compile_table:
return _compile_table[_type](self, tree) if type(tree) == _type:
return _compile_table[_type](self, tree)
except Exception as e:
err = HyCompileError(str(e))
err.exception = e
err.start_line = getattr(e, "start_line", None)
err.start_column = getattr(e, "start_column", None)
raise err
raise HyCompileError("Unknown type - `%s'" % (str(type(tree)))) raise HyCompileError("Unknown type - `%s'" % (str(type(tree))))
@ -100,10 +139,12 @@ class HyASTCompiler(object):
@builds("do") @builds("do")
@builds("progn") @builds("progn")
@checkargs(min=1)
def compile_do_expression(self, expr): def compile_do_expression(self, expr):
return [self.compile(x) for x in expr[1:]] return [self.compile(x) for x in expr[1:]]
@builds("throw") @builds("throw")
@checkargs(min=1)
def compile_throw_expression(self, expr): def compile_throw_expression(self, expr):
expr.pop(0) expr.pop(0)
exc = self.compile(expr.pop(0)) exc = self.compile(expr.pop(0))
@ -116,6 +157,7 @@ class HyASTCompiler(object):
tback=None) tback=None)
@builds("try") @builds("try")
@checkargs(min=1)
def compile_try_expression(self, expr): def compile_try_expression(self, expr):
expr.pop(0) # try expr.pop(0) # try
@ -134,6 +176,7 @@ class HyASTCompiler(object):
orelse=[]) orelse=[])
@builds("catch") @builds("catch")
@checkargs(min=2)
def compile_catch_expression(self, expr): def compile_catch_expression(self, expr):
expr.pop(0) # catch expr.pop(0) # catch
_type = self.compile(expr.pop(0)) _type = self.compile(expr.pop(0))
@ -163,25 +206,16 @@ class HyASTCompiler(object):
return self._mangle_branch([branch]) return self._mangle_branch([branch])
@builds("if") @builds("if")
@checkargs(min=2, max=3)
def compile_if_expression(self, expr): def compile_if_expression(self, expr):
expr.pop(0) expr.pop(0) # if
try: test = self.compile(expr.pop(0))
test = expr.pop(0) body = self._code_branch(self.compile(expr.pop(0)))
except IndexError:
raise TypeError("if expects at least 2 arguments, got 0")
test = self.compile(test)
try:
body = expr.pop(0)
except IndexError:
raise TypeError("if expects at least 2 arguments, got 1")
body = self._code_branch(self.compile(body))
orel = []
if len(expr) == 1: if len(expr) == 1:
orel = self._code_branch(self.compile(expr.pop(0))) orel = self._code_branch(self.compile(expr.pop(0)))
elif len(expr) > 1: else:
raise TypeError("if expects 2 or 3 arguments, got %d" % ( orel = []
len(expr) + 2))
return ast.If(test=test, return ast.If(test=test,
body=body, body=body,
@ -210,6 +244,7 @@ class HyASTCompiler(object):
nl=True) nl=True)
@builds("assert") @builds("assert")
@checkargs(1)
def compile_assert_expression(self, expr): def compile_assert_expression(self, expr):
expr.pop(0) # assert expr.pop(0) # assert
e = expr.pop(0) e = expr.pop(0)
@ -219,6 +254,7 @@ class HyASTCompiler(object):
col_offset=e.start_column) col_offset=e.start_column)
@builds("lambda") @builds("lambda")
@checkargs(min=2)
def compile_lambda_expression(self, expr): def compile_lambda_expression(self, expr):
expr.pop(0) expr.pop(0)
sig = expr.pop(0) sig = expr.pop(0)
@ -241,10 +277,12 @@ class HyASTCompiler(object):
body=self.compile(body)) body=self.compile(body))
@builds("pass") @builds("pass")
@checkargs(0)
def compile_pass_expression(self, expr): def compile_pass_expression(self, expr):
return ast.Pass(lineno=expr.start_line, col_offset=expr.start_column) return ast.Pass(lineno=expr.start_line, col_offset=expr.start_column)
@builds("yield") @builds("yield")
@checkargs(1)
def compile_yield_expression(self, expr): def compile_yield_expression(self, expr):
expr.pop(0) expr.pop(0)
return ast.Yield( return ast.Yield(
@ -272,6 +310,7 @@ class HyASTCompiler(object):
asname=str(x[1])) for x in modlist]) asname=str(x[1])) for x in modlist])
@builds("import_from") @builds("import_from")
@checkargs(min=1)
def compile_import_from_expression(self, expr): def compile_import_from_expression(self, expr):
expr.pop(0) # index expr.pop(0) # index
return ast.ImportFrom( return ast.ImportFrom(
@ -282,6 +321,7 @@ class HyASTCompiler(object):
level=0) level=0)
@builds("get") @builds("get")
@checkargs(2)
def compile_index_expression(self, expr): def compile_index_expression(self, expr):
expr.pop(0) # index expr.pop(0) # index
val = self.compile(expr.pop(0)) # target val = self.compile(expr.pop(0)) # target
@ -295,6 +335,7 @@ class HyASTCompiler(object):
ctx=ast.Load()) ctx=ast.Load())
@builds("slice") @builds("slice")
@checkargs(min=1, max=3)
def compile_slice_expression(self, expr): def compile_slice_expression(self, expr):
expr.pop(0) # index expr.pop(0) # index
val = self.compile(expr.pop(0)) # target val = self.compile(expr.pop(0)) # target
@ -317,6 +358,7 @@ class HyASTCompiler(object):
ctx=ast.Load()) ctx=ast.Load())
@builds("assoc") @builds("assoc")
@checkargs(3)
def compile_assoc_expression(self, expr): def compile_assoc_expression(self, expr):
expr.pop(0) # assoc expr.pop(0) # assoc
# (assoc foo bar baz) => foo[bar] = baz # (assoc foo bar baz) => foo[bar] = baz
@ -337,6 +379,7 @@ class HyASTCompiler(object):
value=self.compile(val)) value=self.compile(val))
@builds("decorate_with") @builds("decorate_with")
@checkargs(min=1)
def compile_decorate_expression(self, expr): def compile_decorate_expression(self, expr):
expr.pop(0) # decorate-with expr.pop(0) # decorate-with
fn = self.compile(expr.pop(-1)) fn = self.compile(expr.pop(-1))
@ -346,6 +389,7 @@ class HyASTCompiler(object):
return fn return fn
@builds("with_as") @builds("with_as")
@checkargs(min=2)
def compile_with_as_expression(self, expr): def compile_with_as_expression(self, expr):
expr.pop(0) # with-as expr.pop(0) # with-as
ctx = self.compile(expr.pop(0)) ctx = self.compile(expr.pop(0))
@ -372,6 +416,7 @@ class HyASTCompiler(object):
ctx=ast.Load()) ctx=ast.Load())
@builds("list_comp") @builds("list_comp")
@checkargs(min=2, max=3)
def compile_list_comprehension(self, expr): def compile_list_comprehension(self, expr):
# (list-comp expr (target iter) cond?) # (list-comp expr (target iter) cond?)
expr.pop(0) expr.pop(0)
@ -406,6 +451,7 @@ class HyASTCompiler(object):
return name return name
@builds("kwapply") @builds("kwapply")
@checkargs(2)
def compile_kwapply_expression(self, expr): def compile_kwapply_expression(self, expr):
expr.pop(0) # kwapply expr.pop(0) # kwapply
call = self.compile(expr.pop(0)) call = self.compile(expr.pop(0))
@ -421,10 +467,8 @@ class HyASTCompiler(object):
@builds("not") @builds("not")
@builds("~") @builds("~")
@checkargs(1)
def compile_unary_operator(self, expression): def compile_unary_operator(self, expression):
if len(expression) != 2:
raise TypeError("Unary operator expects only 1 argument, got %d"
% (len(expression) - 1))
ops = {"not": ast.Not, ops = {"not": ast.Not,
"~": ast.Invert} "~": ast.Invert}
operator = expression.pop(0) operator = expression.pop(0)
@ -444,6 +488,7 @@ class HyASTCompiler(object):
@builds("in") @builds("in")
@builds("is_not") @builds("is_not")
@builds("not_in") @builds("not_in")
@checkargs(min=2)
def compile_compare_op_expression(self, expression): def compile_compare_op_expression(self, expression):
ops = {"=": ast.Eq, "!=": ast.NotEq, ops = {"=": ast.Eq, "!=": ast.NotEq,
"<": ast.Lt, "<=": ast.LtE, "<": ast.Lt, "<=": ast.LtE,
@ -467,6 +512,7 @@ class HyASTCompiler(object):
@builds("-") @builds("-")
@builds("/") @builds("/")
@builds("*") @builds("*")
@checkargs(min=2)
def compile_maths_expression(self, expression): def compile_maths_expression(self, expression):
# operator = Mod | Pow | LShift | RShift | BitOr | # operator = Mod | Pow | LShift | RShift | BitOr |
# BitXor | BitAnd | FloorDiv # BitXor | BitAnd | FloorDiv
@ -535,6 +581,7 @@ class HyASTCompiler(object):
@builds("def") @builds("def")
@builds("setf") @builds("setf")
@builds("setv") @builds("setv")
@checkargs(2)
def compile_def_expression(self, expression): def compile_def_expression(self, expression):
expression.pop(0) # "def" expression.pop(0) # "def"
name = expression.pop(0) name = expression.pop(0)
@ -556,6 +603,7 @@ class HyASTCompiler(object):
targets=[name], value=what) targets=[name], value=what)
@builds("foreach") @builds("foreach")
@checkargs(min=1)
def compile_for_expression(self, expression): def compile_for_expression(self, expression):
ret_status = self.returnable ret_status = self.returnable
self.returnable = False self.returnable = False
@ -579,17 +627,10 @@ class HyASTCompiler(object):
return ret return ret
@builds("while") @builds("while")
@checkargs(min=2)
def compile_while_expression(self, expr): def compile_while_expression(self, expr):
expr.pop(0) # "while" expr.pop(0) # "while"
test = self.compile(expr.pop(0))
try:
test = expr.pop(0)
except IndexError:
raise TypeError("while expects at least 2 arguments, got 0")
test = self.compile(test)
if not expr:
raise TypeError("while expects a body")
return ast.While(test=test, return ast.While(test=test,
body=self._mangle_branch([ body=self._mangle_branch([
@ -607,6 +648,7 @@ class HyASTCompiler(object):
col_offset=expr.start_column) col_offset=expr.start_column)
@builds("fn") @builds("fn")
@checkargs(min=2)
def compile_fn_expression(self, expression): def compile_fn_expression(self, expression):
expression.pop(0) # fn expression.pop(0) # fn

View File

@ -48,7 +48,13 @@ def import_file_to_hst(fpath):
def import_file_to_ast(fpath): def import_file_to_ast(fpath):
tree = import_file_to_hst(fpath) tree = import_file_to_hst(fpath)
ast = hy_compile(tree) try:
ast = hy_compile(tree)
except Exception as e:
print("Compilation error at %s:%d,%d"
% (fpath, e.start_line, e.start_column))
print("Compilation error: " + e.message)
raise e.exception
return ast return ast

View File

@ -38,7 +38,7 @@ def cant_compile(expr):
try: try:
hy_compile(tokenize(expr)) hy_compile(tokenize(expr))
assert False assert False
except TypeError: except HyCompileError:
pass pass
@ -56,7 +56,7 @@ def test_ast_bad_if_0_arg():
try: try:
hy_compile(tokenize("(if)")) hy_compile(tokenize("(if)"))
assert False assert False
except TypeError: except HyCompileError:
pass pass
@ -65,7 +65,7 @@ def test_ast_bad_if_1_arg():
try: try:
hy_compile(tokenize("(if foobar)")) hy_compile(tokenize("(if foobar)"))
assert False assert False
except TypeError: except HyCompileError:
pass pass
@ -74,7 +74,7 @@ def test_ast_bad_if_too_much_arg():
try: try:
hy_compile(tokenize("(if 1 2 3 4 5)")) hy_compile(tokenize("(if 1 2 3 4 5)"))
assert False assert False
except TypeError: except HyCompileError:
pass pass
@ -103,7 +103,7 @@ def test_ast_bad_while_0_arg():
try: try:
hy_compile(tokenize("(while)")) hy_compile(tokenize("(while)"))
assert False assert False
except TypeError: except HyCompileError:
pass pass
@ -112,10 +112,143 @@ def test_ast_bad_while_1_arg():
try: try:
hy_compile(tokenize("(while (true))")) hy_compile(tokenize("(while (true))"))
assert False assert False
except TypeError: except HyCompileError:
pass pass
def test_ast_good_do():
"Make sure AST can compile valid do"
hy_compile(tokenize("(do 1)"))
def test_ast_bad_do():
"Make sure AST can't compile invalid do"
cant_compile("(do)")
def test_ast_good_throw():
"Make sure AST can compile valid throw"
hy_compile(tokenize("(throw 1)"))
def test_ast_bad_throw():
"Make sure AST can't compile invalid throw"
cant_compile("(throw)")
def test_ast_good_try():
"Make sure AST can compile valid try"
hy_compile(tokenize("(try 1)"))
def test_ast_bad_try():
"Make sure AST can't compile invalid try"
cant_compile("(try)")
def test_ast_good_catch():
"Make sure AST can compile valid catch"
hy_compile(tokenize("(catch 1 2)"))
def test_ast_bad_catch():
"Make sure AST can't compile invalid catch"
cant_compile("(catch)")
cant_compile("(catch 1)")
def test_ast_good_assert():
"Make sure AST can compile valid assert"
hy_compile(tokenize("(assert 1)"))
def test_ast_bad_assert():
"Make sure AST can't compile invalid assert"
cant_compile("(assert)")
cant_compile("(assert 1 2)")
def test_ast_good_lambda():
"Make sure AST can compile valid lambda"
hy_compile(tokenize("(lambda [] 1)"))
def test_ast_bad_lambda():
"Make sure AST can't compile invalid lambda"
cant_compile("(lambda)")
cant_compile("(lambda [])")
def test_ast_good_pass():
"Make sure AST can compile valid pass"
hy_compile(tokenize("(pass)"))
def test_ast_bad_pass():
"Make sure AST can't compile invalid pass"
cant_compile("(pass 1)")
cant_compile("(pass 1 2)")
def test_ast_good_yield():
"Make sure AST can compile valid yield"
hy_compile(tokenize("(yield 1)"))
def test_ast_bad_yield():
"Make sure AST can't compile invalid yield"
cant_compile("(yield)")
cant_compile("(yield 1 2)")
def test_ast_good_import_from():
"Make sure AST can compile valid import-from"
hy_compile(tokenize("(import-from x y)"))
def test_ast_bad_import_from():
"Make sure AST can't compile invalid import-from"
cant_compile("(import-from)")
def test_ast_good_get():
"Make sure AST can compile valid get"
hy_compile(tokenize("(get x y)"))
def test_ast_bad_get():
"Make sure AST can't compile invalid get"
cant_compile("(get)")
cant_compile("(get 1)")
cant_compile("(get 1 2 3)")
def test_ast_good_slice():
"Make sure AST can compile valid slice"
hy_compile(tokenize("(slice x)"))
hy_compile(tokenize("(slice x y)"))
hy_compile(tokenize("(slice x y z)"))
def test_ast_bad_slice():
"Make sure AST can't compile invalid slice"
cant_compile("(slice)")
cant_compile("(slice 1 2 3 4)")
def test_ast_good_assoc():
"Make sure AST can compile valid assoc"
hy_compile(tokenize("(assoc x y z)"))
def test_ast_bad_assoc():
"Make sure AST can't compile invalid assoc"
cant_compile("(assoc)")
cant_compile("(assoc 1)")
cant_compile("(assoc 1 2)")
cant_compile("(assoc 1 2 3 4)")
def test_ast_valid_while(): def test_ast_valid_while():
"Make sure AST can't compile invalid while" "Make sure AST can't compile invalid while"
hy_compile(tokenize("(while foo bar)")) hy_compile(tokenize("(while foo bar)"))
@ -151,7 +284,7 @@ def test_ast_non_decoratable():
try: try:
hy_compile(tokenize("(decorate-with (foo) (* x x))")) hy_compile(tokenize("(decorate-with (foo) (* x x))"))
assert True is False assert True is False
except TypeError: except HyCompileError:
pass pass
@ -162,7 +295,7 @@ def test_ast_non_kwapplyable():
try: try:
hy_compile(code) hy_compile(code)
assert True is False assert True is False
except TypeError: except HyCompileError:
pass pass

View File

@ -119,7 +119,7 @@
(defn test-index [] (defn test-index []
"NATIVE: Test that dict access works" "NATIVE: Test that dict access works"
(assert (get {"one" "two"} "one") "two") (assert (= (get {"one" "two"} "one") "two"))
(assert (= (get [1 2 3 4 5] 1) 2))) (assert (= (get [1 2 3 4 5] 1) 2)))