From 38f9e8d811134ae19c1d87415c8b21b537c6eacc Mon Sep 17 00:00:00 2001 From: Srikanth Ronanki Date: Tue, 25 Apr 2017 19:18:18 +0900 Subject: [PATCH] Update data_utils.py --- data_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/data_utils.py b/data_utils.py index 101a343..73a347f 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,4 +1,5 @@ import numpy as np +import random from random import sample ''' @@ -6,10 +7,14 @@ return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) ''' -def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): +def split_dataset(x, y, ratio = [0.7, 0.15, 0.15], shuffle=True): # number of examples data_len = len(x) lens = [ int(data_len*item) for item in ratio ] + + #shuffle data + if shuffle: + [x, y] = shuffle_data(x, y) trainX, trainY = x[:lens[0]], y[:lens[0]] testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] @@ -17,7 +22,17 @@ def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): return (trainX,trainY), (testX,testY), (validX,validY) +def shuffle_data(x, y): + # number of examples + data_len = len(x) + indices = [i for i in range(data_len)] + + random.seed(271638) + random.shuffle(indices) + return x[indices], y[indices] + + ''' generate batches from dataset yield (x_gen, y_gen)