Introduce a context manager for compiler.returnable

Signed-off-by: Julien Danjou <julien@danjou.info>
This commit is contained in:
Julien Danjou 2013-04-21 00:17:30 +02:00
parent adecd32897
commit 7066d53b02
3 changed files with 82 additions and 60 deletions

View File

@ -33,7 +33,7 @@ from hy.models.list import HyList
from hy.models.dict import HyDict from hy.models.dict import HyDict
from hy.models.keyword import HyKeyword from hy.models.keyword import HyKeyword
from hy.util import flatten_literal_list, str_type from hy.util import flatten_literal_list, str_type, temporary_attribute_value
from collections import defaultdict from collections import defaultdict
import codecs import codecs
@ -129,6 +129,9 @@ class HyASTCompiler(object):
self.anon_fn_count = 0 self.anon_fn_count = 0
self.imports = defaultdict(list) self.imports = defaultdict(list)
def being_returnable(self, v):
return temporary_attribute_value(self, "returnable", v)
def compile(self, tree): def compile(self, tree):
try: try:
for _type in _compile_table: for _type in _compile_table:
@ -951,37 +954,35 @@ class HyASTCompiler(object):
@builds("foreach") @builds("foreach")
@checkargs(min=1) @checkargs(min=1)
def compile_for_expression(self, expression): def compile_for_expression(self, expression):
ret_status = self.returnable with self.being_returnable(False):
self.returnable = False expression.pop(0) # for
name, iterable = expression.pop(0)
target = self._storeize(self.compile_symbol(name))
expression.pop(0) # for orelse = []
name, iterable = expression.pop(0) # (foreach [] body (else …))
target = self._storeize(self.compile_symbol(name)) if expression and expression[-1][0] == HySymbol("else"):
else_expr = expression.pop()
if len(else_expr) > 2:
raise HyTypeError(
else_expr,
"`else' statement in `foreach' is too long")
elif len(else_expr) == 2:
orelse = self._code_branch(
self.compile(else_expr[1]),
else_expr[1].start_line,
else_expr[1].start_column)
orelse = [] ret = ast.For(lineno=expression.start_line,
# (foreach [] body (else …)) col_offset=expression.start_column,
if expression and expression[-1][0] == HySymbol("else"): target=target,
else_expr = expression.pop() iter=self.compile(iterable),
if len(else_expr) > 2: body=self._code_branch(
raise HyTypeError(else_expr, [self.compile(x) for x in expression],
"`else' statement in `foreach' is too long") expression.start_line,
elif len(else_expr) == 2: expression.start_column),
orelse = self._code_branch( orelse=orelse)
self.compile(else_expr[1]),
else_expr[1].start_line,
else_expr[1].start_column)
ret = ast.For(lineno=expression.start_line,
col_offset=expression.start_column,
target=target,
iter=self.compile(iterable),
body=self._code_branch(
[self.compile(x) for x in expression],
expression.start_line,
expression.start_column),
orelse=orelse)
self.returnable = ret_status
return ret return ret
@builds("while") @builds("while")
@ -1012,47 +1013,44 @@ class HyASTCompiler(object):
def compile_fn_expression(self, expression): def compile_fn_expression(self, expression):
expression.pop(0) # fn expression.pop(0) # fn
ret_status = self.returnable
self.anon_fn_count += 1 self.anon_fn_count += 1
name = "_hy_anon_fn_%d" % (self.anon_fn_count) name = "_hy_anon_fn_%d" % (self.anon_fn_count)
sig = expression.pop(0) sig = expression.pop(0)
body = [] body = []
if expression != []: if expression != []:
self.returnable = True with self.being_returnable(True):
tailop = self.compile(expression.pop(-1)) tailop = self.compile(expression.pop(-1))
self.returnable = False with self.being_returnable(False):
for el in expression: for el in expression:
body.append(self.compile(el)) body.append(self.compile(el))
body.append(tailop) body.append(tailop)
self.returnable = True with self.being_returnable(True):
body = self._code_branch(body, body = self._code_branch(body,
expression.start_line, expression.start_line,
expression.start_column) expression.start_column)
ret = ast.FunctionDef( ret = ast.FunctionDef(
name=name, name=name,
lineno=expression.start_line, lineno=expression.start_line,
col_offset=expression.start_column, col_offset=expression.start_column,
args=ast.arguments( args=ast.arguments(
args=[ args=[
ast.Name( ast.Name(
arg=ast_str(x), id=ast_str(x), arg=ast_str(x), id=ast_str(x),
ctx=ast.Param(), ctx=ast.Param(),
lineno=x.start_line, lineno=x.start_line,
col_offset=x.start_column) col_offset=x.start_column)
for x in sig], for x in sig],
vararg=None, vararg=None,
kwarg=None, kwarg=None,
kwonlyargs=[], kwonlyargs=[],
kw_defaults=[], kw_defaults=[],
defaults=[]), defaults=[]),
body=body, body=body,
decorator_list=[]) decorator_list=[])
self.returnable = ret_status
return ret return ret
@builds(HyInteger) @builds(HyInteger)

View File

@ -1,4 +1,5 @@
# Copyright (c) 2013 Paul Tagliamonte <paultag@debian.org> # Copyright (c) 2013 Paul Tagliamonte <paultag@debian.org>
# Copyright (c) 2013 Julien Danjou <julien@danjou.info>
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
@ -18,6 +19,7 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
import contextlib
import sys import sys
@ -27,6 +29,15 @@ else:
str_type = unicode str_type = unicode
@contextlib.contextmanager
def temporary_attribute_value(obj, attribute, value):
"""Temporarily switch an object attribute value to another value."""
original_value = getattr(obj, attribute)
setattr(obj, attribute, value)
yield
setattr(obj, attribute, original_value)
def flatten_literal_list(entry): def flatten_literal_list(entry):
for e in entry: for e in entry:
if type(e) == list: if type(e) == list:

13
tests/test_util.py Normal file
View File

@ -0,0 +1,13 @@
from hy import util
def test_temporary_attribute_value():
class O(object):
def __init__(self):
self.foobar = 0
o = O()
with util.temporary_attribute_value(o, "foobar", 42):
assert o.foobar == 42
assert o.foobar == 0