From a8d6afb511a69687bbb2b7e88a3cf67917e1697e Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 22 Oct 2024 18:14:52 -0700 Subject: [PATCH] Disabling amp context when invoking compiler (#138659) Disabling amp context when invoking compiler (#138624) Fix for https://github.com/pytorch/pytorch/issues/133974 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138624 Approved by: https://github.com/bdhirsh, https://github.com/drisspg (cherry picked from commit 5942b2985000e0c69ec955b6c88dee8b5d7e67fd) Co-authored-by: eellison --- test/inductor/test_cpu_repro.py | 41 +++++++++++++++++ .../jit_compile_runtime_wrappers.py | 46 +++++++++---------- 2 files changed, 63 insertions(+), 24 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index f90b9da1f11a34..34d908d7eba05e 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3941,6 +3941,47 @@ def forward(self, x): x = torch.randn(1, 4, 2, 2) self.common(fn, (x,)) + @parametrize("is_inference", (True, False)) + def test_disabled_amp(self, is_inference): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.all_head_size = 12 * 64 + self.dense = nn.Linear(self.all_head_size, self.all_head_size) + + def forward(self, q, k, v): + context_layer = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.2 + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, + ) + context_layer = context_layer.view(new_context_layer_shape) + return self.dense(context_layer) + + mod = M().to(torch.bfloat16).eval() + + q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + inputs = ( + q, + k, + v, + ) + compiler_mode = torch.compile(mod) + from torch.nn.attention import sdpa_kernel, SDPBackend + + context = contextlib.nullcontext if not is_inference else torch.no_grad + with config.patch( + {"fallback_random": True} + ), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH): + torch.manual_seed(0) + eager = mod(*inputs) + torch.manual_seed(0) + self.assertEqual(compiler_mode(*inputs), eager) + @requires_vectorization def test_vec_indirect_load_cse_cache(self): # https://github.com/pytorch/pytorch/issues/123502 diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 5dc236f314b079..b86fbad6a288e6 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -555,7 +555,9 @@ def aot_dispatch_autograd( ), ) - with track_graph_compiling(aot_config, "forward"): + # AMP is already traced out in joint graph. we do not wish to reapply it accidentally + # in the compiler. + with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): # flat_args at this point might still be subclasses- # make sure to pass the unwrapped fake tensors into the compiler! adjusted_flat_args = joint_inputs[0] @@ -620,7 +622,7 @@ def aot_dispatch_autograd( # NB: It's important to compile backwards ahead of time, as this may # add extra guards which we need to apply to the Dynamo cache at # forwards - with track_graph_compiling(aot_config, "backward"): + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): placeholder_list = fx_placeholder_vals(bw_module) forward_saved_for_backwards_strides = None @@ -672,28 +674,24 @@ def aot_dispatch_autograd( compiled_bw_func = None if num_symints_saved_for_bw > 0: - context = torch._C._DisableAutocast if disable_amp else nullcontext - with context(): - try: - compiled_bw_func = aot_config.bw_compiler( - bw_module, placeholder_list - ) - except Exception as e: - exc = e - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "eager_compile_backwards_failure", - "encoding": "string", - }, - payload_fn=lambda: "\n".join( - traceback.format_exception(exc) - ), - ) - log.warning( - "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", - exc_info=True, - ) + try: + compiled_bw_func = aot_config.bw_compiler( + bw_module, placeholder_list + ) + except Exception as e: + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(traceback.format_exception(exc)), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) # Compiled autograd will run the bw_module in the backward pass, # so recompilation need happen anyway if the backward pass is ever # called.