Skip to content

Commit

Permalink
Disabling amp context when invoking compiler (pytorch#138659)
Browse files Browse the repository at this point in the history
Disabling amp context when invoking compiler (pytorch#138624)

Fix for pytorch#133974

Pull Request resolved: pytorch#138624
Approved by: https://github.com/bdhirsh, https://github.com/drisspg

(cherry picked from commit 5942b29)

Co-authored-by: eellison <[email protected]>
  • Loading branch information
pytorchbot and eellison authored Oct 23, 2024
1 parent f31b8bb commit a8d6afb
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
41 changes: 41 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3941,6 +3941,47 @@ def forward(self, x):
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))

@parametrize("is_inference", (True, False))
def test_disabled_amp(self, is_inference):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.all_head_size = 12 * 64
self.dense = nn.Linear(self.all_head_size, self.all_head_size)

def forward(self, q, k, v):
context_layer = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.2
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size,
)
context_layer = context_layer.view(new_context_layer_shape)
return self.dense(context_layer)

mod = M().to(torch.bfloat16).eval()

q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
inputs = (
q,
k,
v,
)
compiler_mode = torch.compile(mod)
from torch.nn.attention import sdpa_kernel, SDPBackend

context = contextlib.nullcontext if not is_inference else torch.no_grad
with config.patch(
{"fallback_random": True}
), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH):
torch.manual_seed(0)
eager = mod(*inputs)
torch.manual_seed(0)
self.assertEqual(compiler_mode(*inputs), eager)

@requires_vectorization
def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502
Expand Down
46 changes: 22 additions & 24 deletions torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,9 @@ def aot_dispatch_autograd(
),
)

with track_graph_compiling(aot_config, "forward"):
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
# in the compiler.
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
adjusted_flat_args = joint_inputs[0]
Expand Down Expand Up @@ -620,7 +622,7 @@ def aot_dispatch_autograd(
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
with track_graph_compiling(aot_config, "backward"):
with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
placeholder_list = fx_placeholder_vals(bw_module)

forward_saved_for_backwards_strides = None
Expand Down Expand Up @@ -672,28 +674,24 @@ def aot_dispatch_autograd(

compiled_bw_func = None
if num_symints_saved_for_bw > 0:
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
try:
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
)
except Exception as e:
exc = e
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "eager_compile_backwards_failure",
"encoding": "string",
},
payload_fn=lambda: "\n".join(
traceback.format_exception(exc)
),
)
log.warning(
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True,
)
try:
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
)
except Exception as e:
exc = e
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "eager_compile_backwards_failure",
"encoding": "string",
},
payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
)
log.warning(
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True,
)
# Compiled autograd will run the bw_module in the backward pass,
# so recompilation need happen anyway if the backward pass is ever
# called.
Expand Down

0 comments on commit a8d6afb

Please sign in to comment.