Skip to content

Commit

Permalink
Fixing Python 3.6 deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Apr 12, 2020
1 parent b64d6d9 commit 5f718c4
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
21 changes: 19 additions & 2 deletions tap/tap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from argparse import ArgumentParser
from collections import OrderedDict
from copy import deepcopy
from itertools import cycle
import json
from pprint import pformat
import sys
Expand All @@ -23,9 +22,12 @@
boolean_type,
TupleTypeEnforcer,
PythonObjectEncoder,
as_python_object
as_python_object,
fix_py36_copy
)


# Constants
EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()

SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool}
Expand Down Expand Up @@ -486,3 +488,18 @@ def __str__(self) -> str:
:return: A formatted string representation of the dictionary of all arguments.
"""
return pformat(self.as_dict())

@fix_py36_copy
def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:
"""Deepcopy the Tap object."""
copied = type(self).__new__(type(self))

if memo is None:
memo = {}

memo[id(self)] = copied

for (k, v) in self.__dict__.items():
copied.__dict__[k] = deepcopy(v, memo)

return copied
58 changes: 46 additions & 12 deletions tap/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from argparse import ArgumentParser, ArgumentTypeError
from base64 import b64encode, b64decode
from collections import OrderedDict
import copy
from functools import wraps
import inspect
from io import StringIO
from json import JSONEncoder
import os
import pickle
import re
import subprocess
import sys
import tokenize
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
from typing_extensions import Literal
from typing_inspect import get_args


NO_CHANGES_STATUS = """nothing to commit, working tree clean"""
PRIMITIVES = (str, int, float, bool)

Expand Down Expand Up @@ -247,16 +251,16 @@ def __call__(self, arg: str) -> Any:
return arg


class Tuple:
class MockTuple:
"""Mock of a tuple needed to prevent JSON encoding tuples as lists."""
def __init__(self, _tuple: tuple) -> None:
self.tuple = _tuple


def nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any:
def _nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any:
"""Replaces any instance (including instances within lists, tuple, dict) of find_type with an instance of replace_type.
Note: Tuples, lists, and dictionaries are NOT modified in place, they are replaced.
Note: Tuples, lists, and dicts are NOT modified in place.
Note: Does NOT do a nested search through objects besides tuples, lists, and dicts (e.g. sets).
:param obj: The object to modify by replacing find_type instances with replace_type instances.
Expand All @@ -265,14 +269,16 @@ def nested_replace_type(obj: Any, find_type: type, replace_type: type) -> Any:
:return: A version of obj with all instances of find_type replaced by replace_type
"""
if isinstance(obj, tuple):
obj = tuple(nested_replace_type(item, find_type, replace_type) for item in obj)
obj = tuple(_nested_replace_type(item, find_type, replace_type) for item in obj)

if isinstance(obj, list):
obj = [nested_replace_type(item, find_type, replace_type) for item in obj]
elif isinstance(obj, list):
obj = [_nested_replace_type(item, find_type, replace_type) for item in obj]

if isinstance(obj, dict):
obj = {nested_replace_type(key, find_type, replace_type): nested_replace_type(value, find_type, replace_type)
for key, value in obj.items()}
elif isinstance(obj, dict):
obj = {
_nested_replace_type(key, find_type, replace_type): _nested_replace_type(value, find_type, replace_type)
for key, value in obj.items()
}

if isinstance(obj, find_type):
obj = replace_type(obj)
Expand All @@ -286,7 +292,7 @@ class PythonObjectEncoder(JSONEncoder):
See: https://stackoverflow.com/a/36252257
"""
def iterencode(self, o: Any, _one_shot: bool = False) -> Iterator[str]:
o = nested_replace_type(o, tuple, Tuple)
o = _nested_replace_type(o, tuple, MockTuple)
return super(PythonObjectEncoder, self).iterencode(o, _one_shot)

def default(self, obj: Any) -> Any:
Expand All @@ -295,7 +301,7 @@ def default(self, obj: Any) -> Any:
'_type': 'set',
'_value': list(obj)
}
elif isinstance(obj, Tuple):
elif isinstance(obj, MockTuple):
return {
'_type': 'tuple',
'_value': list(obj.tuple)
Expand Down Expand Up @@ -327,3 +333,31 @@ def as_python_object(dct: Any) -> Any:
raise ValueError(f'Special type "{_type}" not supported for JSON loading.')

return dct


def fix_py36_copy(func: Callable) -> Callable:
"""Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers.
Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object
"""
if sys.version_info[:2] != (3, 6):
return func

re_type = type(re.compile(''))

@wraps(func)
def wrapper(*args, **kwargs):
has_prev_val = re_type in copy._deepcopy_dispatch
prev_val = copy._deepcopy_dispatch.get(re_type, None)
copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r

result = func(*args, **kwargs)

if has_prev_val:
copy._deepcopy_dispatch[re_type] = prev_val
else:
del copy._deepcopy_dispatch[re_type]

return result

return wrapper
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
type_to_str,
get_literals,
TupleTypeEnforcer,
nested_replace_type,
_nested_replace_type,
PythonObjectEncoder,
as_python_object
)
Expand Down Expand Up @@ -320,19 +320,19 @@ def test_tuple_type_enforcer_infinite(self):
class NestedReplaceTypeTests(TestCase):
def test_nested_replace_type_notype(self):
obj = ['123', 4, 5, ('hello', 4.4)]
replaced_obj = nested_replace_type(obj, bool, int)
replaced_obj = _nested_replace_type(obj, bool, int)
self.assertEqual(obj, replaced_obj)

def test_nested_replace_type_unnested(self):
obj = ['123', 4, 5, ('hello', 4.4), True, False, 'hi there']
replaced_obj = nested_replace_type(obj, tuple, list)
replaced_obj = _nested_replace_type(obj, tuple, list)
correct_obj = ['123', 4, 5, ['hello', 4.4], True, False, 'hi there']
self.assertNotEqual(obj, replaced_obj)
self.assertEqual(correct_obj, replaced_obj)

def test_nested_replace_type_nested(self):
obj = ['123', [4, (1, 2, (3, 4))], 5, ('hello', (4,), 4.4), {'1': [2, 3, [{'2': (3, 10)}, ' hi ']]}]
replaced_obj = nested_replace_type(obj, tuple, list)
replaced_obj = _nested_replace_type(obj, tuple, list)
correct_obj = ['123', [4, [1, 2, [3, 4]]], 5, ['hello', [4], 4.4], {'1': [2, 3, [{'2': [3, 10]}, ' hi ']]}]
self.assertNotEqual(obj, replaced_obj)
self.assertEqual(correct_obj, replaced_obj)
Expand Down

0 comments on commit 5f718c4

Please sign in to comment.