diff --git a/hy/macros.py b/hy/macros.py index 231d66e..d3530fd 100644 --- a/hy/macros.py +++ b/hy/macros.py @@ -20,7 +20,13 @@ from hy.models.expression import HyExpression from hy.models.string import HyString +from hy.models.symbol import HySymbol from hy.models.list import HyList +from hy.models.integer import HyInteger +from hy.models.float import HyFloat +from hy.models.complex import HyComplex +from hy.models.dict import HyDict +from hy.util import str_type from collections import defaultdict @@ -44,12 +50,32 @@ def require(source_module_name, target_module_name): refs[name] = macro +def _wrap_value(x): + wrapper = _wrappers.get(type(x)) + if wrapper is None: + return x + else: + return wrapper(x) + +_wrappers = { + int: HyInteger, + bool: lambda x: HySymbol("True") if x else HySymbol("False"), + float: HyFloat, + complex: HyComplex, + str_type: HyString, + dict: lambda d: HyDict(_wrap_value(x) for x in sum(d.items(), ())), + list: lambda l: HyList(_wrap_value(x) for x in l) +} + + def process(tree, module_name): if isinstance(tree, HyExpression): fn = tree[0] if fn in ("quote", "quasiquote"): return tree - ntree = HyExpression([fn] + [process(x, module_name) for x in tree[1:]]) + ntree = HyExpression( + [fn] + [process(x, module_name) for x in tree[1:]] + ) ntree.replace(tree) if isinstance(fn, HyString): @@ -57,7 +83,7 @@ def process(tree, module_name): if m is None: m = _hy_macros[None].get(fn) if m is not None: - obj = m(*ntree[1:]) + obj = _wrap_value(m(*ntree[1:])) obj.replace(tree) return obj diff --git a/tests/native_tests/native_macros.hy b/tests/native_tests/native_macros.hy index dcb3904..276685f 100644 --- a/tests/native_tests/native_macros.hy +++ b/tests/native_tests/native_macros.hy @@ -8,3 +8,29 @@ (setv x []) (rev (.append x 1) (.append x 2) (.append x 3)) (assert (= x [3 2 1]))) + + +; Macros returning constants + +(defmacro an-int [] 42) +(assert (= (an-int) 42)) + +(defmacro a-true [] True) +(assert (= (a-true) True)) +(defmacro a-false [] False) +(assert (= (a-false) False)) + +(defmacro a-float [] 42.) +(assert (= (a-float) 42.)) + +(defmacro a-complex [] 42j) +(assert (= (a-complex) 42j)) + +(defmacro a-string [] "foo") +(assert (= (a-string) "foo")) + +(defmacro a-list [] [1 2]) +(assert (= (a-list) [1 2])) + +(defmacro a-dict [] {1 2}) +(assert (= (a-dict) {1 2}))