Clean up the decorators used in the compiler
This commit is contained in:
parent
98fbdcfc50
commit
210086c7ca
189
hy/compiler.py
189
hy/compiler.py
@ -19,7 +19,6 @@ from hy._compat import (
|
||||
raise_empty)
|
||||
from hy.macros import require, macroexpand, tag_macroexpand
|
||||
import hy.importer
|
||||
import hy.inspect
|
||||
|
||||
import traceback
|
||||
import importlib
|
||||
@ -70,7 +69,8 @@ def ast_str(x, piecewise=False):
|
||||
return x if PY3 else x.encode('UTF8')
|
||||
|
||||
|
||||
_compile_table = {}
|
||||
_special_form_compilers = {}
|
||||
_model_compilers = {}
|
||||
_decoratables = (ast.FunctionDef, ast.ClassDef)
|
||||
if PY35:
|
||||
_decoratables += (ast.AsyncFunctionDef,)
|
||||
@ -81,23 +81,9 @@ _bad_roots = tuple(ast_str(x) for x in (
|
||||
"unquote", "unquote-splice", "unpack-mapping", "except"))
|
||||
|
||||
|
||||
def builds(*types, **kwargs):
|
||||
# A decorator that adds the decorated method to _compile_table for
|
||||
# compiling `types`, but only if kwargs['iff'] (if provided) is
|
||||
# true.
|
||||
if not kwargs.get('iff', True):
|
||||
return lambda fn: fn
|
||||
|
||||
def _dec(fn):
|
||||
for t in types:
|
||||
if isinstance(t, string_types):
|
||||
t = ast_str(t)
|
||||
_compile_table[t] = fn
|
||||
return fn
|
||||
return _dec
|
||||
|
||||
|
||||
def special(names, pattern):
|
||||
"""Declare special operators. The decorated method and the given pattern
|
||||
is assigned to _special_form_compilers for each of the listed names."""
|
||||
pattern = whole(pattern)
|
||||
def dec(fn):
|
||||
for name in names if isinstance(names, list) else [names]:
|
||||
@ -105,11 +91,20 @@ def special(names, pattern):
|
||||
condition, name = name
|
||||
if not condition:
|
||||
continue
|
||||
_compile_table[ast_str(name)] = (fn, pattern)
|
||||
_special_form_compilers[ast_str(name)] = (fn, pattern)
|
||||
return fn
|
||||
return dec
|
||||
|
||||
|
||||
def builds_model(*model_types):
|
||||
"Assign the decorated method to _model_compilers for the given types."
|
||||
def _dec(fn):
|
||||
for t in model_types:
|
||||
_model_compilers[t] = fn
|
||||
return fn
|
||||
return _dec
|
||||
|
||||
|
||||
def spoof_positions(obj):
|
||||
if not isinstance(obj, HyObject):
|
||||
return
|
||||
@ -338,44 +333,6 @@ def _nargs(n):
|
||||
return "%d argument%s" % (n, ("" if n == 1 else "s"))
|
||||
|
||||
|
||||
def checkargs(exact=None, min=None, max=None, even=None, multiple=None):
|
||||
def _dec(fn):
|
||||
def checker(self, expression):
|
||||
if exact is not None and (len(expression) - 1) != exact:
|
||||
_raise_wrong_args_number(
|
||||
expression, "`%%s' needs %s, got %%d" % _nargs(exact))
|
||||
if min is not None and (len(expression) - 1) < min:
|
||||
_raise_wrong_args_number(
|
||||
expression,
|
||||
"`%%s' needs at least %s, got %%d." % _nargs(min))
|
||||
|
||||
if max is not None and (len(expression) - 1) > max:
|
||||
_raise_wrong_args_number(
|
||||
expression,
|
||||
"`%%s' needs at most %s, got %%d" % _nargs(max))
|
||||
|
||||
is_even = not((len(expression) - 1) % 2)
|
||||
if even is not None and is_even != even:
|
||||
even_str = "even" if even else "odd"
|
||||
_raise_wrong_args_number(
|
||||
expression,
|
||||
"`%%s' needs an %s number of arguments, got %%d"
|
||||
% (even_str))
|
||||
|
||||
if multiple is not None:
|
||||
if not (len(expression) - 1) in multiple:
|
||||
choices = ", ".join([str(val) for val in multiple[:-1]])
|
||||
choices += " or %s" % multiple[-1]
|
||||
_raise_wrong_args_number(
|
||||
expression,
|
||||
"`%%s' needs %s arguments, got %%d" % choices)
|
||||
|
||||
return fn(self, expression)
|
||||
|
||||
return checker
|
||||
return _dec
|
||||
|
||||
|
||||
def is_unpack(kind, x):
|
||||
return (isinstance(x, HyExpression)
|
||||
and len(x) > 0
|
||||
@ -436,47 +393,21 @@ class HyASTCompiler(object):
|
||||
self.imports = defaultdict(set)
|
||||
return ret.stmts
|
||||
|
||||
def compile_atom(self, atom_type, atom):
|
||||
if isinstance(atom_type, string_types):
|
||||
atom_type = ast_str(atom_type)
|
||||
if atom_type in _bad_roots:
|
||||
raise HyTypeError(atom, "The special form '{}' "
|
||||
"is not allowed here".format(atom_type))
|
||||
if atom_type in _compile_table:
|
||||
# _compile_table[atom_type] is a method for compiling this
|
||||
# type of atom, so call it. If it has an extra parameter,
|
||||
# pass in `atom_type`.
|
||||
atom_compiler = _compile_table[atom_type]
|
||||
if isinstance(atom_compiler, tuple):
|
||||
# This build method has a pattern.
|
||||
build_method, pattern = atom_compiler
|
||||
try:
|
||||
parse_tree = pattern.parse(atom[1:])
|
||||
except NoParseError as e:
|
||||
raise HyTypeError(atom,
|
||||
"parse error for special form '{}': {}'".format(
|
||||
atom[0], str(e)))
|
||||
ret = build_method(self, atom, atom[0], *parse_tree)
|
||||
else:
|
||||
arity = hy.inspect.get_arity(atom_compiler)
|
||||
# Compliation methods may mutate the atom, so copy it first.
|
||||
atom = copy.copy(atom)
|
||||
ret = (atom_compiler(self, atom, atom_type)
|
||||
if arity == 3
|
||||
else atom_compiler(self, atom))
|
||||
if not isinstance(ret, Result):
|
||||
ret = Result() + ret
|
||||
return ret
|
||||
def compile_atom(self, atom):
|
||||
if not isinstance(atom, HyObject):
|
||||
atom = wrap_value(atom)
|
||||
if isinstance(atom, HyObject):
|
||||
spoof_positions(atom)
|
||||
return self.compile_atom(type(atom), atom)
|
||||
if not isinstance(atom, HyObject):
|
||||
return
|
||||
spoof_positions(atom)
|
||||
if type(atom) not in _model_compilers:
|
||||
return
|
||||
# Compilation methods may mutate the atom, so copy it first.
|
||||
atom = copy.copy(atom)
|
||||
return Result() + _model_compilers[type(atom)](self, atom)
|
||||
|
||||
def compile(self, tree):
|
||||
try:
|
||||
_type = type(tree)
|
||||
ret = self.compile_atom(_type, tree)
|
||||
ret = self.compile_atom(tree)
|
||||
if ret:
|
||||
self.update_imports(ret)
|
||||
return ret
|
||||
@ -1633,13 +1564,7 @@ class HyASTCompiler(object):
|
||||
if ast_str(root) == "eval_and_compile"
|
||||
else Result())
|
||||
|
||||
@checkargs(1)
|
||||
def _compile_keyword_call(self, expression):
|
||||
expression.append(expression.pop(0))
|
||||
expression.insert(0, HySymbol("get"))
|
||||
return self.compile(expression)
|
||||
|
||||
@builds(HyExpression)
|
||||
@builds_model(HyExpression)
|
||||
def compile_expression(self, expression):
|
||||
# Perform macro expansions
|
||||
expression = macroexpand(expression, self)
|
||||
@ -1648,23 +1573,45 @@ class HyASTCompiler(object):
|
||||
return self.compile(expression)
|
||||
|
||||
if expression == []:
|
||||
return self.compile_list(expression, HyList)
|
||||
return self.compile_atom(HyList().replace(expression))
|
||||
|
||||
fn = expression[0]
|
||||
func = None
|
||||
if isinstance(fn, HyKeyword):
|
||||
return self._compile_keyword_call(expression)
|
||||
if len(expression) > 2:
|
||||
raise HyTypeError(
|
||||
expression, "keyword calls take only 1 argument")
|
||||
expression.append(expression.pop(0))
|
||||
expression.insert(0, HySymbol("get"))
|
||||
return self.compile(expression)
|
||||
|
||||
if isinstance(fn, HySymbol):
|
||||
# First check if `fn` is a special form, unless it has an
|
||||
|
||||
# First check if `fn` is a special operator, unless it has an
|
||||
# `unpack-iterable` in it, since Python's operators (`+`,
|
||||
# etc.) can't unpack. An exception to this exception is that
|
||||
# tuple literals (`,`) can unpack.
|
||||
if fn == "," or not (
|
||||
any(is_unpack("iterable", x) for x in expression[1:])):
|
||||
ret = self.compile_atom(fn, expression)
|
||||
if ret:
|
||||
return ret
|
||||
sfn = ast_str(fn)
|
||||
if (sfn in _special_form_compilers or sfn in _bad_roots) and (
|
||||
sfn == mangle(",") or
|
||||
not any(is_unpack("iterable", x) for x in expression[1:])):
|
||||
if sfn in _bad_roots:
|
||||
raise HyTypeError(
|
||||
expression,
|
||||
"The special form '{}' is not allowed here".format(fn))
|
||||
# `sfn` is a special operator. Get the build method and
|
||||
# pattern-match the arguments.
|
||||
build_method, pattern = _special_form_compilers[sfn]
|
||||
try:
|
||||
parse_tree = pattern.parse(expression[1:])
|
||||
except NoParseError as e:
|
||||
raise HyTypeError(
|
||||
expression[min(e.state.pos + 1, len(expression) - 1)],
|
||||
"parse error for special form '{}': {}".format(
|
||||
expression[0],
|
||||
e.msg.replace("<EOF>", "end of form")))
|
||||
return Result() + build_method(
|
||||
self, expression, expression[0], *parse_tree)
|
||||
|
||||
if fn.startswith("."):
|
||||
# (.split "test test") -> "test test".split()
|
||||
@ -1718,11 +1665,11 @@ class HyASTCompiler(object):
|
||||
expression, func=func.expr, args=args, keywords=keywords,
|
||||
starargs=oldpy_star, kwargs=oldpy_kw)
|
||||
|
||||
@builds(HyInteger, HyFloat, HyComplex)
|
||||
def compile_numeric_literal(self, x, building):
|
||||
@builds_model(HyInteger, HyFloat, HyComplex)
|
||||
def compile_numeric_literal(self, x):
|
||||
f = {HyInteger: long_type,
|
||||
HyFloat: float,
|
||||
HyComplex: complex}[building]
|
||||
HyComplex: complex}[type(x)]
|
||||
# Work around https://github.com/berkerpeksag/astor/issues/85 :
|
||||
# astor can't generate Num nodes with NaN, so we have
|
||||
# to build an expression that evaluates to NaN.
|
||||
@ -1741,7 +1688,7 @@ class HyASTCompiler(object):
|
||||
return nan()
|
||||
return nn(f(x))
|
||||
|
||||
@builds(HySymbol)
|
||||
@builds_model(HySymbol)
|
||||
def compile_symbol(self, symbol):
|
||||
if "." in symbol:
|
||||
glob, local = symbol.rsplit(".", 1)
|
||||
@ -1772,7 +1719,7 @@ class HyASTCompiler(object):
|
||||
|
||||
return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load())
|
||||
|
||||
@builds(HyKeyword)
|
||||
@builds_model(HyKeyword)
|
||||
def compile_keyword(self, obj):
|
||||
ret = Result()
|
||||
ret += asty.Call(
|
||||
@ -1783,19 +1730,19 @@ class HyASTCompiler(object):
|
||||
ret.add_imports("hy", {"HyKeyword"})
|
||||
return ret
|
||||
|
||||
@builds(HyString, HyBytes)
|
||||
def compile_string(self, string, building):
|
||||
node = asty.Bytes if PY3 and building is HyBytes else asty.Str
|
||||
f = bytes_type if building is HyBytes else str_type
|
||||
@builds_model(HyString, HyBytes)
|
||||
def compile_string(self, string):
|
||||
node = asty.Bytes if PY3 and type(string) is HyBytes else asty.Str
|
||||
f = bytes_type if type(string) is HyBytes else str_type
|
||||
return node(string, s=f(string))
|
||||
|
||||
@builds(HyList, HySet)
|
||||
def compile_list(self, expression, building):
|
||||
@builds_model(HyList, HySet)
|
||||
def compile_list(self, expression):
|
||||
elts, ret, _ = self._compile_collect(expression)
|
||||
node = {HyList: asty.List, HySet: asty.Set}[building]
|
||||
node = {HyList: asty.List, HySet: asty.Set}[type(expression)]
|
||||
return ret + node(expression, elts=elts, ctx=ast.Load())
|
||||
|
||||
@builds(HyDict)
|
||||
@builds_model(HyDict)
|
||||
def compile_dict(self, m):
|
||||
keyvalues, ret, _ = self._compile_collect(m, dict_display=True)
|
||||
return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2])
|
||||
|
@ -37,7 +37,7 @@ class Completer(object):
|
||||
if not isinstance(namespace, dict):
|
||||
raise TypeError('namespace must be a dictionary')
|
||||
self.namespace = namespace
|
||||
self.path = [hy.compiler._compile_table,
|
||||
self.path = [hy.compiler._special_form_compilers,
|
||||
builtins.__dict__,
|
||||
hy.macros._hy_macros[None],
|
||||
namespace]
|
||||
|
@ -72,9 +72,7 @@
|
||||
|
||||
;; TODO: move to hy.extra.reserved?
|
||||
(import hy)
|
||||
(setv special-forms (list-comp k
|
||||
[k (.keys hy.compiler._compile-table)]
|
||||
(isinstance k hy._compat.string-types)))
|
||||
(setv special-forms (list (.keys hy.compiler._special-form-compilers)))
|
||||
|
||||
|
||||
(defn lambda-list [form]
|
||||
|
@ -18,6 +18,6 @@
|
||||
hy.core.shadow.EXPORTS
|
||||
(list (.keys (get hy.macros._hy_macros None)))
|
||||
keyword.kwlist
|
||||
(list-comp k [k (.keys hy.compiler.-compile-table)]
|
||||
(isinstance k hy._compat.string-types))))))))
|
||||
(list (.keys hy.compiler._special_form_compilers))
|
||||
(list hy.compiler._bad_roots)))))))
|
||||
_cache)
|
||||
|
@ -55,7 +55,7 @@ def test_compiler_yield_return():
|
||||
HyExpression([HySymbol("+"),
|
||||
HyInteger(1),
|
||||
HyInteger(1)]))
|
||||
ret = compiler.HyASTCompiler('test').compile_atom("fn", e)
|
||||
ret = compiler.HyASTCompiler('test').compile_atom(e)
|
||||
|
||||
assert len(ret.stmts) == 1
|
||||
stmt, = ret.stmts
|
||||
|
Loading…
x
Reference in New Issue
Block a user