Skip to content

Commit

Permalink
slightly reduce the tile artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
procyontao committed Jun 21, 2024
1 parent 531ae7e commit 8789c95
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 52 deletions.
17 changes: 10 additions & 7 deletions models/unet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def predict_one(args,one_tomo,output_file=None):
with mrcfile.open(one_tomo,permissive=True) as mrcData:
real_data = mrcData.data.astype(np.float32)*-1
voxelsize = mrcData.voxel_size
real_data = normalize(real_data,percentile=args.normalize_percentile)
data=np.expand_dims(real_data,axis=-1)
reform_ins = reform3D(data)
data = reform_ins.pad_and_crop_new(args.cube_size,args.crop_size)
data = normalize(real_data,percentile=args.normalize_percentile)
#data=np.expand_dims(real_data,axis=-1)
#reform_ins = reform3D(data)
#data = reform_ins.pad_and_crop_new(args.cube_size,args.crop_size)
#print(data.shape)
reform_ins = reform3D(data,args.cube_size,args.crop_size,9)
data = reform_ins.pad_and_crop()
#to_predict_data_shape:(n,cropsize,cropsize,cropsize,1)
#imposing wedge to every cubes
#data=wedge_imposing(data)
Expand All @@ -105,11 +108,11 @@ def predict_one(args,one_tomo,output_file=None):
in_data = data[i*N:(i+1)*N]
# in_data_gen = get_gen_single(in_data,args.batch_size,shuffle=False)
# in_data_tf = tf.data.Dataset.from_generator(in_data_gen,output_types=tf.float32)
outData[i*N:(i+1)*N] = model.predict(in_data,verbose=0)
outData[i*N:(i+1)*N] = model.predict(in_data,verbose=0).squeeze()
outData = outData[0:num_patches]

outData=reform_ins.restore_from_cubes_new(outData.reshape(outData.shape[0:-1]), args.cube_size, args.crop_size)

