diff --git a/tap/tap.py b/tap/tap.py index ce9b632..4d76355 100644 --- a/tap/tap.py +++ b/tap/tap.py @@ -1,7 +1,6 @@ from argparse import ArgumentParser from collections import OrderedDict from copy import deepcopy -from itertools import cycle import json from pprint import pformat import sys @@ -23,9 +22,12 @@ boolean_type, TupleTypeEnforcer, PythonObjectEncoder, - as_python_object + as_python_object, + fix_py36_copy ) + +# Constants EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple() SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool} @@ -486,3 +488,18 @@ def __str__(self) -> str: :return: A formatted string representation of the dictionary of all arguments. """ return pformat(self.as_dict()) + + @fix_py36_copy + def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType: + """Deepcopy the Tap object.""" + copied = type(self).__new__(type(self)) + + if memo is None: + memo = {} + + memo[id(self)] = copied + + for (k, v) in self.__dict__.items(): + copied.__dict__[k] = deepcopy(v, memo) + + return copied diff --git a/tap/utils.py b/tap/utils.py index f8cdfbf..4f75a5c 100644 --- a/tap/utils.py +++ b/tap/utils.py @@ -1,6 +1,8 @@ from argparse import ArgumentParser, ArgumentTypeError from base64 import b64encode, b64decode from collections import OrderedDict +import copy +from functools import wraps import inspect from io import StringIO from json import JSONEncoder @@ -8,11 +10,13 @@ import pickle import re import subprocess +import sys import tokenize -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union from typing_extensions import Literal from typing_inspect import get_args + NO_CHANGES_STATUS = """nothing to commit, working tree clean""" PRIMITIVES = (str, int, float, bool) @@ -247,16 +251,16 @@ def __call__(self, arg: str) -> Any: return arg -class Tuple: +class MockTuple: """Mock of a tuple needed to prevent JSON encoding tuples as lists.""" def __init__(self, _tuple: tuple) -> None: self.tuple = _tuple -def nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any: +def _nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any: """Replaces any instance (including instances within lists, tuple, dict) of find_type with an instance of replace_type. - Note: Tuples, lists, and dictionaries are NOT modified in place, they are replaced. + Note: Tuples, lists, and dicts are NOT modified in place. Note: Does NOT do a nested search through objects besides tuples, lists, and dicts (e.g. sets). :param obj: The object to modify by replacing find_type instances with replace_type instances. @@ -265,14 +269,16 @@ def nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any: :return: A version of obj with all instances of find_type replaced by replace_type """ if isinstance(obj, tuple): - obj = tuple(nested_replace_type(item, find_type, replace_type) for item in obj) + obj = tuple(_nested_replace_type(item, find_type, replace_type) for item in obj) - if isinstance(obj, list): - obj = [nested_replace_type(item, find_type, replace_type) for item in obj] + elif isinstance(obj, list): + obj = [_nested_replace_type(item, find_type, replace_type) for item in obj] - if isinstance(obj, dict): - obj = {nested_replace_type(key, find_type, replace_type): nested_replace_type(value, find_type, replace_type) - for key, value in obj.items()} + elif isinstance(obj, dict): + obj = { + _nested_replace_type(key, find_type, replace_type): _nested_replace_type(value, find_type, replace_type) + for key, value in obj.items() + } if isinstance(obj, find_type): obj = replace_type(obj) @@ -286,7 +292,7 @@ class PythonObjectEncoder(JSONEncoder): See: https://stackoverflow.com/a/36252257 """ def iterencode(self, o: Any, _one_shot: bool = False) -> Iterator[str]: - o = nested_replace_type(o, tuple, Tuple) + o = _nested_replace_type(o, tuple, MockTuple) return super(PythonObjectEncoder, self).iterencode(o, _one_shot) def default(self, obj: Any) -> Any: @@ -295,7 +301,7 @@ def default(self, obj: Any) -> Any: '_type': 'set', '_value': list(obj) } - elif isinstance(obj, Tuple): + elif isinstance(obj, MockTuple): return { '_type': 'tuple', '_value': list(obj.tuple) @@ -327,3 +333,31 @@ def as_python_object(dct: Any) -> Any: raise ValueError(f'Special type "{_type}" not supported for JSON loading.') return dct + + +def fix_py36_copy(func: Callable) -> Callable: + """Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers. + + Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object + """ + if sys.version_info[:2] != (3, 6): + return func + + re_type = type(re.compile('')) + + @wraps(func) + def wrapper(*args, **kwargs): + has_prev_val = re_type in copy._deepcopy_dispatch + prev_val = copy._deepcopy_dispatch.get(re_type, None) + copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r + + result = func(*args, **kwargs) + + if has_prev_val: + copy._deepcopy_dispatch[re_type] = prev_val + else: + del copy._deepcopy_dispatch[re_type] + + return result + + return wrapper diff --git a/tests/test_utils.py b/tests/test_utils.py index d05f343..061b355 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,7 +17,7 @@ type_to_str, get_literals, TupleTypeEnforcer, - nested_replace_type, + _nested_replace_type, PythonObjectEncoder, as_python_object ) @@ -320,19 +320,19 @@ def test_tuple_type_enforcer_infinite(self): class NestedReplaceTypeTests(TestCase): def test_nested_replace_type_notype(self): obj = ['123', 4, 5, ('hello', 4.4)] - replaced_obj = nested_replace_type(obj, bool, int) + replaced_obj = _nested_replace_type(obj, bool, int) self.assertEqual(obj, replaced_obj) def test_nested_replace_type_unnested(self): obj = ['123', 4, 5, ('hello', 4.4), True, False, 'hi there'] - replaced_obj = nested_replace_type(obj, tuple, list) + replaced_obj = _nested_replace_type(obj, tuple, list) correct_obj = ['123', 4, 5, ['hello', 4.4], True, False, 'hi there'] self.assertNotEqual(obj, replaced_obj) self.assertEqual(correct_obj, replaced_obj) def test_nested_replace_type_nested(self): obj = ['123', [4, (1, 2, (3, 4))], 5, ('hello', (4,), 4.4), {'1': [2, 3, [{'2': (3, 10)}, ' hi ']]}] - replaced_obj = nested_replace_type(obj, tuple, list) + replaced_obj = _nested_replace_type(obj, tuple, list) correct_obj = ['123', [4, [1, 2, [3, 4]]], 5, ['hello', [4], 4.4], {'1': [2, 3, [{'2': [3, 10]}, ' hi ']]}] self.assertNotEqual(obj, replaced_obj) self.assertEqual(correct_obj, replaced_obj)