Clean up the decorators used in the compiler

This commit is contained in:
Kodi Arfer 2018-04-22 12:48:37 -07:00
parent 98fbdcfc50
commit 210086c7ca
5 changed files with 73 additions and 128 deletions

View File

@ -19,7 +19,6 @@ from hy._compat import (
raise_empty) raise_empty)
from hy.macros import require, macroexpand, tag_macroexpand from hy.macros import require, macroexpand, tag_macroexpand
import hy.importer import hy.importer
import hy.inspect
import traceback import traceback
import importlib import importlib
@ -70,7 +69,8 @@ def ast_str(x, piecewise=False):
return x if PY3 else x.encode('UTF8') return x if PY3 else x.encode('UTF8')
_compile_table = {} _special_form_compilers = {}
_model_compilers = {}
_decoratables = (ast.FunctionDef, ast.ClassDef) _decoratables = (ast.FunctionDef, ast.ClassDef)
if PY35: if PY35:
_decoratables += (ast.AsyncFunctionDef,) _decoratables += (ast.AsyncFunctionDef,)
@ -81,23 +81,9 @@ _bad_roots = tuple(ast_str(x) for x in (
"unquote", "unquote-splice", "unpack-mapping", "except")) "unquote", "unquote-splice", "unpack-mapping", "except"))
def builds(*types, **kwargs):
# A decorator that adds the decorated method to _compile_table for
# compiling `types`, but only if kwargs['iff'] (if provided) is
# true.
if not kwargs.get('iff', True):
return lambda fn: fn
def _dec(fn):
for t in types:
if isinstance(t, string_types):
t = ast_str(t)
_compile_table[t] = fn
return fn
return _dec
def special(names, pattern): def special(names, pattern):
"""Declare special operators. The decorated method and the given pattern
is assigned to _special_form_compilers for each of the listed names."""
pattern = whole(pattern) pattern = whole(pattern)
def dec(fn): def dec(fn):
for name in names if isinstance(names, list) else [names]: for name in names if isinstance(names, list) else [names]:
@ -105,11 +91,20 @@ def special(names, pattern):
condition, name = name condition, name = name
if not condition: if not condition:
continue continue
_compile_table[ast_str(name)] = (fn, pattern) _special_form_compilers[ast_str(name)] = (fn, pattern)
return fn return fn
return dec return dec
def builds_model(*model_types):
"Assign the decorated method to _model_compilers for the given types."
def _dec(fn):
for t in model_types:
_model_compilers[t] = fn
return fn
return _dec
def spoof_positions(obj): def spoof_positions(obj):
if not isinstance(obj, HyObject): if not isinstance(obj, HyObject):
return return
@ -338,44 +333,6 @@ def _nargs(n):
return "%d argument%s" % (n, ("" if n == 1 else "s")) return "%d argument%s" % (n, ("" if n == 1 else "s"))
def checkargs(exact=None, min=None, max=None, even=None, multiple=None):
def _dec(fn):
def checker(self, expression):
if exact is not None and (len(expression) - 1) != exact:
_raise_wrong_args_number(
expression, "`%%s' needs %s, got %%d" % _nargs(exact))
if min is not None and (len(expression) - 1) < min:
_raise_wrong_args_number(
expression,
"`%%s' needs at least %s, got %%d." % _nargs(min))
if max is not None and (len(expression) - 1) > max:
_raise_wrong_args_number(
expression,
"`%%s' needs at most %s, got %%d" % _nargs(max))
is_even = not((len(expression) - 1) % 2)
if even is not None and is_even != even:
even_str = "even" if even else "odd"
_raise_wrong_args_number(
expression,
"`%%s' needs an %s number of arguments, got %%d"
% (even_str))
if multiple is not None:
if not (len(expression) - 1) in multiple:
choices = ", ".join([str(val) for val in multiple[:-1]])
choices += " or %s" % multiple[-1]
_raise_wrong_args_number(
expression,
"`%%s' needs %s arguments, got %%d" % choices)
return fn(self, expression)
return checker
return _dec
def is_unpack(kind, x): def is_unpack(kind, x):
return (isinstance(x, HyExpression) return (isinstance(x, HyExpression)
and len(x) > 0 and len(x) > 0
@ -436,47 +393,21 @@ class HyASTCompiler(object):
self.imports = defaultdict(set) self.imports = defaultdict(set)
return ret.stmts return ret.stmts
def compile_atom(self, atom_type, atom): def compile_atom(self, atom):
if isinstance(atom_type, string_types):
atom_type = ast_str(atom_type)
if atom_type in _bad_roots:
raise HyTypeError(atom, "The special form '{}' "
"is not allowed here".format(atom_type))
if atom_type in _compile_table:
# _compile_table[atom_type] is a method for compiling this
# type of atom, so call it. If it has an extra parameter,
# pass in `atom_type`.
atom_compiler = _compile_table[atom_type]
if isinstance(atom_compiler, tuple):
# This build method has a pattern.
build_method, pattern = atom_compiler
try:
parse_tree = pattern.parse(atom[1:])
except NoParseError as e:
raise HyTypeError(atom,
"parse error for special form '{}': {}'".format(
atom[0], str(e)))
ret = build_method(self, atom, atom[0], *parse_tree)
else:
arity = hy.inspect.get_arity(atom_compiler)
# Compliation methods may mutate the atom, so copy it first.
atom = copy.copy(atom)
ret = (atom_compiler(self, atom, atom_type)
if arity == 3
else atom_compiler(self, atom))
if not isinstance(ret, Result):
ret = Result() + ret
return ret
if not isinstance(atom, HyObject): if not isinstance(atom, HyObject):
atom = wrap_value(atom) atom = wrap_value(atom)
if isinstance(atom, HyObject): if not isinstance(atom, HyObject):
return
spoof_positions(atom) spoof_positions(atom)
return self.compile_atom(type(atom), atom) if type(atom) not in _model_compilers:
return
# Compilation methods may mutate the atom, so copy it first.
atom = copy.copy(atom)
return Result() + _model_compilers[type(atom)](self, atom)
def compile(self, tree): def compile(self, tree):
try: try:
_type = type(tree) ret = self.compile_atom(tree)
ret = self.compile_atom(_type, tree)
if ret: if ret:
self.update_imports(ret) self.update_imports(ret)
return ret return ret
@ -1633,13 +1564,7 @@ class HyASTCompiler(object):
if ast_str(root) == "eval_and_compile" if ast_str(root) == "eval_and_compile"
else Result()) else Result())
@checkargs(1) @builds_model(HyExpression)
def _compile_keyword_call(self, expression):
expression.append(expression.pop(0))
expression.insert(0, HySymbol("get"))
return self.compile(expression)
@builds(HyExpression)
def compile_expression(self, expression): def compile_expression(self, expression):
# Perform macro expansions # Perform macro expansions
expression = macroexpand(expression, self) expression = macroexpand(expression, self)
@ -1648,23 +1573,45 @@ class HyASTCompiler(object):
return self.compile(expression) return self.compile(expression)
if expression == []: if expression == []:
return self.compile_list(expression, HyList) return self.compile_atom(HyList().replace(expression))
fn = expression[0] fn = expression[0]
func = None func = None
if isinstance(fn, HyKeyword): if isinstance(fn, HyKeyword):
return self._compile_keyword_call(expression) if len(expression) > 2:
raise HyTypeError(
expression, "keyword calls take only 1 argument")
expression.append(expression.pop(0))
expression.insert(0, HySymbol("get"))
return self.compile(expression)
if isinstance(fn, HySymbol): if isinstance(fn, HySymbol):
# First check if `fn` is a special form, unless it has an
# First check if `fn` is a special operator, unless it has an
# `unpack-iterable` in it, since Python's operators (`+`, # `unpack-iterable` in it, since Python's operators (`+`,
# etc.) can't unpack. An exception to this exception is that # etc.) can't unpack. An exception to this exception is that
# tuple literals (`,`) can unpack. # tuple literals (`,`) can unpack.
if fn == "," or not ( sfn = ast_str(fn)
any(is_unpack("iterable", x) for x in expression[1:])): if (sfn in _special_form_compilers or sfn in _bad_roots) and (
ret = self.compile_atom(fn, expression) sfn == mangle(",") or
if ret: not any(is_unpack("iterable", x) for x in expression[1:])):
return ret if sfn in _bad_roots:
raise HyTypeError(
expression,
"The special form '{}' is not allowed here".format(fn))
# `sfn` is a special operator. Get the build method and
# pattern-match the arguments.
build_method, pattern = _special_form_compilers[sfn]
try:
parse_tree = pattern.parse(expression[1:])
except NoParseError as e:
raise HyTypeError(
expression[min(e.state.pos + 1, len(expression) - 1)],
"parse error for special form '{}': {}".format(
expression[0],
e.msg.replace("<EOF>", "end of form")))
return Result() + build_method(
self, expression, expression[0], *parse_tree)
if fn.startswith("."): if fn.startswith("."):
# (.split "test test") -> "test test".split() # (.split "test test") -> "test test".split()
@ -1718,11 +1665,11 @@ class HyASTCompiler(object):
expression, func=func.expr, args=args, keywords=keywords, expression, func=func.expr, args=args, keywords=keywords,
starargs=oldpy_star, kwargs=oldpy_kw) starargs=oldpy_star, kwargs=oldpy_kw)
@builds(HyInteger, HyFloat, HyComplex) @builds_model(HyInteger, HyFloat, HyComplex)
def compile_numeric_literal(self, x, building): def compile_numeric_literal(self, x):
f = {HyInteger: long_type, f = {HyInteger: long_type,
HyFloat: float, HyFloat: float,
HyComplex: complex}[building] HyComplex: complex}[type(x)]
# Work around https://github.com/berkerpeksag/astor/issues/85 : # Work around https://github.com/berkerpeksag/astor/issues/85 :
# astor can't generate Num nodes with NaN, so we have # astor can't generate Num nodes with NaN, so we have
# to build an expression that evaluates to NaN. # to build an expression that evaluates to NaN.
@ -1741,7 +1688,7 @@ class HyASTCompiler(object):
return nan() return nan()
return nn(f(x)) return nn(f(x))
@builds(HySymbol) @builds_model(HySymbol)
def compile_symbol(self, symbol): def compile_symbol(self, symbol):
if "." in symbol: if "." in symbol:
glob, local = symbol.rsplit(".", 1) glob, local = symbol.rsplit(".", 1)
@ -1772,7 +1719,7 @@ class HyASTCompiler(object):
return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load()) return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load())
@builds(HyKeyword) @builds_model(HyKeyword)
def compile_keyword(self, obj): def compile_keyword(self, obj):
ret = Result() ret = Result()
ret += asty.Call( ret += asty.Call(
@ -1783,19 +1730,19 @@ class HyASTCompiler(object):
ret.add_imports("hy", {"HyKeyword"}) ret.add_imports("hy", {"HyKeyword"})
return ret return ret
@builds(HyString, HyBytes) @builds_model(HyString, HyBytes)
def compile_string(self, string, building): def compile_string(self, string):
node = asty.Bytes if PY3 and building is HyBytes else asty.Str node = asty.Bytes if PY3 and type(string) is HyBytes else asty.Str
f = bytes_type if building is HyBytes else str_type f = bytes_type if type(string) is HyBytes else str_type
return node(string, s=f(string)) return node(string, s=f(string))
@builds(HyList, HySet) @builds_model(HyList, HySet)
def compile_list(self, expression, building): def compile_list(self, expression):
elts, ret, _ = self._compile_collect(expression) elts, ret, _ = self._compile_collect(expression)
node = {HyList: asty.List, HySet: asty.Set}[building] node = {HyList: asty.List, HySet: asty.Set}[type(expression)]
return ret + node(expression, elts=elts, ctx=ast.Load()) return ret + node(expression, elts=elts, ctx=ast.Load())
@builds(HyDict) @builds_model(HyDict)
def compile_dict(self, m): def compile_dict(self, m):
keyvalues, ret, _ = self._compile_collect(m, dict_display=True) keyvalues, ret, _ = self._compile_collect(m, dict_display=True)
return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2]) return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2])

