Skip to content

Commit

Permalink
add more tests (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 28, 2024
1 parent 56d1df9 commit 5965ab9
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 80 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[run]
source = depyf
# omit patched file from pytorch
omit = depyf/explain/patched*

[report]
include = depyf/*
2 changes: 2 additions & 0 deletions .github/workflows/test_decompile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ jobs:
run: |
pytest --cov=depyf tests/test.py
coverage run --append python_coverage.py
coverage run --append tests/test_code_owner.py
coverage run --append tests/test_ensure.py
python tests/assert.py
- name: Upload results to Codecov
Expand Down
11 changes: 0 additions & 11 deletions depyf/code_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,6 @@ def visit_FunctionDef(self, node):
# return self.generic_visit(node)


def structure_hash(source_code: str) -> str:
"""Compute the hash of code structure, ignore the function name difference.
This is because PyTorch dynamically generates function names.
"""
tree = ast.parse(source_code)
tree = IdentifierReplacer().visit(tree)
modified_code = astor.to_source(tree)
hash_value = hashlib.md5(modified_code.encode()).hexdigest()
return hash_value


def fix_irregular_code(
old_bytecode: CodeType,
src_code: str,
Expand Down
3 changes: 2 additions & 1 deletion depyf/decompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,8 @@ def cleanup_instructions(code, instructions: List[Instruction]):

def __init__(self, code: Union[CodeType, Callable]):
if callable(code):
code = code.__code__
from depyf.utils import get_code_owner
code = get_code_owner(code).__code__
self.code = code
instructions = list(convert_instruction(_)
for _ in dis.get_instructions(code))
Expand Down
10 changes: 0 additions & 10 deletions depyf/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ def _extract_artifacts(original_code: CodeType, module):
result = DynamoOptimizationResult(original_code, None, module)
return result


def _collect_compiled_subgraphs(result: DynamoOptimizationResult):
compiled_subgraphs = {
entry.compiled_subgraph_proxy.name: entry.compiled_subgraph for entry in result.compiled_code_entries}
for entry in result.compiled_code_entries:
for func in entry.referenced_global_functions.values():
ans = _collect_compiled_subgraphs(func)
compiled_subgraphs.update(ans)
return compiled_subgraphs

def dump_src(original_code: CodeType, module):
from depyf.explain.global_variables import data
assert data["is_inside_prepare_debug"], "`dump_src` must be used inside `depyf.prepare_debug`."
Expand Down
14 changes: 5 additions & 9 deletions depyf/explain/enable_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,14 @@ def __call__(self, code, new_code):
import dill
# code object, especially `new_code` constructed by Dynamo, may not be able to be dumped using `marshal`.
# see https://github.com/pytorch/pytorch/issues/116013 for more details.
try:
with contextlib.suppress(Exception):
dill.dump(code, open(filename + ".original_bytecode", "wb"))
except:
pass
try:

with contextlib.suppress(Exception):
dill.dump(new_code, open(filename + ".transformed_bytecode", "wb"))
except:
pass
try:

with contextlib.suppress(Exception):
dill.dump(decompiled_and_compiled_back_code, open(filename + ".decompiled_and_compiled_back_bytecode", "wb"))
except:
pass

# this fix is used for PyTorch prior to PR https://github.com/pytorch/pytorch/pull/114487
from torch._dynamo.utils import orig_code_map
Expand Down
2 changes: 1 addition & 1 deletion depyf/explain/patched___call__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def patched___call__(self, code, check_fn):
from depyf.explain.global_variables import data
from depyf.explain.utils import get_code_owner
from depyf.utils import get_code_owner
import torch
unpatched___call__ = data["unpatched___call__"]
optimized_functions = data["optimized_functions"]
Expand Down
3 changes: 2 additions & 1 deletion depyf/explain/patched_lazy_format_graph_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
from depyf.explain.utils import get_current_compiled_fn_name, get_code_owner, write_code_to_file_template
from depyf.explain.utils import get_current_compiled_fn_name, write_code_to_file_template
from depyf.utils import get_code_owner
func_name = get_current_compiled_fn_name()
file_name = name if name != func_name else "Captured Graph"
file_name = func_name + " " + file_name
Expand Down
47 changes: 1 addition & 46 deletions depyf/explain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,6 @@
from dataclasses import dataclass
import contextlib

import depyf
from depyf.decompiler import DecompilationError
from depyf.utils import get_function_signature


def decompile_ensure(fn, overwite_fn_name=None):
try:
decompiled_source_code = depyf.Decompiler(
fn).decompile(overwite_fn_name=overwite_fn_name)
except DecompilationError as e:
header = get_function_signature(fn, overwite_fn_name=overwite_fn_name)
decompiled_source_code = header + " 'Failed to decompile.'\n"
return decompiled_source_code


class CodeProxy:
instances: Dict[str, "CodeProxy"] = {}
used_instances: Set[str] = set()
Expand All @@ -49,6 +34,7 @@ def consume_new_name(name: str):

@staticmethod
def decompile_with_name(code: CodeType, name: str, skip_decompile=False):
from depyf.utils import decompile_ensure
if hasattr(code, "__code__"):
code = code.__code__
if code.co_name.startswith("transformed_code_") or code.co_name.startswith("__transformed_code_"):
Expand Down Expand Up @@ -320,37 +306,6 @@ def write_code_to_file_template(src, path_template):
return new_filepath


def get_code_owner(fn):
"""A callable object `fn` might have a __code__ attribute, which is a code object.
However, `fn` might not be the owner of the code object. Only the code owner can change the code object.
This function returns the owner of the code object.
An example:
class A:
def func(self):
return 1
a = A()
`a.func.__code__` is read-only. `A.func.__code__` is writable.
We can change the code object via `a.func.__func__.__code__`.
"""
import functools
while True:
if hasattr(fn, "__func__"):
# deal with bounded function
fn = fn.__func__
elif hasattr(fn, "__wrapped__"):
# deal with lru_cache or other decorators
fn = fn.__wrapped__
elif isinstance(fn, functools.partial):
# deal with partial function
fn = fn.func
elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"):
# deal with callable object
fn = fn.__call__.__func__
else:
break
return fn


def get_current_compiled_fn_name():
import torch
from torch._dynamo.bytecode_transformation import _unique_id_counter
Expand Down
45 changes: 45 additions & 0 deletions depyf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,48 @@ def safe_create_directory(path):
except OSError as e:
if not os.path.isdir(path):
raise



def get_code_owner(fn):
"""A callable object `fn` might have a __code__ attribute, which is a code object.
However, `fn` might not be the owner of the code object. Only the code owner can change the code object.
This function returns the owner of the code object.
An example:
class A:
def func(self):
return 1
a = A()
`a.func.__code__` is read-only. `A.func.__code__` is writable.
We can change the code object via `a.func.__func__.__code__`.
"""
import functools
while True:
if hasattr(fn, "__func__"):
# deal with bounded function
fn = fn.__func__
elif hasattr(fn, "__wrapped__"):
# deal with lru_cache or other decorators
fn = fn.__wrapped__
elif isinstance(fn, functools.partial):
# deal with partial function
fn = fn.func
elif hasattr(fn, "__call__") and hasattr(fn.__call__, "__func__"):
# deal with callable object
fn = fn.__call__.__func__
else:
break
return fn



def decompile_ensure(fn: CodeType, overwite_fn_name=None):
import depyf
from depyf.decompiler import DecompilationError
try:
decompiled_source_code = depyf.Decompiler(
fn).decompile(overwite_fn_name=overwite_fn_name)
except DecompilationError as e:
header = get_function_signature(fn, overwite_fn_name=overwite_fn_name)
decompiled_source_code = header + " 'Failed to decompile.'\n"
return decompiled_source_code
16 changes: 16 additions & 0 deletions tests/test_code_owner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from functools import partial, lru_cache

def f(a, b):
return a + b

class A:
def __call__(self, a, b):
return a + b

import depyf

print(depyf.decompile(partial(f, 1)))

print(depyf.decompile(lru_cache(None)(f)))

print(depyf.decompile(A()))
11 changes: 11 additions & 0 deletions tests/test_ensure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from depyf.utils import decompile_ensure

import asyncio

def f(a, b):
try:
return a + b
finally:
return a - b

print(decompile_ensure(f.__code__))
2 changes: 1 addition & 1 deletion tests/test_pytorch/test_simple_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ def fn():
return x.grad

import depyf
with depyf.prepare_debug("./simple_output"):
with depyf.prepare_debug("./simple_output", log_bytecode=True, clean_wild_fx_code=False):
fn()

0 comments on commit 5965ab9

Please sign in to comment.