Skip to content

Conversation

kctezcan
Copy link
Contributor

@kctezcan kctezcan commented Sep 30, 2025

Description

This is an dditional PR over a previous PR: #961

The previous one introduces a new function to embed cells for the targets. This PR uses the existing embed_cells() function to embed the target tokens. The purpose is to reduce duplicated code and prevent potential "code rot" etc..

I have tested both training and inference with this.

Issue Number

Ref #941
Refs #941
Closes #941
Closes #941

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

time_win: tuple,
normalizer, # dataset
normalizer, # dataset,
use_normalizer: str, # "source_normalizer" or "target_normalizer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename use_normalizer to channel_to_normalize. Even though the type and possible values are clearly documented use_normalizer indicates for a boolean value.
Another option is to rename normalizer to normaliser_datasetor normaliser_dsso you can use normalizer instead of use_normalizer

)
for stl_b in batch
]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use less lines, because it looks more complex than it actually is.

target_source_like_tokens_lens = torch.stack([
     torch.stack([
          torch.stack([
             s.target_source_like_tokens_lens[fstep]
             if len(s.target_source_like_tokens_lens[fstep]) > 0
             else torch.tensor([])
             for fstep in range(len(s.target_source_like_tokens_lens))
          ]) for s in stl_b
       ]) for stl_b in batch
    ])

If this was caused by ruff then just forget about this comment...

for ib, sb in enumerate(batch):
for itype, s in enumerate(sb):
for fstep in range(offsets.shape[0]):
if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: with if any(target_source_like_tokens_lens[ib, type, fstep]): for better efficiency.

# batch sample list when non-empty
for fstep in range(len(self.target_source_like_tokens_cells)):
if (
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace

if (
   torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
   > 0
):

with

if any(len(s) > 0 for s in self.target_source_like_tokens_cells[fstep]):

for slightly better efficiency.

Maybe you can find a way to replace len(s) with a way to do the check in constant time without having to write multiple lines of code.

times: np.array,
time_win: tuple,
normalizer, # dataset
use_normalizer: str, # "source_normalizer" or "target_normalizer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename use_tokenizer as you did in tokeniser_forecast.py(see first comment)

tokens_target_det = tokens_target.detach() # explicitly detach as well
tokens_targets.append(tokens_target_det)

return_dict = {"preds_all": preds_all, "posteriors": posteriors}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initiate return_dictabove first if check on "encode_targets_latent".
Move the key accesses on return_dict at the end of the first if check on "encode_targets_latent".
Remove the second if check on "encode_targets_latent".

# # we don't append an empty tensor for the source
# tokens_all.append(torch.tensor([], dtype=self.dtype, device="cuda"))
# el
if source_tokens_lens.sum() != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace if source_tokens_lens.sum() != 0: with if source_tokens_lens.any(): for better efficiency

@github-project-automation github-project-automation bot moved this to In Progress in WeatherGen-dev Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

encoding target variales in the latent space

2 participants