#outData=reform_ins.restore_from_cubes_new(outData.reshape(outData.shape[0:-1]), args.cube_size, args.crop_size)
outData=reform_ins.restore(outData)
outData = normalize(outData,percentile=args.normalize_percentile)
with mrcfile.new(output_file, overwrite=True) as output_mrc:
output_mrc.set_data(-outData)
Expand Down
2 changes: 1 addition & 1 deletion util/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def wiener1d(angpix, voltage, cs, defocus, snrfalloff, deconvstrength, highpassn
wiener = ctf/(ctf*ctf+1/snr)
return ctf, wiener

def tom_deconv_tomo(vol_file, out_file,angpix, voltage, cs, defocus, snrfalloff, deconvstrength, highpassnyquist, phaseflipped, phaseshift, ncpu=8):
def tom_deconv_tomo(vol_file, out_file, angpix, voltage, cs, defocus, snrfalloff, deconvstrength, highpassnyquist, phaseflipped, phaseshift, ncpu=8):
with mrcfile.open(vol_file, permissive=True) as f:
header_in = f.header
vol = f.data
Expand Down
143 changes: 99 additions & 44 deletions util/toTile.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,130 @@
import numpy as np

class reform3D:
def __init__(self,data3D):
self._sp = data3D.shape
def __init__(self,data3D, cubesize, cropsize, edge_depth):
self._sp = np.array(data3D.shape)
self._orig_data = data3D
self.cubesize = cubesize
self.cropsize = cropsize
self.edge_depth = edge_depth
self._sidelen = np.ceil((self._sp + edge_depth * 2)/self.cubesize).astype(int)
#self._sidelen = np.ceil((1.*self._sp)/self.cubesize).astype(int)

def pad_and_crop_new(self, cubesize=32, cropsize = 64):
self._cropsize = cropsize
sp = np.array(self._sp)
self._sidelen = sp//cubesize+1
padi = int((cropsize - cubesize)/2)
padsize = (self._sidelen*cubesize + padi - sp).astype(int)
data = np.pad(self._orig_data,((padi,padsize[0]),(padi,padsize[1]),(padi,padsize[2]),(0,0)),'symmetric')
def pad_and_crop(self):

#----------------------------|---------------------------
#| |
#| ---------------|------|--------edge---------------
#| | ------------|------|--------image_edge----------
#| | | | |
#| | | | |
#| | | | |
#| | | | |
#| -----cube------- |
#| | | |
#| | | |
#|----------crop--------------
#| | |
#| | |
#| | |
#| | |
pad_left = int((self.cropsize - self.cubesize)/2 + self.edge_depth)

# pad_right + pad_left + shape = sidelen * cube_zize + (crop_size-cube_size)
# pad_right + pad_left + shape >= (self._sp + edge_depth * 2) + (crop_size-cube_size)
pad_right = (self._sidelen * self.cubesize + (self.cropsize-self.cubesize) - pad_left - self._sp).astype(int)

data = np.pad(self._orig_data,((pad_left,pad_right[0]),(pad_left,pad_right[1]),(pad_left,pad_right[2])),'symmetric')
outdata=[]

for i in range(self._sidelen[0]):
for j in range(self._sidelen[1]):
for k in range(self._sidelen[2]):
cube = data[i*cubesize:i*cubesize+cropsize,
j*cubesize:j*cubesize+cropsize,
k*cubesize:k*cubesize+cropsize]
cube = data[i*self.cubesize:i*self.cubesize+self.cropsize,
j*self.cubesize:j*self.cubesize+self.cropsize,
k*self.cubesize:k*self.cubesize+self.cropsize]
outdata.append(cube)
outdata=np.array(outdata)
return outdata

def mask(self, x_len, y_len, z_len):
# need to consider should partisioned to len+1 so that left and right can add to one
p = 2*self.edge_depth#(self.cropsize - self.cubesize)
assert x_len > 2*p
assert y_len > 2*p
assert z_len > 2*p

array_x = np.minimum(np.arange(x_len+1), p) / p
array_x = array_x * np.flip(array_x)
array_x = array_x[np.newaxis,np.newaxis,:]

array_y = np.minimum(np.arange(y_len+1), p) / p
array_y = array_y * np.flip(array_y)
array_y = array_y[np.newaxis,:,np.newaxis]

array_z = np.minimum(np.arange(z_len+1), p) / p
array_z = array_z * np.flip(array_z)
array_z = array_z[:,np.newaxis,np.newaxis]

def pad_and_crop(self,cropsize=(64,64,64)):
self._cropsize = cropsize
sp = np.array(self._sp)
padsize = (sp//64+1)*64-sp
data = np.pad(self._orig_data,((0,padsize[0]),(0,padsize[1]),(0,padsize[2]),(0,0)),'edge')
self._sidelen = (padsize+sp)//64
out = array_x * array_y * array_z
return out[:x_len,:y_len,:z_len]

outdata=[]
for i in range(self._sidelen[0]):
for j in range(self._sidelen[1]):
for k in range(self._sidelen[2]):
cube = data[i*cropsize[0]:(i+1)*cropsize[0],j*cropsize[0]:(j+1)*cropsize[0],k*cropsize[0]:(k+1)*cropsize[0]]
outdata.append(cube)
outdata=np.array(outdata)
return outdata

def restore_from_cubes(self,cubes):
if len(cubes.shape)==5 and cubes.shape[-1]==1:
cubes = cubes.reshape(cubes.shape[0:-1])
new = np.zeros((self._sidelen[0]*64,self._sidelen[1]*64,self._sidelen[2]*64))
def restore(self,cubes):

start = (self.cropsize-self.cubesize)//2-self.edge_depth
end = (self.cropsize-self.cubesize)//2+self.cubesize+self.edge_depth
cubes = cubes[:,start:end,start:end,start:end]

restored = np.zeros((self._sidelen[0]*self.cubesize+self.edge_depth*2,
self._sidelen[1]*self.cubesize+self.edge_depth*2,
self._sidelen[2]*self.cubesize+self.edge_depth*2))
print("size restored", restored.shape)
mask_cube = self.mask(self.cubesize+self.edge_depth*2,self.cubesize+self.edge_depth*2,self.cubesize+self.edge_depth*2)
for i in range(self._sidelen[0]):
for j in range(self._sidelen[1]):
for k in range(self._sidelen[2]):
new[i*self._cropsize[0]:(i+1)*self._cropsize[0],j*self._cropsize[0]:(j+1)*self._cropsize[0],k*self._cropsize[0]:(k+1)*self._cropsize[0]] \
= cubes[i*self._sidelen[1]*self._sidelen[2]+j*self._sidelen[1]+k]
return new[0:self._sp[0],0:self._sp[1],0:self._sp[2]]
restored[i*self.cubesize:(i+1)*self.cubesize+self.edge_depth*2,
j*self.cubesize:(j+1)*self.cubesize+self.edge_depth*2,
k*self.cubesize:(k+1)*self.cubesize+self.edge_depth*2] \
+= cubes[i*self._sidelen[1]*self._sidelen[2]+j*self._sidelen[2]+k]\
*mask_cube


p =self.edge_depth*2 #int((self.cropsize-self.cubesize)/2+self.edge_depth)
restored = restored[p:p+self._sp[0],p:p+self._sp[1],p:p+self._sp[2]]
return restored

def restore_from_cubes_new(self,cubes, cubesize = 32, cropsize = 64):
if len(cubes.shape)==5 and cubes.shape[-1]==1:
cubes = cubes.reshape(cubes.shape[0:-1])
def mask_old(self):
from functools import reduce
c = self.cropsize
p = (self.cropsize - self.cubesize)
mask = np.ones((c, c, c))
f = lambda x: min(x, p)/p
for i in range(c):
for j in range(c):
for k in range(c):
d = [i, c-i, j, c-j, k, c-k]
d = map(f,d)
d = reduce(lambda a,b: a*b, d)
mask[i,j,k] = d
return mask
def restore_from_cubes(self,cubes):

new = np.zeros((self._sidelen[0]*cubesize,self._sidelen[1]*cubesize,self._sidelen[2]*cubesize))
start=int((cropsize-cubesize)/2)
end=int((cropsize+cubesize)/2)
new = np.zeros((self._sidelen[0]*self.cubesize,
self._sidelen[1]*self.cubesize,
self._sidelen[2]*self.cubesize))
start=int((self.cropsize-self.cubesize)/2)
end=int((self.cropsize+self.cubesize)/2)

for i in range(self._sidelen[0]):
for j in range(self._sidelen[1]):
for k in range(self._sidelen[2]):
new[i*cubesize:(i+1)*cubesize,j*cubesize:(j+1)*cubesize,k*cubesize:(k+1)*cubesize] \
= cubes[i*self._sidelen[1]*self._sidelen[2]+j*self._sidelen[2]+k][start:end,start:end,start:end]
new[i*self.cubesize:(i+1)*self.cubesize,
j*self.cubesize:(j+1)*self.cubesize,
k*self.cubesize:(k+1)*self.cubesize] \
= cubes[i*self._sidelen[1]*self._sidelen[2]+j*self._sidelen[2]+k][start:end,start:end,start:end]
return new[0:self._sp[0],0:self._sp[1],0:self._sp[2]]


def pad4times(self,time=4):
sp = np.array(self._orig_data.shape)
sp = np.expand_dims(sp,axis=0)
Expand All @@ -86,4 +141,4 @@ def cropback(self,padded):
return padded[:orig_sp[0][0]][:orig_sp[0][1]][:orig_sp[0][2]]

if __name__ == '__main__':
pass
pass

0 comments on commit 8789c95

Please sign in to comment.