Skip to content

few changes to make the device/datatype work in the composition model #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: example_polarizability_mol
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,12 @@ def train(
optimizer.load_state_dict(self.optimizer_state_dict)

# Create a scheduler:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=self.hypers["scheduler_factor"],
patience=self.hypers["scheduler_patience"],
)
def lr_lambda(step):
if step < 1000:
return step / 1000 # Linear warm-up
return 1.0 # Keep lr constant after warm-up

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
if self.scheduler_state_dict is not None:
# same as the optimizer, try to load the scheduler state dict
if not (model.module if is_distributed else model).has_new_targets:
Expand Down Expand Up @@ -313,6 +314,7 @@ def train(
train_loss_batch.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()

if is_distributed:
# sum the loss over all processes
Expand Down Expand Up @@ -425,27 +427,27 @@ def train(
rank=rank,
)

lr_scheduler.step(val_loss)
new_lr = lr_scheduler.get_last_lr()[0]
if new_lr != old_lr:
if new_lr < 1e-7:
logger.info("Learning rate is too small, stopping training")
break
else:
logger.info(f"Changing learning rate from {old_lr} to {new_lr}")
old_lr = new_lr
# load best model and optimizer state dict, re-initialize scheduler
(model.module if is_distributed else model).load_state_dict(
self.best_model_state_dict
)
optimizer.load_state_dict(self.best_optimizer_state_dict)
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=self.hypers["scheduler_factor"],
patience=self.hypers["scheduler_patience"],
)
# lr_scheduler.step(val_loss)
# new_lr = lr_scheduler.get_last_lr()[0]
# if new_lr != old_lr:
# if new_lr < 1e-7:
# logger.info("Learning rate is too small, stopping training")
# break
# else:
# logger.info(f"Changing learning rate from {old_lr} to {new_lr}")
# old_lr = new_lr
# # load best model and optimizer state dict, re-initialize schedulr
# (model.module if is_distributed else model).load_state_dict(
# self.best_model_state_dict
# )
# optimizer.load_state_dict(self.best_optimizer_state_dict)
# for param_group in optimizer.param_groups:
# param_group["lr"] = new_lr
# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
# optimizer,
# factor=self.hypers["scheduler_factor"],
# patience=self.hypers["scheduler_patience"],
# )

