Replies: 4 comments 4 replies
-
Hi @tvercaut. We haven't written any convenience API for the sparse solvers so far; we mostly use then under the hood for our optimizers. I'll try to find time to write a short example in the next few days. You don't really need any of the linearization stuff, but you need to construct the right structures to pass to Also, the comment about reset is there because our optimizers call |
Beta Was this translation helpful? Give feedback.
-
One quick comment, just in case you haven't seen this yet. Our solvers work such that given The underlying function that solves the linear system can be imported with |
Beta Was this translation helpful? Give feedback.
-
I've also asked @maurimo, who developed all the sparse solvers and also the author of BaSpaCho, to offer some guidance. But, in general, If your goal is to solve a linear system w/o it being part of and optimization problem, you don't really need to go through the Ideally, a lot of this would be wrapped up in a higher-level API, but we are a bandwidth limited at the moment. We actively welcome community contribution, so we would be more than happy to guide you if you want to take a stab a this. import numpy as np
import torch
from scipy.sparse import csr_matrix
from theseus.extlib.baspacho_solver import SymbolicDecomposition
from theseus.optimizer.autograd import BaspachoSolveFunction
from theseus.optimizer.linear_system import SparseStructure
from theseus.utils import random_sparse_binary_matrix
device = "cuda"
rng = torch.Generator(device=device)
num_rows = 100
num_cols = 100
fill = 0.1
batch_size = 16
A_skel = random_sparse_binary_matrix(
num_rows, num_cols, fill, min_entries_per_col=1, rng=rng
)
A_val = torch.rand(
(batch_size, A_skel.nnz), dtype=torch.double, device=device, generator=rng
)
b = torch.randn(
(batch_size, num_rows), dtype=torch.double, device=device, generator=rng
)
structure = SparseStructure(
A_skel.indices,
A_skel.indptr,
num_rows,
num_cols,
dtype=np.float64,
)
# convert to tensors for accelerated Mt x M operation
A_row_ptr = torch.tensor(structure.row_ptr, dtype=torch.int64).to(device)
A_col_ind = torch.tensor(structure.col_ind, dtype=torch.int64).to(device)
var_dims = [10, 20, 10, 20, 10, 20, 10]
var_start_cols = np.cumsum([0, *var_dims[:-1]])
# compute block-structure of AtA.
At_mock = structure.mock_csc_transpose()
num_vars = len(var_start_cols)
to_blocks = csr_matrix(
(
np.ones(num_cols),
np.arange(num_cols),
[*var_start_cols, num_cols],
),
(num_vars, num_cols),
)
block_At_mock = to_blocks @ At_mock
block_AtA_mock = (block_At_mock @ block_At_mock.T).tocsr()
block_AtA_mock.sort_indices()
param_size = torch.tensor(var_dims, dtype=torch.int64)
block_struct_ptrs = torch.tensor(block_AtA_mock.indptr, dtype=torch.int64)
block_struct_inds = torch.tensor(block_AtA_mock.indices, dtype=torch.int64)
symbolic_decomposition = SymbolicDecomposition(
param_size, block_struct_ptrs, block_struct_inds, device
)
alpha = torch.rand(batch_size, device=device, dtype=torch.double, generator=rng)
beta = torch.rand(batch_size, device=device, dtype=torch.double, generator=rng)
x = BaspachoSolveFunction.apply(
A_val,
b,
structure,
A_row_ptr,
A_col_ind,
symbolic_decomposition,
(alpha, beta),
)
print(x) |
Beta Was this translation helpful? Give feedback.
-
@tvercaut Here is an example for (the non-differentiable) import numpy as np
import scipy
import torch
from theseus.extlib.cusolver_lu_solver import CusolverLUSolver
from theseus.utils import Timer
matrix = "raefsky4" # "cfd2"
A_np_coo = scipy.io.mmread(f"sparse/{matrix}/{matrix}.mtx")
A_np_csr = scipy.sparse.csr_matrix(A_np_coo)
b_np = np.random.randn(A_np_coo.shape[1])
batch_size = 8
timer_cpu = Timer("cpu")
with timer_cpu:
for _ in range(batch_size):
scipy.sparse.linalg.spsolve(A_np_csr, b_np)
print(timer_cpu.elapsed_time)
A_row_ptr = torch.tensor(A_np_csr.indptr).cuda()
A_col_ind = torch.tensor(A_np_csr.indices).cuda()
A_val = torch.tensor(A_np_csr.data).cuda().repeat(batch_size, 1)
A_num_rows = A_row_ptr.size(0) - 1
A_num_cols = A_num_rows
b = torch.tensor(b_np).cuda().repeat(batch_size, 1)
timer_gpu = Timer("cuda")
x = b.clone()
with timer_gpu:
slv = CusolverLUSolver(batch_size, A_num_cols, A_row_ptr, A_col_ind)
slv.factor(A_val)
slv.solve(x)
print(timer_gpu.elapsed_time)
b_sol = A_np_csr @ x[0].cpu().numpy()
print(np.linalg.norm(b_np - b_sol)) |
Beta Was this translation helpful? Give feedback.
-
PyTorch provides some out-of-the-box support for sparse matrices:
https://pytorch.org/docs/stable/sparse.html
However, as discused for example in pytorch/pytorch#69538, there is limited support for linear algebra operations with it.
I read in the Theseus README that Theseus provides sparse linear solvers (CHOLMOD, LU, BaSpaCho) with GPU support. However, their usage is not clear to me. I couldn't find a simple tutorial for it but maybe I missed it.
Suppose I have a sparse CSR PyTorch tensor
A
(which I know is SPD) and a dense PyTorch vectorb
both on the GPU. Is there a simple way of using say BaSpaCho to solve forAx=b
?Looking at test_baspacho_sparse_backward.py, I tried things along the way of the below but it didn't work directly and it feels more complicated than it ought to be.
Are there any convenience wrappers that I missed and would allow to switch easily between sparse solvers?
Beta Was this translation helpful? Give feedback.
All reactions