Check the number of arguments for each function

Signed-off-by: Julien Danjou <julien@danjou.info>
This commit is contained in:
Julien Danjou 2013-04-06 16:33:06 +02:00
parent 4b57fd0a51
commit 0eb795b4a5
4 changed files with 221 additions and 40 deletions

View File

@ -50,6 +50,38 @@ def builds(_type):
return _dec
def _raise_wrong_args_number(expression, error):
err = TypeError(error % (expression.pop(0),
len(expression)))
err.start_line = expression.start_line
err.start_column = expression.start_column
raise err
def checkargs(exact=None, min=None, max=None):
def _dec(fn):
def checker(self, expression):
if exact is not None and (len(expression) - 1) != exact:
_raise_wrong_args_number(expression,
"`%%s' needs %d arguments, got %%d" %
exact)
if min is not None and (len(expression) - 1) < min:
_raise_wrong_args_number(expression,
"`%%s' needs at least %d arguments, got %%d" %
min)
if max is not None and (len(expression) - 1) > max:
_raise_wrong_args_number(expression,
"`%%s' needs at most %d arguments, got %%d" %
max)
return fn(self, expression)
return checker
return _dec
class HyASTCompiler(object):
def __init__(self):
@ -57,9 +89,16 @@ class HyASTCompiler(object):
self.anon_fn_count = 0
def compile(self, tree):
for _type in _compile_table:
if type(tree) == _type:
return _compile_table[_type](self, tree)
try:
for _type in _compile_table:
if type(tree) == _type:
return _compile_table[_type](self, tree)
except Exception as e:
err = HyCompileError(str(e))
err.exception = e
err.start_line = getattr(e, "start_line", None)
err.start_column = getattr(e, "start_column", None)
raise err
raise HyCompileError("Unknown type - `%s'" % (str(type(tree))))
@ -100,10 +139,12 @@ class HyASTCompiler(object):
@builds("do")
@builds("progn")
@checkargs(min=1)
def compile_do_expression(self, expr):
return [self.compile(x) for x in expr[1:]]
@builds("throw")
@checkargs(min=1)
def compile_throw_expression(self, expr):
expr.pop(0)
exc = self.compile(expr.pop(0))
@ -116,6 +157,7 @@ class HyASTCompiler(object):
tback=None)
@builds("try")
@checkargs(min=1)
def compile_try_expression(self, expr):
expr.pop(0) # try
@ -134,6 +176,7 @@ class HyASTCompiler(object):
orelse=[])
@builds("catch")
@checkargs(min=2)
def compile_catch_expression(self, expr):
expr.pop(0) # catch
_type = self.compile(expr.pop(0))
@ -163,25 +206,16 @@ class HyASTCompiler(object):
return self._mangle_branch([branch])
@builds("if")
@checkargs(min=2, max=3)
def compile_if_expression(self, expr):
expr.pop(0)
try:
test = expr.pop(0)
except IndexError:
raise TypeError("if expects at least 2 arguments, got 0")
test = self.compile(test)
try:
body = expr.pop(0)
except IndexError:
raise TypeError("if expects at least 2 arguments, got 1")
body = self._code_branch(self.compile(body))
orel = []
expr.pop(0) # if
test = self.compile(expr.pop(0))
body = self._code_branch(self.compile(expr.pop(0)))
if len(expr) == 1:
orel = self._code_branch(self.compile(expr.pop(0)))
elif len(expr) > 1:
raise TypeError("if expects 2 or 3 arguments, got %d" % (
len(expr) + 2))
else:
orel = []
return ast.If(test=test,
body=body,
@ -210,6 +244,7 @@ class HyASTCompiler(object):
nl=True)
@builds("assert")
@checkargs(1)
def compile_assert_expression(self, expr):
expr.pop(0) # assert
e = expr.pop(0)
@ -219,6 +254,7 @@ class HyASTCompiler(object):
col_offset=e.start_column)
@builds("lambda")
@checkargs(min=2)
def compile_lambda_expression(self, expr):
expr.pop(0)
sig = expr.pop(0)
@ -241,10 +277,12 @@ class HyASTCompiler(object):
body=self.compile(body))
@builds("pass")
@checkargs(0)
def compile_pass_expression(self, expr):
return ast.Pass(lineno=expr.start_line, col_offset=expr.start_column)
@builds("yield")
@checkargs(1)
def compile_yield_expression(self, expr):
expr.pop(0)
return ast.Yield(
@ -272,6 +310,7 @@ class HyASTCompiler(object):
asname=str(x[1])) for x in modlist])
@builds("import_from")
@checkargs(min=1)
def compile_import_from_expression(self, expr):
expr.pop(0) # index
return ast.ImportFrom(
@ -282,6 +321,7 @@ class HyASTCompiler(object):
level=0)
@builds("get")
@checkargs(2)
def compile_index_expression(self, expr):
expr.pop(0) # index
val = self.compile(expr.pop(0)) # target
@ -295,6 +335,7 @@ class HyASTCompiler(object):
ctx=ast.Load())
@builds("slice")
@checkargs(min=1, max=3)
def compile_slice_expression(self, expr):
expr.pop(0) # index
val = self.compile(expr.pop(0)) # target
@ -317,6 +358,7 @@ class HyASTCompiler(object):
ctx=ast.Load())
@builds("assoc")
@checkargs(3)
def compile_assoc_expression(self, expr):
expr.pop(0) # assoc
# (assoc foo bar baz) => foo[bar] = baz
@ -337,6 +379,7 @@ class HyASTCompiler(object):
value=self.compile(val))
@builds("decorate_with")
@checkargs(min=1)
def compile_decorate_expression(self, expr):
expr.pop(0) # decorate-with
fn = self.compile(expr.pop(-1))
@ -346,6 +389,7 @@ class HyASTCompiler(object):
return fn
@builds("with_as")
@checkargs(min=2)
def compile_with_as_expression(self, expr):
expr.pop(0) # with-as
ctx = self.compile(expr.pop(0))
@ -372,6 +416,7 @@ class HyASTCompiler(object):
ctx=ast.Load())
@builds("list_comp")
@checkargs(min=2, max=3)
def compile_list_comprehension(self, expr):
# (list-comp expr (target iter) cond?)
expr.pop(0)
@ -406,6 +451,7 @@ class HyASTCompiler(object):
return name
@builds("kwapply")
@checkargs(2)
def compile_kwapply_expression(self, expr):
expr.pop(0) # kwapply
call = self.compile(expr.pop(0))
@ -421,10 +467,8 @@ class HyASTCompiler(object):
@builds("not")
@builds("~")
@checkargs(1)
def compile_unary_operator(self, expression):
if len(expression) != 2:
raise TypeError("Unary operator expects only 1 argument, got %d"
% (len(expression) - 1))
ops = {"not": ast.Not,
"~": ast.Invert}
operator = expression.pop(0)
@ -444,6 +488,7 @@ class HyASTCompiler(object):
@builds("in")
@builds("is_not")
@builds("not_in")
@checkargs(min=2)
def compile_compare_op_expression(self, expression):
ops = {"=": ast.Eq, "!=": ast.NotEq,
"<": ast.Lt, "<=": ast.LtE,
@ -467,6 +512,7 @@ class HyASTCompiler(object):
@builds("-")
@builds("/")
@builds("*")
@checkargs(min=2)
def compile_maths_expression(self, expression):
# operator = Mod | Pow | LShift | RShift | BitOr |
# BitXor | BitAnd | FloorDiv
@ -535,6 +581,7 @@ class HyASTCompiler(object):
@builds("def")
@builds("setf")
@builds("setv")
@checkargs(2)
def compile_def_expression(self, expression):
expression.pop(0) # "def"
name = expression.pop(0)
@ -556,6 +603,7 @@ class HyASTCompiler(object):
targets=[name], value=what)
@builds("foreach")
@checkargs(min=1)
def compile_for_expression(self, expression):
ret_status = self.returnable
self.returnable = False
@ -579,17 +627,10 @@ class HyASTCompiler(object):
return ret
@builds("while")
@checkargs(min=2)
def compile_while_expression(self, expr):
expr.pop(0) # "while"
try:
test = expr.pop(0)
except IndexError:
raise TypeError("while expects at least 2 arguments, got 0")
test = self.compile(test)
if not expr:
raise TypeError("while expects a body")
test = self.compile(expr.pop(0))
return ast.While(test=test,
body=self._mangle_branch([
@ -607,6 +648,7 @@ class HyASTCompiler(object):
col_offset=expr.start_column)
@builds("fn")
@checkargs(min=2)
def compile_fn_expression(self, expression):
expression.pop(0) # fn

View File

@ -48,7 +48,13 @@ def import_file_to_hst(fpath):
def import_file_to_ast(fpath):
tree = import_file_to_hst(fpath)
ast = hy_compile(tree)
try:
ast = hy_compile(tree)
except Exception as e:
print("Compilation error at %s:%d,%d"
% (fpath, e.start_line, e.start_column))
print("Compilation error: " + e.message)
raise e.exception
return ast

View File

@ -38,7 +38,7 @@ def cant_compile(expr):
try:
hy_compile(tokenize(expr))
assert False
except TypeError:
except HyCompileError:
pass
@ -56,7 +56,7 @@ def test_ast_bad_if_0_arg():
try:
hy_compile(tokenize("(if)"))
assert False
except TypeError:
except HyCompileError:
pass
@ -65,7 +65,7 @@ def test_ast_bad_if_1_arg():
try:
hy_compile(tokenize("(if foobar)"))
assert False
except TypeError:
except HyCompileError:
pass
@ -74,7 +74,7 @@ def test_ast_bad_if_too_much_arg():
try:
hy_compile(tokenize("(if 1 2 3 4 5)"))
assert False
except TypeError:
except HyCompileError:
pass
@ -103,7 +103,7 @@ def test_ast_bad_while_0_arg():
try:
hy_compile(tokenize("(while)"))
assert False
except TypeError:
except HyCompileError:
pass
@ -112,10 +112,143 @@ def test_ast_bad_while_1_arg():
try:
hy_compile(tokenize("(while (true))"))
assert False
except TypeError:
except HyCompileError:
pass
def test_ast_good_do():
"Make sure AST can compile valid do"
hy_compile(tokenize("(do 1)"))
def test_ast_bad_do():
"Make sure AST can't compile invalid do"
cant_compile("(do)")
def test_ast_good_throw():
"Make sure AST can compile valid throw"
hy_compile(tokenize("(throw 1)"))
def test_ast_bad_throw():
"Make sure AST can't compile invalid throw"
cant_compile("(throw)")
def test_ast_good_try():
"Make sure AST can compile valid try"
hy_compile(tokenize("(try 1)"))
def test_ast_bad_try():
"Make sure AST can't compile invalid try"
cant_compile("(try)")
def test_ast_good_catch():
"Make sure AST can compile valid catch"
hy_compile(tokenize("(catch 1 2)"))
def test_ast_bad_catch():
"Make sure AST can't compile invalid catch"
cant_compile("(catch)")
cant_compile("(catch 1)")
def test_ast_good_assert():
"Make sure AST can compile valid assert"
hy_compile(tokenize("(assert 1)"))
def test_ast_bad_assert():
"Make sure AST can't compile invalid assert"
cant_compile("(assert)")
cant_compile("(assert 1 2)")
def test_ast_good_lambda():
"Make sure AST can compile valid lambda"
hy_compile(tokenize("(lambda [] 1)"))
def test_ast_bad_lambda():
"Make sure AST can't compile invalid lambda"
cant_compile("(lambda)")
cant_compile("(lambda [])")
def test_ast_good_pass():
"Make sure AST can compile valid pass"
hy_compile(tokenize("(pass)"))
def test_ast_bad_pass():
"Make sure AST can't compile invalid pass"
cant_compile("(pass 1)")
cant_compile("(pass 1 2)")
def test_ast_good_yield():
"Make sure AST can compile valid yield"
hy_compile(tokenize("(yield 1)"))
def test_ast_bad_yield():
"Make sure AST can't compile invalid yield"
cant_compile("(yield)")
cant_compile("(yield 1 2)")
def test_ast_good_import_from():
"Make sure AST can compile valid import-from"
hy_compile(tokenize("(import-from x y)"))
def test_ast_bad_import_from():
"Make sure AST can't compile invalid import-from"
cant_compile("(import-from)")
def test_ast_good_get():
"Make sure AST can compile valid get"
hy_compile(tokenize("(get x y)"))
def test_ast_bad_get():
"Make sure AST can't compile invalid get"
cant_compile("(get)")
cant_compile("(get 1)")
cant_compile("(get 1 2 3)")
def test_ast_good_slice():
"Make sure AST can compile valid slice"
hy_compile(tokenize("(slice x)"))
hy_compile(tokenize("(slice x y)"))
hy_compile(tokenize("(slice x y z)"))
def test_ast_bad_slice():
"Make sure AST can't compile invalid slice"
cant_compile("(slice)")
cant_compile("(slice 1 2 3 4)")
def test_ast_good_assoc():
"Make sure AST can compile valid assoc"
hy_compile(tokenize("(assoc x y z)"))
def test_ast_bad_assoc():
"Make sure AST can't compile invalid assoc"
cant_compile("(assoc)")
cant_compile("(assoc 1)")
cant_compile("(assoc 1 2)")
cant_compile("(assoc 1 2 3 4)")
def test_ast_valid_while():
"Make sure AST can't compile invalid while"
hy_compile(tokenize("(while foo bar)"))
@ -151,7 +284,7 @@ def test_ast_non_decoratable():
try:
hy_compile(tokenize("(decorate-with (foo) (* x x))"))
assert True is False
except TypeError:
except HyCompileError:
pass
@ -162,7 +295,7 @@ def test_ast_non_kwapplyable():
try:
hy_compile(code)
assert True is False
except TypeError:
except HyCompileError:
pass

View File

@ -119,7 +119,7 @@
(defn test-index []
"NATIVE: Test that dict access works"
(assert (get {"one" "two"} "one") "two")
(assert (= (get {"one" "two"} "one") "two"))
(assert (= (get [1 2 3 4 5] 1) 2)))