Skip to content

Commit

Permalink
Update tensorify pass to specialize symfloats we didn't tensorify away (
Browse files Browse the repository at this point in the history
pytorch#138868)

As discussed w/ @ezyang offline, one way to de-risk the `specialize_float=False` rollout is to specialize all backed symfloats that we fail to tensorify away. This diff does a few things:

1) It fixes a bug where item_memo gets dropped (due to incorrect epoch invalidation)
2) It updates the tensorify pass to do the backup specialization

This pass was originally part of the [PR](pytorch#137782) that flips `specialize_float=False` but we learned that the blast radius is simply too large. We've pivoted to a more milestone driven approach where we learn from the failures of the aforementioned PR and cherry pick fixes into main first. After this current PR lands our strategy is as follows:

1) Integrate turning off specialize float only in the automatic dynamic pass.
2) Put up a canary diff that only turns off specialize float in `backend=eager` mode to sniff out symfloat related bugs in dynamo due to code paths we previously never exercised.
3) Put up a canary diff that only turns off specialize float in `backend=aot_eager` mode to sniff out symfloat related bugs in aotautograd due to code paths we previously never exercised.

Pull Request resolved: pytorch#138868
Approved by: https://github.com/ezyang
  • Loading branch information
bobrenjc93 authored and pytorchmergebot committed Nov 1, 2024
1 parent c8a648d commit 094d288
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 50 deletions.
36 changes: 36 additions & 0 deletions test/inductor/test_torchinductor_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,42 @@ def fn(x, y):
self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0))
self.assertEqual(cnt.frame_count, 1)

@torch._dynamo.config.patch(specialize_float=False)
def test_unspecialized_float_fallback_specialization(self):
def fn(x, y, z):
return (
torch.tensor(z),
torch.exp(torch.tensor(z)) * (x * y),
x.size(0),
math.sqrt(x.size(0)),
math.floor(math.sqrt(x.size(0))),
math.floor(math.sqrt(x.numel())),
math.floor(math.sqrt(x.dim())),
math.floor(math.sqrt(z)),
)

cnt = CompileCounterWithBackend("inductor")
fn_opt = torch._dynamo.optimize(cnt)(fn)
x = torch.arange(3)
z = 1.3

self.assertEqual(fn(x, 2.0, z), fn_opt(x, 2.0, z))
self.assertEqual(fn(x, 3.0, z), fn_opt(x, 3.0, z))
self.assertEqual(cnt.frame_count, 1)

@torch._dynamo.config.patch(specialize_float=False)
def test_unspecialized_float_fallback_symint_specialization(self):
def fn(x, y):
return math.floor(x**2) * y

cnt = CompileCounterWithBackend("inductor")
fn_opt = torch._dynamo.optimize(cnt)(fn)
y = torch.arange(3)

self.assertEqual(fn(2.0, y), fn_opt(2.0, y))
self.assertEqual(fn(3.0, y), fn_opt(3.0, y))
self.assertEqual(cnt.frame_count, 2)

def test_sort_dynamic_shape_with_check(self, device):
if TEST_WITH_ROCM or torch.device(device).type != GPU_TYPE:

Expand Down
24 changes: 20 additions & 4 deletions torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils.dlpack
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import lazy_format_graph_code
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -130,9 +130,25 @@ def aot_dispatch_base_graph(
mod_when_exporting_non_strict, assigned_buffers
)

saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only(
torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared
)
# TODO: Refactor the following code so detach() persists item_memo
def detach_and_copy_item_memo(t):
detached_t = t.detach()
if hasattr(t, "item_memo"):
detached_t.item_memo = t.item_memo
return detached_t

fake_mode = detect_fake_mode()
if fake_mode:
saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only(
torch.Tensor,
detach_and_copy_item_memo,
updated_flat_args_subclasses_desugared,
)
else:
saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only(
torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared
)

fw_module = _create_graph(
fn_to_trace,
updated_flat_args_subclasses_desugared,
Expand Down
35 changes: 22 additions & 13 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
value,
source=item_source,
dynamic_dim=DimDynamic.DYNAMIC,
symbolic_context=symbolic_context,
)
# NB: reusing item_memo here ensures that we invalidate on
# mutation
Expand Down Expand Up @@ -520,18 +521,18 @@ class FakeTensorConfig:
debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1"


