Clean up the decorators used in the compiler
This commit is contained in:
parent
98fbdcfc50
commit
210086c7ca
189
hy/compiler.py
189
hy/compiler.py
@ -19,7 +19,6 @@ from hy._compat import (
|
|||||||
raise_empty)
|
raise_empty)
|
||||||
from hy.macros import require, macroexpand, tag_macroexpand
|
from hy.macros import require, macroexpand, tag_macroexpand
|
||||||
import hy.importer
|
import hy.importer
|
||||||
import hy.inspect
|
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import importlib
|
import importlib
|
||||||
@ -70,7 +69,8 @@ def ast_str(x, piecewise=False):
|
|||||||
return x if PY3 else x.encode('UTF8')
|
return x if PY3 else x.encode('UTF8')
|
||||||
|
|
||||||
|
|
||||||
_compile_table = {}
|
_special_form_compilers = {}
|
||||||
|
_model_compilers = {}
|
||||||
_decoratables = (ast.FunctionDef, ast.ClassDef)
|
_decoratables = (ast.FunctionDef, ast.ClassDef)
|
||||||
if PY35:
|
if PY35:
|
||||||
_decoratables += (ast.AsyncFunctionDef,)
|
_decoratables += (ast.AsyncFunctionDef,)
|
||||||
@ -81,23 +81,9 @@ _bad_roots = tuple(ast_str(x) for x in (
|
|||||||
"unquote", "unquote-splice", "unpack-mapping", "except"))
|
"unquote", "unquote-splice", "unpack-mapping", "except"))
|
||||||
|
|
||||||
|
|
||||||
def builds(*types, **kwargs):
|
|
||||||
# A decorator that adds the decorated method to _compile_table for
|
|
||||||
# compiling `types`, but only if kwargs['iff'] (if provided) is
|
|
||||||
# true.
|
|
||||||
if not kwargs.get('iff', True):
|
|
||||||
return lambda fn: fn
|
|
||||||
|
|
||||||
def _dec(fn):
|
|
||||||
for t in types:
|
|
||||||
if isinstance(t, string_types):
|
|
||||||
t = ast_str(t)
|
|
||||||
_compile_table[t] = fn
|
|
||||||
return fn
|
|
||||||
return _dec
|
|
||||||
|
|
||||||
|
|
||||||
def special(names, pattern):
|
def special(names, pattern):
|
||||||
|
"""Declare special operators. The decorated method and the given pattern
|
||||||
|
is assigned to _special_form_compilers for each of the listed names."""
|
||||||
pattern = whole(pattern)
|
pattern = whole(pattern)
|
||||||
def dec(fn):
|
def dec(fn):
|
||||||
for name in names if isinstance(names, list) else [names]:
|
for name in names if isinstance(names, list) else [names]:
|
||||||
@ -105,11 +91,20 @@ def special(names, pattern):
|
|||||||
condition, name = name
|
condition, name = name
|
||||||
if not condition:
|
if not condition:
|
||||||
continue
|
continue
|
||||||
_compile_table[ast_str(name)] = (fn, pattern)
|
_special_form_compilers[ast_str(name)] = (fn, pattern)
|
||||||
return fn
|
return fn
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
def builds_model(*model_types):
|
||||||
|
"Assign the decorated method to _model_compilers for the given types."
|
||||||
|
def _dec(fn):
|
||||||
|
for t in model_types:
|
||||||
|
_model_compilers[t] = fn
|
||||||
|
return fn
|
||||||
|
return _dec
|
||||||
|
|
||||||
|
|
||||||
def spoof_positions(obj):
|
def spoof_positions(obj):
|
||||||
if not isinstance(obj, HyObject):
|
if not isinstance(obj, HyObject):
|
||||||
return
|
return
|
||||||
@ -338,44 +333,6 @@ def _nargs(n):
|
|||||||
return "%d argument%s" % (n, ("" if n == 1 else "s"))
|
return "%d argument%s" % (n, ("" if n == 1 else "s"))
|
||||||
|
|
||||||
|
|
||||||
def checkargs(exact=None, min=None, max=None, even=None, multiple=None):
|
|
||||||
def _dec(fn):
|
|
||||||
def checker(self, expression):
|
|
||||||
if exact is not None and (len(expression) - 1) != exact:
|
|
||||||
_raise_wrong_args_number(
|
|
||||||
expression, "`%%s' needs %s, got %%d" % _nargs(exact))
|
|
||||||
if min is not None and (len(expression) - 1) < min:
|
|
||||||
_raise_wrong_args_number(
|
|
||||||
expression,
|
|
||||||
"`%%s' needs at least %s, got %%d." % _nargs(min))
|
|
||||||
|
|
||||||
if max is not None and (len(expression) - 1) > max:
|
|
||||||
_raise_wrong_args_number(
|
|
||||||
expression,
|
|
||||||
"`%%s' needs at most %s, got %%d" % _nargs(max))
|
|
||||||
|
|
||||||
is_even = not((len(expression) - 1) % 2)
|
|
||||||
if even is not None and is_even != even:
|
|
||||||
even_str = "even" if even else "odd"
|
|
||||||
_raise_wrong_args_number(
|
|
||||||
expression,
|
|
||||||
"`%%s' needs an %s number of arguments, got %%d"
|
|
||||||
% (even_str))
|
|
||||||
|
|
||||||
if multiple is not None:
|
|
||||||
if not (len(expression) - 1) in multiple:
|
|
||||||
choices = ", ".join([str(val) for val in multiple[:-1]])
|
|
||||||
choices += " or %s" % multiple[-1]
|
|
||||||
_raise_wrong_args_number(
|
|
||||||
expression,
|
|
||||||
"`%%s' needs %s arguments, got %%d" % choices)
|
|
||||||
|
|
||||||
return fn(self, expression)
|
|
||||||
|
|
||||||
return checker
|
|
||||||
return _dec
|
|
||||||
|
|
||||||
|
|
||||||
def is_unpack(kind, x):
|
def is_unpack(kind, x):
|
||||||
return (isinstance(x, HyExpression)
|
return (isinstance(x, HyExpression)
|
||||||
and len(x) > 0
|
and len(x) > 0
|
||||||
@ -436,47 +393,21 @@ class HyASTCompiler(object):
|
|||||||
self.imports = defaultdict(set)
|
self.imports = defaultdict(set)
|
||||||
return ret.stmts
|
return ret.stmts
|
||||||
|
|
||||||
def compile_atom(self, atom_type, atom):
|
def compile_atom(self, atom):
|
||||||
if isinstance(atom_type, string_types):
|
|
||||||
atom_type = ast_str(atom_type)
|
|
||||||
if atom_type in _bad_roots:
|
|
||||||
raise HyTypeError(atom, "The special form '{}' "
|
|
||||||
"is not allowed here".format(atom_type))
|
|
||||||
if atom_type in _compile_table:
|
|
||||||
# _compile_table[atom_type] is a method for compiling this
|
|
||||||
# type of atom, so call it. If it has an extra parameter,
|
|
||||||
# pass in `atom_type`.
|
|
||||||
atom_compiler = _compile_table[atom_type]
|
|
||||||
if isinstance(atom_compiler, tuple):
|
|
||||||
# This build method has a pattern.
|
|
||||||
build_method, pattern = atom_compiler
|
|
||||||
try:
|
|
||||||
parse_tree = pattern.parse(atom[1:])
|
|
||||||
except NoParseError as e:
|
|
||||||
raise HyTypeError(atom,
|
|
||||||
"parse error for special form '{}': {}'".format(
|
|
||||||
atom[0], str(e)))
|
|
||||||
ret = build_method(self, atom, atom[0], *parse_tree)
|
|
||||||
else:
|
|
||||||
arity = hy.inspect.get_arity(atom_compiler)
|
|
||||||
# Compliation methods may mutate the atom, so copy it first.
|
|
||||||
atom = copy.copy(atom)
|
|
||||||
ret = (atom_compiler(self, atom, atom_type)
|
|
||||||
if arity == 3
|
|
||||||
else atom_compiler(self, atom))
|
|
||||||
if not isinstance(ret, Result):
|
|
||||||
ret = Result() + ret
|
|
||||||
return ret
|
|
||||||
if not isinstance(atom, HyObject):
|
if not isinstance(atom, HyObject):
|
||||||
atom = wrap_value(atom)
|
atom = wrap_value(atom)
|
||||||
if isinstance(atom, HyObject):
|
if not isinstance(atom, HyObject):
|
||||||
spoof_positions(atom)
|
return
|
||||||
return self.compile_atom(type(atom), atom)
|
spoof_positions(atom)
|
||||||
|
if type(atom) not in _model_compilers:
|
||||||
|
return
|
||||||
|
# Compilation methods may mutate the atom, so copy it first.
|
||||||
|
atom = copy.copy(atom)
|
||||||
|
return Result() + _model_compilers[type(atom)](self, atom)
|
||||||
|
|
||||||
def compile(self, tree):
|
def compile(self, tree):
|
||||||
try:
|
try:
|
||||||
_type = type(tree)
|
ret = self.compile_atom(tree)
|
||||||
ret = self.compile_atom(_type, tree)
|
|
||||||
if ret:
|
if ret:
|
||||||
self.update_imports(ret)
|
self.update_imports(ret)
|
||||||
return ret
|
return ret
|
||||||
@ -1633,13 +1564,7 @@ class HyASTCompiler(object):
|
|||||||
if ast_str(root) == "eval_and_compile"
|
if ast_str(root) == "eval_and_compile"
|
||||||
else Result())
|
else Result())
|
||||||
|
|
||||||
@checkargs(1)
|
@builds_model(HyExpression)
|
||||||
def _compile_keyword_call(self, expression):
|
|
||||||
expression.append(expression.pop(0))
|
|
||||||
expression.insert(0, HySymbol("get"))
|
|
||||||
return self.compile(expression)
|
|
||||||
|
|
||||||
@builds(HyExpression)
|
|
||||||
def compile_expression(self, expression):
|
def compile_expression(self, expression):
|
||||||
# Perform macro expansions
|
# Perform macro expansions
|
||||||
expression = macroexpand(expression, self)
|
expression = macroexpand(expression, self)
|
||||||
@ -1648,23 +1573,45 @@ class HyASTCompiler(object):
|
|||||||
return self.compile(expression)
|
return self.compile(expression)
|
||||||
|
|
||||||
if expression == []:
|
if expression == []:
|
||||||
return self.compile_list(expression, HyList)
|
return self.compile_atom(HyList().replace(expression))
|
||||||
|
|
||||||
fn = expression[0]
|
fn = expression[0]
|
||||||
func = None
|
func = None
|
||||||
if isinstance(fn, HyKeyword):
|
if isinstance(fn, HyKeyword):
|
||||||
return self._compile_keyword_call(expression)
|
if len(expression) > 2:
|
||||||
|
raise HyTypeError(
|
||||||
|
expression, "keyword calls take only 1 argument")
|
||||||
|
expression.append(expression.pop(0))
|
||||||
|
expression.insert(0, HySymbol("get"))
|
||||||
|
return self.compile(expression)
|
||||||
|
|
||||||
if isinstance(fn, HySymbol):
|
if isinstance(fn, HySymbol):
|
||||||
# First check if `fn` is a special form, unless it has an
|
|
||||||
|
# First check if `fn` is a special operator, unless it has an
|
||||||
# `unpack-iterable` in it, since Python's operators (`+`,
|
# `unpack-iterable` in it, since Python's operators (`+`,
|
||||||
# etc.) can't unpack. An exception to this exception is that
|
# etc.) can't unpack. An exception to this exception is that
|
||||||
# tuple literals (`,`) can unpack.
|
# tuple literals (`,`) can unpack.
|
||||||
if fn == "," or not (
|
sfn = ast_str(fn)
|
||||||
any(is_unpack("iterable", x) for x in expression[1:])):
|
if (sfn in _special_form_compilers or sfn in _bad_roots) and (
|
||||||
ret = self.compile_atom(fn, expression)
|
sfn == mangle(",") or
|
||||||
if ret:
|
not any(is_unpack("iterable", x) for x in expression[1:])):
|
||||||
return ret
|
if sfn in _bad_roots:
|
||||||
|
raise HyTypeError(
|
||||||
|
expression,
|
||||||
|
"The special form '{}' is not allowed here".format(fn))
|
||||||
|
# `sfn` is a special operator. Get the build method and
|
||||||
|
# pattern-match the arguments.
|
||||||
|
build_method, pattern = _special_form_compilers[sfn]
|
||||||
|
try:
|
||||||
|
parse_tree = pattern.parse(expression[1:])
|
||||||
|
except NoParseError as e:
|
||||||
|
raise HyTypeError(
|
||||||
|
expression[min(e.state.pos + 1, len(expression) - 1)],
|
||||||
|
"parse error for special form '{}': {}".format(
|
||||||
|
expression[0],
|
||||||
|
e.msg.replace("<EOF>", "end of form")))
|
||||||
|
return Result() + build_method(
|
||||||
|
self, expression, expression[0], *parse_tree)
|
||||||
|
|
||||||
if fn.startswith("."):
|
if fn.startswith("."):
|
||||||
# (.split "test test") -> "test test".split()
|
# (.split "test test") -> "test test".split()
|
||||||
@ -1718,11 +1665,11 @@ class HyASTCompiler(object):
|
|||||||
expression, func=func.expr, args=args, keywords=keywords,
|
expression, func=func.expr, args=args, keywords=keywords,
|
||||||
starargs=oldpy_star, kwargs=oldpy_kw)
|
starargs=oldpy_star, kwargs=oldpy_kw)
|
||||||
|
|
||||||
@builds(HyInteger, HyFloat, HyComplex)
|
@builds_model(HyInteger, HyFloat, HyComplex)
|
||||||
def compile_numeric_literal(self, x, building):
|
def compile_numeric_literal(self, x):
|
||||||
f = {HyInteger: long_type,
|
f = {HyInteger: long_type,
|
||||||
HyFloat: float,
|
HyFloat: float,
|
||||||
HyComplex: complex}[building]
|
HyComplex: complex}[type(x)]
|
||||||
# Work around https://github.com/berkerpeksag/astor/issues/85 :
|
# Work around https://github.com/berkerpeksag/astor/issues/85 :
|
||||||
# astor can't generate Num nodes with NaN, so we have
|
# astor can't generate Num nodes with NaN, so we have
|
||||||
# to build an expression that evaluates to NaN.
|
# to build an expression that evaluates to NaN.
|
||||||
@ -1741,7 +1688,7 @@ class HyASTCompiler(object):
|
|||||||
return nan()
|
return nan()
|
||||||
return nn(f(x))
|
return nn(f(x))
|
||||||
|
|
||||||
@builds(HySymbol)
|
@builds_model(HySymbol)
|
||||||
def compile_symbol(self, symbol):
|
def compile_symbol(self, symbol):
|
||||||
if "." in symbol:
|
if "." in symbol:
|
||||||
glob, local = symbol.rsplit(".", 1)
|
glob, local = symbol.rsplit(".", 1)
|
||||||
@ -1772,7 +1719,7 @@ class HyASTCompiler(object):
|
|||||||
|
|
||||||
return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load())
|
return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load())
|
||||||
|
|
||||||
@builds(HyKeyword)
|
@builds_model(HyKeyword)
|
||||||
def compile_keyword(self, obj):
|
def compile_keyword(self, obj):
|
||||||
ret = Result()
|
ret = Result()
|
||||||
ret += asty.Call(
|
ret += asty.Call(
|
||||||
@ -1783,19 +1730,19 @@ class HyASTCompiler(object):
|
|||||||
ret.add_imports("hy", {"HyKeyword"})
|
ret.add_imports("hy", {"HyKeyword"})
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@builds(HyString, HyBytes)
|
@builds_model(HyString, HyBytes)
|
||||||
def compile_string(self, string, building):
|
def compile_string(self, string):
|
||||||
node = asty.Bytes if PY3 and building is HyBytes else asty.Str
|
node = asty.Bytes if PY3 and type(string) is HyBytes else asty.Str
|
||||||
f = bytes_type if building is HyBytes else str_type
|
f = bytes_type if type(string) is HyBytes else str_type
|
||||||
return node(string, s=f(string))
|
return node(string, s=f(string))
|
||||||
|
|
||||||
@builds(HyList, HySet)
|
@builds_model(HyList, HySet)
|
||||||
def compile_list(self, expression, building):
|
def compile_list(self, expression):
|
||||||
elts, ret, _ = self._compile_collect(expression)
|
elts, ret, _ = self._compile_collect(expression)
|
||||||
node = {HyList: asty.List, HySet: asty.Set}[building]
|
node = {HyList: asty.List, HySet: asty.Set}[type(expression)]
|
||||||
return ret + node(expression, elts=elts, ctx=ast.Load())
|
return ret + node(expression, elts=elts, ctx=ast.Load())
|
||||||
|
|
||||||
@builds(HyDict)
|
@builds_model(HyDict)
|
||||||
def compile_dict(self, m):
|
def compile_dict(self, m):
|
||||||
keyvalues, ret, _ = self._compile_collect(m, dict_display=True)
|
keyvalues, ret, _ = self._compile_collect(m, dict_display=True)
|
||||||
return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2])
|
return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2])
|
||||||
|
@ -37,7 +37,7 @@ class Completer(object):
|
|||||||
if not isinstance(namespace, dict):
|
if not isinstance(namespace, dict):
|
||||||
raise TypeError('namespace must be a dictionary')
|
raise TypeError('namespace must be a dictionary')
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.path = [hy.compiler._compile_table,
|
self.path = [hy.compiler._special_form_compilers,
|
||||||
builtins.__dict__,
|
builtins.__dict__,
|
||||||
hy.macros._hy_macros[None],
|
hy.macros._hy_macros[None],
|
||||||
namespace]
|
namespace]
|
||||||
|
@ -72,9 +72,7 @@
|
|||||||
|
|
||||||
;; TODO: move to hy.extra.reserved?
|
;; TODO: move to hy.extra.reserved?
|
||||||
(import hy)
|
(import hy)
|
||||||
(setv special-forms (list-comp k
|
(setv special-forms (list (.keys hy.compiler._special-form-compilers)))
|
||||||
[k (.keys hy.compiler._compile-table)]
|
|
||||||
(isinstance k hy._compat.string-types)))
|
|
||||||
|
|
||||||
|
|
||||||
(defn lambda-list [form]
|
(defn lambda-list [form]
|
||||||
|
@ -18,6 +18,6 @@
|
|||||||
hy.core.shadow.EXPORTS
|
hy.core.shadow.EXPORTS
|
||||||
(list (.keys (get hy.macros._hy_macros None)))
|
(list (.keys (get hy.macros._hy_macros None)))
|
||||||
keyword.kwlist
|
keyword.kwlist
|
||||||
(list-comp k [k (.keys hy.compiler.-compile-table)]
|
(list (.keys hy.compiler._special_form_compilers))
|
||||||
(isinstance k hy._compat.string-types))))))))
|
(list hy.compiler._bad_roots)))))))
|
||||||
_cache)
|
_cache)
|
||||||
|
@ -55,7 +55,7 @@ def test_compiler_yield_return():
|
|||||||
HyExpression([HySymbol("+"),
|
HyExpression([HySymbol("+"),
|
||||||
HyInteger(1),
|
HyInteger(1),
|
||||||
HyInteger(1)]))
|
HyInteger(1)]))
|
||||||
ret = compiler.HyASTCompiler('test').compile_atom("fn", e)
|
ret = compiler.HyASTCompiler('test').compile_atom(e)
|
||||||
|
|
||||||
assert len(ret.stmts) == 1
|
assert len(ret.stmts) == 1
|
||||||
stmt, = ret.stmts
|
stmt, = ret.stmts
|
||||||
|
Loading…
Reference in New Issue
Block a user