diff --git a/hy/compiler.py b/hy/compiler.py index fd31ec8..4bfbd56 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1,4 +1,7 @@ +# -*- encoding: utf-8 -*- +# # Copyright (c) 2013 Paul Tagliamonte +# Copyright (c) 2013 Julien Danjou # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), @@ -191,29 +194,77 @@ class HyASTCompiler(object): orelse=[]) @builds("catch") - @checkargs(min=2) def compile_catch_expression(self, expr): expr.pop(0) # catch - _type = self.compile(expr.pop(0)) - name = expr.pop(0) - if sys.version_info[0] >= 3: - # Python3 features a change where the Exception handler - # moved the name from a Name() to a pure Python String type. - # - # We'll just make sure it's a pure "string", and let it work - # it's magic. - name = ast_str(name) + try: + exceptions = expr.pop(0) + except IndexError: + exceptions = [] + # exceptions catch should be either: + # [[list of exceptions]] + # or + # [variable [list of exceptions]] + # or + # [variable exception] + # or + # [exception] + # or + # [] + if len(exceptions) > 2: + raise TypeError("`catch' exceptions list is too long") + + # [variable [list of exceptions]] + # let's pop variable and use it as name + if len(exceptions) == 2: + name = exceptions.pop(0) + if sys.version_info[0] >= 3: + # Python3 features a change where the Exception handler + # moved the name from a Name() to a pure Python String type. + # + # We'll just make sure it's a pure "string", and let it work + # it's magic. + name = ast_str(name) + else: + # Python2 requires an ast.Name, set to ctx Store. + name = self._storeize(self.compile(name)) else: - # Python2 requires an ast.Name, set to ctx Store. - name = self._storeize(self.compile(name)) + name = None + + try: + exceptions_list = exceptions.pop(0) + except IndexError: + exceptions_list = [] + + if isinstance(exceptions_list, list): + if len(exceptions_list): + # [FooBar BarFoo] → catch Foobar and BarFoo exceptions + _type = ast.Tuple(elts=[self.compile(x) + for x in exceptions_list], + lineno=expr.start_line, + col_offset=expr.start_column, + ctx=ast.Load()) + else: + # [] → all exceptions catched + _type = None + elif isinstance(exceptions_list, HySymbol): + _type = self.compile(exceptions_list) + else: + raise TypeError("`catch' needs a valid exception list to catch") + + if len(expr) == 0: + # No body + body = [ast.Pass(lineno=expr.start_line, + col_offset=expr.start_column)] + else: + body = self._code_branch([self.compile(x) for x in expr]) return ast.ExceptHandler( lineno=expr.start_line, col_offset=expr.start_column, type=_type, name=name, - body=self._code_branch([self.compile(x) for x in expr])) + body=body) def _code_branch(self, branch): if isinstance(branch, list): diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index 6b59ee3..3633851 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -116,13 +116,20 @@ def test_ast_bad_try(): def test_ast_good_catch(): "Make sure AST can compile valid catch" - hy_compile(tokenize("(catch 1 2)")) + hy_compile(tokenize("(catch)")) + hy_compile(tokenize("(catch [])")) + hy_compile(tokenize("(catch [Foobar])")) + # hy_compile(tokenize("(catch [[]])")) + # hy_compile(tokenize("(catch [x FooBar])")) + # hy_compile(tokenize("(catch [x [FooBar BarFoo]])")) + # hy_compile(tokenize("(catch [x [FooBar BarFoo]])")) def test_ast_bad_catch(): "Make sure AST can't compile invalid catch" - cant_compile("(catch)") cant_compile("(catch 1)") + cant_compile("(catch [1 3])") + cant_compile("(catch [x [FooBar] BarBar]])") def test_ast_good_assert(): diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index cd3e1b4..c137cc7 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -164,8 +164,56 @@ "NATIVE: test Exceptions" (try (throw (KeyError)) - (catch IOError e (assert (= 2 1))) - (catch KeyError e (+ 1 1) (assert (= 1 1))))) + (catch [[IOError]] (assert false)) + (catch [e [KeyError]] (assert e))) + + (try + (get [1] 3) + (catch [IndexError] (assert true)) + (catch [IndexError] (pass))) + + (try + (print foobar42ofthebaz) + (catch [IndexError] (assert false)) + (catch [NameError] (pass))) + + (try + (get [1] 3) + (catch [e IndexError] (assert (isinstance e IndexError)))) + + (try + (get [1] 3) + (catch [e [IndexError NameError]] (assert (isinstance e IndexError)))) + + (try + (print foobar42ofthebaz) + (catch [e [IndexError NameError]] (assert (isinstance e NameError)))) + + (try + (print foobar42) + (catch [[IndexError NameError]] (pass))) + + (try + (get [1] 3) + (catch [[IndexError NameError]] (pass))) + + (try + (print foobar42ofthebaz) + (catch)) + + (try + (print foobar42ofthebaz) + (catch [])) + + (try + (print foobar42ofthebaz) + (catch [] (pass))) + + (try + (print foobar42ofthebaz) + (catch [] + (setv foobar42ofthebaz 42) + (assert (= foobar42ofthebaz 42))))) (defn test-earmuffs [] "NATIVE: Test earmuffs" @@ -327,7 +375,7 @@ 6)) (try (assert (= x 42)) ; This ain't true - (catch NameError e (assert e))) + (catch [e [NameError]] (assert e))) (assert (= y 123)))