From 542477882be015c2e41a3bb82494c5fd0deca982 Mon Sep 17 00:00:00 2001 From: Dan Barzilay Date: Wed, 4 Oct 2023 00:30:59 +0300 Subject: [PATCH 1/3] grid_integrator: Allow passing custom arguments to integrand function. This is to avoid invoking a lambda wrapper function just expose an already existing functionality. --- torchquad/integration/grid_integrator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index 2243399d..1e5c74b6 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -28,7 +28,7 @@ def f(integration_domain, N, requires_grad=False, backend=None): def _weights(self, N, dim, backend, requires_grad=False): return None - def integrate(self, fn, dim, N, integration_domain, backend): + def integrate(self, fn, dim, N, integration_domain, backend, args=None): """Integrate the passed function on the passed domain using a Composite Newton Cotes rule. The argument meanings are explained in the sub-classes. @@ -47,7 +47,7 @@ def integrate(self, fn, dim, N, integration_domain, backend): logger.debug("Evaluating integrand on the grid.") function_values, num_points = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) self._nr_of_fevals = num_points From 05f4f2e133678e0280adffd39d44dc2e0c07cb14 Mon Sep 17 00:00:00 2001 From: Dan Barzilay Date: Wed, 4 Oct 2023 00:37:12 +0300 Subject: [PATCH 2/3] Make sure to pass args in all calls to evaluate_integral --- torchquad/integration/grid_integrator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index 1e5c74b6..c5d3967f 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -139,7 +139,7 @@ def _adjust_N(dim, N): return N def get_jit_compiled_integrate( - self, dim, N=None, integration_domain=None, backend=None + self, dim, N=None, integration_domain=None, backend=None, args=None ): """Create an integrate function where the performance-relevant steps except the integrand evaluation are JIT compiled. Use this method only if the integrand cannot be compiled. @@ -197,7 +197,7 @@ def get_jit_compiled_integrate( def compiled_integrate(fn, integration_domain): grid_points, hs, n_per_dim = jit_calculate_grid(N, integration_domain) function_values, _ = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) return jit_calculate_result( function_values, dim, int(n_per_dim), hs, integration_domain @@ -238,6 +238,7 @@ def step3(function_values, hs, integration_domain): example_integrand, grid_points, weights=self._weights(n_per_dim, dim, backend), + args=args, ) # Trace the third step @@ -257,7 +258,7 @@ def step3(function_values, hs, integration_domain): def compiled_integrate(fn, integration_domain): grid_points, hs, _ = step1(integration_domain) function_values, _ = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) result = step3(function_values, hs, integration_domain) return result From 5242965bb2d7f9487a24c3ed714bb521472a471a Mon Sep 17 00:00:00 2001 From: Dan Barzilay Date: Wed, 4 Oct 2023 00:40:17 +0300 Subject: [PATCH 3/3] Include new args in comment --- torchquad/integration/grid_integrator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index c5d3967f..72350371 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -151,6 +151,7 @@ def get_jit_compiled_integrate( N (int, optional): Total number of sample points to use for the integration. See the integrate method documentation for more details. integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It can also determine the numerical backend. backend (string, optional): Numerical backend. Defaults to integration_domain's backend if it is a tensor and otherwise to the backend from the latest call to set_up_backend or "torch" for backwards compatibility. + args (list or tuple, optional): Any arguments required by the function. Defaults to None. Returns: function(fn, integration_domain): JIT compiled integrate function where all parameters except the integrand and domain are fixed