Merge pull request #1683 from brandonwillard/fix-py27-failed-import-modules
Fix `sys.modules` for failed imports in Python 2.7
This commit is contained in:
commit
4132adb9fe
@ -293,7 +293,8 @@ else:
|
||||
mod_type == imp.PKG_DIRECTORY and
|
||||
os.path.isfile(pkg_path)):
|
||||
|
||||
if fullname in sys.modules:
|
||||
was_in_sys = fullname in sys.modules
|
||||
if was_in_sys:
|
||||
mod = sys.modules[fullname]
|
||||
else:
|
||||
mod = sys.modules.setdefault(
|
||||
@ -311,7 +312,15 @@ else:
|
||||
|
||||
mod.__name__ = fullname
|
||||
|
||||
self.exec_module(mod, fullname=fullname)
|
||||
try:
|
||||
self.exec_module(mod, fullname=fullname)
|
||||
except Exception:
|
||||
# Follow Python 2.7 logic and only remove a new, bad
|
||||
# module; otherwise, leave the old--and presumably
|
||||
# good--module in there.
|
||||
if not was_in_sys:
|
||||
del sys.modules[fullname]
|
||||
raise
|
||||
|
||||
if mod is None:
|
||||
self._reopen()
|
||||
@ -385,7 +394,7 @@ else:
|
||||
self.code = self.byte_compile_hy(fullname)
|
||||
|
||||
if self.code is None:
|
||||
super(HyLoader, self).get_code(fullname=fullname)
|
||||
super(HyLoader, self).get_code(fullname=fullname)
|
||||
|
||||
return self.code
|
||||
|
||||
|
@ -5,7 +5,6 @@
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import imp
|
||||
import tempfile
|
||||
import runpy
|
||||
import importlib
|
||||
@ -15,12 +14,16 @@ from fractions import Fraction
|
||||
import pytest
|
||||
|
||||
import hy
|
||||
from hy._compat import bytes_type
|
||||
from hy.errors import HyTypeError
|
||||
from hy.lex import LexException
|
||||
from hy.compiler import hy_compile
|
||||
from hy.importer import hy_parse, HyLoader, cache_from_source
|
||||
|
||||
try:
|
||||
from importlib import reload
|
||||
except ImportError:
|
||||
from imp import reload
|
||||
|
||||
|
||||
def test_basics():
|
||||
"Make sure the basics of the importer work"
|
||||
@ -85,6 +88,15 @@ def test_import_error_reporting():
|
||||
assert _import_error_test() is not None
|
||||
|
||||
|
||||
def test_import_error_cleanup():
|
||||
"Failed initial imports should not leave dead modules in `sys.modules`."
|
||||
|
||||
with pytest.raises(hy.errors.HyMacroExpansionError):
|
||||
importlib.import_module('tests.resources.fails')
|
||||
|
||||
assert 'tests.resources.fails' not in sys.modules
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.dont_write_bytecode,
|
||||
reason="Bytecode generation is suppressed")
|
||||
def test_import_autocompiles():
|
||||
@ -127,7 +139,13 @@ def test_eval():
|
||||
|
||||
|
||||
def test_reload():
|
||||
"""Copied from CPython's `test_import.py`"""
|
||||
"""Generate a test module, confirm that it imports properly (and puts the
|
||||
module in `sys.modules`), then modify the module so that it produces an
|
||||
error when reloaded. Next, fix the error, reload, and check that the
|
||||
module is updated and working fine. Rinse, repeat.
|
||||
|
||||
This test is adapted from CPython's `test_import.py`.
|
||||
"""
|
||||
|
||||
def unlink(filename):
|
||||
os.unlink(source)
|
||||
@ -160,7 +178,7 @@ def test_reload():
|
||||
f.write("(setv b (// 20 0))")
|
||||
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
imp.reload(mod)
|
||||
reload(mod)
|
||||
|
||||
# But we still expect the module to be in sys.modules.
|
||||
mod = sys.modules.get(TESTFN)
|
||||
@ -178,7 +196,7 @@ def test_reload():
|
||||
f.write("(setv a 11)")
|
||||
f.write("(setv b (// 20 1))")
|
||||
|
||||
imp.reload(mod)
|
||||
reload(mod)
|
||||
|
||||
mod = sys.modules.get(TESTFN)
|
||||
assert mod is not None
|
||||
@ -186,15 +204,17 @@ def test_reload():
|
||||
assert mod.a == 11
|
||||
assert mod.b == 20
|
||||
|
||||
# Now cause a LexException
|
||||
# Now cause a `LexException`, and confirm that the good module and its
|
||||
# contents stick around.
|
||||
unlink(source)
|
||||
|
||||
with open(source, "w") as f:
|
||||
# Missing paren...
|
||||
f.write("(setv a 11")
|
||||
f.write("(setv b (// 20 1))")
|
||||
|
||||
with pytest.raises(LexException):
|
||||
imp.reload(mod)
|
||||
reload(mod)
|
||||
|
||||
mod = sys.modules.get(TESTFN)
|
||||
assert mod is not None
|
||||
@ -209,7 +229,7 @@ def test_reload():
|
||||
f.write("(setv a 12)")
|
||||
f.write("(setv b (// 10 1))")
|
||||
|
||||
imp.reload(mod)
|
||||
reload(mod)
|
||||
|
||||
mod = sys.modules.get(TESTFN)
|
||||
assert mod is not None
|
||||
@ -219,8 +239,9 @@ def test_reload():
|
||||
|
||||
finally:
|
||||
del sys.path[0]
|
||||
if TESTFN in sys.modules:
|
||||
del sys.modules[TESTFN]
|
||||
unlink(source)
|
||||
del sys.modules[TESTFN]
|
||||
|
||||
|
||||
def test_circular():
|
||||
|
5
tests/resources/fails.hy
Normal file
5
tests/resources/fails.hy
Normal file
@ -0,0 +1,5 @@
|
||||
"This module produces an error when imported."
|
||||
(defmacro a-macro [x]
|
||||
(+ x 1))
|
||||
|
||||
(print (a-macro 'blah))
|
Loading…
Reference in New Issue
Block a user