Skip to content

Make LLPR module checkpoint-friendly #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,17 @@

import torch

from metatrain.utils.io import load_model


# %%
#
# Models can be loaded using the :func:`metatrain.utils.io.load_model` function from
# the. For already exported models The function requires the path to the exported model
# and, for many models, also the path to the respective extensions directory. Both are
# produced during the training process.


model = load_model("model.pt", extensions_directory="extensions/")

# %%
#
# In metatrain, a Dataset is composed of a list of systems and a dictionary of targets.
# The following lines illustrate how to read systems and targets from xyz files, and
# how to create a Dataset object from them.

from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402
from metatrain.utils.neighbor_lists import ( # noqa: E402
get_requested_neighbor_lists,
Expand All @@ -74,7 +66,13 @@
}
targets, _ = read_targets(target_config)

requested_neighbor_lists = get_requested_neighbor_lists(model)

from metatrain.utils.llpr import LLPRUncertaintyModel # noqa: E402


llpr_model = LLPRUncertaintyModel("model.ckpt")

requested_neighbor_lists = get_requested_neighbor_lists(llpr_model)
qm9_systems = [
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
Expand Down Expand Up @@ -111,27 +109,16 @@
# to compute prediction rigidity metrics, which are useful for uncertainty
# quantification and model introspection.

from metatensor.torch.atomistic import ( # noqa: E402
MetatensorAtomisticModel,
ModelMetadata,
)

from metatrain.utils.llpr import LLPRUncertaintyModel # noqa: E402


llpr_model = LLPRUncertaintyModel(model)
llpr_model.compute_covariance(dataloader)
llpr_model.compute_inverse_covariance(regularizer=1e-4)

# calibrate on the same dataset for simplicity. In reality, a separate
# calibration/validation dataset should be used.
print("Calibrate")
llpr_model.calibrate(dataloader)

exported_model = MetatensorAtomisticModel(
llpr_model.eval(),
ModelMetadata(),
llpr_model.capabilities,
)
exported_model = llpr_model.export()

# %%
#
Expand All @@ -148,7 +135,7 @@
outputs={
# request the uncertainty in the atomic energy predictions
"energy": ModelOutput(per_atom=True), # needed to request the uncertainties
"mtt::aux::energy_uncertainty": ModelOutput(per_atom=True),
"energy_uncertainty": ModelOutput(per_atom=True),
# `per_atom=False` would return the total uncertainty for the system,
# or (the inverse of) the TPR (total prediction rigidity)
# you also can request other outputs from the model here, for example:
Expand All @@ -158,7 +145,7 @@
)

outputs = exported_model([ethanol_system], evaluation_options, check_consistency=False)
lpr = outputs["mtt::aux::energy_uncertainty"].block().values.detach().cpu().numpy()
lpr = outputs["energy_uncertainty"].block().values.detach().cpu().numpy()

# %%
#
Expand Down
26 changes: 14 additions & 12 deletions examples/programmatic/llpr_forces/force_llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
ModelEvaluationOptions,
ModelMetadata,
ModelOutput,
load_atomistic_model,
)

from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets
from metatrain.utils.llpr import LLPRUncertaintyModel
from metatrain.utils.loss import TensorMapDictLoss
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)


model = load_atomistic_model("model.pt", extensions_directory="extensions/")
model = model.to("cuda")
dtype = torch.float64 # matching the model that was trained
device = "cuda" if torch.cuda.is_available() else "cpu"

train_systems = read_systems("train.xyz")
train_target_config = {
Expand Down Expand Up @@ -105,7 +107,10 @@
}
test_targets, target_info = read_targets(test_target_config)

requested_neighbor_lists = model.requested_neighbor_lists()
llpr_model = LLPRUncertaintyModel("model.ckpt")
llpr_model.to(device)

requested_neighbor_lists = get_requested_neighbor_lists(llpr_model)
train_systems = [
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in train_systems
Expand Down Expand Up @@ -148,8 +153,6 @@
}
loss_fn = TensorMapDictLoss(loss_weight_dict)

llpr_model = LLPRUncertaintyModel(model)

print("Last layer parameters:")
parameters = []
for name, param in llpr_model.named_parameters():
Expand All @@ -173,7 +176,7 @@
length_unit="angstrom",
outputs={
"mtt::aux::energy_last_layer_features": ModelOutput(per_atom=False),
"mtt::aux::energy_uncertainty": ModelOutput(per_atom=False),
"energy_uncertainty": ModelOutput(per_atom=False),
"energy": ModelOutput(per_atom=False),
},
selected_atoms=None,
Expand All @@ -183,12 +186,11 @@
force_uncertainties = []

for batch in test_dataloader:
dtype = getattr(torch, model.capabilities().dtype)
systems, targets = batch
systems = [system.to("cuda", dtype) for system in systems]
systems = [system.to(device, dtype) for system in systems]
for system in systems:
system.positions.requires_grad = True
targets = {name: tmap.to("cuda", dtype) for name, tmap in targets.items()}
targets = {name: tmap.to(device, dtype) for name, tmap in targets.items()}

outputs = exported_model(systems, evaluation_options, check_consistency=True)
energy = outputs["energy"].block().values
Expand Down Expand Up @@ -221,7 +223,7 @@
force_uncertainty = torch.einsum(
"if, fg, ig -> i",
ll_feature_grads,
exported_model.module.inv_covariances["mtt::aux::energy_uncertainty"],
exported_model.module._get_inv_covariance("energy_uncertainty"),
ll_feature_grads,
)
force_uncertainties.append(force_uncertainty.detach().clone().cpu().numpy())
Expand Down
Loading
Loading