|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import math |
| 4 | +import numpy as np |
| 5 | +from operator import itemgetter |
| 6 | +import os |
| 7 | +import sys |
| 8 | +import xmippLib |
| 9 | +from time import time |
| 10 | +from scipy.ndimage import shift |
| 11 | + |
| 12 | +if __name__ == "__main__": |
| 13 | + |
| 14 | + from xmippPyModules.deepLearningToolkitUtils.utils import checkIf_tf_keras_installed |
| 15 | + |
| 16 | + checkIf_tf_keras_installed() |
| 17 | + fnXmdExp = sys.argv[1] |
| 18 | + fnModel = sys.argv[2] |
| 19 | + sigma = float(sys.argv[3]) |
| 20 | + numEpochs = int(sys.argv[4]) |
| 21 | + batch_size = int(sys.argv[5]) |
| 22 | + gpuId = sys.argv[6] |
| 23 | + numModels = int(sys.argv[7]) |
| 24 | + learning_rate = float(sys.argv[8]) |
| 25 | + patience = int(sys.argv[9]) |
| 26 | + pretrained = sys.argv[10] |
| 27 | + if pretrained == 'yes': |
| 28 | + fnPreModel = sys.argv[11] |
| 29 | + |
| 30 | + if not gpuId.startswith('-1'): |
| 31 | + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| 32 | + os.environ["CUDA_VISIBLE_DEVICES"] = gpuId |
| 33 | + |
| 34 | + from keras.callbacks import ModelCheckpoint |
| 35 | + from keras.models import Model |
| 36 | + from keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense, concatenate, Activation |
| 37 | + from keras.optimizers import * |
| 38 | + import keras |
| 39 | + from keras.models import load_model |
| 40 | + import tensorflow as tf |
| 41 | + |
| 42 | + |
| 43 | + class DataGenerator(keras.utils.all_utils.Sequence): |
| 44 | + """Generates data for fnImgs""" |
| 45 | + |
| 46 | + def __init__(self, fnImgs, labels, sigma, batch_size, dim, readInMemory): |
| 47 | + """Initialization""" |
| 48 | + self.fnImgs = fnImgs |
| 49 | + self.labels = labels |
| 50 | + self.sigma = sigma |
| 51 | + self.batch_size = batch_size |
| 52 | + if self.batch_size > len(self.fnImgs): |
| 53 | + self.batch_size = len(self.fnImgs) |
| 54 | + self.dim = dim |
| 55 | + self.readInMemory = readInMemory |
| 56 | + self.on_epoch_end() |
| 57 | + |
| 58 | + # Read all data in memory |
| 59 | + if self.readInMemory: |
| 60 | + self.Xexp = np.zeros((len(self.labels), self.dim, self.dim, 1), dtype=np.float64) |
| 61 | + for i in range(len(self.labels)): |
| 62 | + Iexp = np.reshape(xmippLib.Image(self.fnImgs[i]).getData(), (self.dim, self.dim, 1)) |
| 63 | + self.Xexp[i,] = (Iexp - np.mean(Iexp)) / np.std(Iexp) |
| 64 | + |
| 65 | + def __len__(self): |
| 66 | + """Denotes the number of batches per epoch""" |
| 67 | + num_batches = int(np.floor((len(self.labels)) / self.batch_size)) |
| 68 | + return num_batches |
| 69 | + |
| 70 | + def __getitem__(self, index): |
| 71 | + """Generate one batch of data""" |
| 72 | + # Generate indexes of the batch |
| 73 | + indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] |
| 74 | + # Find list of IDs |
| 75 | + list_IDs_temp = [] |
| 76 | + for i in range(int(self.batch_size)): |
| 77 | + list_IDs_temp.append(indexes[i]) |
| 78 | + # Generate data |
| 79 | + Xexp, y = self.__data_generation(list_IDs_temp) |
| 80 | + |
| 81 | + return Xexp, y |
| 82 | + |
| 83 | + def on_epoch_end(self): |
| 84 | + """Updates indexes after each epoch""" |
| 85 | + self.indexes = [i for i in range(len(self.labels))] |
| 86 | + np.random.shuffle(self.indexes) |
| 87 | + |
| 88 | + def __data_generation(self, list_IDs_temp): |
| 89 | + """Generates data containing batch_size samples""" |
| 90 | + yvalues = np.array(itemgetter(*list_IDs_temp)(self.labels)) |
| 91 | + |
| 92 | + # Functions to handle the data |
| 93 | + def get_image(fn_image): |
| 94 | + """Normalize image""" |
| 95 | + img = np.reshape(xmippLib.Image(fn_image).getData(), (self.dim, self.dim, 1)) |
| 96 | + return (img - np.mean(img)) / np.std(img) |
| 97 | + |
| 98 | + def shift_image(img, shiftx, shifty): |
| 99 | + """Shifts image in X and Y""" |
| 100 | + return shift(img, (shiftx, shifty, 0), order=1, mode='wrap') |
| 101 | + |
| 102 | + if self.readInMemory: |
| 103 | + Iexp = list(itemgetter(*list_IDs_temp)(self.Xexp)) |
| 104 | + else: |
| 105 | + fnIexp = list(itemgetter(*list_IDs_temp)(self.fnImgs)) |
| 106 | + Iexp = list(map(get_image, fnIexp)) |
| 107 | + # Data augmentation |
| 108 | + rX = self.sigma * np.random.normal(0, 1, size=self.batch_size) |
| 109 | + rY = self.sigma * np.random.normal(0, 1, size=self.batch_size) |
| 110 | + rX = rX + self.sigma * np.random.uniform(-1, 1, size=self.batch_size) |
| 111 | + rY = rY + self.sigma * np.random.uniform(-1, 1, size=self.batch_size) |
| 112 | + # Shift image a random amount of px in each direction |
| 113 | + Xexp = np.array(list((map(shift_image, Iexp, rX, rY)))) |
| 114 | + y = yvalues + np.vstack((rX, rY)).T |
| 115 | + return Xexp, y |
| 116 | + |
| 117 | + def constructModel(Xdim): |
| 118 | + """CNN architecture""" |
| 119 | + inputLayer = Input(shape=(Xdim, Xdim, 1), name="input") |
| 120 | + |
| 121 | + L = Conv2D(8, (3, 3), padding='same')(inputLayer) |
| 122 | + L = BatchNormalization()(L) |
| 123 | + L = Activation(activation='relu')(L) |
| 124 | + L = MaxPooling2D()(L) |
| 125 | + |
| 126 | + L = Conv2D(16, (3, 3), padding='same')(L) |
| 127 | + L = BatchNormalization()(L) |
| 128 | + L = Activation(activation='relu')(L) |
| 129 | + L = MaxPooling2D()(L) |
| 130 | + |
| 131 | + L = Conv2D(32, (3, 3), padding='same')(L) |
| 132 | + L = BatchNormalization()(L) |
| 133 | + L = Activation(activation='relu')(L) |
| 134 | + L = MaxPooling2D()(L) |
| 135 | + |
| 136 | + L = Conv2D(64, (3, 3), padding='same')(L) |
| 137 | + L = BatchNormalization()(L) |
| 138 | + L = Activation(activation='relu')(L) |
| 139 | + L = MaxPooling2D()(L) |
| 140 | + |
| 141 | + L = Flatten()(L) |
| 142 | + |
| 143 | + L = Dense(2, name="output", activation="linear")(L) |
| 144 | + |
| 145 | + return Model(inputLayer, L) |
| 146 | + |
| 147 | + |
| 148 | + def get_labels(fnImages): |
| 149 | + """Returns dimensions, images and shifts values from images files""" |
| 150 | + Xdim, _, _, _, _ = xmippLib.MetaDataInfo(fnImages) |
| 151 | + mdExp = xmippLib.MetaData(fnImages) |
| 152 | + fnImgs = mdExp.getColumnValues(xmippLib.MDL_IMAGE) |
| 153 | + shiftX = mdExp.getColumnValues(xmippLib.MDL_SHIFT_X) |
| 154 | + shiftY = mdExp.getColumnValues(xmippLib.MDL_SHIFT_Y) |
| 155 | + labels = [] |
| 156 | + for x, y in zip(shiftX, shiftY): |
| 157 | + labels.append(np.array((x, y))) |
| 158 | + return Xdim, fnImgs, labels |
| 159 | + |
| 160 | + Xdim, fnImgs, labels = get_labels(fnXmdExp) |
| 161 | + start_time = time() |
| 162 | + |
| 163 | + # Train-Validation sets |
| 164 | + if numModels == 1: |
| 165 | + lenTrain = int(len(fnImgs)*0.8) |
| 166 | + lenVal = len(fnImgs)-lenTrain |
| 167 | + else: |
| 168 | + lenTrain = int(len(fnImgs) / 3) |
| 169 | + print('lenTrain', lenTrain, flush=True) |
| 170 | + lenVal = int(len(fnImgs) / 12) |
| 171 | + |
| 172 | + for index in range(numModels): |
| 173 | + random_sample = np.random.choice(range(0, len(fnImgs)), size=lenTrain+lenVal, replace=False) |
| 174 | + if pretrained == 'yes': |
| 175 | + model = load_model(fnPreModel, compile=False) |
| 176 | + else: |
| 177 | + model = constructModel(Xdim) |
| 178 | + adam_opt = tf.keras.optimizers.Adam(lr=learning_rate) |
| 179 | + model.summary() |
| 180 | + |
| 181 | + model.compile(loss='mean_absolute_error', optimizer='adam') |
| 182 | + |
| 183 | + save_best_model = ModelCheckpoint(fnModel + str(index) + ".h5", monitor='val_loss', |
| 184 | + save_best_only=True) |
| 185 | + patienceCallBack = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience) |
| 186 | + |
| 187 | + training_generator = DataGenerator([fnImgs[i] for i in random_sample[0:lenTrain]], |
| 188 | + [labels[i] for i in random_sample[0:lenTrain]], |
| 189 | + sigma, batch_size, Xdim, readInMemory=False) |
| 190 | + validation_generator = DataGenerator([fnImgs[i] for i in random_sample[lenTrain:lenTrain + lenVal]], |
| 191 | + [labels[i] for i in random_sample[lenTrain:lenTrain + lenVal]], |
| 192 | + sigma, batch_size, Xdim, readInMemory=False) |
| 193 | + |
| 194 | + history = model.fit_generator(generator=training_generator, epochs=numEpochs, |
| 195 | + validation_data=validation_generator, callbacks=[save_best_model, patienceCallBack]) |
| 196 | + |
| 197 | + elapsed_time = time() - start_time |
| 198 | + print("Time in training model: %0.10f seconds." % elapsed_time) |
0 commit comments