hy/hy/compiler.py

402 lines
13 KiB
Python
Raw Normal View History

2013-03-03 22:26:17 +01:00
# Copyright (c) 2012 Paul Tagliamonte <paultag@debian.org>
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
2013-03-05 02:40:23 +01:00
from hy.errors import HyError
2013-03-05 01:12:57 +01:00
2013-03-05 02:40:23 +01:00
from hy.models.expression import HyExpression
2013-03-06 00:39:34 +01:00
from hy.models.integer import HyInteger
2013-03-05 02:40:23 +01:00
from hy.models.string import HyString
2013-03-06 04:08:53 +01:00
from hy.models.symbol import HySymbol
from hy.models.list import HyList
2013-03-09 06:55:27 +01:00
from hy.models.dict import HyDict
2013-03-05 02:40:23 +01:00
import ast
class HyCompileError(HyError):
pass
_compile_table = {}
def builds(_type):
def _dec(fn):
_compile_table[_type] = fn
def shim(*args, **kwargs):
return fn(*args, **kwargs)
return shim
return _dec
class HyASTCompiler(object):
def __init__(self):
self.returnable = False
2013-03-05 04:35:07 +01:00
self.anon_fn_count = 0
2013-03-05 02:40:23 +01:00
def compile(self, tree):
for _type in _compile_table:
if type(tree) == _type:
return _compile_table[_type](self, tree)
raise HyCompileError("Unknown type - `%s'" % (str(type(tree))))
def _mangle_branch(self, tree):
ret = []
tree.reverse()
2013-03-06 02:28:09 +01:00
if self.returnable and len(tree) > 0:
2013-03-06 03:42:54 +01:00
el = tree[0]
2013-03-05 02:40:23 +01:00
if not isinstance(el, ast.stmt):
2013-03-07 04:09:13 +01:00
el = tree.pop(0)
2013-03-05 02:40:23 +01:00
ret.append(ast.Return(value=el,
lineno=el.lineno,
col_offset=el.col_offset))
2013-03-06 00:16:04 +01:00
ret += [ast.Expr(value=el,
lineno=el.lineno,
col_offset=el.col_offset)
if not isinstance(el, ast.stmt) else el for el in tree] # NOQA
2013-03-05 02:40:23 +01:00
ret.reverse()
return ret
@builds(list)
def compile_raw_list(self, entries):
return [self.compile(x) for x in entries]
2013-03-09 05:41:04 +01:00
@builds("do")
def compile_do_expression(self, expr):
return [self.compile(x) for x in expr[1:]]
2013-03-09 06:01:43 +01:00
def _code_branch(self, branch):
if isinstance(branch, list):
return self._mangle_branch(branch)
return self._mangle_branch([branch])
2013-03-09 05:41:04 +01:00
@builds("if")
def compile_if_expression(self, expr):
expr.pop(0)
2013-03-09 06:17:02 +01:00
test = self.compile(expr.pop(0))
body = self._code_branch(self.compile(expr.pop(0)))
orel = []
if len(expr) > 0:
orel = self._code_branch(self.compile(expr.pop(0)))
return ast.If(test=test,
body=body,
orelse=orel,
2013-03-09 05:41:04 +01:00
lineno=expr.start_line,
col_offset=expr.start_column)
2013-03-09 05:07:21 +01:00
2013-03-06 03:42:54 +01:00
@builds("assert")
def compile_assert_expression(self, expr):
expr.pop(0) # assert
e = expr.pop(0)
return ast.Assert(test=self.compile(e),
msg=None,
lineno=e.start_line,
col_offset=e.start_column)
2013-03-09 21:57:13 +01:00
@builds("lambda")
def compile_lambda_expression(self, expr):
expr.pop(0)
sig = expr.pop(0)
body = expr.pop(0)
# assert expr is empty
return ast.Lambda(
lineno=expr.start_line,
col_offset=expr.start_column,
args=ast.arguments(args=[
ast.Name(arg=str(x), id=str(x),
ctx=ast.Param(),
lineno=x.start_line,
col_offset=x.start_column)
for x in sig],
vararg=None,
kwarg=None,
defaults=[],
kwonlyargs=[],
kw_defaults=[]),
body=self.compile(body))
2013-03-10 01:46:32 +01:00
@builds("import")
def compile_import_expression(self, expr):
expr.pop(0) # index
return ast.Import(
lineno=expr.start_line,
col_offset=expr.start_column,
names=[ast.alias(name=str(x), asname=None) for x in expr])
@builds("import_from")
def compile_import_from_expression(self, expr):
expr.pop(0) # index
return ast.ImportFrom(
lineno=expr.start_line,
col_offset=expr.start_column,
module=str(expr.pop(0)),
names=[ast.alias(name=str(x), asname=None) for x in expr],
level=0)
2013-03-09 06:55:27 +01:00
@builds("get")
def compile_index_expression(self, expr):
expr.pop(0) # index
val = self.compile(expr.pop(0)) # target
sli = self.compile(expr.pop(0)) # slice
return ast.Subscript(
lineno=expr.start_line,
col_offset=expr.start_column,
value=val,
slice=ast.Index(value=sli),
ctx=ast.Load())
2013-03-10 03:01:59 +01:00
@builds("decorate_with")
def compile_decorate_expression(self, expr):
expr.pop(0) # decorate-with
fn = self.compile(expr.pop(-1))
if type(fn) != ast.FunctionDef:
raise TypeError("Decorated a non-function")
fn.decorator_list = [self.compile(x) for x in expr]
return fn
2013-03-10 03:14:30 +01:00
@builds("kwapply")
def compile_kwapply_expression(self, expr):
expr.pop(0) # kwapply
call = self.compile(expr.pop(0))
kwargs = expr.pop(0)
if type(call) != ast.Call:
raise TypeError("kwapplying a non-call")
2013-03-10 03:16:28 +01:00
call.keywords = [ast.keyword(arg=str(x),
value=self.compile(kwargs[x])) for x in kwargs]
2013-03-10 03:14:30 +01:00
return call
2013-03-06 03:42:54 +01:00
@builds("=")
2013-03-09 23:15:56 +01:00
@builds("!=")
2013-03-09 02:45:19 +01:00
@builds("<")
@builds("<=")
@builds(">")
@builds(">=")
@builds("is")
@builds("in")
2013-03-10 00:58:47 +01:00
@builds("is_not")
@builds("not_in")
2013-03-06 03:42:54 +01:00
def compile_compare_op_expression(self, expression):
2013-03-10 03:01:59 +01:00
ops = {"=": ast.Eq, "!=": ast.NotEq,
"<": ast.Lt, "<=": ast.LtE,
">": ast.Gt, ">=": ast.GtE,
"is": ast.Is, "is_not": ast.IsNot,
"in": ast.In, "not_in": ast.NotIn}
2013-03-06 03:42:54 +01:00
inv = expression.pop(0)
op = ops[inv]
ops = [op() for x in range(1, len(expression))]
e = expression.pop(0)
return ast.Compare(left=self.compile(e),
ops=ops,
comparators=[self.compile(x) for x in expression],
lineno=e.start_line,
col_offset=e.start_column)
2013-03-06 00:28:27 +01:00
@builds("+")
@builds("-")
@builds("/")
@builds("*")
def compile_maths_expression(self, expression):
2013-03-06 00:39:34 +01:00
# operator = Mod | Pow | LShift | RShift | BitOr |
2013-03-06 00:28:27 +01:00
# BitXor | BitAnd | FloorDiv
# (to implement list) XXX
ops = {"+": ast.Add,
"/": ast.Div,
"*": ast.Mult,
"-": ast.Sub}
inv = expression.pop(0)
op = ops[inv]
left = self.compile(expression.pop(0))
calc = None
for child in expression:
calc = ast.BinOp(left=left,
op=op(),
right=self.compile(child),
lineno=child.start_line,
col_offset=child.start_column)
left = calc
return calc
2013-03-05 02:40:23 +01:00
@builds(HyExpression)
def compile_expression(self, expression):
2013-03-05 04:35:07 +01:00
fn = expression[0]
2013-03-09 22:42:07 +01:00
if isinstance(fn, HyString):
if fn in _compile_table:
return _compile_table[fn](self, expression)
2013-03-05 04:35:07 +01:00
2013-03-09 22:42:07 +01:00
return ast.Call(func=self.compile(fn),
2013-03-05 02:40:23 +01:00
args=[self.compile(x) for x in expression[1:]],
keywords=[],
starargs=None,
kwargs=None,
lineno=expression.start_line,
col_offset=expression.start_column)
2013-03-06 00:16:04 +01:00
@builds("def")
def compile_def_expression(self, expression):
expression.pop(0) # "def"
name = expression.pop(0)
what = self.compile(expression.pop(0))
if type(what) == ast.FunctionDef:
# We special case a FunctionDef, since we can define by setting
# FunctionDef's .name attribute, rather then foo == anon_fn. This
# helps keep things clean.
what.name = str(name)
return what
name = self.compile(name)
name.ctx = ast.Store()
return ast.Assign(
lineno=expression.start_line,
col_offset=expression.start_column,
targets=[name], value=what)
2013-03-07 04:09:13 +01:00
@builds("for")
def compile_for_expression(self, expression):
2013-03-07 04:13:14 +01:00
ret_status = self.returnable
self.returnable = False
2013-03-07 04:09:13 +01:00
expression.pop(0) # for
name, iterable = expression.pop(0)
target = self.compile_symbol(name)
target.ctx = ast.Store()
2013-03-07 04:13:14 +01:00
# support stuff like:
# (for [x [1 2 3 4]
# y [a b c d]] ...)
ret = ast.For(lineno=expression.start_line,
col_offset=expression.start_column,
target=target,
iter=self.compile(iterable),
body=self._mangle_branch([
self.compile(x) for x in expression]),
orelse=[])
self.returnable = ret_status
return ret
2013-03-07 04:09:13 +01:00
2013-03-06 04:08:53 +01:00
@builds(HyList)
def compile_list(self, expr):
return ast.List(
elts=[self.compile(x) for x in expr],
ctx=ast.Load(),
lineno=expr.start_line,
col_offset=expr.start_column)
2013-03-05 04:35:07 +01:00
@builds("fn")
def compile_fn_expression(self, expression):
2013-03-06 00:16:04 +01:00
expression.pop(0) # fn
2013-03-05 04:35:07 +01:00
ret_status = self.returnable
self.returnable = True
self.anon_fn_count += 1
name = "_hy_anon_fn_%d" % (self.anon_fn_count)
sig = expression.pop(0)
2013-03-06 00:16:04 +01:00
ret = ast.FunctionDef(name=name,
2013-03-05 15:08:13 +01:00
lineno=expression.start_line,
col_offset=expression.start_column,
args=ast.arguments(args=[
ast.Name(arg=str(x), id=str(x),
ctx=ast.Param(),
lineno=x.start_line,
col_offset=x.start_column)
2013-03-06 00:16:04 +01:00
for x in sig],
vararg=None,
kwarg=None,
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
2013-03-09 06:01:43 +01:00
body=self._code_branch([
2013-03-05 15:08:13 +01:00
self.compile(x) for x in expression]),
decorator_list=[])
2013-03-05 04:35:07 +01:00
self.returnable = ret_status
return ret
2013-03-06 00:39:34 +01:00
@builds(HyInteger)
def compile_number(self, number):
2013-03-06 03:49:46 +01:00
return ast.Num(n=int(number), # See HyInteger above.
2013-03-06 00:39:34 +01:00
lineno=number.start_line,
col_offset=number.start_column)
2013-03-05 02:40:23 +01:00
@builds(HySymbol)
def compile_symbol(self, symbol):
2013-03-10 01:46:32 +01:00
if "." in symbol:
glob, local = symbol.rsplit(".", 1)
glob = HySymbol(glob)
glob.replace(symbol)
return ast.Attribute(
lineno=symbol.start_line,
col_offset=symbol.start_column,
value=self.compile_symbol(glob),
attr=str(local),
ctx=ast.Load()
)
2013-03-05 02:40:23 +01:00
return ast.Name(id=str(symbol), ctx=ast.Load(),
lineno=symbol.start_line,
col_offset=symbol.start_column)
@builds(HyString)
def compile_string(self, string):
return ast.Str(s=str(string), lineno=string.start_line,
col_offset=string.start_column)
2013-03-09 06:55:27 +01:00
@builds(HyDict)
def compile_dict(self, m):
keys = []
vals = []
for entry in m:
keys.append(self.compile(entry))
vals.append(self.compile(m[entry]))
return ast.Dict(
lineno=m.start_line,
col_offset=m.start_column,
keys=keys,
values=vals)
2013-03-05 02:40:23 +01:00
def hy_compile(tree):
" Compile a HyObject tree into a Python AST tree. "
compiler = HyASTCompiler()
ret = ast.Module(body=compiler._mangle_branch(compiler.compile(tree)))
return ret