diff --git a/hy/compiler.py b/hy/compiler.py index ff38181..9a130f0 100755 --- a/hy/compiler.py +++ b/hy/compiler.py @@ -2228,9 +2228,16 @@ def hy_compile(tree, module_name, root=ast.Module, get_expr=False): if not get_expr: result += result.expr_as_stmt() + module_docstring = None + if (PY37 and result.stmts and + isinstance(result.stmts[0], ast.Expr) and + isinstance(result.stmts[0].value, ast.Str)): + module_docstring = result.stmts.pop(0).value.s + body = compiler.imports_as_stmts(tree) + result.stmts - ret = root(body=body) + ret = root(body=body, docstring=( + None if module_docstring is None else module_docstring)) if get_expr: expr = ast.Expression(body=expr) diff --git a/tests/compilers/test_ast.py b/tests/compilers/test_ast.py index af87f8c..2aea3a9 100644 --- a/tests/compilers/test_ast.py +++ b/tests/compilers/test_ast.py @@ -53,7 +53,7 @@ def cant_compile(expr): def s(x): - return can_compile(x).body[0].value.s + return can_compile('"module docstring" ' + x).body[-1].value.s def test_ast_bad_type(): @@ -476,13 +476,12 @@ def test_ast_unicode_strings(): def _compile_string(s): hy_s = HyString(s) - hy_s.start_line = hy_s.end_line = 0 - hy_s.start_column = hy_s.end_column = 0 - code = hy_compile(hy_s, "__main__") + code = hy_compile([hy_s], "__main__") + # We put hy_s in a list so it isn't interpreted as a docstring. - # code == ast.Module(body=[ast.Expr(value=ast.Str(s=xxx))]) - return code.body[0].value.s + # code == ast.Module(body=[ast.Expr(value=ast.List(elts=[ast.Str(s=xxx)]))]) + return code.body[0].value.elts[0].s assert _compile_string("test") == "test" assert _compile_string("\u03b1\u03b2") == "\u03b1\u03b2"