-
Notifications
You must be signed in to change notification settings - Fork 38
Ktezcan/dev/iss941 encode targets sepfstep #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Ktezcan/dev/iss941 encode targets sepfstep #1019
Conversation
…941_encode_targets_sepfstep
time_win: tuple, | ||
normalizer, # dataset | ||
normalizer, # dataset, | ||
use_normalizer: str, # "source_normalizer" or "target_normalizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename use_normalizer
to channel_to_normalize
. Even though the type and possible values are clearly documented use_normalizer
indicates for a boolean value.
Another option is to rename normalizer
to normaliser_dataset
or normaliser_ds
so you can use normalizer
instead of use_normalizer
) | ||
for stl_b in batch | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use less lines, because it looks more complex than it actually is.
target_source_like_tokens_lens = torch.stack([
torch.stack([
torch.stack([
s.target_source_like_tokens_lens[fstep]
if len(s.target_source_like_tokens_lens[fstep]) > 0
else torch.tensor([])
for fstep in range(len(s.target_source_like_tokens_lens))
]) for s in stl_b
]) for stl_b in batch
])
If this was caused by ruff then just forget about this comment...
for ib, sb in enumerate(batch): | ||
for itype, s in enumerate(sb): | ||
for fstep in range(offsets.shape[0]): | ||
if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0:
with if any(target_source_like_tokens_lens[ib, type, fstep]):
for better efficiency.
# batch sample list when non-empty | ||
for fstep in range(len(self.target_source_like_tokens_cells)): | ||
if ( | ||
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace
if (
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
> 0
):
with
if any(len(s) > 0 for s in self.target_source_like_tokens_cells[fstep]):
for slightly better efficiency.
Maybe you can find a way to replace len(s) with a way to do the check in constant time without having to write multiple lines of code.
times: np.array, | ||
time_win: tuple, | ||
normalizer, # dataset | ||
use_normalizer: str, # "source_normalizer" or "target_normalizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename use_tokenizer
as you did in tokeniser_forecast.py
(see first comment)
tokens_target_det = tokens_target.detach() # explicitly detach as well | ||
tokens_targets.append(tokens_target_det) | ||
|
||
return_dict = {"preds_all": preds_all, "posteriors": posteriors} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initiate return_dict
above first if check on "encode_targets_latent"
.
Move the key accesses on return_dict
at the end of the first if check on "encode_targets_latent"
.
Remove the second if check on "encode_targets_latent"
.
# # we don't append an empty tensor for the source | ||
# tokens_all.append(torch.tensor([], dtype=self.dtype, device="cuda")) | ||
# el | ||
if source_tokens_lens.sum() != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace if source_tokens_lens.sum() != 0:
with if source_tokens_lens.any():
for better efficiency
Description
This is an dditional PR over a previous PR: #961
The previous one introduces a new function to embed cells for the targets. This PR uses the existing
embed_cells()
function to embed the target tokens. The purpose is to reduce duplicated code and prevent potential "code rot" etc..I have tested both training and inference with this.
Issue Number
Ref #941
Refs #941
Closes #941
Closes #941
Checklist before asking for review
./scripts/actions.sh lint
./scripts/actions.sh unit-test
./scripts/actions.sh integration-test
launch-slurm.py --time 60