View File

@ -37,7 +37,7 @@ class Completer(object):
if not isinstance(namespace, dict): if not isinstance(namespace, dict):
raise TypeError('namespace must be a dictionary') raise TypeError('namespace must be a dictionary')
self.namespace = namespace self.namespace = namespace
self.path = [hy.compiler._compile_table, self.path = [hy.compiler._special_form_compilers,
builtins.__dict__, builtins.__dict__,
hy.macros._hy_macros[None], hy.macros._hy_macros[None],
namespace] namespace]

View File

@ -72,9 +72,7 @@
;; TODO: move to hy.extra.reserved? ;; TODO: move to hy.extra.reserved?
(import hy) (import hy)
(setv special-forms (list-comp k (setv special-forms (list (.keys hy.compiler._special-form-compilers)))
[k (.keys hy.compiler._compile-table)]
(isinstance k hy._compat.string-types)))
(defn lambda-list [form] (defn lambda-list [form]

View File

@ -18,6 +18,6 @@
hy.core.shadow.EXPORTS hy.core.shadow.EXPORTS
(list (.keys (get hy.macros._hy_macros None))) (list (.keys (get hy.macros._hy_macros None)))
keyword.kwlist keyword.kwlist
(list-comp k [k (.keys hy.compiler.-compile-table)] (list (.keys hy.compiler._special_form_compilers))
(isinstance k hy._compat.string-types)))))))) (list hy.compiler._bad_roots)))))))
_cache) _cache)

View File

@ -55,7 +55,7 @@ def test_compiler_yield_return():
HyExpression([HySymbol("+"), HyExpression([HySymbol("+"),
HyInteger(1), HyInteger(1),
HyInteger(1)])) HyInteger(1)]))
ret = compiler.HyASTCompiler('test').compile_atom("fn", e) ret = compiler.HyASTCompiler('test').compile_atom(e)
assert len(ret.stmts) == 1 assert len(ret.stmts) == 1
stmt, = ret.stmts stmt, = ret.stmts