Merge pull request #1683 from brandonwillard/fix-py27-failed-import-modules

Fix `sys.modules` for failed imports in Python 2.7
This commit is contained in:
Kodi Arfer 2018-10-16 16:05:40 -04:00 committed by GitHub
commit 4132adb9fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 12 deletions

View File

@ -293,7 +293,8 @@ else:
mod_type == imp.PKG_DIRECTORY and mod_type == imp.PKG_DIRECTORY and
os.path.isfile(pkg_path)): os.path.isfile(pkg_path)):
if fullname in sys.modules: was_in_sys = fullname in sys.modules
if was_in_sys:
mod = sys.modules[fullname] mod = sys.modules[fullname]
else: else:
mod = sys.modules.setdefault( mod = sys.modules.setdefault(
@ -311,7 +312,15 @@ else:
mod.__name__ = fullname mod.__name__ = fullname
try:
self.exec_module(mod, fullname=fullname) self.exec_module(mod, fullname=fullname)
except Exception:
# Follow Python 2.7 logic and only remove a new, bad
# module; otherwise, leave the old--and presumably
# good--module in there.
if not was_in_sys:
del sys.modules[fullname]
raise
if mod is None: if mod is None:
self._reopen() self._reopen()

View File

@ -5,7 +5,6 @@
import os import os
import sys import sys
import ast import ast
import imp
import tempfile import tempfile
import runpy import runpy
import importlib import importlib
@ -15,12 +14,16 @@ from fractions import Fraction
import pytest import pytest
import hy import hy
from hy._compat import bytes_type
from hy.errors import HyTypeError from hy.errors import HyTypeError
from hy.lex import LexException from hy.lex import LexException
from hy.compiler import hy_compile from hy.compiler import hy_compile
from hy.importer import hy_parse, HyLoader, cache_from_source from hy.importer import hy_parse, HyLoader, cache_from_source
try:
from importlib import reload
except ImportError:
from imp import reload
def test_basics(): def test_basics():
"Make sure the basics of the importer work" "Make sure the basics of the importer work"
@ -85,6 +88,15 @@ def test_import_error_reporting():
assert _import_error_test() is not None assert _import_error_test() is not None
def test_import_error_cleanup():
"Failed initial imports should not leave dead modules in `sys.modules`."
with pytest.raises(hy.errors.HyMacroExpansionError):
importlib.import_module('tests.resources.fails')
assert 'tests.resources.fails' not in sys.modules
@pytest.mark.skipif(sys.dont_write_bytecode, @pytest.mark.skipif(sys.dont_write_bytecode,
reason="Bytecode generation is suppressed") reason="Bytecode generation is suppressed")
def test_import_autocompiles(): def test_import_autocompiles():
@ -127,7 +139,13 @@ def test_eval():
def test_reload(): def test_reload():
"""Copied from CPython's `test_import.py`""" """Generate a test module, confirm that it imports properly (and puts the
module in `sys.modules`), then modify the module so that it produces an
error when reloaded. Next, fix the error, reload, and check that the
module is updated and working fine. Rinse, repeat.
This test is adapted from CPython's `test_import.py`.
"""
def unlink(filename): def unlink(filename):
os.unlink(source) os.unlink(source)
@ -160,7 +178,7 @@ def test_reload():
f.write("(setv b (// 20 0))") f.write("(setv b (// 20 0))")
with pytest.raises(ZeroDivisionError): with pytest.raises(ZeroDivisionError):
imp.reload(mod) reload(mod)
# But we still expect the module to be in sys.modules. # But we still expect the module to be in sys.modules.
mod = sys.modules.get(TESTFN) mod = sys.modules.get(TESTFN)
@ -178,7 +196,7 @@ def test_reload():
f.write("(setv a 11)") f.write("(setv a 11)")
f.write("(setv b (// 20 1))") f.write("(setv b (// 20 1))")
imp.reload(mod) reload(mod)
mod = sys.modules.get(TESTFN) mod = sys.modules.get(TESTFN)
assert mod is not None assert mod is not None
@ -186,15 +204,17 @@ def test_reload():
assert mod.a == 11 assert mod.a == 11
assert mod.b == 20 assert mod.b == 20
# Now cause a LexException # Now cause a `LexException`, and confirm that the good module and its
# contents stick around.
unlink(source) unlink(source)
with open(source, "w") as f: with open(source, "w") as f:
# Missing paren...
f.write("(setv a 11") f.write("(setv a 11")
f.write("(setv b (// 20 1))") f.write("(setv b (// 20 1))")
with pytest.raises(LexException): with pytest.raises(LexException):
imp.reload(mod) reload(mod)
mod = sys.modules.get(TESTFN) mod = sys.modules.get(TESTFN)
assert mod is not None assert mod is not None
@ -209,7 +229,7 @@ def test_reload():
f.write("(setv a 12)") f.write("(setv a 12)")
f.write("(setv b (// 10 1))") f.write("(setv b (// 10 1))")
imp.reload(mod) reload(mod)
mod = sys.modules.get(TESTFN) mod = sys.modules.get(TESTFN)
assert mod is not None assert mod is not None
@ -219,8 +239,9 @@ def test_reload():
finally: finally:
del sys.path[0] del sys.path[0]
unlink(source) if TESTFN in sys.modules:
del sys.modules[TESTFN] del sys.modules[TESTFN]
unlink(source)
def test_circular(): def test_circular():

5
tests/resources/fails.hy Normal file
View File

@ -0,0 +1,5 @@
"This module produces an error when imported."
(defmacro a-macro [x]
(+ x 1))
(print (a-macro 'blah))