Skip to content

Commit 014e12b

Browse files
committed
clean tsne module
1 parent 9f2b497 commit 014e12b

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def main():
9595
"StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets"
9696

9797
if train_config['distributed_data_parallel']:
98-
msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, and frequency_analysis with DDP. " +\
98+
msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, frequency, and tsne analysis with DDP. " +\
9999
"Please change DDP with a single GPU training or DataParallel instead."
100-
assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + \
101-
train_config['interpolation'] + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg
100+
assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + train_config['interpolation'] +\
101+
train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg
102102

103103
hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \
104104
if train_config['load_all_data_in_memory'] else None

src/utils/misc.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,19 @@ def plot_tsne_scatter_plot(df, tsne_results, flag, run_name, logger):
418418

419419
df['tsne-2d-one'] = tsne_results[:,0]
420420
df['tsne-2d-two'] = tsne_results[:,1]
421+
# x="tsne-2d-one", y="tsne-2d-two",
421422
plt.figure(figsize=(16,10))
422423
sns.scatterplot(
423424
x="tsne-2d-one", y="tsne-2d-two",
424425
hue="labels",
425426
palette=sns.color_palette("hls", 10),
426427
data=df,
427428
legend="full",
428-
alpha=0.3
429-
)
429+
alpha=0.5
430+
).legend(fontsize = 15, loc ='upper right')
431+
plt.title("TSNE result of {flag} images".format(flag=flag), fontsize=25)
432+
plt.xlabel('', fontsize=7)
433+
plt.ylabel('', fontsize=7)
430434
plt.savefig(save_path)
431435
logger.info("Save image to {}".format(save_path))
432436

0 commit comments

Comments
 (0)