diff --git a/tests/importer/test_importer.py b/tests/importer/test_importer.py index c41f86b..d9c3565 100644 --- a/tests/importer/test_importer.py +++ b/tests/importer/test_importer.py @@ -3,10 +3,12 @@ # license. See the LICENSE. import os +import sys import ast +import imp import tempfile -import importlib import runpy +import importlib from fractions import Fraction @@ -14,6 +16,7 @@ import pytest import hy from hy.errors import HyTypeError +from hy.lex import LexException from hy.compiler import hy_compile from hy.importer import hy_parse, HyLoader, cache_from_source @@ -98,3 +101,100 @@ def test_eval(): '(if True "this is if true" "this is if false")') == "this is if true" assert eval_str('(lfor num (range 100) :if (= (% num 2) 1) (pow num 2))') == [ pow(num, 2) for num in range(100) if num % 2 == 1] + + +def test_reload(): + """Copied from CPython's `test_import.py`""" + + def unlink(filename): + os.unlink(source) + bytecode = cache_from_source(source) + if os.path.isfile(bytecode): + os.unlink(bytecode) + + TESTFN = 'testfn' + source = TESTFN + os.extsep + "hy" + with open(source, "w") as f: + f.write("(setv a 1)") + f.write("(setv b 2)") + + sys.path.insert(0, os.curdir) + try: + mod = importlib.import_module(TESTFN) + assert TESTFN in sys.modules + assert mod.a == 1 + assert mod.b == 2 + + # On WinXP, just replacing the .py file wasn't enough to + # convince reload() to reparse it. Maybe the timestamp didn't + # move enough. We force it to get reparsed by removing the + # compiled file too. + unlink(source) + + # Now damage the module. + with open(source, "w") as f: + f.write("(setv a 10)") + f.write("(setv b (// 20 0))") + + with pytest.raises(ZeroDivisionError): + imp.reload(mod) + + # But we still expect the module to be in sys.modules. + mod = sys.modules.get(TESTFN) + assert mod is not None + + # We should have replaced a w/ 10, but the old b value should + # stick. + assert mod.a == 10 + assert mod.b == 2 + + # Now fix the issue and reload the module. + unlink(source) + + with open(source, "w") as f: + f.write("(setv a 11)") + f.write("(setv b (// 20 1))") + + imp.reload(mod) + + mod = sys.modules.get(TESTFN) + assert mod is not None + + assert mod.a == 11 + assert mod.b == 20 + + # Now cause a LexException + unlink(source) + + with open(source, "w") as f: + f.write("(setv a 11") + f.write("(setv b (// 20 1))") + + with pytest.raises(LexException): + imp.reload(mod) + + mod = sys.modules.get(TESTFN) + assert mod is not None + + assert mod.a == 11 + assert mod.b == 20 + + # Fix it and retry + unlink(source) + + with open(source, "w") as f: + f.write("(setv a 12)") + f.write("(setv b (// 10 1))") + + imp.reload(mod) + + mod = sys.modules.get(TESTFN) + assert mod is not None + + assert mod.a == 12 + assert mod.b == 10 + + finally: + del sys.path[0] + unlink(source) + del sys.modules[TESTFN]