-
-
Save bridgesign/f421f69ad4a3858430e5e235bccde8c6 to your computer and use it in GitHub Desktop.
| import torch | |
| import warnings | |
| class BiCGSTAB(): | |
| """ | |
| This is a pytorch implementation of BiCGSTAB or BCGSTAB, a stable version | |
| of the CGD method, published first by Van Der Vrost. | |
| For solving ``Ax = b`` system. | |
| Example: | |
| solver = BiCGSTAB(Ax_gen) | |
| solver.solve(b, x=intial_x, tol=1e-10, atol=1e-16) | |
| """ | |
| def __init__(self, Ax_gen, device='cuda'): | |
| """ | |
| Ax_gen: A function that takes a 1-D tensor x and output Ax | |
| Note: This structure is follwed as it may not be computationally | |
| efficient to compute A explicitly. | |
| """ | |
| self.Ax_gen = Ax_gen | |
| self.device = device | |
| def init_params(self, b, x=None, nsteps=None, tol=1e-10, atol=1e-16): | |
| """ | |
| b: The R.H.S of the system. 1-D tensor | |
| nsteps: Number of steps of calculation | |
| tol: Tolerance such that if ||r||^2 < tol * ||b||^2 then converged | |
| atol: Tolernace such that if ||r||^2 < atol then converged | |
| """ | |
| self.b = b.clone().detach() | |
| self.x = torch.zeros(b.shape[0], device=self.device) if x is None else x | |
| self.residual_tol = tol * torch.vdot(self.b, self.b).real | |
| self.atol = torch.tensor(atol, device=self.device) | |
| self.nsteps = b.shape[0] if nsteps is None else nsteps | |
| self.status, self.r = self.check_convergence(self.x) | |
| self.rho = torch.tensor(1, device=self.device) | |
| self.alpha = torch.tensor(1, device=self.device) | |
| self.omega = torch.tensor(1, device=self.device) | |
| self.v = torch.zeros(b.shape[0], device=self.device) | |
| self.p = torch.zeros(b.shape[0], device=self.device) | |
| self.r_hat = self.r.clone().detach() | |
| def check_convergence(self, x): | |
| r = self.b - self.Ax_gen(x) | |
| rdotr = torch.vdot(r,r).real | |
| if rdotr < self.residual_tol or rdotr < self.atol: | |
| return True, r | |
| else: | |
| return False, r | |
| def step(self): | |
| rho = torch.dot(self.r, self.r_hat) # rho_i <- <r0, r^> | |
| beta = (rho/self.rho)*(self.alpha/self.omega) # beta <- (rho_i/rho_{i-1}) x (alpha/omega_{i-1}) | |
| self.rho = rho # rho_{i-1} <- rho_i replaced self value | |
| self.p = self.r + beta*(self.p - self.omega*self.v) # p_i <- r_{i-1} + beta x (p_{i-1} - w_{i-1} v_{i-1}) replaced p self value | |
| self.v = self.Ax_gen(self.p) # v_i <- Ap_i | |
| self.alpha = self.rho/torch.dot(self.r_hat, self.v) # alpha <- rho_i/<r^, v_i> | |
| s = self.r - self.alpha*self.v # s <- r_{i-1} - alpha v_i | |
| t = self.Ax_gen(s) # t <- As | |
| self.omega = torch.dot(t, s)/torch.dot(t, t) # w_i <- <t, s>/<t, t> | |
| self.x = self.x + self.alpha*self.p + self.omega*s # x_i <- x_{i-1} + alpha p + w_i s | |
| status, res = self.check_convergence(self.x) | |
| if status: | |
| return True | |
| else: | |
| self.r = s - self.omega*t # r_i <- s - w_i t | |
| return False | |
| def solve(self, *args, **kwargs): | |
| """ | |
| Method to find the solution. | |
| Returns the final answer of x | |
| """ | |
| self.init_params(*args, **kwargs) | |
| if self.status: | |
| return self.x | |
| while self.nsteps: | |
| s = self.step() | |
| if s: | |
| return self.x | |
| if self.rho == 0: | |
| break | |
| self.nsteps-=1 | |
| warnings.warn('Convergence has failed :(') | |
| return self.x |
import torch
from BiCGSTAB import BiCGSTAB
A = torch.randn(3,3, device='cuda')
x = torch.randn(3, device='cuda')
# Starting point
x_int = x + 0.01*torch.randn(3, device='cuda')
b = torch.matmul(A, x)
Ax_gen = lambda x: torch.matmul(A, x)
solver = BiCGSTAB(Ax_gen)
print("Original Solution:",x)
print("BiCGSTAB Solution:", solver.solve(b, nsteps=11, x=x_int, tol=1e-3))
Depending on requirement, it might not be possible to calculate A explicitly, but possible to calculate the matrix-vector product with A. Hence, Ax_gen is used as a function that return the matrix-vector product for a given x.
@bridgesign
Thank you for your answer. When I run the example you gave, I get an error as follows. Any ideas what am I doing wrong?
line 94, in
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument vec in method wrapper__mv)
@Nailemre Updated the comment. BiCGSTAB faces convergence issues and hence requires good initialization to work
Thanks for sharing this. As per the description, step 1 here, it looks like
r = self.Ax_gen(x) - self.b(see https://gist.github.com/bridgesign/f421f69ad4a3858430e5e235bccde8c6#file-bicgstab-py-L50)
should be replaced by
r = self.b - self.Ax_gen(x)@tvercaut Thanks! Updated the gist.
Thanks for your reply. I try to solve a Ax=b systems. A shape is (4, 8). b .shape is (4,1). I started with x with shape of (8,1) but it crashes in solver part.
here: self.p = self.r + beta * ( self.p - self.omega * self.v).
beacuse size of x is 8 but size of p is 4.
error is:
Ax_gen = lambda x: torch.matmul(A, x)
RuntimeError: size mismatch, got 4, 4x8,4
If helpful, on my side, I eventualy decided to go for a different implementation as shown here:
https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/utils/bicgstab.py
How can ı run this code ? Can you give an example ?