From 274f5e9002b3511c7b20e097b24623f9b3eb70d6 Mon Sep 17 00:00:00 2001 From: David Schaefer Date: Wed, 9 Aug 2017 10:00:21 +0200 Subject: [PATCH] Fix copy behaviour of HyComplex --- hy/models.py | 19 +++++++++---------- tests/test_models.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/hy/models.py b/hy/models.py index 77c580f..574db49 100644 --- a/hy/models.py +++ b/hy/models.py @@ -171,18 +171,17 @@ class HyComplex(HyObject, complex): complex(foo) was called, given HyComplex(foo). """ - def __new__(cls, num, *args, **kwargs): - value = super(HyComplex, cls).__new__(cls, strip_digit_separators(num)) - if isinstance(num, string_types): - p1, _, p2 = num.lstrip("+-").replace("-", "+").partition("+") + def __new__(cls, real, imag=0, *args, **kwargs): + if isinstance(real, string_types): + value = super(HyComplex, cls).__new__( + cls, strip_digit_separators(real) + ) + p1, _, p2 = real.lstrip("+-").replace("-", "+").partition("+") + check_inf_nan_cap(p1, value.imag if "j" in p1 else value.real) if p2: - check_inf_nan_cap(p1, value.real) check_inf_nan_cap(p2, value.imag) - elif "j" in p1: - check_inf_nan_cap(p1, value.imag) - else: - check_inf_nan_cap(p1, value.real) - return value + return value + return super(HyComplex, cls).__new__(cls, real, imag) _wrappers[complex] = HyComplex diff --git a/tests/test_models.py b/tests/test_models.py index 59ed809..291caa7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,9 +2,10 @@ # This file is part of Hy, which is free software licensed under the Expat # license. See the LICENSE. +import copy from hy._compat import long_type, str_type from hy.models import (wrap_value, replace_hy_obj, HyString, HyInteger, HyList, - HyDict, HySet, HyExpression, HyCons) + HyDict, HySet, HyExpression, HyCons, HyComplex, HyFloat) def test_wrap_long_type(): @@ -125,3 +126,17 @@ def test_cons_replacing(): assert True is False except IndexError: pass + + +def test_number_model_copy(): + i = HyInteger(42) + assert (i == copy.copy(i)) + assert (i == copy.deepcopy(i)) + + f = HyFloat(42.) + assert (f == copy.copy(f)) + assert (f == copy.deepcopy(f)) + + c = HyComplex(42j) + assert (c == copy.copy(c)) + assert (c == copy.deepcopy(c))