diff --git a/hy/compiler.py b/hy/compiler.py index 1a46405..35ac1f8 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -33,7 +33,7 @@ from hy.models.list import HyList from hy.models.dict import HyDict from hy.models.keyword import HyKeyword -from hy.util import flatten_literal_list, str_type +from hy.util import flatten_literal_list, str_type, temporary_attribute_value from collections import defaultdict import codecs @@ -129,6 +129,9 @@ class HyASTCompiler(object): self.anon_fn_count = 0 self.imports = defaultdict(list) + def being_returnable(self, v): + return temporary_attribute_value(self, "returnable", v) + def compile(self, tree): try: for _type in _compile_table: @@ -951,37 +954,35 @@ class HyASTCompiler(object): @builds("foreach") @checkargs(min=1) def compile_for_expression(self, expression): - ret_status = self.returnable - self.returnable = False + with self.being_returnable(False): + expression.pop(0) # for + name, iterable = expression.pop(0) + target = self._storeize(self.compile_symbol(name)) - expression.pop(0) # for - name, iterable = expression.pop(0) - target = self._storeize(self.compile_symbol(name)) + orelse = [] + # (foreach [] body (else …)) + if expression and expression[-1][0] == HySymbol("else"): + else_expr = expression.pop() + if len(else_expr) > 2: + raise HyTypeError( + else_expr, + "`else' statement in `foreach' is too long") + elif len(else_expr) == 2: + orelse = self._code_branch( + self.compile(else_expr[1]), + else_expr[1].start_line, + else_expr[1].start_column) - orelse = [] - # (foreach [] body (else …)) - if expression and expression[-1][0] == HySymbol("else"): - else_expr = expression.pop() - if len(else_expr) > 2: - raise HyTypeError(else_expr, - "`else' statement in `foreach' is too long") - elif len(else_expr) == 2: - orelse = self._code_branch( - self.compile(else_expr[1]), - else_expr[1].start_line, - else_expr[1].start_column) + ret = ast.For(lineno=expression.start_line, + col_offset=expression.start_column, + target=target, + iter=self.compile(iterable), + body=self._code_branch( + [self.compile(x) for x in expression], + expression.start_line, + expression.start_column), + orelse=orelse) - ret = ast.For(lineno=expression.start_line, - col_offset=expression.start_column, - target=target, - iter=self.compile(iterable), - body=self._code_branch( - [self.compile(x) for x in expression], - expression.start_line, - expression.start_column), - orelse=orelse) - - self.returnable = ret_status return ret @builds("while") @@ -1012,47 +1013,44 @@ class HyASTCompiler(object): def compile_fn_expression(self, expression): expression.pop(0) # fn - ret_status = self.returnable - self.anon_fn_count += 1 name = "_hy_anon_fn_%d" % (self.anon_fn_count) sig = expression.pop(0) body = [] if expression != []: - self.returnable = True - tailop = self.compile(expression.pop(-1)) - self.returnable = False - for el in expression: - body.append(self.compile(el)) + with self.being_returnable(True): + tailop = self.compile(expression.pop(-1)) + with self.being_returnable(False): + for el in expression: + body.append(self.compile(el)) body.append(tailop) - self.returnable = True - body = self._code_branch(body, - expression.start_line, - expression.start_column) + with self.being_returnable(True): + body = self._code_branch(body, + expression.start_line, + expression.start_column) - ret = ast.FunctionDef( - name=name, - lineno=expression.start_line, - col_offset=expression.start_column, - args=ast.arguments( - args=[ - ast.Name( - arg=ast_str(x), id=ast_str(x), - ctx=ast.Param(), - lineno=x.start_line, - col_offset=x.start_column) - for x in sig], - vararg=None, - kwarg=None, - kwonlyargs=[], - kw_defaults=[], - defaults=[]), - body=body, - decorator_list=[]) + ret = ast.FunctionDef( + name=name, + lineno=expression.start_line, + col_offset=expression.start_column, + args=ast.arguments( + args=[ + ast.Name( + arg=ast_str(x), id=ast_str(x), + ctx=ast.Param(), + lineno=x.start_line, + col_offset=x.start_column) + for x in sig], + vararg=None, + kwarg=None, + kwonlyargs=[], + kw_defaults=[], + defaults=[]), + body=body, + decorator_list=[]) - self.returnable = ret_status return ret @builds(HyInteger) diff --git a/hy/util.py b/hy/util.py index a600b5e..6b26543 100644 --- a/hy/util.py +++ b/hy/util.py @@ -1,4 +1,5 @@ # Copyright (c) 2013 Paul Tagliamonte +# Copyright (c) 2013 Julien Danjou # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), @@ -18,6 +19,7 @@ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import contextlib import sys @@ -27,6 +29,15 @@ else: str_type = unicode +@contextlib.contextmanager +def temporary_attribute_value(obj, attribute, value): + """Temporarily switch an object attribute value to another value.""" + original_value = getattr(obj, attribute) + setattr(obj, attribute, value) + yield + setattr(obj, attribute, original_value) + + def flatten_literal_list(entry): for e in entry: if type(e) == list: diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..1804fa0 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,13 @@ +from hy import util + + +def test_temporary_attribute_value(): + class O(object): + def __init__(self): + self.foobar = 0 + + o = O() + + with util.temporary_attribute_value(o, "foobar", 42): + assert o.foobar == 42 + assert o.foobar == 0