Skip to content

Commit 8c540b5

Browse files
committed
Merge branch '25-fix-multi-gpu-online-evaluation' into 'main'
Resolve "Fix multi-GPU online evaluation" Closes #25 See merge request optlm/protein_tune_rl!23
2 parents ab9995b + 33fad7c commit 8c540b5

File tree

5 files changed

+97
-45
lines changed

5 files changed

+97
-45
lines changed

protein_tune_rl/collator/dro_collator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __call__(self, batch):
169169

170170
if self.eval:
171171
return {
172+
"__row_idx__": batch["__row_idx__"],
172173
"input_ids": tokenized_masked_prompts_with_completions,
173174
"prompts": tokenized_masked_prompts,
174175
"labels": input_mask,
@@ -178,6 +179,7 @@ def __call__(self, batch):
178179
}
179180

180181
return {
182+
"__row_idx__": batch["__row_idx__"],
181183
"input_ids": tokenized_masked_prompts_with_completions,
182184
"prompts": tokenized_masked_prompts,
183185
"labels": input_mask,

protein_tune_rl/dataset/dro_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def __init__(self, data_directory, chain, region, reward):
99
def __getitem__(self, idx):
1010

1111
return {
12+
"__row_idx__": int(idx),
1213
"prompts": self.data[self.chain].iloc[idx],
1314
"completions": self.data[self.region].iloc[idx],
1415
"rewards": float(self.data[self.reward].iloc[idx]),
@@ -23,6 +24,7 @@ def __init__(self, data_directory, chain, region):
2324
def __getitem__(self, idx):
2425

2526
return {
27+
"__row_idx__": int(idx),
2628
"prompts": self.data[self.chain].iloc[idx],
2729
"completions": self.data[self.region].iloc[idx],
2830
"LC": self.data.LC.iloc[idx],

protein_tune_rl/dataset/infilling_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __len__(self):
1414
def __getitem__(self, idx):
1515

1616
return {
17+
"__row_idx__": int(idx),
1718
"prompts": self.data[self.chain].iloc[idx],
1819
"region": self.data[self.region].iloc[idx],
1920
"LC": self.data.LC.iloc[idx],

protein_tune_rl/protein_evaluator/iglm_evaluator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,11 @@ def run_with_ground_truth(self, output_dir=None):
246246
'generated_sequences': [],
247247
'heavy_chains': [],
248248
'light_chains': [],
249+
'__row_idx__': [],
249250
}
250251

252+
self._log_dataset_info()
253+
251254
for batch_number, batch in enumerate(iter(self.dataloader)):
252255
self.policy.eval()
253256

@@ -272,6 +275,19 @@ def run_with_ground_truth(self, output_dir=None):
272275
eval_df = self._create_evaluation_dataframe(results)
273276
return gather_dataframes(eval_df, device=self.device)
274277

278+
def _log_dataset_info(self):
279+
dataloader = self.dataloader
280+
ddp_enabled = dist.is_available() and dist.is_initialized()
281+
world_size = dist.get_world_size() if ddp_enabled else 1
282+
sampler = getattr(dataloader, "sampler", None)
283+
samples_per_rank = len(sampler) if sampler else len(dataloader.dataset)
284+
batches_per_rank = len(dataloader)
285+
logger.info(
286+
f"Eval: world_size={world_size}, batch_size=1, "
287+
f"per_rank={samples_per_rank} samples/{batches_per_rank} batches, "
288+
f"global_batches_per_epoch={batches_per_rank * world_size}"
289+
)
290+
275291
def _generate_sequences_if_needed(self, tokenized_batch):
276292
"""Generate sequences if any metric requires generated sequences."""
277293
if not any(self.metric_use_generated):
@@ -405,6 +421,7 @@ def _collect_sample_results(
405421
+ "[MASK]"
406422
+ tokenized_batch["seq_post_mask"][0]
407423
)
424+
results['__row_idx__'].append(int(tokenized_batch["__row_idx__"][0]))
408425

409426
def _create_evaluation_dataframe(self, results):
410427
"""Create DataFrame from collected results."""
@@ -413,6 +430,7 @@ def _create_evaluation_dataframe(self, results):
413430
eval_df['HC'] = results['heavy_chains']
414431
eval_df['LC'] = results['light_chains']
415432
eval_df['prompts'] = results['prompts']
433+
eval_df['__row_idx__'] = results['__row_idx__']
416434

417435
for idx, metric in enumerate(self.config['metric']):
418436
eval_df[str(metric['name'])] = [

protein_tune_rl/protein_trainer/dro_trainer.py

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def run_evaluation(self, output_dir, current_step):
161161
eval_df = self.evaluator.run_with_ground_truth()
162162

163163
if dist.get_rank() == 0 and eval_df is not None:
164+
eval_df = eval_df.sort_values("__row_idx__").reset_index(drop=True)
164165
eval_df.to_csv(
165166
f"{output_dir}/evaluation_results_step_{current_step}.csv",
166167
index=False,
@@ -173,67 +174,95 @@ def run(self, output_dir):
173174
"""Run the DRO Trainer for the specified number of optimization steps."""
174175
log_df = pd.DataFrame()
175176

176-
logger.info(
177-
f"Breaking down the training dataset into {len(self.dataloader)} batches."
178-
)
177+
self._log_dataset_info()
179178

180179
current_step = 0
181180
while current_step < self.total_optimization_steps:
182181
for batch_number, batch in enumerate(iter(self.dataloader)):
183-
self.value.train()
184-
self.policy.train()
182+
current_step = self._train_step(batch, current_step, batch_number)
183+
self._log_step(log_df, output_dir, current_step, batch_number)
184+
dist.barrier()
185185

186-
self.policy_optimizer.zero_grad()
187-
self.value_optimizer.zero_grad()
186+
if self._should_checkpoint(current_step):
187+
self._maybe_save_models(output_dir, current_step)
188+
self._maybe_run_evaluation(output_dir, current_step)
188189

189-
tokenized_batch = self.collator(batch)
190+
if current_step >= self.total_optimization_steps:
191+
break
190192

191-
policy_loss, value_loss = self.model_optimizer.calculate_loss(
192-
tokenized_batch
193-
)
193+
self._final_save(output_dir)
194+
return log_df
194195

195-
value_loss.backward()
196-
policy_loss.backward()
196+
def _log_dataset_info(self):
197+
dl = self.dataloader
198+
world = (
199+
dist.get_world_size()
200+
if dist.is_available() and dist.is_initialized()
201+
else 1
202+
)
203+
sampler = getattr(dl, "sampler", None)
197204

198-
self.policy_optimizer.step()
199-
self.value_optimizer.step()
205+
per_rank_samples = len(sampler) if sampler is not None else len(dl.dataset)
206+
per_rank_batches = len(dl)
200207

201-
current_step += 1
208+
logger.info(
209+
f"Per-rank: {per_rank_samples} samples → {per_rank_batches} batches "
210+
f"(batch size={dl.batch_size}, drop_last={dl.drop_last}); "
211+
f"Global: world_size={world}, effective batch size={dl.batch_size * world}, "
212+
f"batches/epoch={per_rank_batches * world}."
213+
)
202214

203-
logger.info(
204-
f"Step {current_step}, Batch {batch_number + 1}: "
205-
f"Policy Loss: {policy_loss.item():.4f}, "
206-
f"Value Loss: {value_loss.item():.4f}"
207-
)
215+
def _train_step(self, batch, current_step, batch_number):
216+
"""Perform a single training step on the provided batch."""
217+
self.value.train()
218+
self.policy.train()
208219

209-
if dist.get_rank() == 0:
210-
step_log_df = pd.DataFrame.from_dict(
211-
{
212-
"step": [current_step],
213-
"policy_loss": [policy_loss.item()],
214-
"value_loss": [value_loss.item()],
215-
}
216-
)
220+
self.policy_optimizer.zero_grad()
221+
self.value_optimizer.zero_grad()
217222

218-
log_df = pd.concat([log_df, step_log_df])
219-
log_df.to_csv(f"{output_dir}/dro_trainer_log.csv", index=False)
220-
dist.barrier()
223+
tokenized_batch = self.collator(batch)
221224

222-
if (current_step % self.check_point_freq == 0) and (current_step > 0):
225+
policy_loss, value_loss = self.model_optimizer.calculate_loss(tokenized_batch)
223226

224-
if self.config["trainer"].get("save_models", True):
225-
if dist.get_rank() == 0:
226-
self.save_models(output_dir, current_step)
227-
dist.barrier()
227+
value_loss.backward()
228+
policy_loss.backward()
228229

229-
# Run online evaluation if configured
230-
if self.config["trainer"].get("evaluate_during_training", False):
231-
self.run_evaluation(output_dir, current_step)
230+
self.policy_optimizer.step()
231+
self.value_optimizer.step()
232232

233-
if current_step >= self.total_optimization_steps:
234-
break
233+
logger.info(
234+
f"Step {current_step + 1}, Batch {batch_number + 1}: Policy Loss: {policy_loss.item():.4f}, Value Loss: {value_loss.item():.4f}"
235+
)
235236

236-
# Final save after training completes
237-
self.policy.module.save(output_dir / "models/final")
237+
self._last_policy_loss = policy_loss
238+
self._last_value_loss = value_loss
238239

239-
return log_df
240+
return current_step + 1
241+
242+
def _log_step(self, log_df, output_dir, current_step, batch_number):
243+
if dist.get_rank() == 0:
244+
step_log_df = pd.DataFrame.from_dict(
245+
{
246+
"step": [current_step],
247+
"policy_loss": [self._last_policy_loss.item()],
248+
"value_loss": [self._last_value_loss.item()],
249+
}
250+
)
251+
log_df = pd.concat([log_df, step_log_df])
252+
log_df.to_csv(f"{output_dir}/dro_trainer_log.csv", index=False)
253+
254+
def _should_checkpoint(self, current_step):
255+
return (current_step % self.check_point_freq == 0) and (current_step > 0)
256+
257+
def _maybe_save_models(self, output_dir, current_step):
258+
if self.config["trainer"].get("save_models", True):
259+
if dist.get_rank() == 0:
260+
self.save_models(output_dir, current_step)
261+
dist.barrier()
262+
263+
def _maybe_run_evaluation(self, output_dir, current_step):
264+
if self.config["trainer"].get("evaluate_during_training", False):
265+
self.run_evaluation(output_dir, current_step)
266+
267+
def _final_save(self, output_dir):
268+
self.policy.module.save(output_dir / "models/final")

0 commit comments

Comments
 (0)