From a9763b34cf358e95d3f88f4818c95568d181f5ad Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 29 Sep 2018 20:46:14 -0500 Subject: [PATCH] Fix `sys.modules` for failed imports in Python 2.7 Newly imported modules with compile and/or run-time errors were not being removed from `sys.modules`. This commit modifies the Python 2.7 loader so that it follows Python's failed-initial-import logic and removes the module from `sys.modules`. --- hy/importer.py | 15 ++++++++++--- tests/importer/test_importer.py | 39 +++++++++++++++++++++++++-------- tests/resources/fails.hy | 5 +++++ 3 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 tests/resources/fails.hy diff --git a/hy/importer.py b/hy/importer.py index 1bc8573..18bab3f 100644 --- a/hy/importer.py +++ b/hy/importer.py @@ -293,7 +293,8 @@ else: mod_type == imp.PKG_DIRECTORY and os.path.isfile(pkg_path)): - if fullname in sys.modules: + was_in_sys = fullname in sys.modules + if was_in_sys: mod = sys.modules[fullname] else: mod = sys.modules.setdefault( @@ -311,7 +312,15 @@ else: mod.__name__ = fullname - self.exec_module(mod, fullname=fullname) + try: + self.exec_module(mod, fullname=fullname) + except Exception: + # Follow Python 2.7 logic and only remove a new, bad + # module; otherwise, leave the old--and presumably + # good--module in there. + if not was_in_sys: + del sys.modules[fullname] + raise if mod is None: self._reopen() @@ -385,7 +394,7 @@ else: self.code = self.byte_compile_hy(fullname) if self.code is None: - super(HyLoader, self).get_code(fullname=fullname) + super(HyLoader, self).get_code(fullname=fullname) return self.code diff --git a/tests/importer/test_importer.py b/tests/importer/test_importer.py index 900da03..224670e 100644 --- a/tests/importer/test_importer.py +++ b/tests/importer/test_importer.py @@ -5,7 +5,6 @@ import os import sys import ast -import imp import tempfile import runpy import importlib @@ -15,12 +14,16 @@ from fractions import Fraction import pytest import hy -from hy._compat import bytes_type from hy.errors import HyTypeError from hy.lex import LexException from hy.compiler import hy_compile from hy.importer import hy_parse, HyLoader, cache_from_source +try: + from importlib import reload +except ImportError: + from imp import reload + def test_basics(): "Make sure the basics of the importer work" @@ -85,6 +88,15 @@ def test_import_error_reporting(): assert _import_error_test() is not None +def test_import_error_cleanup(): + "Failed initial imports should not leave dead modules in `sys.modules`." + + with pytest.raises(hy.errors.HyMacroExpansionError): + importlib.import_module('tests.resources.fails') + + assert 'tests.resources.fails' not in sys.modules + + @pytest.mark.skipif(sys.dont_write_bytecode, reason="Bytecode generation is suppressed") def test_import_autocompiles(): @@ -127,7 +139,13 @@ def test_eval(): def test_reload(): - """Copied from CPython's `test_import.py`""" + """Generate a test module, confirm that it imports properly (and puts the + module in `sys.modules`), then modify the module so that it produces an + error when reloaded. Next, fix the error, reload, and check that the + module is updated and working fine. Rinse, repeat. + + This test is adapted from CPython's `test_import.py`. + """ def unlink(filename): os.unlink(source) @@ -160,7 +178,7 @@ def test_reload(): f.write("(setv b (// 20 0))") with pytest.raises(ZeroDivisionError): - imp.reload(mod) + reload(mod) # But we still expect the module to be in sys.modules. mod = sys.modules.get(TESTFN) @@ -178,7 +196,7 @@ def test_reload(): f.write("(setv a 11)") f.write("(setv b (// 20 1))") - imp.reload(mod) + reload(mod) mod = sys.modules.get(TESTFN) assert mod is not None @@ -186,15 +204,17 @@ def test_reload(): assert mod.a == 11 assert mod.b == 20 - # Now cause a LexException + # Now cause a `LexException`, and confirm that the good module and its + # contents stick around. unlink(source) with open(source, "w") as f: + # Missing paren... f.write("(setv a 11") f.write("(setv b (// 20 1))") with pytest.raises(LexException): - imp.reload(mod) + reload(mod) mod = sys.modules.get(TESTFN) assert mod is not None @@ -209,7 +229,7 @@ def test_reload(): f.write("(setv a 12)") f.write("(setv b (// 10 1))") - imp.reload(mod) + reload(mod) mod = sys.modules.get(TESTFN) assert mod is not None @@ -219,8 +239,9 @@ def test_reload(): finally: del sys.path[0] + if TESTFN in sys.modules: + del sys.modules[TESTFN] unlink(source) - del sys.modules[TESTFN] def test_circular(): diff --git a/tests/resources/fails.hy b/tests/resources/fails.hy new file mode 100644 index 0000000..516fb8e --- /dev/null +++ b/tests/resources/fails.hy @@ -0,0 +1,5 @@ +"This module produces an error when imported." +(defmacro a-macro [x] + (+ x 1)) + +(print (a-macro 'blah))