diff --git a/hy/compiler.py b/hy/compiler.py index a4ba4e1..6e77e09 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -6,6 +6,8 @@ from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex, HyString, HyBytes, HySymbol, HyFloat, HyList, HySet, HyDict, HySequence, wrap_value) +from hy.model_patterns import FORM, SYM, sym, brackets, whole, notpexpr, dolike +from funcparserlib.parser import many, oneplus, maybe, NoParseError from hy.errors import HyCompileError, HyTypeError from hy.lex.parser import mangle @@ -88,6 +90,19 @@ def builds(*types, **kwargs): return _dec +def special(names, pattern): + pattern = whole(pattern) + def dec(fn): + for name in names if isinstance(names, list) else [names]: + if isinstance(name, tuple): + condition, name = name + if not condition: + continue + _compile_table[ast_str(name)] = (fn, pattern) + return fn + return dec + + def spoof_positions(obj): if not isinstance(obj, HyObject): return @@ -422,12 +437,23 @@ class HyASTCompiler(object): # type of atom, so call it. If it has an extra parameter, # pass in `atom_type`. atom_compiler = _compile_table[atom_type] - arity = hy.inspect.get_arity(atom_compiler) - # Compliation methods may mutate the atom, so copy it first. - atom = copy.copy(atom) - ret = (atom_compiler(self, atom, atom_type) - if arity == 3 - else atom_compiler(self, atom)) + if isinstance(atom_compiler, tuple): + # This build method has a pattern. + build_method, pattern = atom_compiler + try: + parse_tree = pattern.parse(atom[1:]) + except NoParseError as e: + raise HyTypeError(atom, + "parse error for special form '{}': {}'".format( + atom[0], str(e))) + ret = build_method(self, atom, atom[0], *parse_tree) + else: + arity = hy.inspect.get_arity(atom_compiler) + # Compliation methods may mutate the atom, so copy it first. + atom = copy.copy(atom) + ret = (atom_compiler(self, atom, atom_type) + if arity == 3 + else atom_compiler(self, atom)) if not isinstance(ret, Result): ret = Result() + ret return ret @@ -712,15 +738,14 @@ class HyASTCompiler(object): return imports, HyExpression([HySymbol(name), form]).replace(form), False - @builds("quote", "quasiquote") - @checkargs(exact=1) - def compile_quote(self, entries): - if entries[0] == "quote": + @special(["quote", "quasiquote"], [FORM]) + def compile_quote(self, expr, root, arg): + if root == "quote": # Never allow unquoting level = float("inf") else: level = 0 - imports, stmts, splice = self._render_quoted_form(entries[1], level) + imports, stmts, splice = self._render_quoted_form(arg, level) ret = self.compile(stmts) ret.add_imports("hy", imports) return ret @@ -730,57 +755,49 @@ class HyASTCompiler(object): raise HyTypeError(expr, "`%s' can't be used at the top-level" % expr[0]) - @builds("unpack-iterable") - @checkargs(exact=1) - def compile_unpack_iterable(self, expr): + @special("unpack-iterable", [FORM]) + def compile_unpack_iterable(self, expr, root, arg): if not PY3: raise HyTypeError(expr, "`unpack-iterable` isn't allowed here") - ret = self.compile(expr[1]) + ret = self.compile(arg) ret += asty.Starred(expr, value=ret.force_expr, ctx=ast.Load()) return ret @builds("unpack-mapping") - @checkargs(exact=1) def compile_unpack_mapping(self, expr): raise HyTypeError(expr, "`unpack-mapping` isn't allowed here") - @builds("exec*", iff=(not PY3)) + @special([(not PY3, "exec*")], [FORM, maybe(FORM), maybe(FORM)]) # Under Python 3, `exec` is a function rather than a statement type, so Hy # doesn't need a special form for it. - @checkargs(min=1, max=3) - def compile_exec(self, expr): - expr.pop(0) + def compile_exec(self, expr, root, body, globals_, locals_): return asty.Exec( expr, - body=self.compile(expr.pop(0)).force_expr, - globals=self.compile(expr.pop(0)).force_expr if expr else None, - locals=self.compile(expr.pop(0)).force_expr if expr else None) + body=self.compile(body).force_expr, + globals=self.compile(globals_).force_expr if globals_ is not None else None, + locals=self.compile(locals_).force_expr if locals_ is not None else None) - @builds("do") - def compile_do(self, expression): - expression.pop(0) - return self._compile_branch(expression) + @special("do", [many(FORM)]) + def compile_do(self, expr, root, body): + return self._compile_branch(body) - @builds("raise") - @checkargs(multiple=[0, 1, 3]) - def compile_raise_expression(self, expr): - expr.pop(0) + @special("raise", [maybe(FORM), maybe(sym(":from") + FORM)]) + def compile_raise_expression(self, expr, root, exc, cause): ret = Result() - if expr: - ret += self.compile(expr.pop(0)) + if exc is not None: + exc = self.compile(exc) + ret += exc + exc = exc.force_expr - cause = None - if len(expr) == 2 and expr[0] == HyKeyword("from"): + if cause is not None: if not PY3: - raise HyCompileError( - "raise from only supported in python 3") - expr.pop(0) - cause = self.compile(expr.pop(0)) - cause = cause.expr + raise HyTypeError(expr, "raise from only supported in python 3") + cause = self.compile(cause) + ret += cause + cause = cause.force_expr - # Use ret.expr to get a literal `None` ret += asty.Raise( - expr, type=ret.expr, exc=ret.expr, + expr, type=ret.expr, exc=exc, inst=None, tback=None, cause=cause) return ret @@ -954,17 +971,14 @@ class HyASTCompiler(object): return _type + asty.ExceptHandler( expr, type=_type.expr, name=name, body=body) - @builds("if*") - @checkargs(min=2, max=3) - def compile_if(self, expression): - expression.pop(0) - cond = self.compile(expression.pop(0)) - body = self.compile(expression.pop(0)) + @special("if*", [FORM, FORM, maybe(FORM)]) + def compile_if(self, expr, _, cond, body, orel_expr): + cond = self.compile(cond) + body = self.compile(body) - orel = Result() nested = root = False - if expression: - orel_expr = expression.pop(0) + orel = Result() + if orel_expr is not None: if isinstance(orel_expr, HyExpression) and isinstance(orel_expr[0], HySymbol) and orel_expr[0] == 'if*': # Nested ifs: don't waste temporaries @@ -982,11 +996,11 @@ class HyASTCompiler(object): branch = orel if branch is not None: if self.temp_if and branch.stmts: - name = asty.Name(expression, + name = asty.Name(expr, id=ast_str(self.temp_if), ctx=ast.Store()) - branch += asty.Assign(expression, + branch += asty.Assign(expr, targets=[name], value=body.force_expr) @@ -999,40 +1013,38 @@ class HyASTCompiler(object): # We have statements in our bodies # Get a temporary variable for the result storage var = self.temp_if or self.get_anon_var() - name = asty.Name(expression, + name = asty.Name(expr, id=ast_str(var), ctx=ast.Store()) # Store the result of the body - body += asty.Assign(expression, + body += asty.Assign(expr, targets=[name], value=body.force_expr) # and of the else clause if not nested or not orel.stmts or (not root and var != self.temp_if): - orel += asty.Assign(expression, + orel += asty.Assign(expr, targets=[name], value=orel.force_expr) # Then build the if - ret += ast.If(test=ret.force_expr, - body=body.stmts, - orelse=orel.stmts, - lineno=expression.start_line, - col_offset=expression.start_column) + ret += asty.If(expr, + test=ret.force_expr, + body=body.stmts, + orelse=orel.stmts) # And make our expression context our temp variable - expr_name = asty.Name(expression, id=ast_str(var), ctx=ast.Load()) + expr_name = asty.Name(expr, id=ast_str(var), ctx=ast.Load()) ret += Result(expr=expr_name, temp_variables=[expr_name, name]) else: # Just make that an if expression - ret += ast.IfExp(test=ret.force_expr, - body=body.force_expr, - orelse=orel.force_expr, - lineno=expression.start_line, - col_offset=expression.start_column) + ret += asty.IfExp(expr, + test=ret.force_expr, + body=body.force_expr, + orelse=orel.force_expr) if root: self.temp_if = None @@ -1060,22 +1072,10 @@ class HyASTCompiler(object): msg = self.compile(expr.pop(0)).force_expr return ret + asty.Assert(expr, test=e, msg=msg) - @builds("global") - @builds("nonlocal", iff=PY3) - @checkargs(min=1) - def compile_global_or_nonlocal(self, expr): - form = expr.pop(0) - names = [] - while len(expr) > 0: - identifier = expr.pop(0) - name = ast_str(identifier) - names.append(name) - if not isinstance(identifier, HySymbol): - raise HyTypeError( - identifier, - "({}) arguments must be Symbols".format(form)) - node = asty.Global if form == "global" else asty.Nonlocal - return node(expr, names=names) + @special(["global", (PY3, "nonlocal")], [oneplus(SYM)]) + def compile_global_or_nonlocal(self, expr, root, syms): + node = asty.Global if root == "global" else asty.Nonlocal + return node(expr, names=list(map(ast_str, syms))) @builds("yield") @checkargs(max=1) @@ -1163,13 +1163,10 @@ class HyASTCompiler(object): return rimports - @builds("get") - @checkargs(min=2) - def compile_index_expression(self, expr): - expr.pop(0) # index - - indices, ret, _ = self._compile_collect(expr[1:]) - ret += self.compile(expr[0]) + @special("get", [FORM, oneplus(FORM)]) + def compile_index_expression(self, expr, name, obj, indices): + indices, ret, _ = self._compile_collect(indices) + ret += self.compile(obj) for ix in indices: ret += asty.Subscript( @@ -1180,50 +1177,34 @@ class HyASTCompiler(object): return ret - @builds(".") - @checkargs(min=1) - def compile_attribute_access(self, expr): - expr.pop(0) # dot + @special(".", [FORM, many(SYM | brackets(FORM))]) + def compile_attribute_access(self, expr, name, invocant, keys): + ret = self.compile(invocant) - ret = self.compile(expr.pop(0)) - - for attr in expr: + for attr in keys: if isinstance(attr, HySymbol): ret += asty.Attribute(attr, value=ret.force_expr, attr=ast_str(attr), ctx=ast.Load()) - elif type(attr) == HyList: - if len(attr) != 1: - raise HyTypeError( - attr, - "The attribute access DSL only accepts HySymbols " - "and one-item lists, got {0}-item list instead".format( - len(attr))) + else: # attr is a HyList compiled_attr = self.compile(attr[0]) ret = compiled_attr + ret + asty.Subscript( attr, value=ret.force_expr, slice=ast.Index(value=compiled_attr.force_expr), ctx=ast.Load()) - else: - raise HyTypeError( - attr, - "The attribute access DSL only accepts HySymbols " - "and one-item lists, got {0} instead".format( - type(attr).__name__)) return ret - @builds("del") - def compile_del_expression(self, expr): - root = expr.pop(0) - if not expr: - return asty.Pass(root) + @special("del", [many(FORM)]) + def compile_del_expression(self, expr, name, args): + if not args: + return asty.Pass(expr) del_targets = [] ret = Result() - for target in expr: + for target in args: compiled_target = self.compile(target) ret += compiled_target del_targets.append(self._storeize(target, compiled_target, @@ -1231,53 +1212,38 @@ class HyASTCompiler(object): return ret + asty.Delete(expr, targets=del_targets) - @builds("cut") - @checkargs(min=1, max=4) - def compile_cut_expression(self, expr): - ret = Result() - nodes = [None] * 4 - for i, e in enumerate(expr[1:]): - ret += self.compile(e) - nodes[i] = ret.force_expr + @special("cut", [FORM, maybe(FORM), maybe(FORM), maybe(FORM)]) + def compile_cut_expression(self, expr, name, obj, lower, upper, step): + ret = [Result()] + def c(e): + ret[0] += self.compile(e) + return ret[0].force_expr - return ret + asty.Subscript( + s = asty.Subscript( expr, - value=nodes[0], - slice=ast.Slice(lower=nodes[1], upper=nodes[2], step=nodes[3]), + value=c(obj), + slice=ast.Slice(lower=c(lower), upper=c(upper), step=c(step)), ctx=ast.Load()) + return ret[0] + s - @builds("with-decorator") - @checkargs(min=1) - def compile_decorate_expression(self, expr): - expr.pop(0) # with-decorator - fn = self.compile(expr.pop()) + @special("with-decorator", [oneplus(FORM)]) + def compile_decorate_expression(self, expr, name, args): + decs, fn = args[:-1], self.compile(args[-1]) if not fn.stmts or not isinstance(fn.stmts[-1], _decoratables): - raise HyTypeError(expr, "Decorated a non-function") - decorators, ret, _ = self._compile_collect(expr) - fn.stmts[-1].decorator_list = decorators + fn.stmts[-1].decorator_list + raise HyTypeError(args[-1], "Decorated a non-function") + decs, ret, _ = self._compile_collect(decs) + fn.stmts[-1].decorator_list = decs + fn.stmts[-1].decorator_list return ret + fn - @builds("with*") - @builds("with/a*", iff=PY35) - @checkargs(min=2) - def compile_with_expression(self, expr): - root = expr.pop(0) + @special(["with*", (PY35, "with/a*")], + [brackets(FORM, maybe(FORM)), many(FORM)]) + def compile_with_expression(self, expr, root, args, body): + thing, ctx = (None, args[0]) if args[1] is None else args + if thing is not None: + thing = self._storeize(thing, self.compile(thing)) + ctx = self.compile(ctx) - args = expr.pop(0) - if not isinstance(args, HyList): - raise HyTypeError(expr, - "{0} expects a list, received `{1}'".format( - root, type(args).__name__)) - if len(args) not in (1, 2): - raise HyTypeError(expr, - "{0} needs [arg (expr)] or [(expr)]".format(root)) - - thing = None - if len(args) == 2: - thing = self._storeize(args[0], self.compile(args.pop(0))) - ctx = self.compile(args.pop(0)) - - body = self._compile_branch(expr) + body = self._compile_branch(body) # Store the result of the body in a tempvar var = self.get_anon_var() @@ -1381,16 +1347,14 @@ class HyASTCompiler(object): value=value.force_expr, generators=gen) - @builds("not", "~") - @checkargs(1) - def compile_unary_operator(self, expression): + @special(["not", "~"], [FORM]) + def compile_unary_operator(self, expr, root, arg): ops = {"not": ast.Not, "~": ast.Invert} - operator = expression.pop(0) - operand = self.compile(expression.pop(0)) + operand = self.compile(arg) operand += asty.UnaryOp( - expression, op=ops[operator](), operand=operand.force_expr) + expr, op=ops[root](), operand=operand.force_expr) return operand @@ -1442,18 +1406,17 @@ class HyASTCompiler(object): raise HyTypeError(entry, "unrecognized (require) syntax") return Result() - @builds("and", "or") - def compile_logical_or_and_and_operator(self, expression): + @special(["and", "or"], [many(FORM)]) + def compile_logical_or_and_and_operator(self, expr, operator, args): ops = {"and": (ast.And, "True"), "or": (ast.Or, "None")} - operator = expression.pop(0) opnode, default = ops[operator] - if len(expression) == 0: + if len(args) == 0: return asty.Name(operator, id=default, ctx=ast.Load()) - elif len(expression) == 1: - return self.compile(expression[0]) + elif len(args) == 1: + return self.compile(args[0]) ret = Result() - values = list(map(self.compile, expression)) + values = list(map(self.compile, args)) if any(value.stmts for value in values): # Compile it to an if...else sequence var = self.get_anon_var() @@ -1500,33 +1463,29 @@ class HyASTCompiler(object): "in": ast.In, "not-in": ast.NotIn} ops = {ast_str(k): v for k, v in ops.items()} - def _compile_compare_op_expression(self, expression): - inv = ast_str(expression.pop(0)) - ops = [self.ops[inv]() for _ in range(len(expression) - 1)] + def _compile_compare_op_expression(self, expr, root, args): + inv = ast_str(root) + ops = [self.ops[inv]() for _ in args[1:]] - e = expression[0] - exprs, ret, _ = self._compile_collect(expression) + exprs, ret, _ = self._compile_collect(args) return ret + asty.Compare( - e, left=exprs[0], ops=ops, comparators=exprs[1:]) + expr, left=exprs[0], ops=ops, comparators=exprs[1:]) - @builds("=", "is", "<", "<=", ">", ">=") - @checkargs(min=1) - def compile_compare_op_expression(self, expression): - if len(expression) == 2: - return (self.compile(expression[1]) + - asty.Name(expression, id="True", ctx=ast.Load())) - return self._compile_compare_op_expression(expression) + @special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)]) + def compile_compare_op_expression(self, expr, root, args): + if len(args) == 1: + return (self.compile(args[0]) + + asty.Name(expr, id="True", ctx=ast.Load())) + return self._compile_compare_op_expression(expr, root, args) - @builds("!=", "is-not") - @checkargs(min=2) - def compile_compare_op_expression_coll(self, expression): - return self._compile_compare_op_expression(expression) + @special(["!=", "is-not"], [FORM, oneplus(FORM)]) + def compile_compare_op_expression_coll(self, expr, root, arg1, args): + return self._compile_compare_op_expression(expr, root, [arg1] + args) - @builds("in", "not-in") - @checkargs(2) - def compile_compare_op_expression_binary(self, expression): - return self._compile_compare_op_expression(expression) + @special(["in", "not-in"], [FORM, FORM]) + def compile_compare_op_expression_binary(self, expr, root, needle, haystack): + return self._compile_compare_op_expression(expr, root, [needle, haystack]) def _compile_maths_expression(self, expr): ops = {"+": ast.Add, @@ -1733,22 +1692,14 @@ class HyASTCompiler(object): expression, func=func.expr, args=args, keywords=keywords, starargs=oldpy_star, kwargs=oldpy_kw) - @builds("setv") - def compile_def_expression(self, expression): - root = expression.pop(0) - if not expression: + @special("setv", [many(FORM + FORM)]) + def compile_def_expression(self, expr, root, pairs): + if len(pairs) == 0: return asty.Name(root, id='None', ctx=ast.Load()) - elif len(expression) == 2: - return self._compile_assign(expression[0], expression[1]) - elif len(expression) % 2 != 0: - raise HyTypeError(expression, - "`{}' needs an even number of arguments".format( - root)) - else: - result = Result() - for tgt, target in zip(expression[::2], expression[1::2]): - result += self._compile_assign(tgt, target) - return result + result = Result() + for pair in pairs: + result += self._compile_assign(*pair) + return result def _compile_assign(self, name, result): @@ -1781,63 +1732,40 @@ class HyASTCompiler(object): return result - @builds("for*") - @builds("for/a*", iff=PY35) - @checkargs(min=1) - def compile_for_expression(self, expression): - root = expression.pop(0) - - args = expression.pop(0) - if not isinstance(args, HyList): - raise HyTypeError(expression, - "`{0}` expects a list, received `{1}`".format( - root, type(args).__name__)) - - try: - target_name, iterable = args - except ValueError: - raise HyTypeError(expression, - "`for` requires two forms in the list") - + @special(["for*", (PY35, "for/a*")], + [brackets(FORM, FORM), many(notpexpr("else")), maybe(dolike("else"))]) + def compile_for_expression(self, expr, root, args, body, else_expr): + target_name, iterable = args target = self._storeize(target_name, self.compile(target_name)) ret = Result() orel = Result() - # (for* [] body (else …)) - if ends_with_else(expression): - else_expr = expression.pop() - for else_body in else_expr[1:]: + if else_expr is not None: + for else_body in else_expr: orel += self.compile(else_body) orel += orel.expr_as_stmt() ret += self.compile(iterable) - body = self._compile_branch(expression) + body = self._compile_branch(body) body += body.expr_as_stmt() node = asty.For if root == 'for*' else asty.AsyncFor - ret += node(expression, + ret += node(expr, target=target, iter=ret.force_expr, - body=body.stmts or [asty.Pass(expression)], + body=body.stmts or [asty.Pass(expr)], orelse=orel.stmts) ret.contains_yield = body.contains_yield return ret - @builds("while") - @checkargs(min=1) - def compile_while_expression(self, expr): - expr.pop(0) # "while" - cond = expr.pop(0) + @special(["while"], [FORM, many(notpexpr("else")), maybe(dolike("else"))]) + def compile_while_expression(self, expr, root, cond, body, else_expr): cond_compiled = self.compile(cond) - else_expr = None - if ends_with_else(expr): - else_expr = expr.pop() - if cond_compiled.stmts: # We need to ensure the statements for the condition are # executed on every iteration. Rewrite the loop to use a @@ -1853,16 +1781,16 @@ class HyASTCompiler(object): # changes its truth value, but use (not (not ...)) instead of # `bool` in case `bool` has been redefined. e(s('setv'), cond_var, e(s('not'), e(s('not'), cond))), - e(s('if*'), cond_var, e(s('do'), *expr)), - *([else_expr] if else_expr is not None else []))).replace(expr)) # noqa + e(s('if*'), cond_var, e(s('do'), *body)), + *([e(s('else'), *else_expr)] if else_expr is not None else []))).replace(expr)) # noqa orel = Result() if else_expr is not None: - for else_body in else_expr[1:]: + for else_body in else_expr: orel += self.compile(else_body) orel += orel.expr_as_stmt() - body = self._compile_branch(expr) + body = self._compile_branch(body) body += body.expr_as_stmt() ret = cond_compiled + asty.While( diff --git a/hy/model_patterns.py b/hy/model_patterns.py new file mode 100644 index 0000000..36b91ae --- /dev/null +++ b/hy/model_patterns.py @@ -0,0 +1,76 @@ +# Copyright 2018 the authors. +# This file is part of Hy, which is free software licensed under the Expat +# license. See the LICENSE. + +"Parser combinators for pattern-matching Hy model trees." + +from hy.models import HyExpression, HySymbol, HyKeyword, HyString, HyList +from funcparserlib.parser import ( + some, skip, many, finished, a, Parser, NoParseError, State) +from functools import reduce +from itertools import repeat +from operator import add +from math import isinf + +FORM = some(lambda _: True) +SYM = some(lambda x: isinstance(x, HySymbol)) +STR = some(lambda x: isinstance(x, HyString)) + +def sym(wanted): + "Parse and skip the given symbol or keyword." + if wanted.startswith(":"): + return skip(a(HyKeyword(wanted[1:]))) + return skip(some(lambda x: isinstance(x, HySymbol) and x == wanted)) + +def whole(parsers): + """Parse the parsers in the given list one after another, then + expect the end of the input.""" + if len(parsers) == 0: + return finished >> (lambda x: []) + if len(parsers) == 1: + return parsers[0] + finished >> (lambda x: x[:-1]) + return reduce(add, parsers) + skip(finished) + +def _grouped(group_type, parsers): return ( + some(lambda x: isinstance(x, group_type)) >> + (lambda x: group_type(whole(parsers).parse(x)).replace(x, recursive=False))) + +def brackets(*parsers): + "Parse the given parsers inside square brackets." + return _grouped(HyList, parsers) + +def pexpr(*parsers): + "Parse the given parsers inside a parenthesized expression." + return _grouped(HyExpression, parsers) + +def dolike(head): + "Parse a `do`-like form." + return pexpr(sym(head), many(FORM)) + +def notpexpr(*disallowed_heads): + """Parse any object other than a HyExpression beginning with a + HySymbol equal to one of the disallowed_heads.""" + return some(lambda x: not ( + isinstance(x, HyExpression) and + x and + isinstance(x[0], HySymbol) and + x[0] in disallowed_heads)) + +def times(lo, hi, parser): + """Parse `parser` several times (`lo` to `hi`) in a row. `hi` can be + float('inf'). The result is a list no matter the number of instances.""" + @Parser + def f(tokens, s): + result = [] + for _ in range(lo): + (v, s) = parser.run(tokens, s) + result.append(v) + end = s.max + try: + for _ in (repeat(1) if isinf(hi) else range(hi - lo)): + (v, s) = parser.run(tokens, s) + result.append(v) + except NoParseError as e: + end = e.state.max + return result, State(s.pos, end) + return f diff --git a/setup.py b/setup.py index f5ebb9a..f6e4c65 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ class Install(install): "." + filename[:-len(".hy")]) install.run(self) -install_requires = ['rply>=0.7.5', 'astor', 'clint>=0.4'] +install_requires = ['rply>=0.7.5', 'astor', 'funcparserlib>=0.3.6', 'clint>=0.4'] if os.name == 'nt': install_requires.append('pyreadline>=2.1') diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index b1e7ef8..69dddfc 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -511,7 +511,6 @@ def test_compile_error(): """Ensure we get compile error in tricky cases""" with pytest.raises(HyTypeError) as excinfo: can_compile("(fn [] (in [1 2 3]))") - assert excinfo.value.message == "`in' needs 2 arguments, got 1" def test_for_compile_error(): diff --git a/tests/native_tests/language.hy b/tests/native_tests/language.hy index fbda979..a92628f 100644 --- a/tests/native_tests/language.hy +++ b/tests/native_tests/language.hy @@ -87,8 +87,8 @@ (assert (= b 2)) (setv y 0 x 1 y x) (assert (= y 1)) - (try (eval '(setv a 1 b)) - (except [e [TypeError]] (assert (in "`setv' needs an even number of arguments" (str e)))))) + (with [(pytest.raises HyTypeError)] + (eval '(setv a 1 b)))) (defn test-setv-returns-none [] @@ -280,7 +280,15 @@ y (range 2)] (+ 1 1) (else (setv flag (+ flag 2)))) - (assert (= flag 2))) + (assert (= flag 2)) + + (setv l []) + (defn f [] + (for [x [4 9 2]] + (.append l (* 10 x)) + (yield x))) + (for [_ (f)]) + (assert (= l [40 90 20]))) (defn test-while-loop []