Skip to content

Commit

Permalink
Fix for transpose operators
Browse files Browse the repository at this point in the history
  • Loading branch information
cgcgcg committed Dec 12, 2023
1 parent 34a0596 commit 0d31f86
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 15 deletions.
11 changes: 6 additions & 5 deletions base/PyNucleus_base/LinearOperator_{SCALAR}.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,12 @@ cdef class {SCALAR_label}Transpose_Linear_Operator({SCALAR_label}LinearOperator)
cdef INDEX_t matvec(self,
{SCALAR}_t[::1] x,
{SCALAR}_t[::1] y) except -1:
self.A.matvecTrans(x, y)
return 0
return self.A.matvecTrans(x, y)

cdef INDEX_t matvec_no_overwrite(self,
{SCALAR}_t[::1] x,
{SCALAR}_t[::1] y) except -1:
self.A.matvecTrans_no_overwrite(x, y)
return 0
return self.A.matvecTrans_no_overwrite(x, y)

def isSparse(self):
return self.A.isSparse()
Expand All @@ -677,7 +675,10 @@ cdef class {SCALAR_label}Transpose_Linear_Operator({SCALAR_label}LinearOperator)
return Bcsr

def toarray(self):
return self.A.transpose().toarray()
try:
return self.A.transpose().toarray()
except AttributeError:
return np.ascontiguousarray(self.A.toarray().T)

def get_diagonal(self):
return np.array(self.A.diagonal, copy=False)
Expand Down
92 changes: 83 additions & 9 deletions base/PyNucleus_base/linear_operators.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1213,26 +1213,57 @@ cdef class sumMultiplyOperator(LinearOperator):
cdef:
INDEX_t i
LinearOperator op
int ret
op = self.ops[0]
op.matvec(x, y)
ret = op.matvec(x, y)
scaleScalar(y, self.coeffs[0])
for i in range(1, self.coeffs.shape[0]):
op = self.ops[i]
op.matvec(x, self.z)
ret = min(ret, op.matvec(x, self.z))
assign3(y, y, 1.0, self.z, self.coeffs[i])
return 0
return ret

cdef INDEX_t matvec_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
INDEX_t i
LinearOperator op = 0
int ret = 0
for i in range(self.coeffs.shape[0]):
op = self.ops[i]
ret = min(op.matvec(x, self.z), ret)
assign3(y, y, 1.0, self.z, self.coeffs[i])
return ret

cdef INDEX_t matvecTrans(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
INDEX_t i
LinearOperator op
int ret
op = self.ops[0]
ret = op.matvecTrans(x, y)
scaleScalar(y, self.coeffs[0])
for i in range(1, self.coeffs.shape[0]):
op = self.ops[i]
ret = min(ret, op.matvecTrans(x, self.z))
assign3(y, y, 1.0, self.z, self.coeffs[i])
return ret

cdef INDEX_t matvecTrans_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
INDEX_t i
LinearOperator op
int ret = 0
for i in range(self.coeffs.shape[0]):
op = self.ops[i]
op.matvec(x, self.z)
ret = min(ret, op.matvecTrans(x, self.z))
assign3(y, y, 1.0, self.z, self.coeffs[i])
return 0
return ret

def toarray(self):
return sum([c*op.toarray() for c, op in zip(self.coeffs, self.ops)])
Expand Down Expand Up @@ -1436,8 +1467,34 @@ cdef class multiIntervalInterpolationOperator(LinearOperator):
interpolationOperator op
assert self.selected != -1
op = self.ops[self.selected]
op.matvec(x, y)
return 0
return op.matvec(x, y)

cdef INDEX_t matvec_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
interpolationOperator op
assert self.selected != -1
op = self.ops[self.selected]
return op.matvec_no_overwrite(x, y)

cdef INDEX_t matvecTrans(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
interpolationOperator op
assert self.selected != -1
op = self.ops[self.selected]
return op.matvecTrans(x, y)

cdef INDEX_t matvecTrans_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
cdef:
interpolationOperator op
assert self.selected != -1
op = self.ops[self.selected]
return op.matvecTrans_no_overwrite(x, y)

def toarray(self):
assert self.selected != -1
Expand Down Expand Up @@ -1521,8 +1578,25 @@ cdef class delayedConstructionOperator(LinearOperator):
REAL_t[::1] x,
REAL_t[::1] y) except -1:
self.assure_constructed()
self.A.matvec(x, y)
return 0
return self.A.matvec(x, y)

cdef INDEX_t matvec_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
self.assure_constructed()
return self.A.matvec_no_overwrite(x, y)

cdef INDEX_t matvecTrans(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
self.assure_constructed()
return self.A.matvecTrans(x, y)

cdef INDEX_t matvecTrans_no_overwrite(self,
REAL_t[::1] x,
REAL_t[::1] y) except -1:
self.assure_constructed()
return self.A.matvecTrans_no_overwrite(x, y)

def toarray(self):
self.assure_constructed()
Expand Down
2 changes: 1 addition & 1 deletion base/PyNucleus_base/utilsFem.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def diff(self, d):
result[p.label] = (p.value, d[p.label])
elif isinstance(p.value, (int, INDEX, REAL, float)):
if not np.allclose(p.value, d[p.label],
rtol=rTol, atol=aTol):
rtol=rTol, atol=aTol) and not (np.isnan(p.value) and np.isnan(d[p.label])):
print(p.label, p.value, d[p.label], rTol, aTol, p.rTol, p.aTol)
result[p.label] = (p.value, d[p.label])
else:
Expand Down

0 comments on commit 0d31f86

Please sign in to comment.