Skip to content

Commit

Permalink
Hide PyTorch trace compilation warnings
Browse files Browse the repository at this point in the history
The test execution shows warnings about traces being potentially incorrect because the Python3 control flow is not completely recorded.
This includes conditions on the shape of the integration domain tensor.
Since the only arguments of the compiled integration function are the integrand and integration domain,
and the dimensionality of this integration domain is constant,
we can ignore the warnings.

After this change,
the two `get_jit_compiled_integrate` functions hide PyTorch trace compilation warnings with `warnings.filterwarnings`.
  • Loading branch information
FHof committed Oct 14, 2023
1 parent c6baf39 commit fb32846
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
22 changes: 12 additions & 10 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def compiled_integrate(fn, integration_domain):
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch
import warnings

# The PyTorch trace compilation warnings contain many false
# positives, so we hide all trace compiler warnings
def trace_without_warnings(*args, **kwargs):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=torch.jit.TracerWarning
)
return torch.jit.trace(*args, **kwargs)

# Define traceable first and third steps
def step1(integration_domain):
Expand All @@ -229,7 +239,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the first step
step1 = torch.jit.trace(step1, (integration_domain,))
step1 = trace_without_warnings(step1, (integration_domain,))

# Get example input for the third step
grid_points, hs, n_per_dim = step1(integration_domain)
Expand All @@ -241,15 +251,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the third step
# Avoid the warnings about a .grad attribute access of a
# non-leaf Tensor
if hs.requires_grad:
hs = hs.detach()
hs.requires_grad = True
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(
step3 = trace_without_warnings(
step3, (function_values, hs, integration_domain)
)

Expand Down
23 changes: 16 additions & 7 deletions torchquad/integration/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def compiled_integrate(fn, integration_domain):
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch
import warnings

# The PyTorch trace compilation warnings contain many false
# positives, so we hide all trace compiler warnings
def trace_without_warnings(*args, **kwargs):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=torch.jit.TracerWarning
)
return torch.jit.trace(*args, **kwargs)

# Define traceable first and third steps
def step1(integration_domain):
Expand All @@ -206,7 +216,9 @@ def step1(integration_domain):
step3 = self.calculate_result

# Trace the first step (which is non-deterministic)
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
step1 = trace_without_warnings(
step1, (integration_domain,), check_trace=False
)

# Get example input for the third step
sample_points = step1(integration_domain)
Expand All @@ -215,12 +227,9 @@ def step1(integration_domain):
)

# Trace the third step
if function_values.requires_grad:
# Avoid the warning about a .grad attribute access of a
# non-leaf Tensor
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, integration_domain))
step3 = trace_without_warnings(
step3, (function_values, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
Expand Down

0 comments on commit fb32846

Please sign in to comment.