Skip to content

Commit

Permalink
add api reference doc (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Sep 16, 2024
1 parent 2e70b60 commit a65f043
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 60 deletions.
40 changes: 37 additions & 3 deletions depyf/decompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,41 @@ def __hash__(self):
def __eq__(self, other):
return hash(self) == hash(other)

def decompile(code: Union[CodeType, Callable]):
"""Decompile a code object or a function."""
def decompile(code: Union[CodeType, Callable]) -> str:
"""Decompile any callable or code object into Python source code.
It is especially useful for some dynamically generated code, like ``torch.compile``,
or ``dataclasses``.
Example usage:
.. code-block:: python
from dataclasses import dataclass
@dataclass
class Data:
x: int
y: float
import depyf
print(depyf.decompile(Data.__init__))
print(depyf.decompile(Data.__eq__))
Output:
.. code-block:: python
def __init__(self, x, y):
self.x = x
self.y = y
return None
def __eq__(self, other):
if other.__class__ is self.__class__:
return (self.x, self.y) == (other.x, other.y)
return NotImplemented
The output source code is semantically equivalent to the function, but not syntactically the same. It verbosely adds many details that are hidden in the Python code. For example, the above output code of ``__init__`` explicitly returns ``None``, which is typically ignored.
Another detail is that the output code of ``__eq__`` returns ``NotImplemented`` instead of raising ``NotImplemented`` exception when the types are different. At the first glance, it seems to be a bug. However, it is actually the correct behavior. The ``__eq__`` method should return ``NotImplemented`` when the types are different, so that the other object can try to compare with the current object. See `the Python documentation <https://docs.python.org/3/library/numbers.html#implementing-the-arithmetic-operations>`_ for more details.
"""
return Decompiler(code).decompile()

53 changes: 49 additions & 4 deletions depyf/explain/enable_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,53 @@ def enable_bytecode_hook(hook):
@contextlib.contextmanager
def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):
"""
Args:
dump_src_dir: the directory to dump the source code.
clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
log_bytecode: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code).
A context manager to dump debugging information for torch.compile.
It should wrap the code that actually triggers the compilation, rather than
the code that applies ``torch.compile``.
Example:
.. code-block:: python
import torch
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
# main()
# surround the code you want to run inside `with depyf.prepare_debug`
import depyf
with depyf.prepare_debug("./dump_src_dir"):
main()
After running the code, you will find the dumped information in the directory ``dump_src_dir``. The details are organized into the following:
- ``full_code_for_xxx.py`` for each function using torch.compile
- ``__transformed_code_for_xxx.py`` for Python code associated with each graph.
- ``__transformed_code_for_xxx.py.xxx_bytecode`` for Python bytecode, dumped code object, can be loaded via ``dill.load(open("/path/to/file", "wb"))``. Note that the load function might import some modules like transformers. Make sure you have these modules installed.
- ``__compiled_fn_xxx.py`` for each computation graph and its optimization:
- ``Captured Graph``: a plain forward computation graph
- ``Joint Graph``: joint forward-backward graph from AOTAutograd
- ``Forward Graph``: forward graph from AOTAutograd
- ``Backward Graph``: backward graph from AOTAutograd
- ``kernel xxx``: compiled CPU/GPU kernel wrapper from Inductor.
Arguments:
- ``dump_src_dir``: the directory to dump the source code.
- ``clean_wild_fx_code``: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.
- ``log_bytecode``: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code).
"""

if not isinstance(dump_src_dir, str):
raise RuntimeError('''You are using an obsolete usage style`depyf.prepare_debug(func=function, dump_src_dir="/path")`. Please use `depyf.prepare_debug(dump_src_dir="/path")` instead, which will automatically capture all compiled functions.''')

Expand Down Expand Up @@ -185,6 +227,9 @@ def prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False):

@contextlib.contextmanager
def debug():
"""
A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix ``full_code_for_`` in the ``dump_src_dir`` argument of :func:`depyf.prepare_debug`, and set breakpoints in their separate ``__transformed_code_`` files according to the function name. Then continue your debugging.
"""
from .global_variables import data
if data["is_inside_prepare_debug"]:
raise RuntimeError("You cannot use `depyf.debug` inside `depyf.prepare_debug`.")
Expand Down
50 changes: 50 additions & 0 deletions depyf/explain/enhance_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,52 @@ def pytorch_bytecode_src_hook(code: types.CodeType, new_code: types.CodeType):


def install():
"""
Install the bytecode hook for PyTorch, integrate into PyTorch's logging system.
Example:
.. code-block:: python
import torch
import depyf
depyf.install()
# anything with torch.compile
@torch.compile
def f(a, b):
return a + b
f(torch.tensor(1), torch.tensor(2))
Turn on bytecode log by ``export TORCH_LOGS="+bytecode"``, and execute the script.
We will see the decompiled source code in the log:
.. code-block:: text
ORIGINAL BYTECODE f test.py line 5
7 0 LOAD_FAST 0 (a)
2 LOAD_FAST 1 (b)
4 BINARY_ADD
6 RETURN_VALUE
MODIFIED BYTECODE f test.py line 5
5 0 LOAD_GLOBAL 0 (__compiled_fn_1)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
possible source code:
def f(a, b):
__temp_2, = __compiled_fn_1(a, b)
return __temp_2
If you find the decompiled code is wrong,please submit an issue at https://github.com/thuml/depyf/issues.
To uninstall the hook, use :func:`depyf.uninstall()`.
"""
import torch
global _handle
if _handle is not None:
Expand All @@ -37,6 +83,10 @@ def install():


def uninstall():
"""
Uninstall the bytecode hook for PyTorch.
Should be called after :func:`depyf.install()`.
"""
global _handle
if _handle is None:
return
Expand Down
47 changes: 0 additions & 47 deletions docs/advanced.rst

This file was deleted.

25 changes: 25 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
API Reference
=============

Understand and debug ``torch.compile``
--------------------------------------

.. warning::

It is recommended to read the :doc:`walk_through` to have a basic understanding of how ``torch.compile`` works, before using the following functions.

.. autofunction:: depyf.prepare_debug

.. autofunction:: depyf.debug

Decompile general Python Bytecode/Function
-------------------------------------------

.. autofunction:: depyf.decompile

Enhance PyTorch Logging
-----------------------

.. autofunction:: depyf.install

.. autofunction:: depyf.uninstall
9 changes: 4 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ You can also check `the advanced usages <./index.html>`_ and `frequently asked q
If you'd like to contribute (which we highly appreciate), please read the `developer documentation <./dev_doc.html>`_ section.

.. toctree::
:maxdepth: 1
:hidden:
:maxdepth: 2

api_reference
walk_through
advanced
faq
dev_doc
opt_tutorial
dev_doc
faq
5 changes: 4 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
sphinx-rtd-theme
sphinx-rtd-theme
astor
dill
torch

0 comments on commit a65f043

Please sign in to comment.