Introduce a context manager for compiler.returnable

Signed-off-by: Julien Danjou <julien@danjou.info>
This commit is contained in:
Julien Danjou 2013-04-21 00:17:30 +02:00
parent adecd32897
commit 7066d53b02
3 changed files with 82 additions and 60 deletions

View File

@ -33,7 +33,7 @@ from hy.models.list import HyList
from hy.models.dict import HyDict
from hy.models.keyword import HyKeyword
from hy.util import flatten_literal_list, str_type
from hy.util import flatten_literal_list, str_type, temporary_attribute_value
from collections import defaultdict
import codecs
@ -129,6 +129,9 @@ class HyASTCompiler(object):
self.anon_fn_count = 0
self.imports = defaultdict(list)
def being_returnable(self, v):
return temporary_attribute_value(self, "returnable", v)
def compile(self, tree):
try:
for _type in _compile_table:
@ -951,37 +954,35 @@ class HyASTCompiler(object):
@builds("foreach")
@checkargs(min=1)
def compile_for_expression(self, expression):
ret_status = self.returnable
self.returnable = False
with self.being_returnable(False):
expression.pop(0) # for
name, iterable = expression.pop(0)
target = self._storeize(self.compile_symbol(name))
expression.pop(0) # for
name, iterable = expression.pop(0)
target = self._storeize(self.compile_symbol(name))
orelse = []
# (foreach [] body (else …))
if expression and expression[-1][0] == HySymbol("else"):
else_expr = expression.pop()
if len(else_expr) > 2:
raise HyTypeError(
else_expr,
"`else' statement in `foreach' is too long")
elif len(else_expr) == 2:
orelse = self._code_branch(
self.compile(else_expr[1]),
else_expr[1].start_line,
else_expr[1].start_column)
orelse = []
# (foreach [] body (else …))
if expression and expression[-1][0] == HySymbol("else"):
else_expr = expression.pop()
if len(else_expr) > 2:
raise HyTypeError(else_expr,
"`else' statement in `foreach' is too long")
elif len(else_expr) == 2:
orelse = self._code_branch(
self.compile(else_expr[1]),
else_expr[1].start_line,
else_expr[1].start_column)
ret = ast.For(lineno=expression.start_line,
col_offset=expression.start_column,
target=target,
iter=self.compile(iterable),
body=self._code_branch(
[self.compile(x) for x in expression],
expression.start_line,
expression.start_column),
orelse=orelse)
ret = ast.For(lineno=expression.start_line,
col_offset=expression.start_column,
target=target,
iter=self.compile(iterable),
body=self._code_branch(
[self.compile(x) for x in expression],
expression.start_line,
expression.start_column),
orelse=orelse)
self.returnable = ret_status
return ret
@builds("while")
@ -1012,47 +1013,44 @@ class HyASTCompiler(object):
def compile_fn_expression(self, expression):
expression.pop(0) # fn
ret_status = self.returnable
self.anon_fn_count += 1
name = "_hy_anon_fn_%d" % (self.anon_fn_count)
sig = expression.pop(0)
body = []
if expression != []:
self.returnable = True
tailop = self.compile(expression.pop(-1))
self.returnable = False
for el in expression:
body.append(self.compile(el))
with self.being_returnable(True):
tailop = self.compile(expression.pop(-1))
with self.being_returnable(False):
for el in expression:
body.append(self.compile(el))
body.append(tailop)
self.returnable = True
body = self._code_branch(body,
expression.start_line,
expression.start_column)
with self.being_returnable(True):
body = self._code_branch(body,
expression.start_line,
expression.start_column)
ret = ast.FunctionDef(
name=name,
lineno=expression.start_line,
col_offset=expression.start_column,
args=ast.arguments(
args=[
ast.Name(
arg=ast_str(x), id=ast_str(x),
ctx=ast.Param(),
lineno=x.start_line,
col_offset=x.start_column)
for x in sig],
vararg=None,
kwarg=None,
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=body,
decorator_list=[])
ret = ast.FunctionDef(
name=name,
lineno=expression.start_line,
col_offset=expression.start_column,
args=ast.arguments(
args=[
ast.Name(
arg=ast_str(x), id=ast_str(x),
ctx=ast.Param(),
lineno=x.start_line,
col_offset=x.start_column)
for x in sig],
vararg=None,
kwarg=None,
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=body,
decorator_list=[])
self.returnable = ret_status
return ret
@builds(HyInteger)

View File

@ -1,4 +1,5 @@
# Copyright (c) 2013 Paul Tagliamonte <paultag@debian.org>
# Copyright (c) 2013 Julien Danjou <julien@danjou.info>
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
@ -18,6 +19,7 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import contextlib
import sys
@ -27,6 +29,15 @@ else:
str_type = unicode
@contextlib.contextmanager
def temporary_attribute_value(obj, attribute, value):
"""Temporarily switch an object attribute value to another value."""
original_value = getattr(obj, attribute)
setattr(obj, attribute, value)
yield
setattr(obj, attribute, original_value)
def flatten_literal_list(entry):
for e in entry:
if type(e) == list:

13
tests/test_util.py Normal file
View File

@ -0,0 +1,13 @@
from hy import util
def test_temporary_attribute_value():
class O(object):
def __init__(self):
self.foobar = 0
o = O()
with util.temporary_attribute_value(o, "foobar", 42):
assert o.foobar == 42
assert o.foobar == 0