diff --git a/hy/compiler.py b/hy/compiler.py index 6da7215..5e847ab 100644 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -1672,14 +1672,68 @@ class HyASTCompiler(object): @checkargs(min=2) def compile_logical_or_and_and_operator(self, expression): ops = {"and": ast.And, - "or": ast.Or} + "or" : ast.Or} operator = expression.pop(0) - values, ret, _ = self._compile_collect(expression) + ret = Result() + values = list(map(self.compile, expression)) + has_stmt = any(value.stmts for value in values) + root_line, root_column = operator.start_line, operator.start_column + if has_stmt: + # Compile it to an if...else sequence + var = self.get_anon_var() + name = ast.Name(id=var, + ctx=ast.Store(), + lineno=root_line, + col_offset=root_column) + expr_name = ast.Name(id=var, + ctx=ast.Load(), + lineno=root_line, + col_offset=root_column) - ret += ast.BoolOp(op=ops[operator](), - lineno=operator.start_line, - col_offset=operator.start_column, - values=values) + def make_assign(value, node=None): + if node is None: + line, column = root_line, root_column + else: + line, column = node.lineno, node.col_offset + return ast.Assign(targets=[ast.Name(id=var, + ctx=ast.Store(), + lineno=line, + col_offset=column)], + value=value, + lineno=line, + col_offset=column) + root = [] + current = root + for i, value in enumerate(values): + if value.stmts: + node = value.stmts[0] + current.extend(value.stmts) + else: + node = value.expr + current.append(make_assign(value.force_expr, value.force_expr)) + if i == len(values)-1: + # Skip a redundant 'if'. + break + if operator == "and": + cond = expr_name + elif operator == "or": + cond = ast.UnaryOp(op=ast.Not(), + operand=expr_name, + lineno=node.lineno, + col_offset=node.col_offset) + current.append(ast.If(test=cond, + body=[], + lineno=node.lineno, + col_offset=node.col_offset, + orelse=[])) + current = current[-1].body + ret = sum(root, ret) + ret += Result(expr=expr_name, temp_variables=[expr_name, name]) + else: + ret += ast.BoolOp(op=ops[operator](), + lineno=root_line, + col_offset=root_column, + values=[value.force_expr for value in values]) return ret @builds("=") diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index 580cccf..069c1c1 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -767,7 +767,11 @@ (let [[and123 (and 1 2 3)] [and-false (and 1 False 3)]] (assert (= and123 3)) - (assert (= and-false False)))) + (assert (= and-false False))) + ; short circuiting + (setv a 1) + (and 0 (setv a 2)) + (assert (= a 1))) (defn test-or [] @@ -777,7 +781,11 @@ [or-none-true (or False False)]] (assert (= or-all-true 1)) (assert (= or-some-true "hello")) - (assert (= or-none-true False)))) + (assert (= or-none-true False))) + ; short circuiting + (setv a 1) + (or 1 (setv a 2)) + (assert (= a 1))) (defn test-if-return-branching []