val_metric = get_selected_metric(
finalized_val_info, self.hypers["best_model_metric"]
Expand Down
21 changes: 14 additions & 7 deletions src/metatrain/utils/additive/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def train_model(
dtype=self.dummy_buffer.dtype,
).reshape(-1, 1)
self.weights[target_key] = TensorMap(
keys=Labels.single(),
keys=Labels.single().to(device),
blocks=[
TensorBlock(
values=weights_tensor,
values=weights_tensor.to(device),
samples=Labels(
names=["center_type"],
values=torch.tensor(
Expand All @@ -145,9 +145,11 @@ def train_model(
),
components=self.dataset_info.targets[target_key]
.layout.block()
.to(device)
.components,
properties=self.dataset_info.targets[target_key]
.layout.block()
.to(device)
.properties,
)
],
Expand Down Expand Up @@ -273,17 +275,19 @@ def train_model(
if self.dataset_info.targets[target_key].per_atom:
# hack: metatensor.join doesn't work on single blocks;
# create TensorMaps, join, and then extract the joined block

joined_blocks = metatensor.torch.join(
[
TensorMap(
keys=Labels.single(),
blocks=[b],
keys=Labels.single().to(device),
blocks=[b.to(device)],
)
for b in block_list
],
axis="samples",
remove_tensor_name=True,
).block()

# This code doesn't work because mean_over_samples_block
# actually does a sum...
# weights_tensor = (
Expand All @@ -300,7 +304,10 @@ def train_model(
# .values
# )
weights_tensor = torch.empty(
len(self.atomic_types), len(metadata_block.properties)
len(self.atomic_types),
len(metadata_block.properties),
dtype=dtype,
device=device,
)
for i_type, atomic_type in enumerate(self.atomic_types):
mask = (
Expand All @@ -320,7 +327,7 @@ def train_model(
weights_tensor = weights_tensor.unsqueeze(1)
weight_blocks.append(
TensorBlock(
values=weights_tensor,
values=weights_tensor.to(device),
samples=Labels(
["center_type"],
values=torch.tensor(
Expand All @@ -331,7 +338,7 @@ def train_model(
c.to(device) for c in metadata_block.components
],
properties=metadata_block.properties.to(device),
)
).to(device)
)
self.weights[target_key] = TensorMap(
keys=self.dataset_info.targets[target_key].layout.keys.to(device),
Expand Down
49 changes: 28 additions & 21 deletions src/metatrain/utils/data/readers/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from metatensor.torch.atomistic import System
from omegaconf import DictConfig

from .split import split_structurewise
from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,22 +47,26 @@ def read_energy(target: DictConfig) -> Tuple[TensorMap, TargetInfo]:
# the actual metadata in the tensor maps
_check_tensor_map_metadata(tensor_map, target_info.layout)

selections = [
Labels(
names=["system"],
values=torch.tensor([[int(i)]]),
)
for i in torch.unique(
torch.concatenate(
[block.samples.column("system") for block in tensor_map.blocks()]
)
)
]
tensor_maps = metatensor.torch.split(tensor_map, "samples", selections)
# selections = [
# Labels(
# names=["system"],
# values=torch.tensor([[int(i)]]),
# )
# for i in torch.unique(
# torch.concatenate(
# [block.samples.column("system") for block in tensor_map.blocks()]
# )
# )
# ]

#TODO: replace with fast metatensor.torch.split once available
tensor_maps = split_structurewise(tensor_map)

return tensor_maps, target_info


def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:
print("Now I am reading generic")
tensor_map = _wrapped_metatensor_read(target["read_from"])

for block in tensor_map.blocks():
Expand All @@ -76,14 +80,17 @@ def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:
# actual properties of the tensor maps
target_info.layout = _empty_tensor_map_like(tensor_map)

selections = [
Labels(
names=["system"],
values=torch.tensor([[int(i)]]),
)
for i in torch.unique(tensor_map.block(0).samples.column("system"))
]
tensor_maps = metatensor.torch.split(tensor_map, "samples", selections)
# selections = [
# Labels(
# names=["system"],
# values=torch.tensor([[int(i)]]),
# )
# for i in torch.unique(tensor_map.block(0).samples.column("system"))
# ]

#TODO: replace with fast metatensor.torch.split once available
tensor_maps = split_structurewise(tensor_map)

return tensor_maps, target_info


Expand Down
60 changes: 60 additions & 0 deletions src/metatrain/utils/data/readers/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import metatensor.torch
import torch
from metatensor.torch import Labels
from metatensor.torch import TensorMap, TensorBlock

# use torch split
# torch split does smth like this torch.split(tensor, [split_1_length, split_2_length, ...])
# torch.unique_consecutive(x, return_counts=True) actually returns exactly this as counts.

# we can assume that the tensormaps are "dense" (ie every atom has a value)
# and they are already sorted

#a = torch.arange(1000).reshape(2, 500)
#b = torch.split(a, torch.tensor([2 for i in range(500)]).tolist())


def split_structurewise(tensormap):
"""
Split a TensorMap structurewise.
Assumes dense and sorted TensorMap.
"""

#tensormap = metatensor.torch.sort(tensormap, axes="samples")

sample_values = tensormap.block(0).samples.values
_, counts = torch.unique_consecutive(sample_values[:,0], return_counts=True)

# get the keys of all blocks
# is that even possible different sample names?
sample_names_block = { str(key): block.samples.names for key, block in tensormap.items() }
components_block_wise = { str(key): block.components for key, block in tensormap.items() }
properties_block_wise = { str(key): block.properties for key, block in tensormap.items() }

splitted = {}
splitted_samples = torch.split(sample_values, counts.tolist())

for key, block in tensormap.items():
splitted[str(key)] = torch.split(block.values, counts.tolist())

tensor_maps = []

for i, sample in enumerate(splitted_samples):

blocks = []

for key in splitted.keys():
samples_block = Labels(sample_names_block[key], sample)

blocks.append(TensorBlock(
samples=samples_block,
values=splitted[key][i],
components=components_block_wise[key],
properties=properties_block_wise[key]
))

tensor_maps.append(TensorMap(keys=tensormap.keys, blocks=blocks))

# return a list of tensormaps
return tensor_maps

Loading