# This memorizes the unbacked SymInt representing quantities like the number
# of nonzero elements in this tensor. There is one instance of the descriptor
# per particular quantity to memoize.
# This memorizes unbacked SymInt or SymFloats representing quantities like the
# number of nonzero elements in this tensor or learning rate. There is one
# instance of the descriptor per particular quantity to memoize.
#
# Memoization is helpful if you do something like x[mask] and y[mask];
# mask.nonzero() gets repeatedly called and should give a consistent unbacked
# SymInt. It needs to be invalidated in the same way constant is.
# SymInt. It needs to be invalidated in the same way constant is.
#
# Making this a descriptor may seem overly fancy, but actually it's the most
# convenient way to make sure we have access to FakeTensor during access,
# which is required for testing version counter and epoch validity
class SymIntMemoDescriptor:
# convenient way to ensure access to FakeTensor during access, which is
# required for testing version counter and epoch validity.​
class SymNumberMemoDescriptor:
_name: str

# By default, SymInts in this memo are invalidated across versions/epochs.
Expand Down Expand Up @@ -562,9 +563,14 @@ def _memo_epoch(self, obj: FakeTensor) -> str:

def __get__(
self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
) -> Optional[torch.SymInt]:
) -> Optional[Union[torch.SymInt, torch.SymFloat]]:
if (r := getattr(obj, self._memo(obj))) is None:
return None

# If backed, it's ok to preserve memo since we know it won't renumber.
if r.node.hint is not None:
return r

