Skip to content

Commit 8ee6efe

Browse files
authored
Merge pull request #785 from I2PC/asr_deepCenterAssign
Deep Center new programs
2 parents 1ff15df + 05d26f7 commit 8ee6efe

File tree

4 files changed

+985
-0
lines changed

4 files changed

+985
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python3
2+
3+
import math
4+
import numpy as np
5+
from operator import itemgetter
6+
import os
7+
import sys
8+
import xmippLib
9+
from time import time
10+
from scipy.ndimage import shift
11+
12+
if __name__ == "__main__":
13+
14+
from xmippPyModules.deepLearningToolkitUtils.utils import checkIf_tf_keras_installed
15+
16+
checkIf_tf_keras_installed()
17+
fnXmdExp = sys.argv[1]
18+
fnModel = sys.argv[2]
19+
sigma = float(sys.argv[3])
20+
numEpochs = int(sys.argv[4])
21+
batch_size = int(sys.argv[5])
22+
gpuId = sys.argv[6]
23+
numModels = int(sys.argv[7])
24+
learning_rate = float(sys.argv[8])
25+
patience = int(sys.argv[9])
26+
pretrained = sys.argv[10]
27+
if pretrained == 'yes':
28+
fnPreModel = sys.argv[11]
29+
30+
if not gpuId.startswith('-1'):
31+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
32+
os.environ["CUDA_VISIBLE_DEVICES"] = gpuId
33+
34+
from keras.callbacks import ModelCheckpoint
35+
from keras.models import Model
36+
from keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense, concatenate, Activation
37+
from keras.optimizers import *
38+
import keras
39+
from keras.models import load_model
40+
import tensorflow as tf
41+
42+
43+
class DataGenerator(keras.utils.all_utils.Sequence):
44+
"""Generates data for fnImgs"""
45+
46+
def __init__(self, fnImgs, labels, sigma, batch_size, dim, readInMemory):
47+
"""Initialization"""
48+
self.fnImgs = fnImgs
49+
self.labels = labels
50+
self.sigma = sigma
51+
self.batch_size = batch_size
52+
if self.batch_size > len(self.fnImgs):
53+
self.batch_size = len(self.fnImgs)
54+
self.dim = dim
55+
self.readInMemory = readInMemory
56+
self.on_epoch_end()
57+
58+
# Read all data in memory
59+
if self.readInMemory:
60+
self.Xexp = np.zeros((len(self.labels), self.dim, self.dim, 1), dtype=np.float64)
61+
for i in range(len(self.labels)):
62+
Iexp = np.reshape(xmippLib.Image(self.fnImgs[i]).getData(), (self.dim, self.dim, 1))
63+
self.Xexp[i,] = (Iexp - np.mean(Iexp)) / np.std(Iexp)
64+
65+
def __len__(self):
66+
"""Denotes the number of batches per epoch"""
67+
num_batches = int(np.floor((len(self.labels)) / self.batch_size))
68+
return num_batches
69+
70+
def __getitem__(self, index):
71+
"""Generate one batch of data"""
72+
# Generate indexes of the batch
73+
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
74+
# Find list of IDs
75+
list_IDs_temp = []
76+
for i in range(int(self.batch_size)):
77+
list_IDs_temp.append(indexes[i])
78+
# Generate data
79+
Xexp, y = self.__data_generation(list_IDs_temp)
80+
81+
return Xexp, y
82+
83+
def on_epoch_end(self):
84+
"""Updates indexes after each epoch"""
85+
self.indexes = [i for i in range(len(self.labels))]
86+
np.random.shuffle(self.indexes)
87+
88+
def __data_generation(self, list_IDs_temp):
89+
"""Generates data containing batch_size samples"""
90+
yvalues = np.array(itemgetter(*list_IDs_temp)(self.labels))
91+
92+
# Functions to handle the data
93+
def get_image(fn_image):
94+
"""Normalize image"""
95+
img = np.reshape(xmippLib.Image(fn_image).getData(), (self.dim, self.dim, 1))
96+
return (img - np.mean(img)) / np.std(img)
97+
98+
def shift_image(img, shiftx, shifty):
99+
"""Shifts image in X and Y"""
100+
return shift(img, (shiftx, shifty, 0), order=1, mode='wrap')
101+
102+
if self.readInMemory:
103+
Iexp = list(itemgetter(*list_IDs_temp)(self.Xexp))
104+
else:
105+
fnIexp = list(itemgetter(*list_IDs_temp)(self.fnImgs))
106+
Iexp = list(map(get_image, fnIexp))
107+
# Data augmentation
108+
rX = self.sigma * np.random.normal(0, 1, size=self.batch_size)
109+
rY = self.sigma * np.random.normal(0, 1, size=self.batch_size)
110+
rX = rX + self.sigma * np.random.uniform(-1, 1, size=self.batch_size)
111+
rY = rY + self.sigma * np.random.uniform(-1, 1, size=self.batch_size)
112+
# Shift image a random amount of px in each direction
113+
Xexp = np.array(list((map(shift_image, Iexp, rX, rY))))
114+
y = yvalues + np.vstack((rX, rY)).T
115+
return Xexp, y
116+
117+
def constructModel(Xdim):
118+
"""CNN architecture"""
119+
inputLayer = Input(shape=(Xdim, Xdim, 1), name="input")
120+
121+
L = Conv2D(8, (3, 3), padding='same')(inputLayer)
122+
L = BatchNormalization()(L)
123+
L = Activation(activation='relu')(L)
124+
L = MaxPooling2D()(L)
125+
126+
L = Conv2D(16, (3, 3), padding='same')(L)
127+
L = BatchNormalization()(L)
128+
L = Activation(activation='relu')(L)
129+
L = MaxPooling2D()(L)
130+
131+
L = Conv2D(32, (3, 3), padding='same')(L)
132+
L = BatchNormalization()(L)
133+
L = Activation(activation='relu')(L)
134+
L = MaxPooling2D()(L)
135+
136+
L = Conv2D(64, (3, 3), padding='same')(L)
137+
L = BatchNormalization()(L)
138+
L = Activation(activation='relu')(L)
139+
L = MaxPooling2D()(L)
140+
141+
L = Flatten()(L)
142+
143+
L = Dense(2, name="output", activation="linear")(L)
144+
145+
return Model(inputLayer, L)
146+
147+
148+
def get_labels(fnImages):
149+
"""Returns dimensions, images and shifts values from images files"""
150+
Xdim, _, _, _, _ = xmippLib.MetaDataInfo(fnImages)
151+
mdExp = xmippLib.MetaData(fnImages)
152+
fnImgs = mdExp.getColumnValues(xmippLib.MDL_IMAGE)
153+
shiftX = mdExp.getColumnValues(xmippLib.MDL_SHIFT_X)
154+
shiftY = mdExp.getColumnValues(xmippLib.MDL_SHIFT_Y)
155+
labels = []
156+
for x, y in zip(shiftX, shiftY):
157+
labels.append(np.array((x, y)))
158+
return Xdim, fnImgs, labels
159+
160+
Xdim, fnImgs, labels = get_labels(fnXmdExp)
161+
start_time = time()
162+
163+
# Train-Validation sets
164+
if numModels == 1:
165+
lenTrain = int(len(fnImgs)*0.8)
166+
lenVal = len(fnImgs)-lenTrain
167+
else:
168+
lenTrain = int(len(fnImgs) / 3)
169+
print('lenTrain', lenTrain, flush=True)
170+
lenVal = int(len(fnImgs) / 12)
171+
172+
for index in range(numModels):
173+
random_sample = np.random.choice(range(0, len(fnImgs)), size=lenTrain+lenVal, replace=False)
174+
if pretrained == 'yes':
175+
model = load_model(fnPreModel, compile=False)
176+
else:
177+
model = constructModel(Xdim)
178+
adam_opt = tf.keras.optimizers.Adam(lr=learning_rate)
179+
model.summary()
180+
181+
model.compile(loss='mean_absolute_error', optimizer='adam')
182+
183+
save_best_model = ModelCheckpoint(fnModel + str(index) + ".h5", monitor='val_loss',
184+
save_best_only=True)
185+
patienceCallBack = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience)
186+
187+
training_generator = DataGenerator([fnImgs[i] for i in random_sample[0:lenTrain]],
188+
[labels[i] for i in random_sample[0:lenTrain]],
189+
sigma, batch_size, Xdim, readInMemory=False)
190+
validation_generator = DataGenerator([fnImgs[i] for i in random_sample[lenTrain:lenTrain + lenVal]],
191+
[labels[i] for i in random_sample[lenTrain:lenTrain + lenVal]],
192+
sigma, batch_size, Xdim, readInMemory=False)
193+
194+
history = model.fit_generator(generator=training_generator, epochs=numEpochs,
195+
validation_data=validation_generator, callbacks=[save_best_model, patienceCallBack])
196+
197+
elapsed_time = time() - start_time
198+
print("Time in training model: %0.10f seconds." % elapsed_time)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#!/usr/bin/env python3
2+
3+
import numpy as np
4+
import os
5+
import sys
6+
import xmippLib
7+
from time import time
8+
9+
maxSize = 32
10+
11+
if __name__ == "__main__":
12+
from xmippPyModules.deepLearningToolkitUtils.utils import checkIf_tf_keras_installed
13+
14+
checkIf_tf_keras_installed()
15+
fnXmdExp = sys.argv[1]
16+
gpuId = sys.argv[2]
17+
outputDir = sys.argv[3]
18+
fnXmdImages = sys.argv[4]
19+
fnModel = sys.argv[5]
20+
numModels = int(sys.argv[6])
21+
tolerance = int(sys.argv[7])
22+
maxModels = int(sys.argv[8])
23+
24+
if not gpuId.startswith('-1'):
25+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
26+
os.environ["CUDA_VISIBLE_DEVICES"] = gpuId
27+
28+
import keras
29+
from keras.models import load_model
30+
31+
32+
33+
class DataGenerator(keras.utils.all_utils.Sequence):
34+
"""Generates data for fnImgs"""
35+
36+
def __init__(self, fnImgs, maxSize, dim, readInMemory):
37+
"""Initialization"""
38+
self.fnImgs = fnImgs
39+
self.maxSize = maxSize
40+
self.dim = dim
41+
self.readInMemory = readInMemory
42+
self.on_epoch_end()
43+
44+
# Read all data in memory
45+
if self.readInMemory:
46+
self.Xexp = np.zeros((len(self.fnImgs), self.dim, self.dim, 1), dtype=np.float64)
47+
for i in range(len(self.fnImgs)):
48+
Iexp = np.reshape(xmippLib.Image(self.fnImgs[i]).getData(), (self.dim, self.dim, 1))
49+
self.Xexp[i,] = (Iexp - np.mean(Iexp)) / np.std(Iexp)
50+
51+
def __len__(self):
52+
"""Denotes the number of batches per predictions"""
53+
num = len(self.fnImgs) // maxSize
54+
if len(self.fnImgs) % maxSize > 0:
55+
num = num + 1
56+
return num
57+
58+
def __getitem__(self, index):
59+
"""Generate one batch of data"""
60+
# Generate indexes of the batch
61+
indexes = self.indexes[index * maxSize:(index + 1) * maxSize]
62+
# Find list of IDs
63+
list_IDs_temp = []
64+
for i in range(len(indexes)):
65+
list_IDs_temp.append(indexes[i])
66+
67+
# Generate data
68+
Xexp = self.__data_generation(list_IDs_temp)
69+
70+
return Xexp
71+
72+
def on_epoch_end(self):
73+
self.indexes = [i for i in range(len(self.fnImgs))]
74+
75+
def getNumberOfBlocks(self):
76+
self.st = len(self.fnImgs) // maxSize
77+
if len(self.fnImgs) % maxSize > 0:
78+
self.st = self.st + 1
79+
80+
def __data_generation(self, list_IDs_temp):
81+
"""Generates data containing batch_size samples"""
82+
# Initialization
83+
Xexp = np.zeros((len(list_IDs_temp), self.dim, self.dim, 1), dtype=np.float64)
84+
# Generate data
85+
for i, ID in enumerate(list_IDs_temp):
86+
# Read image
87+
if self.readInMemory:
88+
Xexp[i, ] = self.Xexp[ID]
89+
else:
90+
Iexp = np.reshape(xmippLib.Image(self.fnImgs[ID]).getData(), (self.dim, self.dim, 1))
91+
Xexp[i, ] = (Iexp - np.mean(Iexp)) / np.std(Iexp)
92+
return Xexp
93+
94+
def produce_output(mdExp, Y, distance, fnImg):
95+
ID = 0
96+
for objId in mdExp:
97+
# Set predictions in mdExp
98+
shiftX, shiftY = Y[ID]
99+
mdExp.setValue(xmippLib.MDL_SHIFT_X, float(shiftX), objId)
100+
mdExp.setValue(xmippLib.MDL_SHIFT_Y, float(shiftY), objId)
101+
mdExp.setValue(xmippLib.MDL_IMAGE, fnImg[ID], objId)
102+
if distance[ID] > tolerance:
103+
mdExp.setValue(xmippLib.MDL_ENABLED, -1, objId)
104+
ID += 1
105+
106+
def average_of_shifts(predshift):
107+
"""Consensus tool"""
108+
# Calculates average shift for each particle
109+
av_shift = np.average(predshift, axis=0)
110+
distancesxy = np.abs(av_shift-predshift)
111+
# min number of models
112+
minModels = np.shape(predshift)[0] - maxModels
113+
# Calculates norm 1 of distances
114+
distances = np.sum(distancesxy, axis=1)
115+
max_distance = np.max(distances)
116+
# max distance model to the average
117+
max_dif_model = np.argmax(distances)
118+
while (np.shape(predshift)[0] > minModels) and (max_distance > tolerance):
119+
# deletes predictions from the max_dif_model and recalculates averages
120+
predshift = np.delete(predshift, max_dif_model, axis=0)
121+
av_shift = np.average(predshift, axis=0)
122+
distancesxy = np.abs(av_shift - predshift)
123+
distances = np.sum(distancesxy, axis=1)
124+
max_distance = np.max(distances)
125+
max_dif_model = np.argmax(distances)
126+
return np.append(av_shift, max_distance)
127+
128+
def compute_shift_averages(predshift):
129+
"""Calls consensus tool"""
130+
averages_mdistance = np.array(list(map(average_of_shifts, predshift)))
131+
average = averages_mdistance[:, 0:2]
132+
mdistance = averages_mdistance[:, 2]
133+
return average, mdistance
134+
135+
Xdim, _, _, _, _ = xmippLib.MetaDataInfo(fnXmdExp)
136+
137+
mdExp = xmippLib.MetaData(fnXmdExp)
138+
fnImgs = mdExp.getColumnValues(xmippLib.MDL_IMAGE)
139+
140+
mdExpImages = xmippLib.MetaData(fnXmdImages)
141+
fnImages = mdExpImages.getColumnValues(xmippLib.MDL_IMAGE)
142+
143+
start_time = time()
144+
145+
predictions = np.zeros((len(fnImgs), numModels, 2))
146+
ShiftManager = DataGenerator(fnImgs, maxSize, Xdim, readInMemory=False)
147+
for index in range(numModels):
148+
ShiftModel = load_model(fnModel + str(index) + ".h5", compile=False)
149+
ShiftModel.compile(loss="mean_squared_error", optimizer='adam')
150+
predictions[:, index, :] = ShiftModel.predict_generator(ShiftManager, ShiftManager.getNumberOfBlocks())
151+
Y, distance = compute_shift_averages(predictions)
152+
produce_output(mdExp, Y, distance, fnImages)
153+
154+
mdExp.write(os.path.join(outputDir, "predict_results.xmd"))
155+
156+
elapsed_time = time() - start_time
157+
print("Time in training model: %0.10f seconds." % elapsed_time)
158+
159+
160+
161+

0 commit comments

Comments
 (0)