diff --git a/hy/compiler.py b/hy/compiler.py index 1db58bd..dd6af76 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1301,9 +1301,13 @@ class HyASTCompiler(object): @builds("setv") @checkargs(2) def compile_def_expression(self, expression): - expression.pop(0) - name = expression.pop(0) - result = self.compile(expression.pop(0)) + return self._compile_assign(expression[1], expression[2], + expression.start_line, + expression.start_column) + + def _compile_assign(self, name, result, + start_line, start_column): + result = self.compile(result) if result.temp_variables and isinstance(name, HyString): result.rename(name) @@ -1313,8 +1317,8 @@ class HyASTCompiler(object): st_name = self._storeize(ld_name) result += ast.Assign( - lineno=expression.start_line, - col_offset=expression.start_column, + lineno=start_line, + col_offset=start_column, targets=[st_name], value=result.force_expr) result += ld_name @@ -1440,6 +1444,56 @@ class HyASTCompiler(object): return ret + @builds("defclass") + @checkargs(min=1) + def compile_class_expression(self, expression): + expression.pop(0) # class + + class_name = expression.pop(0) + + if expression: + base_list = expression.pop(0) + if not isinstance(base_list, HyList): + raise HyTypeError(expression, + "Bases class must be a list") + bases_expr, bases = self._compile_collect(base_list) + else: + bases_expr = [] + bases = Result() + + body = Result() + + if expression: + try: + body_expression = iter(expression.pop(0)) + except TypeError: + raise HyTypeError( + expression, + "Wrong argument type for defclass slots definition.") + for b in body_expression: + if len(b) != 2: + raise HyTypeError( + expression, + "Wrong number of argument in defclass slot.") + body += self._compile_assign(b[0], b[1], + b.start_line, b.start_column) + body += body.expr_as_stmt() + + if not body.stmts: + body += ast.Pass(lineno=expression.start_line, + col_offset=expression.start_column) + + return bases + ast.ClassDef( + lineno=expression.start_line, + col_offset=expression.start_column, + decorator_list=[], + name=ast_str(class_name), + keywords=[], + starargs=None, + kwargs=None, + bases=bases_expr, + body=body.stmts) + @builds(HyInteger) def compile_integer(self, number): return ast.Num(n=int(number), diff --git a/tests/__init__.py b/tests/__init__.py index 6f87e4d..5738538 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,6 +2,7 @@ import hy # noqa +from .native_tests.defclass import * # noqa from .native_tests.math import * # noqa from .native_tests.language import * # noqa from .native_tests.unless import * # noqa diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index 318aa0e..958c7b8 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -197,6 +197,19 @@ def test_ast_bad_global(): cant_compile("(global foo bar)") +def test_ast_good_defclass(): + "Make sure AST can compile valid defclass" + hy_compile(tokenize("(defclass a)")) + hy_compile(tokenize("(defclass a [])")) + + +def test_ast_bad_defclass(): + "Make sure AST can't compile invalid defclass" + cant_compile("(defclass)") + cant_compile("(defclass a null)") + cant_compile("(defclass a null null)") + + def test_ast_good_lambda(): "Make sure AST can compile valid lambda" hy_compile(tokenize("(lambda [])")) diff --git a/tests/native_tests/defclass.hy b/tests/native_tests/defclass.hy new file mode 100644 index 0000000..4860d39 --- /dev/null +++ b/tests/native_tests/defclass.hy @@ -0,0 +1,62 @@ +(defn test-defclass [] + "NATIVE: test defclass simple mechanism" + (defclass A) + (assert (isinstance (A) A))) + + +(defn test-defclass-inheritance [] + "NATIVE: test defclass inheritance" + (defclass A []) + (assert (isinstance (A) object)) + (defclass A [object]) + (assert (isinstance (A) object)) + (defclass B [A]) + (assert (isinstance (B) A)) + (defclass C [object]) + (defclass D [B C]) + (assert (isinstance (D) A)) + (assert (isinstance (D) B)) + (assert (isinstance (D) C)) + (assert (not (isinstance (A) D)))) + + +(defn test-defclass-slots [] + "NATIVE: test defclass slots" + (defclass A [] + [[x 42]]) + (assert (= A.x 42)) + (assert (= (getattr (A) "x") 42))) + + +(defn test-defclass-slots-fn [] + "NATIVE: test defclass slots with fn" + (defclass B [] + [[x 42] + [y (fn [self value] + (+ self.x value))]]) + (assert (= B.x 42)) + (assert (= (.y (B) 5) 47)) + (let [[b (B)]] + (setv B.x 0) + (assert (= (.y b 1) 1)))) + + +(defn test-defclass-dynamic-inheritance [] + "NATIVE: test defclass with dynamic inheritance" + (defclass A [((fn [] (if true list dict)))] + [[x 42]]) + (assert (isinstance (A) list)) + (defclass A [((fn [] (if false list dict)))] + [[x 42]]) + (assert (isinstance (A) dict))) + + +(defn test-defclass-no-fn-leak [] + "NATIVE: test defclass slots with fn" + (defclass A [] + [[x (fn [] 1)]]) + (try + (do + (x) + (assert false)) + (except [NameError])))