# Version counter based tracking isn't 100% sound but it's close
# enough
if (
Expand All @@ -577,7 +583,9 @@ def __get__(
return None
return r

def __set__(self, obj: FakeTensor, value: Optional[torch.SymInt]) -> None:
def __set__(
self, obj: FakeTensor, value: Optional[Union[torch.SymInt, torch.SymFloat]]
) -> None:
if value is None:
setattr(obj, self._memo(obj), None)
setattr(obj, self._memo_vc(obj), None)
Expand Down Expand Up @@ -606,14 +614,14 @@ class FakeTensor(Tensor):
# TODO: Generalize this as needed, e.g., into a trie of memos, if
# you do something like x[0].item() (x[0] is fresh each time, so
# memo mechanism here won't work)
nonzero_memo = SymIntMemoDescriptor()
item_memo = SymIntMemoDescriptor()
unique_memo = SymIntMemoDescriptor()
nonzero_memo = SymNumberMemoDescriptor()
item_memo = SymNumberMemoDescriptor()
unique_memo = SymNumberMemoDescriptor()

# We expect nested_int_memo to be None when an offsets is a graph
# intermediate, or an input that has never been associated with a
# nested int.
nested_int_memo = SymIntMemoDescriptor(is_nested_int=True)
nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True)

# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
Expand Down Expand Up @@ -891,6 +899,7 @@ def get_nested_int(
self.nested_int_memo = self.fake_mode.create_symbolic_nested_int(
nt_tensor_id=None
)
assert isinstance(self.nested_int_memo, torch.SymInt)
return self.nested_int_memo * coeff

# Similar to FunctionalTensor.tolist
Expand Down
14 changes: 13 additions & 1 deletion torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ def rebind_unbacked(
u1,
)
continue

# We only care about rebinding unbacked things
if u1.node.hint is not None:
continue

raw_u1 = u1.node.expr
# Simplify SymBool binding
if (
Expand Down Expand Up @@ -4067,6 +4072,7 @@ def create_unspecified_symbol(
source: Source,
dynamic_dim: DimDynamic = DimDynamic.DUCK,
constraint_dim: DimConstraint = None, # NB: includes None
symbolic_context: Optional[StatelessSymbolicContext] = None,
) -> sympy.Expr:
"""
Create a symbol with an unspecified value
Expand All @@ -4086,7 +4092,7 @@ def create_unspecified_symbol(
constraint_dim,
positive=None,
do_not_specialize_zero_one=True,
symbolic_context=None,
symbolic_context=symbolic_context,
)

@record_shapeenv_event()
Expand Down Expand Up @@ -6103,6 +6109,12 @@ def _evaluate_expr(

# TODO: split conjunctions and evaluate them separately

if isinstance(
orig_expr,
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
):
return orig_expr

# Don't track this one
@functools.lru_cache(None)
def compute_concrete_val() -> sympy.Basic:
Expand Down
91 changes: 59 additions & 32 deletions torch/fx/passes/_tensorify_python_scalars.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

import logging
from typing import List, Union
from typing import Any, List, Union

from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom

import torch
import torch.fx as fx
from torch._prims_common import get_computation_dtype
from torch._subclasses import fake_tensor # noqa: TCH001
from torch._utils_internal import JustKnobsConfig
from torch.fx._utils import lazy_format_graph_code
from torch.fx.experimental.symbolic_shapes import ShapeEnv # noqa: TCH001
from torch.fx.experimental.symbolic_shapes import guard_scalar, ShapeEnv # noqa: TCH001
from torch.fx.graph_module import GraphModule # noqa: TCH001

# TODO: refactor
from torch.fx.passes.runtime_assert import _get_sym_val
from torch.fx.proxy import MetaProxy
from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
from torch.utils._sympy.reference import TensorReferenceAnalysis
from torch.utils._sympy.symbol import symbol_is_type, SymT


__all__: List[str] = []
Expand Down Expand Up @@ -107,6 +112,7 @@ def tensorify_python_scalars(
placeholders = set()
for node in graph.nodes:
if node.op != "placeholder":
first_non_placeholder = node
break
else:
placeholders.add(node)
Expand All @@ -116,10 +122,6 @@ def tensorify_python_scalars(
def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
# sympy_interp() with hash consing, and special handling for
# generating constants correctly
from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom

from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp

# hash cons
if isinstance(expr, Symbol) and expr not in expr_to_tensor_proxy:
Expand Down Expand Up @@ -176,25 +178,24 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
nodes[i + 1] if node not in placeholders else first_non_placeholder
):
# Look for tensor.item() calls on placeholders
if unbacked_bindings := node.meta.get("unbacked_bindings"):
for s in unbacked_bindings.keys():
if (
node is not None
and node.op == "call_function"
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if dtype != torch.float64:
continue

assert isinstance(node.args[0], fx.Node), node.args[0]

expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
if (
node is not None
and node.op == "call_function"
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if dtype != torch.float64:
continue

assert isinstance(node.args[0], fx.Node), node.args[0]

s = node.meta["val"].node.expr
expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)

elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance(
Expand All @@ -206,7 +207,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:

# Look for functions to convert
if node.op == "call_function" and node.target in SUPPORTED_OPS:
args = []
args: List[Any] = []
transform = False
compute_dtype = get_computation_dtype(node.meta["val"].dtype)

Expand All @@ -229,8 +230,10 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
)

args.append(proxy)
else:
elif isinstance(a, fx.Node):
args.append(MetaProxy(a, tracer=tracer, fake_mode=fake_mode))
else:
args.append(a)

if transform:
replacement_proxy = node.target(*args)
Expand All @@ -244,13 +247,37 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
)

node.replace_all_uses_with(replacement_proxy.node)

graph.erase_node(node)

# DCE symbols (which are guaranteed to be pure) only
for proxy in reversed(expr_to_sym_proxy.values()):
if len(proxy.node.users) == 0 and proxy.node.op != "placeholder":
graph.erase_node(proxy.node)
# Now do one more pass that specializes all symfloats we didn't manage
# to tensorify away.
for node in reversed(graph.nodes):
if node.op == "output" or node.op == "placeholder":
continue

with graph.inserting_before(node):
if len(node.users) == 0 and not node.is_impure():
graph.erase_node(node)
continue

if isinstance(
(val := node.meta.get("val")),
(torch.SymFloat, torch.SymInt, torch.SymBool),
):
if len(val.node.expr.free_symbols) > 0 and all(
symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols
):
# If all symbols are backed symfloats, we can just specialize the whole node
# and get more precise guards. eg.
#
# zf = a.item()
# zf2 = zf // 2
# op(.. zf2 ..)
#
# It's better to guard on zf // 2 == 2.0 than zf == 5.0

node.replace_all_uses_with(guard_scalar(val))
graph.erase_node(node)

graph_code_log.debug(
"%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
Expand Down

0 comments on commit 094d288

Please sign in to comment.