Skip to content

Commit 11f939b

Browse files
committed
Save log every step
1 parent 0dfa5db commit 11f939b

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

protein_tune_rl/protein_trainer/dpo_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,14 @@ def run(self, output_dir):
156156
batch = self.collator(raw_batch)
157157
# Perform one training step
158158
current_step = self._train_step(batch, current_step, batch_idx)
159+
# Log the step results
160+
log_df = self._log_step(log_df, output_dir, current_step)
159161

160162
# Checkpointing and evaluation
161163
if self._should_checkpoint(current_step, self.check_point_freq):
162164
# Log the step results
163-
log_df = self._log_step(log_df, output_dir, current_step)
165+
if dist.get_rank() == 0:
166+
log_df.to_csv(f"{output_dir}/dpo_trainer_log.csv", index=False)
164167
dist.barrier()
165168
# Save model checkpoints
166169
if dist.get_rank() == 0 and self.config["trainer"].get(
@@ -217,5 +220,4 @@ def _log_step(self, log_df, output_dir, current_step):
217220
"avg_margin": self._last_avg_margin,
218221
}
219222
log_df = pd.concat([log_df, pd.DataFrame([step_data])], ignore_index=True)
220-
log_df.to_csv(f"{output_dir}/dpo_trainer_log.csv", index=False)
221223
return log_df

protein_tune_rl/protein_trainer/dro_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ def run(self, output_dir):
181181

182182
# Perform one training step
183183
current_step = self._train_step(batch, current_step, batch_number)
184+
# Log the step results
185+
log_df = self._log_step(log_df, output_dir, current_step, batch_number)
184186

185187
if self._should_checkpoint(current_step, self.check_point_freq):
186-
# Log the step results
187-
log_df = self._log_step(
188-
log_df, output_dir, current_step, batch_number
189-
)
188+
if dist.get_rank() == 0:
189+
log_df.to_csv(f"{output_dir}/dro_trainer_log.csv", index=False)
190190
dist.barrier()
191191
self._maybe_save_models(output_dir, current_step)
192192
self._maybe_run_evaluation(output_dir, current_step)
@@ -234,7 +234,6 @@ def _log_step(self, log_df, output_dir, current_step, batch_number):
234234
}
235235
)
236236
log_df = pd.concat([log_df, step_log_df])
237-
log_df.to_csv(f"{output_dir}/dro_trainer_log.csv", index=False)
238237
return log_df
239238

240239
def _maybe_save_models(self, output_dir, current_step):

0 commit comments

Comments
 (0)