diff --git a/docs/src/advanced-concepts/auxiliary-outputs.rst b/docs/src/advanced-concepts/auxiliary-outputs.rst index 4a9f60d25..98dee8d1f 100644 --- a/docs/src/advanced-concepts/auxiliary-outputs.rst +++ b/docs/src/advanced-concepts/auxiliary-outputs.rst @@ -86,4 +86,4 @@ features See the `feature output `_ -in ``metatensor.torch.atomistic``. +in ``metatomic.torch``. diff --git a/docs/src/advanced-concepts/output-naming.rst b/docs/src/advanced-concepts/output-naming.rst index 7162bc346..480a9e10c 100644 --- a/docs/src/advanced-concepts/output-naming.rst +++ b/docs/src/advanced-concepts/output-naming.rst @@ -2,7 +2,7 @@ Output naming ============= The name and format of the outputs in ``metatrain`` are based on -those of the ``_ package. An immediate example is given by the ``energy`` output. diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index 5d81b89b4..054ae603d 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -104,7 +104,7 @@ method. .. code-block:: python - from metatensor.torch.atomistic import MetatensorAtomisticModel, ModelMetadata + from metatomic.torch import AtomisticModel, ModelMetadata class ModelInterface: @@ -146,7 +146,7 @@ method. def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: pass Note that the ``ModelInterface`` does not necessarily inherit from @@ -165,8 +165,8 @@ the ``architecture`` key should contain references about the general architectur The ``export()`` method is required to transform a trained model into a standalone file to be used in combination with molecular dynamic engines to run simulations. We provide a helper function :py:func:`metatrain.utils.export.export` to export a torch -model to an :py:class:`MetatensorAtomisticModel -`. +model to an :py:class:`AtomisticModel +`. Trainer class (``trainer.py``) ------------------------------ diff --git a/docs/src/dev-docs/utils/data/systems_to_ase.rst b/docs/src/dev-docs/utils/data/systems_to_ase.rst index 07b9360d6..913a77163 100644 --- a/docs/src/dev-docs/utils/data/systems_to_ase.rst +++ b/docs/src/dev-docs/utils/data/systems_to_ase.rst @@ -2,7 +2,7 @@ Converting Systems to ASE ######################### Some machine learning models might train on ``ase.Atoms`` objects. -This module provides a function to convert a ``metatensor.torch.atomistic.System`` +This module provides a function to convert a ``metatomic.torch.System`` object to an ``ase.Atoms`` object. .. automodule:: metatrain.utils.data.system_to_ase diff --git a/docs/src/dev-docs/utils/neighbor_lists.rst b/docs/src/dev-docs/utils/neighbor_lists.rst index b4833761c..0fba78244 100644 --- a/docs/src/dev-docs/utils/neighbor_lists.rst +++ b/docs/src/dev-docs/utils/neighbor_lists.rst @@ -1,7 +1,7 @@ Neighbor lists ============== -Utilities to attach neighbor lists to a ``metatensor.torch.atomistic.System`` object. +Utilities to attach neighbor lists to a ``metatomic.torch.System`` object. .. automodule:: metatrain.utils.neighbor_lists :members: diff --git a/docs/src/getting-started/checkpoints.rst b/docs/src/getting-started/checkpoints.rst index 4353f32d5..44039dc7c 100644 --- a/docs/src/getting-started/checkpoints.rst +++ b/docs/src/getting-started/checkpoints.rst @@ -67,7 +67,7 @@ The ``metadata.yaml`` file should have the following structure: You can also add additional keywords like additional references to the metadata file. The fields are the same for :class:`ModelMetadata -` class from metatensor. +` class from metatensor. Exporting remote models ----------------------- diff --git a/docs/src/getting-started/custom_dataset_conf.rst b/docs/src/getting-started/custom_dataset_conf.rst index d887c1cff..93465a229 100644 --- a/docs/src/getting-started/custom_dataset_conf.rst +++ b/docs/src/getting-started/custom_dataset_conf.rst @@ -91,7 +91,7 @@ Allows defining multiple target sections, each with a unique name. and ``stress`` are enabled by default. - Other target sections can also be defined, as long as they are prefixed by ``mtt::``. For example, ``mtt::free_energy``. In general, all targets that are not standard - outputs of ``metatensor.torch.atomistic`` (see + outputs of ``metatomic.torch`` (see https://docs.metatensor.org/latest/atomistic/outputs.html) should be prefixed by ``mtt::``. diff --git a/examples/ase/run_ase.py b/examples/ase/run_ase.py index 864f4ad80..7f34336e3 100644 --- a/examples/ase/run_ase.py +++ b/examples/ase/run_ase.py @@ -36,7 +36,7 @@ import matplotlib.pyplot as plt import numpy as np from ase.geometry.analysis import Analysis -from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator +from metatomic.torch.ase_calculator import MetatensorCalculator # %% diff --git a/examples/programmatic/disk_dataset/disk_dataset.py b/examples/programmatic/disk_dataset/disk_dataset.py index 99b7632a0..b182eeede 100644 --- a/examples/programmatic/disk_dataset/disk_dataset.py +++ b/examples/programmatic/disk_dataset/disk_dataset.py @@ -14,7 +14,7 @@ import ase.io import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch +from metatomic.torch import NeighborListOptions, systems_to_torch from metatrain.utils.data import DiskDatasetWriter from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index bc1ec41b8..a3a58a7fb 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -111,8 +111,8 @@ # to compute prediction rigidity metrics, which are useful for uncertainty # quantification and model introspection. -from metatensor.torch.atomistic import ( # noqa: E402 - MetatensorAtomisticModel, +from metatomic.torch import ( # noqa: E402 + AtomisticModel, ModelMetadata, ) @@ -127,7 +127,7 @@ # calibration/validation dataset should be used. llpr_model.calibrate(dataloader) -exported_model = MetatensorAtomisticModel( +exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, @@ -140,7 +140,7 @@ # specific outputs from the model. In this case, we request the uncertainty in the # atomic energy predictions. -from metatensor.torch.atomistic import ModelEvaluationOptions, ModelOutput # noqa: E402 +from metatomic.torch import ModelEvaluationOptions, ModelOutput # noqa: E402 evaluation_options = ModelEvaluationOptions( diff --git a/examples/programmatic/llpr_forces/force_llpr.py b/examples/programmatic/llpr_forces/force_llpr.py index 793e10fc4..e23e54099 100644 --- a/examples/programmatic/llpr_forces/force_llpr.py +++ b/examples/programmatic/llpr_forces/force_llpr.py @@ -1,8 +1,8 @@ import matplotlib.pyplot as plt import numpy as np import torch -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelEvaluationOptions, ModelMetadata, ModelOutput, @@ -163,7 +163,7 @@ llpr_model.compute_inverse_covariance() llpr_model.calibrate(valid_dataloader) -exported_model = MetatensorAtomisticModel( +exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, diff --git a/examples/programmatic/use_architectures_outside/use_outside.py b/examples/programmatic/use_architectures_outside/use_outside.py index 1fcc86d0a..b31d38607 100644 --- a/examples/programmatic/use_architectures_outside/use_outside.py +++ b/examples/programmatic/use_architectures_outside/use_outside.py @@ -14,7 +14,7 @@ # import torch -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from metatrain.experimental.nanopet import NanoPET from metatrain.utils.architectures import get_default_hypers diff --git a/examples/zbl/dimers.py b/examples/zbl/dimers.py index 04069a5a5..992ef5212 100644 --- a/examples/zbl/dimers.py +++ b/examples/zbl/dimers.py @@ -36,7 +36,7 @@ import matplotlib.pyplot as plt import numpy as np import torch -from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator +from metatomic.torch.ase_calculator import MetatensorCalculator # %% diff --git a/pyproject.toml b/pyproject.toml index 3610e52a9..642967235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,13 @@ authors = [{name = "metatrain developers"}] # Strict version pinning to avoid regression test failing on new versions dependencies = [ "ase", - "metatensor-learn==0.3.2", - "metatensor-operations==0.3.3", - "metatensor-torch==0.7.6", + "huggingface_hub", + "metatensor-learn >=0.3.2,<0.4", + "metatensor-operations >=0.3.3,<0.4", + "metatensor-torch >=0.7.6,<0.8", + "metatomic-torch >=0.1.2,<0.2", "jsonschema", - "omegaconf", + "omegaconf >= 2.3.0", "python-hostlist", "vesin", ] diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index b6b0df889..072c0b4a4 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -9,7 +9,7 @@ import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import MetatensorAtomisticModel +from metatomic.torch import AtomisticModel from omegaconf import DictConfig, OmegaConf from ..utils.data import ( @@ -167,7 +167,7 @@ def _concatenate_tensormaps( def _eval_targets( - model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule], + model: Union[AtomisticModel, torch.jit._script.RecursiveScriptModule], dataset: Union[Dataset, torch.utils.data.Subset], options: Dict[str, TargetInfo], return_predictions: bool, @@ -220,10 +220,10 @@ def _eval_targets( collate_fn=collate_fn, shuffle=False, ) - - # Initialize RMSE accumulator: - rmse_accumulator = RMSEAccumulator() - mae_accumulator = MAEAccumulator() + # Not initializing the accumulator + # Initialize RMSE accumulator: + # rmse_accumulator = RMSEAccumulator() + # mae_accumulator = MAEAccumulator() # If we're returning the predictions, we need to store them: if return_predictions: @@ -265,7 +265,7 @@ def _eval_targets( model, systems, options, - is_training=False, + is_training=False, check_consistency=check_consistency, ) @@ -279,8 +279,9 @@ def _eval_targets( batch_targets_per_atom = average_by_num_atoms( batch_targets, systems, per_structure_keys=[] ) - rmse_accumulator.update(batch_predictions_per_atom, batch_targets_per_atom) - mae_accumulator.update(batch_predictions_per_atom, batch_targets_per_atom) + # CHANGE: Do not calculate the loss because it currently does not support arbitrary loss functions + # rmse_accumulator.update(batch_predictions_per_atom, batch_targets_per_atom) + # mae_accumulator.update(batch_predictions_per_atom, batch_targets_per_atom) if return_predictions: all_predictions.append(batch_predictions) @@ -288,18 +289,19 @@ def _eval_targets( total_time += time_taken timings_per_atom.append(time_taken / sum(len(system) for system in systems)) + # CHANGE: Do not calculate the loss because it currently does not support arbitrary loss functions # Finalize the metrics - rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"]) - mae_values = mae_accumulator.finalize(not_per_atom=["positions_gradients"]) - metrics = {**rmse_values, **mae_values} - - # print the RMSEs with MetricLogger - metric_logger = MetricLogger( - log_obj=logger, - dataset_info=model.capabilities(), - initial_metrics=metrics, - ) - metric_logger.log(metrics) + # rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"]) + # mae_values = mae_accumulator.finalize(not_per_atom=["positions_gradients"]) + # metrics = {**rmse_values, **mae_values} + + # # print the RMSEs with MetricLogger + # metric_logger = MetricLogger( + # log_obj=logger, + # dataset_info=model.capabilities(), + # initial_metrics=metrics, + # ) + # metric_logger.log(metrics) # Log timings timings_per_atom = np.array(timings_per_atom) @@ -320,7 +322,7 @@ def _eval_targets( def eval_model( - model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule], + model: Union[AtomisticModel, torch.jit._script.RecursiveScriptModule], options: DictConfig, output: Union[Path, str] = "output.xyz", batch_size: int = 1, @@ -362,9 +364,9 @@ def eval_model( # and we calculate RMSEs eval_targets, eval_info_dict = read_targets(options["targets"]) else: - # in this case, we have no targets: we evaluate everything - # (but we don't/can't calculate RMSEs) - # TODO: allow the user to specify which outputs to evaluate + # in this case, we have no targets: we evaluate everything + # (but we don't/can't calculate RMSEs) + # TODO: allow the user to specify which outputs to evaluate eval_targets = {} eval_info_dict = {} do_strain_grad = all( diff --git a/src/metatrain/cli/export.py b/src/metatrain/cli/export.py index ca0446db0..4f10be740 100644 --- a/src/metatrain/cli/export.py +++ b/src/metatrain/cli/export.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Union import torch -from metatensor.torch.atomistic import ModelMetadata, is_atomistic_model +from metatomic.torch import ModelMetadata, is_atomistic_model from omegaconf import OmegaConf from ..utils.io import check_file_extension, load_model diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 8d9f9110f..1c0e59e63 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -43,7 +43,6 @@ from .export import _has_extensions from .formatter import CustomHelpFormatter - def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: """Add `train_model` paramaters to an argparse (sub)-parser.""" @@ -509,9 +508,16 @@ def train_model( mts_atomistic_model.buffers(), ) ).device + # CHANGE: metatensor does not yet support saving noncontiguous tensors (TEST) + # try: + # mts_atomistic_model.module.additive_models[0].weights['mtt::dos'] = mt.make_contiguous(mts_atomistic_model.module.additive_models[0].weights['mtt::dos']) + # except: + # print ("Failed to make DOS additive model contiguous, the target probably does not exist") + + mts_atomistic_model.save(str(output_checked), collect_extensions=extensions_path) # the model is first saved and then reloaded 1) for good practice and 2) because - # MetatensorAtomisticModel only torchscripts (makes faster) during save() + # AtomisticModel only torchscripts (makes faster) during save() # Copy the exported model and the checkpoint also to the checkpoint directory checkpoint_path = Path(checkpoint_dir) diff --git a/src/metatrain/deprecated/pet/model.py b/src/metatrain/deprecated/pet/model.py index fc0f0ce38..454b0b037 100644 --- a/src/metatrain/deprecated/pet/model.py +++ b/src/metatrain/deprecated/pet/model.py @@ -3,8 +3,8 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelMetadata, ModelOutput, @@ -274,7 +274,7 @@ def load_checkpoint( def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: raise ValueError(f"Unsupported dtype {self.dtype} for PET") @@ -313,4 +313,4 @@ def export( append_metadata_references(metadata, self.__default_metadata__) - return MetatensorAtomisticModel(self.eval(), metadata, capabilities) + return AtomisticModel(self.eval(), metadata, capabilities) diff --git a/src/metatrain/deprecated/pet/tests/test_exported.py b/src/metatrain/deprecated/pet/tests/test_exported.py index daabe8f78..c62258eea 100644 --- a/src/metatrain/deprecated/pet/tests/test_exported.py +++ b/src/metatrain/deprecated/pet/tests/test_exported.py @@ -1,6 +1,6 @@ import pytest import torch -from metatensor.torch.atomistic import ( +from metatomic.torch import ( ModelCapabilities, ModelEvaluationOptions, ModelMetadata, diff --git a/src/metatrain/deprecated/pet/tests/test_functionality.py b/src/metatrain/deprecated/pet/tests/test_functionality.py index ca62b4173..418be3e2f 100644 --- a/src/metatrain/deprecated/pet/tests/test_functionality.py +++ b/src/metatrain/deprecated/pet/tests/test_functionality.py @@ -5,8 +5,8 @@ import torch from jsonschema.exceptions import ValidationError from metatensor.torch import Labels -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelEvaluationOptions, ModelMetadata, @@ -105,7 +105,7 @@ def test_prediction(): supported_devices=["cpu", "cuda"], ) - model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) + model = AtomisticModel(model.eval(), ModelMetadata(), capabilities) model( [system], evaluation_options, @@ -157,7 +157,7 @@ def test_per_atom_predictions_functionality(): supported_devices=["cpu", "cuda"], ) - model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) + model = AtomisticModel(model.eval(), ModelMetadata(), capabilities) model( [system], evaluation_options, @@ -219,7 +219,7 @@ def test_selected_atoms_functionality(): selected_atoms=selected_atoms, ) - model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) + model = AtomisticModel(model.eval(), ModelMetadata(), capabilities) model( [system], evaluation_options, diff --git a/src/metatrain/deprecated/pet/tests/test_pet_compatibility.py b/src/metatrain/deprecated/pet/tests/test_pet_compatibility.py index 64f173902..f8eda85ba 100644 --- a/src/metatrain/deprecated/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/deprecated/pet/tests/test_pet_compatibility.py @@ -1,7 +1,7 @@ import pytest import torch -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelEvaluationOptions, ModelMetadata, @@ -130,7 +130,7 @@ def test_predictions_compatibility(cutoff): outputs=capabilities.outputs, ) - model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) + model = AtomisticModel(model.eval(), ModelMetadata(), capabilities) mtm_pet_prediction = ( model( [system], diff --git a/src/metatrain/deprecated/pet/utils/systems_to_batch_dict.py b/src/metatrain/deprecated/pet/utils/systems_to_batch_dict.py index 9c6f390a9..2ed6fb249 100644 --- a/src/metatrain/deprecated/pet/utils/systems_to_batch_dict.py +++ b/src/metatrain/deprecated/pet/utils/systems_to_batch_dict.py @@ -3,7 +3,7 @@ import pet_neighbors_convert # noqa: F401 import torch from metatensor.torch import Labels -from metatensor.torch.atomistic import NeighborListOptions, System +from metatomic.torch import NeighborListOptions, System def collate_graph_dicts( @@ -340,7 +340,7 @@ def systems_to_batch_dict( Converts a standard input data format of `metatrain` to a PyTorch Geometric `Batch` object, compatible with `PET` model. - :param systems: The list of systems in `metatensor.torch.atomistic.System` + :param systems: The list of systems in `metatomic.torch.System` format, that needs to be converted. :param options: A `NeighborListOptions` objects specifying the parameters for a neighbor list, which will be used during the convertation. diff --git a/src/metatrain/experimental/nanopet/model.py b/src/metatrain/experimental/nanopet/model.py index a8ca2eb34..8f771caa0 100644 --- a/src/metatrain/experimental/nanopet/model.py +++ b/src/metatrain/experimental/nanopet/model.py @@ -5,8 +5,8 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelMetadata, ModelOutput, @@ -234,7 +234,7 @@ def forward( selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: # Checks on systems (species) and outputs are done in the - # MetatensorAtomisticModel wrapper + # AtomisticModel wrapper device = systems[0].device @@ -578,7 +578,7 @@ def load_checkpoint( def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {dtype} for NanoPET") @@ -616,7 +616,7 @@ def export( append_metadata_references(metadata, self.__default_metadata__) - return MetatensorAtomisticModel(self.eval(), metadata, capabilities) + return AtomisticModel(self.eval(), metadata, capabilities) def _add_output(self, target_name: str, target_info: TargetInfo) -> None: # warn that, for Cartesian tensors, we assume that they are symmetric diff --git a/src/metatrain/experimental/nanopet/modules/structures.py b/src/metatrain/experimental/nanopet/modules/structures.py index 738cf95e2..942b552aa 100644 --- a/src/metatrain/experimental/nanopet/modules/structures.py +++ b/src/metatrain/experimental/nanopet/modules/structures.py @@ -1,7 +1,7 @@ from typing import List import torch -from metatensor.torch.atomistic import NeighborListOptions, System +from metatomic.torch import NeighborListOptions, System def concatenate_structures( diff --git a/src/metatrain/experimental/nanopet/tests/test_exported.py b/src/metatrain/experimental/nanopet/tests/test_exported.py index 6d6eea9d4..abac629ba 100644 --- a/src/metatrain/experimental/nanopet/tests/test_exported.py +++ b/src/metatrain/experimental/nanopet/tests/test_exported.py @@ -1,6 +1,6 @@ import pytest import torch -from metatensor.torch.atomistic import ModelEvaluationOptions, ModelMetadata, System +from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System from metatrain.experimental.nanopet import NanoPET from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/experimental/nanopet/tests/test_functionality.py b/src/metatrain/experimental/nanopet/tests/test_functionality.py index 72f8ec2ee..bc27926c5 100644 --- a/src/metatrain/experimental/nanopet/tests/test_functionality.py +++ b/src/metatrain/experimental/nanopet/tests/test_functionality.py @@ -2,7 +2,7 @@ import pytest import torch from jsonschema.exceptions import ValidationError -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf from metatrain.experimental.nanopet.model import NanoPET diff --git a/src/metatrain/experimental/nanopet/tests/test_regression.py b/src/metatrain/experimental/nanopet/tests/test_regression.py index 93ffa4205..a9c665874 100644 --- a/src/metatrain/experimental/nanopet/tests/test_regression.py +++ b/src/metatrain/experimental/nanopet/tests/test_regression.py @@ -2,7 +2,7 @@ import numpy as np import torch -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from omegaconf import OmegaConf from metatrain.experimental.nanopet import NanoPET, Trainer diff --git a/src/metatrain/experimental/nanopet/tests/test_torchscript.py b/src/metatrain/experimental/nanopet/tests/test_torchscript.py index 04389db7d..0b1473a77 100644 --- a/src/metatrain/experimental/nanopet/tests/test_torchscript.py +++ b/src/metatrain/experimental/nanopet/tests/test_torchscript.py @@ -1,7 +1,7 @@ import copy import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.experimental.nanopet import NanoPET from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/gap/model.py b/src/metatrain/gap/model.py index 85ce7b00e..e7d659707 100644 --- a/src/metatrain/gap/model.py +++ b/src/metatrain/gap/model.py @@ -10,8 +10,8 @@ from metatensor.torch import Labels as TorchLabels from metatensor.torch import TensorBlock as TorchTensorBlock from metatensor.torch import TensorMap as TorchTensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelMetadata, ModelOutput, @@ -263,7 +263,7 @@ def forward( def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: interaction_ranges = [self.hypers["soap"]["cutoff"]["radius"]] for additive_model in self.additive_models: if hasattr(additive_model, "cutoff_radius"): @@ -296,7 +296,7 @@ def export( append_metadata_references(metadata, self.__default_metadata__) - return MetatensorAtomisticModel(self.eval(), metadata, capabilities) + return AtomisticModel(self.eval(), metadata, capabilities) ######################################################################################## diff --git a/src/metatrain/gap/tests/test_exported.py b/src/metatrain/gap/tests/test_exported.py index 813427db4..267f9a4c4 100644 --- a/src/metatrain/gap/tests/test_exported.py +++ b/src/metatrain/gap/tests/test_exported.py @@ -1,5 +1,5 @@ import torch -from metatensor.torch.atomistic import ModelMetadata +from metatomic.torch import ModelMetadata from omegaconf import OmegaConf from metatrain.gap import GAP, Trainer diff --git a/src/metatrain/gap/tests/test_regression.py b/src/metatrain/gap/tests/test_regression.py index 13ddb8126..cf7aa70f9 100644 --- a/src/metatrain/gap/tests/test_regression.py +++ b/src/metatrain/gap/tests/test_regression.py @@ -97,11 +97,11 @@ def test_regression_train_and_invariance(): system.rotate(48, "y") original_output = gap( - [metatensor.torch.atomistic.systems_to_torch(original_system)], + [metatomic.torch.systems_to_torch(original_system)], {"mtt::U0": gap.outputs["mtt::U0"]}, ) rotated_output = gap( - [metatensor.torch.atomistic.systems_to_torch(system)], + [metatomic.torch.systems_to_torch(system)], {"mtt::U0": gap.outputs["mtt::U0"]}, ) @@ -182,11 +182,11 @@ def test_ethanol_regression_train_and_invariance(): system.rotate(48, "y") original_output = gap( - [metatensor.torch.atomistic.systems_to_torch(original_system)], + [metatomic.torch.systems_to_torch(original_system)], {"energy": gap.outputs["energy"]}, ) rotated_output = gap( - [metatensor.torch.atomistic.systems_to_torch(system)], + [metatomic.torch.systems_to_torch(system)], {"energy": gap.outputs["energy"]}, ) diff --git a/src/metatrain/pet/DOSutils.py b/src/metatrain/pet/DOSutils.py new file mode 100644 index 000000000..50a7acedc --- /dev/null +++ b/src/metatrain/pet/DOSutils.py @@ -0,0 +1,29 @@ +import torch + + +def get_dynamic_shift_agnostic_mse(predictions, targets, cutoff_mask, return_shift = False): + # dx is hardcoded for now + if predictions.shape[1] < targets.shape[1]: + smaller = predictions + bigger = targets + else: + smaller = targets + bigger = predictions + + bigger_unfolded = bigger.unfold(1, smaller.shape[1], 1) + smaller_expanded = smaller[:, None, :] + delta = smaller_expanded - bigger_unfolded + # Weibin's addition - assumes prediction is bigger than target + dynamic_delta = delta * cutoff_mask.unsqueeze(dim=1) + device = predictions.device + losses = torch.trapezoid(dynamic_delta * dynamic_delta, dx = 0.05, dim=2) + front_tail = torch.cumulative_trapezoid(predictions**2, dx = 0.05, dim = 1) + shape_difference = predictions.shape[1] - targets.shape[1] + additional_error = torch.hstack([torch.zeros(len(predictions), device = device).reshape(-1,1), front_tail[:,:shape_difference]]) + total_losses = losses + additional_error + final_loss, shift = torch.min(total_losses, dim=1) + result = torch.mean(final_loss) + if return_shift: + return result, shift + else: + return result \ No newline at end of file diff --git a/src/metatrain/pet/default-hypers.yaml b/src/metatrain/pet/default-hypers.yaml index 47224dde1..43765b600 100644 --- a/src/metatrain/pet/default-hypers.yaml +++ b/src/metatrain/pet/default-hypers.yaml @@ -18,6 +18,7 @@ architecture: smearing: 1.4 kspace_resolution: 1.33 interpolation_nodes: 5 + excess_targets: {} training: distributed: false @@ -43,3 +44,6 @@ architecture: weights: {} reduction: mean sliding_factor: null + use_permanent: false + integral_penalty: 2.0 + gradient_penalty: 1e-4 diff --git a/src/metatrain/pet/model.py b/src/metatrain/pet/model.py index f0273139f..10341a089 100644 --- a/src/metatrain/pet/model.py +++ b/src/metatrain/pet/model.py @@ -5,8 +5,8 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelMetadata, ModelOutput, @@ -61,6 +61,33 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: self.embedding = torch.nn.Embedding( len(self.atomic_types) + 1, self.hypers["d_pet"] ) + + # CHANGE: Redefined the target size and remove the mask from the targets + try: + for i in self.hypers["excess_targets"]: + # additional_output = int(self.hypers["excess_targets"][i]) + # target_size = dataset_info.targets[i].layout[0].values.shape[1] + prediction_size = int(self.hypers["excess_targets"][i]) + output_block = metatensor.torch.TensorBlock( + values= torch.empty(0, prediction_size).double(), + samples=metatensor.torch.Labels.empty('system'), + components=[], + # properties=metatensor.torch.Labels.single(), + properties=metatensor.torch.Labels.range("Energy", prediction_size) + ) + output_map = metatensor.torch.TensorMap( + keys = metatensor.torch.Labels.single(), + blocks = [output_block] + ) + dataset_info.targets[i].layout = output_map + except: + print ("Did not manage to change the target size") + try: + del dataset_info.targets["mtt::mask"] + except: + print ("Did not manage to delete mtt::mask from dataset_info.targets") + + gnn_layers = [] for layer_index in range(self.hypers["num_gnn_layers"]): transformer_layer = CartesianTransformer( @@ -170,7 +197,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: def restart(self, dataset_info: DatasetInfo) -> "PET": # merge old and new dataset info - merged_info = self.dataset_info.union(dataset_info) + merged_info = self.dataset_info#.union(dataset_info) #CHANGE: This is because the target shape != model prediction shape new_atomic_types = [ at for at in merged_info.atomic_types if at not in self.atomic_types ] @@ -195,18 +222,18 @@ def restart(self, dataset_info: DatasetInfo) -> "PET": self.dataset_info = merged_info # restart the composition and scaler models - self.additive_models[0].restart( - dataset_info=DatasetInfo( - length_unit=dataset_info.length_unit, - atomic_types=self.atomic_types, - targets={ - target_name: target_info - for target_name, target_info in dataset_info.targets.items() - if CompositionModel.is_valid_target(target_name, target_info) - }, - ), - ) - self.scaler.restart(dataset_info) + # self.additive_models[0].restart( + # dataset_info=DatasetInfo( + # length_unit=dataset_info.length_unit, + # atomic_types=self.atomic_types, + # targets={ + # target_name: target_info + # for target_name, target_info in dataset_info.targets.items() + # if CompositionModel.is_valid_target(target_name, target_info) + # }, + # ), + # ) + # self.scaler.restart(dataset_info) return self @@ -635,24 +662,25 @@ def forward( else: return_dict[output_name] = sum_over_atoms(atomic_property) - if not self.training: - # at evaluation, we also introduce the scaler and additive contributions - return_dict = self.scaler(return_dict) - for additive_model in self.additive_models: - outputs_for_additive_model: Dict[str, ModelOutput] = {} - for name, output in outputs.items(): - if name in additive_model.outputs: - outputs_for_additive_model[name] = output - additive_contributions = additive_model( - systems, - outputs_for_additive_model, - selected_atoms, - ) - for name in additive_contributions: - return_dict[name] = metatensor.torch.add( - return_dict[name], - additive_contributions[name], - ) + # CHANGE: Make sure that the additive models are not called + # if not self.training: + # # at evaluation, we also introduce the scaler and additive contributions + # return_dict = self.scaler(return_dict) + # for additive_model in self.additive_models: + # outputs_for_additive_model: Dict[str, ModelOutput] = {} + # for name, output in outputs.items(): + # if name in additive_model.outputs: + # outputs_for_additive_model[name] = output + # additive_contributions = additive_model( + # systems, + # outputs_for_additive_model, + # selected_atoms, + # ) + # for name in additive_contributions: + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name], + # ) return return_dict @@ -688,7 +716,7 @@ def load_checkpoint( def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {dtype} for PET") @@ -724,7 +752,7 @@ def export( append_metadata_references(metadata, self.__default_metadata__) - return MetatensorAtomisticModel(self.eval(), metadata, capabilities) + return AtomisticModel(self.eval(), metadata, capabilities) def _add_output(self, target_name: str, target_info: TargetInfo) -> None: # warn that, for Cartesian tensors, we assume that they are symmetric diff --git a/src/metatrain/pet/modules/structures.py b/src/metatrain/pet/modules/structures.py index 27f7a599f..d03168eb8 100644 --- a/src/metatrain/pet/modules/structures.py +++ b/src/metatrain/pet/modules/structures.py @@ -2,7 +2,7 @@ import torch from metatensor.torch import Labels, TensorBlock -from metatensor.torch.atomistic import NeighborListOptions, System +from metatomic.torch import NeighborListOptions, System from .nef import ( compute_reversed_neighbor_list, diff --git a/src/metatrain/pet/schema-hypers.json b/src/metatrain/pet/schema-hypers.json index 331cc9831..8df84b717 100644 --- a/src/metatrain/pet/schema-hypers.json +++ b/src/metatrain/pet/schema-hypers.json @@ -57,6 +57,15 @@ "type": "integer" } } + }, + "excess_targets": { + "type": "object", + "patternProperties": { + ".*": { + "type": "number" + } + }, + "additionalProperties": false } }, "additionalProperties": false @@ -255,6 +264,29 @@ } }, "additionalProperties": false + }, + "use_permanent": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "object", + "properties": { + "n_samples": { + "type": "number" + } + }, + "required": ["n_samples"], + "additionalProperties": false + } + ] + }, + "integral_penalty": { + "type": "number" + }, + "gradient_penalty": { + "type": "number" } }, "additionalProperties": false diff --git a/src/metatrain/pet/tests/test_autograd.py b/src/metatrain/pet/tests/test_autograd.py index 6e8570fe9..b36e204c4 100644 --- a/src/metatrain/pet/tests/test_autograd.py +++ b/src/metatrain/pet/tests/test_autograd.py @@ -1,5 +1,5 @@ import torch -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from metatrain.pet import PET from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/pet/tests/test_exported.py b/src/metatrain/pet/tests/test_exported.py index 814810cfd..9761de3a2 100644 --- a/src/metatrain/pet/tests/test_exported.py +++ b/src/metatrain/pet/tests/test_exported.py @@ -1,6 +1,6 @@ import pytest import torch -from metatensor.torch.atomistic import ModelEvaluationOptions, ModelMetadata, System +from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System from metatrain.pet import PET from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/pet/tests/test_functionality.py b/src/metatrain/pet/tests/test_functionality.py index d40ff6fd7..5693b0c51 100644 --- a/src/metatrain/pet/tests/test_functionality.py +++ b/src/metatrain/pet/tests/test_functionality.py @@ -2,7 +2,7 @@ import pytest import torch from jsonschema.exceptions import ValidationError -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf from metatrain.pet import PET diff --git a/src/metatrain/pet/tests/test_long_range.py b/src/metatrain/pet/tests/test_long_range.py index 712324624..de1a9f821 100644 --- a/src/metatrain/pet/tests/test_long_range.py +++ b/src/metatrain/pet/tests/test_long_range.py @@ -6,7 +6,7 @@ import copy import torch -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf from metatrain.pet import PET, Trainer diff --git a/src/metatrain/pet/tests/test_pet_compatibility.py b/src/metatrain/pet/tests/test_pet_compatibility.py index 1f9c1d593..eeba1be89 100644 --- a/src/metatrain/pet/tests/test_pet_compatibility.py +++ b/src/metatrain/pet/tests/test_pet_compatibility.py @@ -8,7 +8,7 @@ import metatensor.torch import torch -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from metatrain.deprecated.pet import PET from metatrain.deprecated.pet.modules.hypers import Hypers diff --git a/src/metatrain/pet/tests/test_regression.py b/src/metatrain/pet/tests/test_regression.py index 837388730..f16740db4 100644 --- a/src/metatrain/pet/tests/test_regression.py +++ b/src/metatrain/pet/tests/test_regression.py @@ -2,7 +2,7 @@ import numpy as np import torch -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from omegaconf import OmegaConf from metatrain.pet import PET, Trainer diff --git a/src/metatrain/pet/tests/test_torchscript.py b/src/metatrain/pet/tests/test_torchscript.py index 5f7cbe382..ab17fdb8d 100644 --- a/src/metatrain/pet/tests/test_torchscript.py +++ b/src/metatrain/pet/tests/test_torchscript.py @@ -1,7 +1,7 @@ import copy import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.pet import PET from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index f6e97a14b..79cbe0fad 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -39,6 +39,8 @@ from .model import PET from .modules.finetuning import apply_finetuning_strategy +# CHANGE: IMPORT THE loss function FROM DOSutils +from .DOSutils import get_dynamic_shift_agnostic_mse def get_scheduler(optimizer, train_hypers): def func_lr_scheduler(epoch): @@ -138,22 +140,49 @@ def train( for additive_model in model.additive_models: additive_model.to(dtype=torch.float64) - logging.info("Calculating composition weights") - model.additive_models[0].train_model( # this is the composition model - train_datasets, - model.additive_models[1:], - self.hypers["fixed_composition_weights"], - ) + # CHANGE: Do not use additive and scaler models + # logging.info("Calculating composition weights") + # model.additive_models[0].train_model( # this is the composition model + # train_datasets, + # model.additive_models[1:], + # self.hypers["fixed_composition_weights"], + # ) - if self.hypers["scale_targets"]: - logging.info("Calculating scaling weights") - model.scaler.train_model( - train_datasets, model.additive_models, treat_as_additive=True - ) + # if self.hypers["scale_targets"]: + # logging.info("Calculating scaling weights") + # model.scaler.train_model( + # train_datasets, model.additive_models, treat_as_additive=True + # ) if is_distributed: model = DistributedDataParallel(model, device_ids=[device]) + # CHANGE: Include a permanent set that is included in every training batch + if self.hypers['use_permanent']: + n_samples = self.hypers['use_permanent']['n_samples'] + n_train = len(train_datasets[0]) + train_systems = [] + permanent_systems = [] + train_targets = {} + permanent_targets = {} + keys = ["mtt::dos", "mtt::mask"] # WARNING: Keys are hardcoded for DOS so that the mask remains identifiable + for key in keys: + train_targets[key] = [] + permanent_targets[key] = [] + for i in range(0, n_train, 1): + data_i = train_datasets[0][i] + if i < (n_train - n_samples): + train_systems.append(data_i.system) + for key in keys: + + train_targets[key].append(data_i[key]) + else: + permanent_systems.append(data_i.system) + for key in keys: + permanent_targets[key].append(data_i[key]) + train_datasets = [Dataset.from_dict({"system": train_systems, **train_targets})] + permanent_datasets = [Dataset.from_dict({"system": permanent_systems, **permanent_targets})] + logging.info("Setting up data loaders") if is_distributed: @@ -177,9 +206,14 @@ def train( ) for val_dataset in val_datasets ] + # CHANGE: Include the permanent dataset but that should not be distributed + if self.hypers['use_permanent']: # CHANGE: Include the permanent dataset + permanent_samplers = [None] * len(permanent_datasets) else: train_samplers = [None] * len(train_datasets) val_samplers = [None] * len(val_datasets) + if self.hypers['use_permanent']: # CHANGE: Include the permanent dataset + permanent_samplers = [None] * len(permanent_datasets) # Create dataloader for the training datasets: train_dataloaders = [] @@ -215,7 +249,28 @@ def train( ) val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + # CHANGE: Include the permanent dataset + if self.hypers['use_permanent']: + permanent_dataloaders = [] + for dataset, sampler in zip(permanent_datasets, permanent_samplers): + permanent_dataloaders.append( + DataLoader( + dataset=dataset, + batch_size=self.hypers["batch_size"], + sampler=sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + ) + permanent_dataloader = CombinedDataLoader(permanent_dataloaders, shuffle=False) + logging.info("Setting up Permanent Dataset") + train_targets = (model.module if is_distributed else model).dataset_info.targets + try: + del train_targets["mtt::mask"] # CHANGE: Removing the mask from the targets, the mask is not a target for the model to predict + except: + pass outputs_list = [] for target_name, target_info in train_targets.items(): outputs_list.append(target_name) @@ -280,6 +335,9 @@ def train( rotational_augmenter = RotationalAugmenter(train_targets) start_epoch = 0 if self.epoch is None else self.epoch + 1 + # CHANGE: Define the coefficients for the finite diference scheme + interval = 0.05 + t4 = (torch.tensor([1/4, -4/3, 3., -4. , 25/12]).to(device)/interval).unsqueeze(dim = (0)).unsqueeze(dim = (0)).float() # Train the model: if self.best_metric is None: @@ -289,16 +347,19 @@ def train( for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): if is_distributed: - sampler.set_epoch(epoch) - train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) - val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) - if self.hypers["log_mae"]: - train_mae_calculator = MAEAccumulator( - self.hypers["log_separate_blocks"] - ) - val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + # CHANGE: Not using the default Accumulators because they do not support custom loss functions yet + # train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + # val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + # if self.hypers["log_mae"]: + # train_mae_calculator = MAEAccumulator( + # self.hypers["log_separate_blocks"] + # ) + # val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) train_loss = 0.0 + train_count = 0.0 # CHANGE: Added to count the number of training samples for batch in train_dataloader: optimizer.zero_grad() @@ -309,20 +370,24 @@ def train( systems, targets = systems_and_targets_to_device( systems, targets, device ) - for additive_model in ( - model.module if is_distributed else model - ).additive_models: - targets = remove_additive( - systems, targets, additive_model, train_targets - ) - targets = remove_scale( - targets, (model.module if is_distributed else model).scaler - ) + # CHANGE: Remove additive and scaler models + # for additive_model in ( + # model.module if is_distributed else model + # ).additive_models: + # targets = remove_additive( + # systems, targets, additive_model, train_targets + # ) + # targets = remove_scale( + # targets, (model.module if is_distributed else model).scaler + # ) systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + # CHANGE: Extract relevant quantities from the targets + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] predictions = evaluate_model( model, systems, - {key: train_targets[key] for key in targets.keys()}, + # {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, # CHANGE: Use the train_targets keys instead as mask is not a target is_training=True, ) @@ -330,58 +395,171 @@ def train( predictions = average_by_num_atoms( predictions, systems, per_structure_targets ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - train_loss_batch = loss_fn(predictions, targets) - train_loss_batch.backward() + # targets = average_by_num_atoms(targets, systems, per_structure_targets) # By default the targets are already averaged by the number of atoms + + # CHANGE: DOS Training loop + dos_predictions = predictions['mtt::dos'][0].values + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + extra_targets = int(dos_predictions.shape[1] - dos_target.shape[1]) # The DOS predictions are longer than the targets, we need to align them + # Calculate DOS loss using dynamic shift agnostic MSE + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradient = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradient.shape[1] # Dimensions lost due to the gradient convolution + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_predictions = [] + adjusted_dos_mask = [] + # external_gradients_loss = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask) + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # We also compute the loss on the cumulative integral of the DOS, it improves the reliability of the fermi level of the final predicted DOS + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + # Keep a count of the number of training samples to calculate the MSE accurately + train_count += len(dos_target) + # Calculate the gradient loss, useful for making sure that the behaviour of the DOS outside the window is reasonable + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradient * (~adjusted_dos_mask[:, dim_loss:]))**2), # non-zero gradients outside the window are penalized + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + total_loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), self.hypers["grad_clip_norm"] ) + # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + # train_loss_batch = loss_fn(predictions, targets) + # train_loss_batch.backward() + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), self.hypers["grad_clip_norm"] + # ) optimizer.step() - + total_loss = (total_loss * len(dos_target)).detach() # CHANGE: We need to multiply the loss by the number of samples in the batch to get the correct loss value if is_distributed: # sum the loss over all processes - torch.distributed.all_reduce(train_loss_batch) - train_loss += train_loss_batch.item() - - train_rmse_calculator.update(predictions, targets) - if self.hypers["log_mae"]: - train_mae_calculator.update(predictions, targets) - finalized_train_info = train_rmse_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, - ) - if self.hypers["log_mae"]: - finalized_train_info.update( - train_mae_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, - ) - ) - + torch.distributed.all_reduce(total_loss) + train_loss += total_loss.item() + # CHANGE: Do not use the default calculators + # train_rmse_calculator.update(predictions, targets) + # if self.hypers["log_mae"]: + # train_mae_calculator.update(predictions, targets) + # finalized_train_info = train_rmse_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # if self.hypers["log_mae"]: + # finalized_train_info.update( + # train_mae_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # ) + # CHANGE: Use permanent dataset + if self.hypers['use_permanent']: + for batch in permanent_dataloader: + optimizer.zero_grad() + systems, targets = batch + systems, targets = rotational_augmenter.apply_random_augmentations( + systems, targets + ) + systems, targets = systems_and_targets_to_device( + systems, targets, device + ) + systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] + predictions = evaluate_model( + model, + systems, + # {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, # CHANGE: Use the train_targets keys instead as mask is not a target + is_training=True, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + dos_predictions = predictions['mtt::dos'][0].values + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradients = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradients.shape[1] # Dimensions lost due to the gradient convolution + aligned_predictions = [] + # external_gradients_loss = [] + adjusted_dos_mask = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + + dos_mask_i = torch.hstack( + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask).bool() + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # Cumulative integral loss + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + # Gradient loss + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradients * (~adjusted_dos_mask[:, dim_loss:]))**2), + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + total_loss.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + # Should always be not distributed so we should not need to accumulate + train_loss /= train_count val_loss = 0.0 + val_count = 0.0 # CHANGE: Added to count the number of validation samples + # val_predictions = [] for batch in val_dataloader: + # CHANGE: Updated validation loop to use the new loss function systems, targets = batch - systems = [system.to(device=device) for system in systems] - targets = { - key: value.to(device=device) for key, value in targets.items() - } - for additive_model in ( - model.module if is_distributed else model - ).additive_models: - targets = remove_additive( - systems, targets, additive_model, train_targets - ) - targets = remove_scale( - targets, (model.module if is_distributed else model).scaler + systems, targets = systems_and_targets_to_device( + systems, targets, device ) - systems = [system.to(dtype=dtype) for system in systems] - targets = {key: value.to(dtype=dtype) for key, value in targets.items()} + systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] predictions = evaluate_model( model, systems, - {key: train_targets[key] for key in targets.keys()}, +# {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, is_training=False, ) @@ -389,43 +567,96 @@ def train( predictions = average_by_num_atoms( predictions, systems, per_structure_targets ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - - val_loss_batch = loss_fn(predictions, targets) + # targets = average_by_num_atoms(targets, systems, per_structure_targets) + + dos_predictions = predictions['mtt::dos'][0].values + # val_predictions.append(dos_predictions.detach()) + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradients = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradients.shape[1] # Dimensions lost due to the gradient convolution + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_predictions = [] + # external_gradients_loss = [] + adjusted_dos_mask = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask).bool() + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # Cumulative integral loss + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + val_count += len(dos_target) + # Gradient loss + # gradient_losses = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_predictions.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradients * (~adjusted_dos_mask[:, dim_loss:]))**2), + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + + + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # gradient_loss = torch.mean(torch.trapezoid(((gradient_losses * (~dos_mask[:, dim_loss:]))**2), + # dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + val_loss_batch = (total_loss * len(dos_target)).detach()# CHANGE: We need to multiply the loss by the number of samples in the batch to get the correct loss value if is_distributed: # sum the loss over all processes torch.distributed.all_reduce(val_loss_batch) val_loss += val_loss_batch.item() - val_rmse_calculator.update(predictions, targets) - if self.hypers["log_mae"]: - val_mae_calculator.update(predictions, targets) - - finalized_val_info = val_rmse_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, - ) - if self.hypers["log_mae"]: - finalized_val_info.update( - val_mae_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, - ) - ) + val_loss /= val_count + # val_predictions = torch.vstack(val_predictions) + # CHANGE: Not using the default calculators + # val_rmse_calculator.update(predictions, targets) + # if self.hypers["log_mae"]: + # val_mae_calculator.update(predictions, targets) + # CHANGE: Not using the default calculators + # finalized_val_info = val_rmse_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # if self.hypers["log_mae"]: + # finalized_val_info.update( + # val_mae_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # ) # Now we log the information: - finalized_train_info = {"loss": train_loss, **finalized_train_info} + # Change: Remove information other than the loss + finalized_train_info = {"loss": train_loss} # , **finalized_train_info} finalized_val_info = { - "loss": val_loss, - **finalized_val_info, + "loss": val_loss, } + # **finalized_val_info, + # } if epoch == start_epoch: - scaler_scales = ( - model.module if is_distributed else model - ).scaler.get_scales_dict() + # CHANGE: Remove Scaler + # scaler_scales = ( + # model.module if is_distributed else model + # ).scaler.get_scales_dict() metric_logger = MetricLogger( log_obj=ROOT_LOGGER, dataset_info=( @@ -433,14 +664,14 @@ def train( ).dataset_info, initial_metrics=[finalized_train_info, finalized_val_info], names=["training", "validation"], - scales={ - key: ( - scaler_scales[key.split(" ")[0]] - if ("MAE" in key or "RMSE" in key) - else 1.0 - ) - for key in finalized_train_info.keys() - }, + # scales={ + # key: ( + # scaler_scales[key.split(" ")[0]] + # if ("MAE" in key or "RMSE" in key) + # else 1.0 + # ) + # for key in finalized_train_info.keys() + # }, ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( @@ -469,9 +700,11 @@ def train( pass # we don't clutter the log at every warm-up step old_lr = new_lr - val_metric = get_selected_metric( - finalized_val_info, self.hypers["best_model_metric"] - ) + # val_metric = get_selected_metric( + # finalized_val_info, self.hypers["best_model_metric"] + # ) + val_metric = val_loss # CHANGE: Use the validation loss as the metric + logging.info("Current Best Validation Metric: %s", self.best_metric) if val_metric < self.best_metric: self.best_metric = val_metric self.best_model_state_dict = copy.deepcopy( @@ -490,6 +723,7 @@ def train( (model.module if is_distributed else model), Path(checkpoint_dir) / f"model_{epoch}.ckpt", ) + # torch.save(val_predictions, Path(checkpoint_dir) / f"val_predictions_{epoch}.pt") # CHANGE: Save the validation predictions # prepare for the checkpoint that will be saved outside the function self.epoch = epoch diff --git a/src/metatrain/pet/trainer_mod.py b/src/metatrain/pet/trainer_mod.py new file mode 100644 index 000000000..929dea79f --- /dev/null +++ b/src/metatrain/pet/trainer_mod.py @@ -0,0 +1,994 @@ +import copy +import logging +from pathlib import Path +from typing import Any, Dict, List, Literal, Union + +import torch +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.additive import remove_additive +from metatrain.utils.augmentation import RotationalAugmenter +from metatrain.utils.data import ( + CombinedDataLoader, + Dataset, + _is_disk_dataset, + collate_fn, +) +from metatrain.utils.distributed.distributed_data_parallel import ( + DistributedDataParallel, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.evaluate_model import evaluate_model +from metatrain.utils.external_naming import to_external_name +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import TensorMapDictLoss +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import remove_scale +from metatrain.utils.transfer import ( + systems_and_targets_to_device, + systems_and_targets_to_dtype, +) + +from .model import PET +from .modules.finetuning import apply_finetuning_strategy + +# CHANGE: IMPORT THE loss function FROM DOSutils +from .DOSutils import get_dynamic_shift_agnostic_mse + +def get_scheduler(optimizer, train_hypers): + def func_lr_scheduler(epoch): + if epoch < train_hypers["num_epochs_warmup"]: + return epoch / train_hypers["num_epochs_warmup"] + delta = epoch - train_hypers["num_epochs_warmup"] + num_blocks = delta // train_hypers["scheduler_patience"] + return 0.5 ** (num_blocks) + + scheduler = LambdaLR(optimizer, func_lr_scheduler) + return scheduler + + +class Trainer: + def __init__(self, train_hypers): + self.hypers = train_hypers + self.optimizer_state_dict = None + self.scheduler_state_dict = None + self.epoch = None + self.best_metric = None + self.best_model_state_dict = None + self.best_optimizer_state_dict = None + + def train( + self, + model: PET, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ): + assert dtype in PET.__supported_dtypes__ + is_distributed = self.hypers["distributed"] + if is_distributed: + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + torch.distributed.init_process_group(backend="nccl") + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + " If you want to run distributed training with SOAP-BPNN, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + else: + device = devices[ + 0 + ] # only one device, as we don't support multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Calculate the neighbor lists in advance (in particular, this + # needs to happen before the additive models are trained, as they + # might need them): + logging.info("Calculating neighbor lists for the datasets") + requested_neighbor_lists = get_requested_neighbor_lists(model) + for dataset in train_datasets + val_datasets: + # If the dataset is a disk dataset, the NLs are already attached, we will + # just check the first system + if _is_disk_dataset(dataset): + system = dataset[0]["system"] + for options in requested_neighbor_lists: + if options not in system.known_neighbor_lists(): + raise ValueError( + "The requested neighbor lists are not attached to the " + f"system. Neighbor list {options} is missing from the " + "first system in the disk dataset. Make sure you save " + "the neighbor lists in the systems when saving the dataset." + ) + else: + for sample in dataset: + system = sample["system"] + # The following line attaches the neighbors lists to the system, + # and doesn't require to reassign the system to the dataset: + get_system_with_neighbor_lists(system, requested_neighbor_lists) + + # Apply fine-tuning strategy if provided + if self.hypers["finetune"]: + model = apply_finetuning_strategy(model, self.hypers["finetune"]) + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of the SOAP-BPNN are always in float64 (to avoid + # numerical errors in the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + + # CHANGE: Do not use additive and scaler models + # logging.info("Calculating composition weights") + # model.additive_models[0].train_model( # this is the composition model + # train_datasets, + # model.additive_models[1:], + # self.hypers["fixed_composition_weights"], + # ) + + # if self.hypers["scale_targets"]: + # logging.info("Calculating scaling weights") + # model.scaler.train_model( + # train_datasets, model.additive_models, treat_as_additive=True + # ) + + if is_distributed: + model = DistributedDataParallel(model, device_ids=[device]) + + # CHANGE: Include a permanent set that is included in every training batch + if self.hypers['use_permanent']: + n_samples = self.hypers['use_permanent']['n_samples'] + n_train = len(train_datasets[0]) + train_systems = [] + permanent_systems = [] + train_targets = {} + permanent_targets = {} + keys = ["mtt::dos", "mtt::mask"] # WARNING: Keys are hardcoded for DOS so that the mask remains identifiable + for key in keys: + train_targets[key] = [] + permanent_targets[key] = [] + for i in range(0, n_train, 1): + data_i = train_datasets[0][i] + if i < (n_train - n_samples): + train_systems.append(data_i.system) + for key in keys: + + train_targets[key].append(data_i[key]) + else: + permanent_systems.append(data_i.system) + for key in keys: + permanent_targets[key].append(data_i[key]) + train_datasets = [Dataset.from_dict({"system": train_systems, **train_targets})] + permanent_datasets = [Dataset.from_dict({"system": permanent_systems, **permanent_targets})] + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + # CHANGE: Include the permanent dataset but that should not be distributed + if self.hypers['use_permanent']: # CHANGE: Include the permanent dataset + permanent_samplers = [None] * len(permanent_datasets) + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + if self.hypers['use_permanent']: # CHANGE: Include the permanent dataset + permanent_samplers = [None] * len(permanent_datasets) + + # Create dataloader for the training datasets: + train_dataloaders = [] + for dataset, sampler in zip(train_datasets, train_samplers): + train_dataloaders.append( + DataLoader( + dataset=dataset, + batch_size=self.hypers["batch_size"], + sampler=sampler, + shuffle=( + sampler is None + ), # the sampler takes care of this (if present) + drop_last=( + sampler is None + ), # the sampler takes care of this (if present) + collate_fn=collate_fn, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for dataset, sampler in zip(val_datasets, val_samplers): + val_dataloaders.append( + DataLoader( + dataset=dataset, + batch_size=self.hypers["batch_size"], + sampler=sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + # CHANGE: Include the permanent dataset + if self.hypers['use_permanent']: + permanent_dataloaders = [] + for dataset, sampler in zip(permanent_datasets, permanent_samplers): + permanent_dataloaders.append( + DataLoader( + dataset=dataset, + batch_size=self.hypers["batch_size"], + sampler=sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + ) + permanent_dataloader = CombinedDataLoader(permanent_dataloaders, shuffle=False) + logging.info("Setting up Permanent Dataset") + + train_targets = (model.module if is_distributed else model).dataset_info.targets + try: + del train_targets["mtt::mask"] # CHANGE: Removing the mask from the targets, the mask is not a target for the model to predict + except: + pass + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss weight dict: + loss_weights_dict = {} + for output_name in outputs_list: + loss_weights_dict[output_name] = ( + self.hypers["loss"]["weights"][ + to_external_name(output_name, train_targets) + ] + if to_external_name(output_name, train_targets) + in self.hypers["loss"]["weights"] + else 1.0 + ) + loss_weights_dict_external = { + to_external_name(key, train_targets): value + for key, value in loss_weights_dict.items() + } + loss_hypers = copy.deepcopy(self.hypers["loss"]) + loss_hypers["weights"] = loss_weights_dict + logging.info(f"Training with loss weights: {loss_weights_dict_external}") + + # Create a loss function: + loss_fn = TensorMapDictLoss( + **loss_hypers, + ) + + if self.hypers["weight_decay"] is not None: + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.hypers["learning_rate"], + weight_decay=self.hypers["weight_decay"], + ) + else: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + + if self.optimizer_state_dict is not None and not self.hypers["finetune"]: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not (model.module if is_distributed else model).has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + lr_scheduler = get_scheduler(optimizer, self.hypers) + + if self.scheduler_state_dict is not None and not self.hypers["finetune"]: + # same as the optimizer, try to load the scheduler state dict + if not (model.module if is_distributed else model).has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logging.info(f"Base learning rate: {self.hypers['learning_rate']}") + logging.info(f"Initial learning rate: {old_lr}") + + rotational_augmenter = RotationalAugmenter(train_targets) + + start_epoch = 0 if self.epoch is None else self.epoch + 1 + # CHANGE: Define the coefficients for the finite diference scheme + interval = 0.05 + t4 = (torch.tensor([1/4, -4/3, 3., -4. , 25/12]).to(device)/interval).unsqueeze(dim = (0)).unsqueeze(dim = (0)).float() + + # Train the model: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Starting training") + epoch = start_epoch + + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + # CHANGE: Not using the default Accumulators because they do not support custom loss functions yet + # train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + # val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + # if self.hypers["log_mae"]: + # train_mae_calculator = MAEAccumulator( + # self.hypers["log_separate_blocks"] + # ) + # val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + train_count = 0.0 # CHANGE: Added to count the number of training samples + for batch in train_dataloader: + optimizer.zero_grad() + + systems, targets = batch + systems, targets = rotational_augmenter.apply_random_augmentations( + systems, targets + ) + systems, targets = systems_and_targets_to_device( + systems, targets, device + ) + # CHANGE: Remove additive and scaler models + # for additive_model in ( + # model.module if is_distributed else model + # ).additive_models: + # targets = remove_additive( + # systems, targets, additive_model, train_targets + # ) + # targets = remove_scale( + # targets, (model.module if is_distributed else model).scaler + # ) + systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + # CHANGE: Extract relevant quantities from the targets + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] + predictions = evaluate_model( + model, + systems, + # {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, # CHANGE: Use the train_targets keys instead as mask is not a target + is_training=True, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + # targets = average_by_num_atoms(targets, systems, per_structure_targets) # By default the targets are already averaged by the number of atoms + + # CHANGE: DOS Training loop + + ## LLPR CALIB TRAINING LOSS + if self.hypers["llpr_calib"]: + + dos_pred_mean = predictions['mtt::aux::dos::ensemble'][0].values.mean(axis=2) + dos_pred_var = predictions['mtt::aux::dos::ensemble'][0].values.var(axis=2) + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + + extra_targets = int(dos_pred_mean.shape[1] - dos_target.shape[1]) # The DOS predictions are longer than the targets, we need to align them + # Obtain shifts from existing support function + _, discrete_shift = get_dynamic_shift_agnostic_mse(dos_pred_mean, dos_target, dos_mask, return_shift = True) + + # GRADIENT LOSS STUFF + # full_gradient = torch.nn.functional.conv1d(dos_pred_mean.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_predictions.shape[1] - full_gradient.shape[1] # Dimensions lost due to the gradient convolution + + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_pred_mean = [] + aligned_pred_var = [] + adjusted_dos_mask = [] + # external_gradients_loss = [] + for index, prediction in enumerate(dos_pred_mean): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + aligned_variance = dos_pred_var[index][discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_pred_mean.append(aligned_prediction) + aligned_pred_var.append(aligned_variance) + adjusted_dos_mask.append(dos_mask_i) + aligned_pred_mean = torch.vstack(aligned_pred_mean) # check shape + aligned_pred_var = torch.vstack(aligned_pred_var) # check shape + adjusted_dos_mask = torch.vstack(adjusted_dos_mask) # check shape + + dos_loss = GaussianNLLLoss(aligned_pred_mean, dos_target, aligned_pred_var) #input, target, var + + # We also compute the loss on the cumulative integral of the DOS, it improves the reliability of the fermi level of the final predicted DOS + # int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + # int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + # int_error = (int_aligned_predictions - int_aligned_targets)**2 + # int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + # int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + + # Keep a count of the number of training samples to calculate the MSE accurately + train_count += len(dos_target) + + # Calculate the gradient loss, useful for making sure that the behaviour of the DOS outside the window is reasonable + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + # gradient_loss = torch.mean(torch.trapezoid(((full_gradient * (~adjusted_dos_mask[:, dim_loss:]))**2), # non-zero gradients outside the window are penalized + # dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = dos_loss # + gradient_loss + int_MSE) + total_loss.backward() + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), self.hypers["grad_clip_norm"] + # ) + # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + # train_loss_batch = loss_fn(predictions, targets) + # train_loss_batch.backward() + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), self.hypers["grad_clip_norm"] + # ) + optimizer.step() + total_loss = (total_loss * len(dos_target)).detach() ## check if this is still applicable + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(total_loss) + + train_loss += total_loss.item() + + ## ORIGINAL TRAIN_LOSS + else: + dos_predictions = predictions['mtt::dos'][0].values + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + extra_targets = int(dos_predictions.shape[1] - dos_target.shape[1]) # The DOS predictions are longer than the targets, we need to align them + # Calculate DOS loss using dynamic shift agnostic MSE + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradient = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradient.shape[1] # Dimensions lost due to the gradient convolution + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_predictions = [] + adjusted_dos_mask = [] + # external_gradients_loss = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask) + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # We also compute the loss on the cumulative integral of the DOS, it improves the reliability of the fermi level of the final predicted DOS + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + # Keep a count of the number of training samples to calculate the MSE accurately + train_count += len(dos_target) + # Calculate the gradient loss, useful for making sure that the behaviour of the DOS outside the window is reasonable + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradient * (~adjusted_dos_mask[:, dim_loss:]))**2), # non-zero gradients outside the window are penalized + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + total_loss.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + # train_loss_batch = loss_fn(predictions, targets) + # train_loss_batch.backward() + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), self.hypers["grad_clip_norm"] + # ) + optimizer.step() + total_loss = (total_loss * len(dos_target)).detach() # CHANGE: We need to multiply the loss by the number of samples in the batch to get the correct loss value + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(total_loss) + train_loss += total_loss.item() + + + # CHANGE: Do not use the default calculators + # train_rmse_calculator.update(predictions, targets) + # if self.hypers["log_mae"]: + # train_mae_calculator.update(predictions, targets) + + # finalized_train_info = train_rmse_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # if self.hypers["log_mae"]: + # finalized_train_info.update( + # train_mae_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # ) + + # CHANGE: Use permanent dataset + if self.hypers['use_permanent']: + + ## GET PREDS AND TARGETS ON PERMANENT SET + for batch in permanent_dataloader: + optimizer.zero_grad() + systems, targets = batch + systems, targets = rotational_augmenter.apply_random_augmentations( + systems, targets + ) + systems, targets = systems_and_targets_to_device( + systems, targets, device + ) + systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] + predictions = evaluate_model( + model, + systems, + # {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, # CHANGE: Use the train_targets keys instead as mask is not a target + is_training=True, + ) + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + + if self.hypers["llpr_calib"]: + + dos_pred_mean = predictions['mtt::aux::dos::ensemble'][0].values.mean(axis=2) + dos_pred_var = predictions['mtt::aux::dos::ensemble'][0].values.var(axis=2) + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + # Obtain shifts from existing support function + _, discrete_shift = get_dynamic_shift_agnostic_mse(dos_pred_mean, dos_target, dos_mask, return_shift = True) + + # GRADIENT LOSS STUFF + # full_gradient = torch.nn.functional.conv1d(dos_pred_mean.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_predictions.shape[1] - full_gradient.shape[1] # Dimensions lost due to the gradient convolution + + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_pred_mean = [] + aligned_pred_var = [] + adjusted_dos_mask = [] + # external_gradients_loss = [] + for index, prediction in enumerate(dos_pred_mean): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + aligned_variance = dos_pred_var[index][discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_pred_mean.append(aligned_prediction) + aligned_pred_var.append(aligned_variance) + adjusted_dos_mask.append(dos_mask_i) + aligned_pred_mean = torch.vstack(aligned_pred_mean) # check shape + aligned_pred_var = torch.vstack(aligned_pred_var) # check shape + adjusted_dos_mask = torch.vstack(adjusted_dos_mask) # check shape + + dos_loss = GaussianNLLLoss(aligned_pred_mean, dos_target, aligned_pred_var) #input, target, var + + # We also compute the loss on the cumulative integral of the DOS, it improves the reliability of the fermi level of the final predicted DOS + # int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + # int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + # int_error = (int_aligned_predictions - int_aligned_targets)**2 + # int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + # int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + + # Calculate the gradient loss, useful for making sure that the behaviour of the DOS outside the window is reasonable + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + # gradient_loss = torch.mean(torch.trapezoid(((full_gradient * (~adjusted_dos_mask[:, dim_loss:]))**2), # non-zero gradients outside the window are penalized + # dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = dos_loss # + gradient_loss + int_MSE) + total_loss.backward() + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), self.hypers["grad_clip_norm"] + # ) + # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + # train_loss_batch = loss_fn(predictions, targets) + # train_loss_batch.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) ## WHY? + optimizer.step() + + ## ORIGINAL PERMANENT DATASET LOSS + else: + dos_predictions = predictions['mtt::dos'][0].values + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradients = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradients.shape[1] # Dimensions lost due to the gradient convolution + aligned_predictions = [] + # external_gradients_loss = [] + adjusted_dos_mask = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + + dos_mask_i = torch.hstack( + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask).bool() + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # Cumulative integral loss + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + # Gradient loss + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_mask.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradients * (~adjusted_dos_mask[:, dim_loss:]))**2), + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + total_loss.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["grad_clip_norm"] + ) + optimizer.step() + # Should always be not distributed so we should not need to accumulate + + + train_loss /= train_count + + + + ### NOW VALIDATION SET ### + + val_loss = 0.0 + val_count = 0.0 # CHANGE: Added to count the number of validation samples + # val_predictions = [] + for batch in val_dataloader: + # CHANGE: Updated validation loop to use the new loss function + systems, targets = batch + systems, targets = systems_and_targets_to_device( + systems, targets, device + ) + systems, targets = systems_and_targets_to_dtype(systems, targets, dtype) + target_dos_batch, mask_batch = targets['mtt::dos'], targets['mtt::mask'] + predictions = evaluate_model( + model, + systems, +# {key: train_targets[key] for key in targets.keys()}, + {key: train_targets[key] for key in train_targets.keys()}, + is_training=False, + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + # targets = average_by_num_atoms(targets, systems, per_structure_targets) + + if self.hypers["llpr_calib"]: + + dos_pred_mean = predictions['mtt::aux::dos::ensemble'][0].values.mean(axis=2) + dos_pred_var = predictions['mtt::aux::dos::ensemble'][0].values.var(axis=2) + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + + # Obtain shifts from existing support function + _, discrete_shift = get_dynamic_shift_agnostic_mse(dos_pred_mean, dos_target, dos_mask, return_shift = True) + + # GRADIENT LOSS STUFF + # full_gradient = torch.nn.functional.conv1d(dos_pred_mean.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_predictions.shape[1] - full_gradient.shape[1] # Dimensions lost due to the gradient convolution + + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_pred_mean = [] + aligned_pred_var = [] + adjusted_dos_mask = [] + # external_gradients_loss = [] + for index, prediction in enumerate(dos_pred_mean): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + aligned_variance = dos_pred_var[index][discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_pred_mean.append(aligned_prediction) + aligned_pred_var.append(aligned_variance) + adjusted_dos_mask.append(dos_mask_i) + aligned_pred_mean = torch.vstack(aligned_pred_mean) # check shape + aligned_pred_var = torch.vstack(aligned_pred_var) # check shape + adjusted_dos_mask = torch.vstack(adjusted_dos_mask) # check shape + + dos_loss = GaussianNLLLoss(aligned_pred_mean, dos_target, aligned_pred_var) #input, target, var + + total_loss = dos_loss + val_loss_batch = (total_loss * len(dos_target)).detach()# CHANGE: We need to multiply the loss by the number of samples in the batch to get the correct loss value + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + + val_loss += val_loss_batch.item() + + + else: + dos_predictions = predictions['mtt::dos'][0].values + # val_predictions.append(dos_predictions.detach()) + dos_target = target_dos_batch[0].values + dos_mask = (mask_batch[0].values).bool() + dos_loss, discrete_shift = get_dynamic_shift_agnostic_mse(dos_predictions, dos_target, dos_mask, return_shift = True) + full_gradients = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + dim_loss = dos_predictions.shape[1] - full_gradients.shape[1] # Dimensions lost due to the gradient convolution + # Obtain aligned targets (The subset of the predictions that corresponds best to the targets) + aligned_predictions = [] + # external_gradients_loss = [] + adjusted_dos_mask = [] + for index, prediction in enumerate(dos_predictions): + aligned_prediction = prediction[discrete_shift[index]:discrete_shift[index] + dos_mask.shape[1]] + # external_gradients_i = external_gradient[index][(discrete_shift[index] + dos_mask.shape[1] - dim_loss):] + # external_gradient_loss_i = torch.trapezoid(external_gradients_i**2, dx = 0.05) * self.hypers['gradient_penalty'] + # external_gradients_loss.append(external_gradient_loss_i) + + dos_mask_i = torch.hstack( #Adjust the mask to account for the discrete shift + [ + (torch.ones(discrete_shift[index])).bool().to(device), + dos_mask[index], + (torch.zeros(int(extra_targets - discrete_shift[index]))).bool().to(device) + ] + ) + aligned_predictions.append(aligned_prediction) + adjusted_dos_mask.append(dos_mask_i) + aligned_predictions = torch.vstack(aligned_predictions) + adjusted_dos_mask = torch.vstack(adjusted_dos_mask).bool() + # mean_external_gradient_loss = torch.mean(torch.tensor(external_gradients_loss)) + # Cumulative integral loss + int_aligned_predictions = torch.cumulative_trapezoid(aligned_predictions, dx = 0.05, dim = 1) + int_aligned_targets = torch.cumulative_trapezoid(dos_target, dx = 0.05, dim = 1) + int_error = (int_aligned_predictions - int_aligned_targets)**2 + int_error = int_error * dos_mask[:,1:].unsqueeze(dim=1) # only penalize the integral where the DOS is defined + int_MSE = torch.mean(torch.trapezoid(int_error, dx = 0.05, dim = 1)) * self.hypers['integral_penalty'] + val_count += len(dos_target) + # Gradient loss + # gradient_losses = torch.nn.functional.conv1d(dos_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # dim_loss = dos_predictions.shape[1] - gradient_losses.shape[1] + gradient_loss = torch.mean(torch.trapezoid(((full_gradients * (~adjusted_dos_mask[:, dim_loss:]))**2), + dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + + + # gradient_losses = torch.nn.functional.conv1d(aligned_predictions.unsqueeze(dim = 1), t4).squeeze(dim = 1) + # gradient_loss = torch.mean(torch.trapezoid(((gradient_losses * (~dos_mask[:, dim_loss:]))**2), + # dx = 0.05, dim = 1)) * self.hypers['gradient_penalty'] + total_loss = (dos_loss + gradient_loss + int_MSE) + val_loss_batch = (total_loss * len(dos_target)).detach()# CHANGE: We need to multiply the loss by the number of samples in the batch to get the correct loss value + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + + val_loss /= val_count + # val_predictions = torch.vstack(val_predictions) + # CHANGE: Not using the default calculators + # val_rmse_calculator.update(predictions, targets) + # if self.hypers["log_mae"]: + # val_mae_calculator.update(predictions, targets) + # CHANGE: Not using the default calculators + # finalized_val_info = val_rmse_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # if self.hypers["log_mae"]: + # finalized_val_info.update( + # val_mae_calculator.finalize( + # not_per_atom=["positions_gradients"] + per_structure_targets, + # is_distributed=is_distributed, + # device=device, + # ) + # ) + + # Now we log the information: + # Change: Remove information other than the loss + finalized_train_info = {"loss": train_loss} # , **finalized_train_info} + finalized_val_info = { + "loss": val_loss, + } + # **finalized_val_info, + # } + + if epoch == start_epoch: + # CHANGE: Remove Scaler + # scaler_scales = ( + # model.module if is_distributed else model + # ).scaler.get_scales_dict() + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=( + model.module if is_distributed else model + ).dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + # scales={ + # key: ( + # scaler_scales[key.split(" ")[0]] + # if ("MAE" in key or "RMSE" in key) + # else 1.0 + # ) + # for key in finalized_train_info.keys() + # }, + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + ) + + lr_scheduler.step() + new_lr = lr_scheduler.get_last_lr()[0] + if new_lr != old_lr: + if new_lr < 1e-7: + logging.info("Learning rate is too small, stopping training") + break + else: + if epoch >= self.hypers["num_epochs_warmup"]: + logging.info( + f"Changing learning rate from {old_lr} to {new_lr}" + ) + elif epoch == self.hypers["num_epochs_warmup"] - 1: + logging.info( + "Finished warm-up. " + f"Now training with learning rate {new_lr}" + ) + else: # epoch < self.hypers["num_epochs_warmup"] - 1: + pass # we don't clutter the log at every warm-up step + old_lr = new_lr + + # val_metric = get_selected_metric( + # finalized_val_info, self.hypers["best_model_metric"] + # ) + val_metric = val_loss # CHANGE: Use the validation loss as the metric + logging.info("Current Best Validation Metric: %s", self.best_metric) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy( + (model.module if is_distributed else model).state_dict() + ) + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + # torch.save(val_predictions, Path(checkpoint_dir) / f"val_predictions_{epoch}.pt") # CHANGE: Save the validation predictions + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + def save_checkpoint(self, model, path: Union[str, Path]): + checkpoint = { + "architecture_name": "pet", + "model_data": { + "model_hypers": model.hypers, + "dataset_info": model.dataset_info, + }, + "model_state_dict": model.state_dict(), + "train_hypers": self.hypers, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], # not used at the moment + train_hypers: Dict[str, Any], + ) -> "Trainer": + epoch = checkpoint["epoch"] + optimizer_state_dict = checkpoint["optimizer_state_dict"] + scheduler_state_dict = checkpoint["scheduler_state_dict"] + best_metric = checkpoint["best_metric"] + best_model_state_dict = checkpoint["best_model_state_dict"] + best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + # Create the trainer + trainer = cls(train_hypers) + trainer.optimizer_state_dict = optimizer_state_dict + trainer.scheduler_state_dict = scheduler_state_dict + trainer.epoch = epoch + trainer.best_metric = best_metric + trainer.best_model_state_dict = best_model_state_dict + trainer.best_optimizer_state_dict = best_optimizer_state_dict + + return trainer diff --git a/src/metatrain/soap_bpnn/model.py b/src/metatrain/soap_bpnn/model.py index 4c24d47c9..bd0d568df 100644 --- a/src/metatrain/soap_bpnn/model.py +++ b/src/metatrain/soap_bpnn/model.py @@ -3,8 +3,8 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelCapabilities, ModelMetadata, ModelOutput, @@ -666,7 +666,7 @@ def load_checkpoint( def export( self, metadata: Optional[ModelMetadata] = None - ) -> MetatensorAtomisticModel: + ) -> AtomisticModel: dtype = next(self.parameters()).dtype if dtype not in self.__supported_dtypes__: raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn") @@ -704,7 +704,7 @@ def export( append_metadata_references(metadata, self.__default_metadata__) - return MetatensorAtomisticModel(self.eval(), metadata, capabilities) + return AtomisticModel(self.eval(), metadata, capabilities) def _add_output(self, target_name: str, target: TargetInfo) -> None: # register bases of spherical tensors (TensorBasis) diff --git a/src/metatrain/soap_bpnn/tests/test_equivariance.py b/src/metatrain/soap_bpnn/tests/test_equivariance.py index f1bcedb74..0c0a15003 100644 --- a/src/metatrain/soap_bpnn/tests/test_equivariance.py +++ b/src/metatrain/soap_bpnn/tests/test_equivariance.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from metatensor.torch.atomistic import System, systems_to_torch +from metatomic.torch import System, systems_to_torch from metatrain.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/soap_bpnn/tests/test_exported.py b/src/metatrain/soap_bpnn/tests/test_exported.py index 8366b5614..ab46f9554 100644 --- a/src/metatrain/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/soap_bpnn/tests/test_exported.py @@ -1,6 +1,6 @@ import pytest import torch -from metatensor.torch.atomistic import ModelEvaluationOptions, ModelMetadata, System +from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System from metatrain.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/soap_bpnn/tests/test_functionality.py b/src/metatrain/soap_bpnn/tests/test_functionality.py index 464c88415..cfe6138e1 100644 --- a/src/metatrain/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/soap_bpnn/tests/test_functionality.py @@ -4,7 +4,7 @@ import pytest import torch from jsonschema.exceptions import ValidationError -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf from metatrain.soap_bpnn import SoapBpnn diff --git a/src/metatrain/soap_bpnn/tests/test_regression.py b/src/metatrain/soap_bpnn/tests/test_regression.py index 4c07d7356..84a95d3ae 100644 --- a/src/metatrain/soap_bpnn/tests/test_regression.py +++ b/src/metatrain/soap_bpnn/tests/test_regression.py @@ -2,7 +2,7 @@ import numpy as np import torch -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from omegaconf import OmegaConf from metatrain.soap_bpnn import SoapBpnn, Trainer diff --git a/src/metatrain/soap_bpnn/tests/test_torchscript.py b/src/metatrain/soap_bpnn/tests/test_torchscript.py index 26ebb4ebf..c46f80182 100644 --- a/src/metatrain/soap_bpnn/tests/test_torchscript.py +++ b/src/metatrain/soap_bpnn/tests/test_torchscript.py @@ -2,7 +2,7 @@ import pytest import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo diff --git a/src/metatrain/utils/additive/composition.py b/src/metatrain/utils/additive/composition.py index 5e489421e..21ab5eb4c 100644 --- a/src/metatrain/utils/additive/composition.py +++ b/src/metatrain/utils/additive/composition.py @@ -4,7 +4,7 @@ import metatensor.torch import torch from metatensor.torch import Labels, LabelsEntry, TensorBlock, TensorMap -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from ..data import Dataset, DatasetInfo, TargetInfo, get_all_targets, get_atomic_types from ..jsonschema import validate @@ -410,7 +410,7 @@ def forward( # Note: atomic types are not checked. At training time, the composition model # is initialized with the correct types. At inference time, the checks are - # performed by MetatensorAtomisticModel. + # performed by AtomisticModel. # create sample labels sample_values_list = [] diff --git a/src/metatrain/utils/additive/remove.py b/src/metatrain/utils/additive/remove.py index fa82659df..cd5ccfa02 100644 --- a/src/metatrain/utils/additive/remove.py +++ b/src/metatrain/utils/additive/remove.py @@ -4,7 +4,7 @@ import metatensor.torch import torch from metatensor.torch import TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from ..data import TargetInfo from ..evaluate_model import evaluate_model diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index 5638aef1d..ed9a91721 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -5,7 +5,7 @@ import torch from ase.data import covalent_radii from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System +from metatomic.torch import ModelOutput, NeighborListOptions, System from ..data import DatasetInfo, TargetInfo from ..jsonschema import validate diff --git a/src/metatrain/utils/augmentation.py b/src/metatrain/utils/augmentation.py index c8c6f856d..c6c84380a 100644 --- a/src/metatrain/utils/augmentation.py +++ b/src/metatrain/utils/augmentation.py @@ -4,7 +4,7 @@ import numpy as np import torch from metatensor.torch import TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from scipy.spatial.transform import Rotation from .data import TargetInfo diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 35307a3d6..513eed587 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -11,8 +11,8 @@ from metatensor.learn.data._namedtuple import namedtuple from metatensor.torch import TensorMap, load_buffer from metatensor.torch import save_buffer as mts_save_buffer -from metatensor.torch.atomistic import System, load_system -from metatensor.torch.atomistic import save as mta_save +from metatomic.torch import System, load_system +from metatomic.torch import save as mta_save from omegaconf import DictConfig from torch.utils.data import Subset @@ -360,7 +360,7 @@ class DiskDataset(torch.utils.data.Dataset): The dataset is stored in a zip file, where each sample is stored in a separate directory. The directory's name is the index of the sample (e.g. ``0/``), and the files in the directory are the system (``system.mta``) and the targets - (each named ``.mts``). These are ``metatensor.torch.atomistic.System`` + (each named ``.mts``). These are ``metatomic.torch.System`` and ``metatensor.torch.TensorMap`` objects, respectively. Such a dataset can be created conveniently using the :py:class:`DiskDatasetWriter` diff --git a/src/metatrain/utils/data/readers/ase.py b/src/metatrain/utils/data/readers/ase.py index 800f63130..805e61dcf 100644 --- a/src/metatrain/utils/data/readers/ase.py +++ b/src/metatrain/utils/data/readers/ase.py @@ -7,7 +7,7 @@ import torch from ase.stress import voigt_6_to_full_3x3_stress from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import System, systems_to_torch +from metatomic.torch import System, systems_to_torch from omegaconf import DictConfig from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info diff --git a/src/metatrain/utils/data/readers/metatensor.py b/src/metatrain/utils/data/readers/metatensor.py index d2e4f0556..43750be1f 100644 --- a/src/metatrain/utils/data/readers/metatensor.py +++ b/src/metatrain/utils/data/readers/metatensor.py @@ -3,7 +3,7 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from omegaconf import DictConfig from ..target_info import TargetInfo, get_energy_target_info, get_generic_target_info diff --git a/src/metatrain/utils/data/readers/readers.py b/src/metatrain/utils/data/readers/readers.py index 5a511c242..2671594f0 100644 --- a/src/metatrain/utils/data/readers/readers.py +++ b/src/metatrain/utils/data/readers/readers.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple from metatensor.torch import TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from omegaconf import DictConfig from ..target_info import TargetInfo @@ -92,7 +92,7 @@ def read_targets( :raises ValueError: if the target name is not valid. Valid target names are those that either start with ``mtt::`` or those that are in the list of - standard outputs of ``metatensor.torch.atomistic`` (see + standard outputs of ``metatomic.torch`` (see https://docs.metatensor.org/latest/atomistic/outputs.html) """ target_dictionary = {} diff --git a/src/metatrain/utils/data/system_to_ase.py b/src/metatrain/utils/data/system_to_ase.py index 78b19f9b0..280de1add 100644 --- a/src/metatrain/utils/data/system_to_ase.py +++ b/src/metatrain/utils/data/system_to_ase.py @@ -1,9 +1,9 @@ import ase -from metatensor.torch.atomistic import System +from metatomic.torch import System def system_to_ase(system: System) -> ase.Atoms: - """Converts a ``metatensor.torch.atomistic.System`` to an ``ase.Atoms`` object. + """Converts a ``metatomic.torch.System`` to an ``ase.Atoms`` object. This will discard any neighbor lists attached to the ``System``. :param system: The system to convert. diff --git a/src/metatrain/utils/data/writers/__init__.py b/src/metatrain/utils/data/writers/__init__.py index 138172230..19f9a96f0 100644 --- a/src/metatrain/utils/data/writers/__init__.py +++ b/src/metatrain/utils/data/writers/__init__.py @@ -2,7 +2,7 @@ from typing import List, Optional from metatensor.torch import TensorMap -from metatensor.torch.atomistic import ModelCapabilities, System +from metatomic.torch import ModelCapabilities, System from .metatensor import write_mts from .xyz import write_xyz diff --git a/src/metatrain/utils/data/writers/metatensor.py b/src/metatrain/utils/data/writers/metatensor.py index d6b18503e..02d07fb17 100644 --- a/src/metatrain/utils/data/writers/metatensor.py +++ b/src/metatrain/utils/data/writers/metatensor.py @@ -3,7 +3,7 @@ import torch from metatensor.torch import TensorMap, save -from metatensor.torch.atomistic import ModelCapabilities, System +from metatomic.torch import ModelCapabilities, System # note that, although we don't use `systems` and `capabilities`, we need them to diff --git a/src/metatrain/utils/data/writers/xyz.py b/src/metatrain/utils/data/writers/xyz.py index 70808b3ab..684b5a1a9 100644 --- a/src/metatrain/utils/data/writers/xyz.py +++ b/src/metatrain/utils/data/writers/xyz.py @@ -5,7 +5,7 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorMap -from metatensor.torch.atomistic import ModelCapabilities, System +from metatomic.torch import ModelCapabilities, System from ...external_naming import to_external_name diff --git a/src/metatrain/utils/evaluate_model.py b/src/metatrain/utils/evaluate_model.py index ae1ce14bc..c8fcd05f6 100644 --- a/src/metatrain/utils/evaluate_model.py +++ b/src/metatrain/utils/evaluate_model.py @@ -3,8 +3,8 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelEvaluationOptions, ModelOutput, System, @@ -18,7 +18,7 @@ def evaluate_model( model: Union[ torch.nn.Module, - MetatensorAtomisticModel, + AtomisticModel, torch.jit._script.RecursiveScriptModule, ], systems: List[System], @@ -226,7 +226,7 @@ def _get_outputs( def _get_model_outputs( model: Union[ torch.nn.Module, - MetatensorAtomisticModel, + AtomisticModel, torch.jit._script.RecursiveScriptModule, ], systems: List[System], diff --git a/src/metatrain/utils/external_naming.py b/src/metatrain/utils/external_naming.py index 2fa040ec2..07bed5b69 100644 --- a/src/metatrain/utils/external_naming.py +++ b/src/metatrain/utils/external_naming.py @@ -1,6 +1,6 @@ from typing import Dict, Union -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput def to_external_name( diff --git a/src/metatrain/utils/io.py b/src/metatrain/utils/io.py index 3b691471b..583ae79dd 100644 --- a/src/metatrain/utils/io.py +++ b/src/metatrain/utils/io.py @@ -6,7 +6,7 @@ from urllib.request import urlretrieve import torch -from metatensor.torch.atomistic import check_atomistic_model, load_atomistic_model +from metatomic.torch import check_atomistic_model, load_atomistic_model from ..utils.architectures import find_all_architectures from .architectures import import_architecture @@ -43,9 +43,9 @@ def check_file_extension( def is_exported_file(path: str) -> bool: """ - Check if a saved model file has been exported to a ``MetatensorAtomisticModel``. + Check if a saved model file has been exported to a ``AtomisticModel``. - The functions uses :py:func:`metatensor.torch.atomistic.check_atomistic_model` to + The functions uses :py:func:`metatomic.torch.check_atomistic_model` to verify. :param path: model path @@ -54,7 +54,7 @@ def is_exported_file(path: str) -> bool: .. seealso:: - :py:func:`metatensor.torch.atomistic.is_atomistic_model` to verify if an already + :py:func:`metatomic.torch.is_atomistic_model` to verify if an already loaded model is exported. """ try: diff --git a/src/metatrain/utils/llpr.py b/src/metatrain/utils/llpr.py index ff25e13b3..ac645675d 100644 --- a/src/metatrain/utils/llpr.py +++ b/src/metatrain/utils/llpr.py @@ -1,10 +1,11 @@ -from typing import Callable, Dict, List, Optional +from collections import defaultdict +from typing import Callable, Dict, DefaultDict, List, Optional import metatensor.torch import numpy as np import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ( +from metatomic.torch import ( ModelCapabilities, ModelEvaluationOptions, ModelOutput, @@ -18,6 +19,7 @@ from .evaluate_model import evaluate_model from .per_atom import average_by_num_atoms +from metatrain.pet.DOSutils import get_dynamic_shift_agnostic_mse class LLPRUncertaintyModel(torch.nn.Module): """A wrapper that adds LLPR uncertainties to a model. @@ -33,26 +35,36 @@ class LLPRUncertaintyModel(torch.nn.Module): def __init__( self, model: torch.jit._script.RecursiveScriptModule, + num_subtargets: Optional[DefaultDict] = defaultdict(lambda: 1), + # TODO: read `num_targets` from capabilities instead of user input + dos: bool = False, # DOS-specific ) -> None: super().__init__() self.model = model - self.ll_feat_size = self.model.module.last_layer_feature_size + self.ll_feat_size = self.model.last_layer_feature_size + + # we need the capabilities of the model to be able to infer the capabilities + # of the LLPR model. Here, we do a trick: we call export on the model to to make + # it handle the conversion from dataset_info to capabilities + old_capabilities = self.model.export().capabilities() + dtype = getattr(torch, old_capabilities.dtype) # update capabilities: now we have additional outputs for the uncertainty - old_capabilities = self.model.capabilities() additional_capabilities = {} - self.uncertainty_multipliers = {} + self.outputs_list = [] for name, output in old_capabilities.outputs.items(): if is_auxiliary_output(name): continue # auxiliary output - uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" + elif "mask" in name: + continue # DOS-specific + self.outputs_list.append(name) + uncertainty_name = _get_uncertainty_name(name) additional_capabilities[uncertainty_name] = ModelOutput( - quantity="", - unit=f"({output.unit})^2", - per_atom=True, + quantity=output.quantity, + unit=output.unit, + per_atom=output.per_atom, ) - self.uncertainty_multipliers[uncertainty_name] = 1.0 self.capabilities = ModelCapabilities( outputs={**old_capabilities.outputs, **additional_capabilities}, atomic_types=old_capabilities.atomic_types, @@ -62,38 +74,61 @@ def __init__( dtype=old_capabilities.dtype, ) - # register covariance and inverse covariance buffers - device = next(self.model.parameters()).device - dtype = getattr(torch, old_capabilities.dtype) - self.covariances = { - uncertainty_name: torch.zeros( - (self.ll_feat_size, self.ll_feat_size), - device=device, - dtype=dtype, + for name in self.outputs_list: + if "mask" in name: + continue # DOS-specific + uncertainty_name = _get_uncertainty_name(name) + self.register_buffer( + f"covariance_{uncertainty_name}", + torch.zeros( + (self.ll_feat_size, self.ll_feat_size), + dtype=dtype, + ), ) - for uncertainty_name in self.uncertainty_multipliers.keys() - } - self.inv_covariances = { - uncertainty_name: torch.zeros( - (self.ll_feat_size, self.ll_feat_size), - device=device, - dtype=dtype, + self.register_buffer( + f"inv_covariance_{uncertainty_name}", + torch.zeros( + (self.ll_feat_size, self.ll_feat_size), + dtype=dtype, + ), ) - for uncertainty_name in self.uncertainty_multipliers.keys() - } + + self.uncertainty_multipliers = {} + # TODO: read `num_targets` from capabilities instead of user input + self.num_subtargets = num_subtargets + self.uncertainty_multipliers[uncertainty_name] = torch.ones( + num_subtargets[name], + device=device, + dtype=dtype, + ) + + # DOS-specific + self.dos = dos + + device = next(self.model.parameters()).device + dtype = getattr(torch, old_capabilities.dtype) + + self.n_ens = defaultdict(lambda: 0) + self.llpr_ensemble_layers = torch.nn.ModuleDict() + self.ensemble_weights_computed = defaultdict(lambda: False) # flags self.covariance_computed = False self.inv_covariance_computed = False self.is_calibrated = False + self.is_recalibrated = False def forward( self, systems: List[System], outputs: Dict[str, ModelOutput], + is_recalibrating: bool = False, selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: + device = systems[0].positions.device + + # make sure matrices are on the same device as systems if list(self.covariances.values())[0].device != device: for name in self.covariances.keys(): self.covariances[name] = self.covariances[name].to(device=device) @@ -116,16 +151,27 @@ def forward( ) return self.model(systems, options, check_consistency=False) + # collect per-atom targets per_atom_all_targets = [output.per_atom for output in outputs.values()] + # impose either all per atom or all not per atom if not all(per_atom_all_targets) and any(per_atom_all_targets): raise ValueError( "All output uncertainties must be either be requested per " "atom or not per atom with LLPR." ) + + # ??? per_atom = per_atom_all_targets[0] + outputs_for_model: Dict[str, ModelOutput] = {} + + # collect last-layer features for uncertainty-requested outputs for name in outputs.keys(): + + if "mask" in name: + continue # DOS-specific + if name.endswith("_uncertainty"): base_name = name.replace("_uncertainty", "").replace("mtt::aux::", "") if base_name not in outputs and f"mtt::{base_name}" not in outputs: @@ -142,6 +188,8 @@ def forward( per_atom=per_atom, ) ) + + # collect actual outputs for name, output in outputs.items(): # remove uncertainties from the requested outputs for the # wrapped model @@ -151,6 +199,7 @@ def forward( continue outputs_for_model[name] = output + # initialize return_dict options = ModelEvaluationOptions( length_unit="", outputs=outputs_for_model, @@ -158,12 +207,17 @@ def forward( ) return_dict = self.model(systems, options, check_consistency=False) + # collect requested uncertainties requested_uncertainties: List[str] = [] for name in outputs.keys(): if name.startswith("mtt::aux::") and name.endswith("_uncertainty"): requested_uncertainties.append(name) - for name in requested_uncertainties: + for name, orig_name in zip(requested_uncertainties, outputs.keys()): + + if "mask" in name: + continue # DOS-specific + ll_features = return_dict[ name.replace("_uncertainty", "_last_layer_features") ] @@ -176,6 +230,7 @@ def forward( self.inv_covariances[name], ll_features.block().values, ).unsqueeze(1) + one_over_pr = TensorMap( keys=Labels( names=["_"], @@ -185,21 +240,50 @@ def forward( ), blocks=[ TensorBlock( - values=one_over_pr_values, + values=one_over_pr_values.expand(-1, self.num_subtargets[orig_name]), samples=ll_features.block().samples, components=ll_features.block().components, - properties=Labels( - names=["_"], - values=torch.tensor( - [[0]], device=ll_features.block().values.device - ), - ), + properties=Labels.range("properties", + self.num_subtargets[orig_name], + ).to(ll_features.block().values.device), + #( + # names=["_"], + # values=torch.tensor( + # [[0]], device=ll_features.block().values.device + # ), + #), + ) + ], + ) + + tsm_multipliers = TensorMap( + keys=Labels( + names=["_"], + values=torch.tensor( + [[0]], device=ll_features.block().values.device + ), + ), + blocks=[ + TensorBlock( + values=self.uncertainty_multipliers[name].expand(len(systems), -1), + samples=ll_features.block().samples, + components=ll_features.block().components, + properties=Labels.range("properties", + self.num_subtargets[orig_name], + ).to(ll_features.block().values.device), + #( + # names=["_"], + # values=torch.tensor( + # [[0]], device=ll_features.block().values.device + # ), + #), ) ], ) return_dict[name] = metatensor.torch.multiply( - one_over_pr, self.uncertainty_multipliers[name] + one_over_pr, + tsm_multipliers, ) # now deal with potential ensembles (see generate_ensemble method) @@ -209,22 +293,25 @@ def forward( requested_ensembles.append(name) for name in requested_ensembles: + + base_name = name.replace("aux::", "").replace("_ensemble", "") + if not self.ensemble_weights_computed[base_name]: + raise RuntimeError(f"Ensemble weights have not been computed for {name}! Aborting...") + + base_name = name.replace("_ensemble", "").replace("aux::", "") ll_features_name = name.replace("_ensemble", "_last_layer_features") if ll_features_name == "energy_last_layer_features": # special case for energy_ensemble ll_features_name = "mtt::aux::energy_last_layer_features" ll_features = return_dict[ll_features_name] - # get the ensemble weights (getattr not supported by torchscript) - ensemble_weights = torch.tensor(0.0) - for buffer_name, buffer in self.named_buffers(): - if buffer_name == name + "_weights": - ensemble_weights = buffer - # the ensemble weights should always be found (checks are performed - # in the generate_ensemble method and in the metatensor wrapper) - ensemble_values = torch.einsum( - "ij, jk -> ik", - ll_features.block().values, - ensemble_weights, + + self.llpr_ensemble_layers[base_name].to(ll_features.block().values.device) + ensemble_values = self.llpr_ensemble_layers[base_name](ll_features.block().values) + + ensemble_values = ensemble_values.reshape( + ensemble_values.shape[0], + self.n_ens[base_name], + -1, ) # since we know the exact mean of the ensemble from the model's prediction, @@ -234,16 +321,18 @@ def forward( # this also takes care of additive contributions that are not present in the # last layer, which can be composition, short-range models, a bias in the # last layer, etc. - original_name = ( - name.replace("_ensemble", "").replace("aux::", "") - if name.replace("_ensemble", "").replace("aux::", "") in outputs - else name.replace("_ensemble", "").replace("mtt::aux::", "") - ) - ensemble_values = ( - ensemble_values - - ensemble_values.mean(dim=1, keepdim=True) - + return_dict[original_name].block().values - ) + + if is_recalibrating == False: + original_name = ( + name.replace("_ensemble", "").replace("aux::", "") + if name.replace("_ensemble", "").replace("aux::", "") in outputs + else name.replace("_ensemble", "").replace("mtt::aux::", "") + ) + ensemble_values = ( + ensemble_values + - ensemble_values.mean(dim=1, keepdim=True) + + return_dict[original_name].block().values.unsqueeze(1) ## DOS specific + ) property_name = "energy" if name == "energy_ensemble" else "ensemble_member" ensemble = TensorMap( @@ -255,14 +344,15 @@ def forward( ), blocks=[ TensorBlock( - values=ensemble_values, + values=ensemble_values.reshape(ensemble_values.shape[0], -1), samples=ll_features.block().samples, components=ll_features.block().components, properties=Labels( - names=[property_name], - values=torch.arange( - ensemble_values.shape[1], device=ensemble_values.device - ).unsqueeze(1), + names=['ensemble_member', 'energy_channel'], # DOS specific + values=torch.cartesian_prod( + torch.arange(ensemble_values.shape[1], device=ensemble_values.device), # DOS specific, double-check! + torch.arange(ensemble_values.shape[2], device=ensemble_values.device), # DOS specific, double-check! + ) ), ) ], @@ -290,6 +380,10 @@ class in ``metatrain``. dtype = next(iter(self.covariances.values())).dtype for batch in train_loader: systems, targets = batch + + if self.dos: + del targets["mtt::mask"] # DOS-specific + n_atoms = torch.tensor( [len(system.positions) for system in systems], device=device ) @@ -491,48 +585,94 @@ def calibrate(self, valid_loader: DataLoader): This data loader should be generated from a dataset from the ``Dataset`` class in ``metatrain.utils.data``. """ + # calibrate the LLPR - # TODO: in the future, we might want to have one calibration factor per - # property for outputs with multiple properties + device = next(iter(self.covariances.values())).device dtype = next(iter(self.covariances.values())).dtype + all_predictions = {} # type: ignore all_targets = {} # type: ignore all_uncertainties = {} # type: ignore + + if self.dos: + all_masks = [] + all_shifts = [] + all_lens = [] + for batch in valid_loader: - systems, targets = batch + + systems, orig_targets = batch + + targets = orig_targets.copy() + if self.dos: + del targets["mtt::mask"] # DOS-specific + systems = [system.to(device=device, dtype=dtype) for system in systems] + + if self.dos: + lens = [len(system) for system in systems] + all_lens += lens + targets = { name: target.to(device=device, dtype=dtype) for name, target in targets.items() } + # evaluate the targets and their uncertainties, not per atom requested_outputs = {} for name in targets: + requested_outputs[name] = ModelOutput( quantity="", unit="", per_atom=False, ) + uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" requested_outputs[uncertainty_name] = ModelOutput( quantity="", unit="", per_atom=False, ) + outputs = self.forward(systems, requested_outputs) + for name, target in targets.items(): + uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" if name not in all_predictions: all_predictions[name] = [] all_targets[name] = [] all_uncertainties[uncertainty_name] = [] + all_predictions[name].append(outputs[name].block().values.detach()) all_targets[name].append(target.block().values) all_uncertainties[uncertainty_name].append( outputs[uncertainty_name].block().values.detach() ) + if self.dos and name == "mtt::dos": + + # accumulate masks + cur_mask = orig_targets["mtt::mask"].block().values.to(device=device, dtype=dtype) + all_masks.append(cur_mask) + print(target.block().values.device, target.block().values.dtype) + print(device) + # accumulate shifts + _, cur_shifts = get_dynamic_shift_agnostic_mse( + outputs[name].block().values.detach(), + target.block().values * torch.tensor(lens).unsqueeze(-1).to(device=device, dtype=dtype), + cur_mask, + return_shift=True) + + all_shifts.append(cur_shifts) + + if self.dos: + all_masks = torch.cat(all_masks, dim=0) + all_shifts = torch.cat(all_shifts, dim=0) + all_lens = torch.tensor(all_lens, device=device, dtype=dtype) + for name in all_predictions: all_predictions[name] = torch.cat(all_predictions[name], dim=0) all_targets[name] = torch.cat(all_targets[name], dim=0) @@ -541,19 +681,66 @@ def calibrate(self, valid_loader: DataLoader): all_uncertainties[uncertainty_name], dim=0 ) + # compute the uncertainty multiplier for name in all_predictions: - # compute the uncertainty multiplier - residuals = all_predictions[name] - all_targets[name] + uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" uncertainties = all_uncertainties[uncertainty_name] - self.uncertainty_multipliers[uncertainty_name] = torch.mean( - residuals**2 / uncertainties - ).item() + + if name == "mtt::dos": + dos_predictions = all_predictions[name] + orig_dos_targets = all_targets[name] + + revised_dos_targets = torch.zeros(dos_predictions.shape, dtype=dtype, device=device) + revised_masks = torch.zeros(dos_predictions.shape, dtype=dtype, device=device) + + # broadcasting tensors + rows = torch.arange(dos_predictions.shape[0]).unsqueeze(1) + cols = all_shifts.unsqueeze(1) + torch.arange(orig_dos_targets.shape[1]).to(device=device) + + # revised DOS target via broadcasting + revised_dos_targets[rows, cols] = orig_dos_targets + + # get revised masks, padding the low E end with 1's + revised_masks[rows, cols] = all_masks + low_e_cols = torch.arange(dos_predictions.shape[1]).unsqueeze(0).expand(len(rows), -1).to(device=device) + low_e_mask = low_e_cols < all_shifts.unsqueeze(1) + revised_masks[low_e_mask] = 1 + mask_count = revised_masks.sum(dim=0) + + # raw residuals + resid = all_predictions[name] - (revised_dos_targets * all_lens.unsqueeze(-1)) + resid_masked_sum = ((resid ** 2) * revised_masks).sum(dim=0) + resid_masked_mean = resid_masked_sum / mask_count.clamp(min=1) + + uncer_masked_sum = (uncertainties * revised_masks).sum(dim=0) + uncer_masked_mean = uncer_masked_sum / mask_count.clamp(min=1) + + # true/pred ratios + ratios = resid_masked_mean / uncer_masked_mean + ratios = torch.where(mask_count > 0, ratios, torch.tensor(float('nan'))) + self.uncertainty_multipliers[uncertainty_name] = ratios + + # generic case with num_subtargets > 1 + elif self.num_subtargets[name] > 1: + residuals = all_predictions[name] - all_targets[name] + self.uncertainty_multipliers[uncertainty_name] = torch.mean( + residuals**2 / uncertainties, + axis=0, + ) + + else: + residuals = all_predictions[name] - all_targets[name] + self.uncertainty_multipliers[uncertainty_name] = torch.mean( + residuals**2 / uncertainties + ).item() self.is_calibrated = True def generate_ensemble( - self, weight_tensors: Dict[str, torch.Tensor], n_members: int + self, + weight_tensors: Dict[str, torch.Tensor], + n_ens: Dict[str, int], ) -> None: """Generate an ensemble of weights for the model. @@ -563,7 +750,7 @@ def generate_ensemble( :param weight_tensors: A dictionary with the weights for the ensemble. The keys should be the names of the weights in the model and the - values should be 1D PyTorch tensors. + values should be PyTorch tensors of (num_targets, num_weights) :param n_members: The number of members in the ensemble. """ # note: we could also allow n_members to be different for each output @@ -576,37 +763,74 @@ def generate_ensemble( for key in weight_tensors: if key not in self.capabilities.outputs.keys(): raise ValueError(f"Output '{key}' not supported by model") - if len(weight_tensors[key].shape) != 1: - raise ValueError("All weights must be 1D tensors") + print(weight_tensors[key].shape) + if len(weight_tensors[key].shape) != 2: # DOS specific + raise ValueError("All weights must be 2D tensors") # DOS specific # sampling; each member is sampled from a multivariate normal distribution # with mean given by the input weights and covariance given by the inverse # covariance matrix for name, weights in weight_tensors.items(): + uncertainty_name = "mtt::aux::" + name.replace("mtt::", "") + "_uncertainty" device = self.inv_covariances[uncertainty_name].device dtype = self.inv_covariances[uncertainty_name].dtype rng = np.random.default_rng() - ensemble_weights = rng.multivariate_normal( - weights.clone().detach().cpu().numpy(), - self.inv_covariances[uncertainty_name].clone().detach().cpu().numpy() - * self.uncertainty_multipliers[uncertainty_name], - size=n_members, - method="svd", - ).T - ensemble_weights = torch.tensor( - ensemble_weights, device=device, dtype=dtype - ) - ensemble_weights_name = ( - "mtt::aux::" + name.replace("mtt::", "") + "_ensemble_weights" - ) - if ensemble_weights_name == "mtt::aux::energy_ensemble_weights": - ensemble_weights_name = "energy_ensemble_weights" - self.register_buffer( - ensemble_weights_name, - ensemble_weights, + + if n_ens[name] < 0: + raise AssertionError(f"Invalid n_ens value for {name}! Double-check your input params. Aborting...") + elif n_ens[name] > 0 and n_ens[name] < 8: + raise AssertionError(f"n_ens for {name} too small. Aborting...") + + self.n_ens[name] = n_ens[name] + # DOS-specific + # loop through pred channels + ensemble_weights = [] + max_multiplier = -1, + for ii in range(weights.shape[0]): + if np.isnan(self.uncertainty_multipliers[uncertainty_name][ii].detach().cpu().numpy()): + print(f"multiplier is NaN for channel # {ii}! We resort to the max_multiplier value...") + cur_ensemble_weights = rng.multivariate_normal( + weights[ii].clone().detach().cpu().numpy(), + self.inv_covariances[uncertainty_name].clone().detach().cpu().numpy() + * max_multiplier, + size=n_ens[name], + method="svd", + ).T + else: + print("ens. generation for energy channel -- #", ii) + cur_ensemble_weights = rng.multivariate_normal( + weights[ii].clone().detach().cpu().numpy(), + self.inv_covariances[uncertainty_name].clone().detach().cpu().numpy() + * self.uncertainty_multipliers[uncertainty_name][ii].detach().cpu().numpy(), + size=n_ens[name], + method="svd", + ).T + if max_multiplier < self.uncertainty_multipliers[uncertainty_name][ii].detach().cpu().numpy(): + max_multiplier = self.uncertainty_multipliers[uncertainty_name][ii].detach().cpu().numpy() + + cur_ensemble_weights = torch.tensor( + cur_ensemble_weights, device=device, dtype=dtype + ) + ensemble_weights.append(cur_ensemble_weights) # DOS specific + ensemble_weights = torch.stack(ensemble_weights, axis=-1) # DOS specific, shape ll_Feat, n_ens, n_channel + print(ensemble_weights.shape) + ensemble_weights = ensemble_weights.reshape( + ensemble_weights.shape[0], + -1, + ) # DOS specific, shape ll_feat, n_ens*n_channel + print(ensemble_weights.shape) + # 1D Linear that goes from ll_feat_size to n_channel * n_ens + self.llpr_ensemble_layers[name] = torch.nn.Linear( + self.ll_feat_size, + weights.shape[0] * n_ens[name], + bias=False ) + with torch.no_grad(): + self.llpr_ensemble_layers[name].weight.copy_(ensemble_weights.T) + self.ensemble_weights_computed[name] = True + # add the ensembles to the capabilities old_outputs = self.capabilities.outputs new_outputs = {} @@ -627,3 +851,16 @@ def generate_ensemble( supported_devices=self.capabilities.supported_devices, dtype=self.capabilities.dtype, ) + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + +def _get_uncertainty_name(name: str): + if name == "energy": + uncertainty_name = "energy_uncertainty" + else: + uncertainty_name = f"mtt::aux::{name.replace('mtt::', '')}_uncertainty" + return uncertainty_name diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index 037c3c8ef..0f38664c6 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -9,7 +9,7 @@ import numpy as np import torch -from metatensor.torch.atomistic import ModelCapabilities +from metatomic.torch import ModelCapabilities from .. import PACKAGE_ROOT, __version__ from .data import DatasetInfo diff --git a/src/metatrain/utils/long_range.py b/src/metatrain/utils/long_range.py index 835bd08ea..cb52830c5 100644 --- a/src/metatrain/utils/long_range.py +++ b/src/metatrain/utils/long_range.py @@ -1,7 +1,7 @@ from typing import List import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System class LongRangeFeaturizer(torch.nn.Module): diff --git a/src/metatrain/utils/loss.py b/src/metatrain/utils/loss.py index 726e3bab4..ad1c991ac 100644 --- a/src/metatrain/utils/loss.py +++ b/src/metatrain/utils/loss.py @@ -2,12 +2,52 @@ import torch from metatensor.torch import TensorMap +from metatensor.torch import mean_over_samples, var_over_samples from omegaconf import DictConfig from torch.nn.modules.loss import _Loss from metatrain.utils.external_naming import to_internal_name +class LLPREnsCalibLoss: + ## template for an actual loss function, currently not in use for DOS + def __init__( + self, + reduction: str = "mean", + weight: float = 1.0, + ): + losses = {} + losses["values"] = torch.nn.GaussianNLLLoss(reduction=reduction) + self.losses = losses + self.weight = weight + + def __call__( + self, + ensemble_pred_tensor_map: TensorMap, + targets_tensor_map: TensorMap, + ) -> torch.Tensor: + + ens_pred_mean_tsm = mean_over_samples(ensemble_pred_tensor_map, "ensemble_member") + ens_pred_var_tsm = var_over_samples(ensemble_pred_tensor_map, "ensemble_member") + + loss = torch.zeros( + (), + dtype=ens_pred_mean_tsm.block(0).values.dtype, + device=ens_pred_mean_tsm.block(0).values.device, + ) + + for key in ens_pred_mean_tsm.keys: + block_mean = ens_pred_mean_tsm.block(key) + block_var = ens_pred_var_tsm.block(key) + block_target = targets_tensor_map.block(key) + values_mean = block_mean.values + values_var = block_var.values + values_target = block_target.values + loss += self.weight * self.losses["values"](values_mean, values_target, values_var) # input, target, var + + return loss + + class TensorMapLoss: """A loss function that operates on two ``metatensor.torch.TensorMap``. diff --git a/src/metatrain/utils/metadata.py b/src/metatrain/utils/metadata.py index 132a5a63a..064259513 100644 --- a/src/metatrain/utils/metadata.py +++ b/src/metatrain/utils/metadata.py @@ -1,6 +1,6 @@ import json -from metatensor.torch.atomistic import ModelMetadata +from metatomic.torch import ModelMetadata def append_metadata_references(self: ModelMetadata, other: ModelMetadata) -> None: diff --git a/src/metatrain/utils/neighbor_lists.py b/src/metatrain/utils/neighbor_lists.py index e35f1cd6a..e448d6c36 100644 --- a/src/metatrain/utils/neighbor_lists.py +++ b/src/metatrain/utils/neighbor_lists.py @@ -5,7 +5,7 @@ import torch import vesin from metatensor.torch import Labels, TensorBlock -from metatensor.torch.atomistic import ( +from metatomic.torch import ( NeighborListOptions, System, ) @@ -91,7 +91,7 @@ def _compute_single_neighbor_list( atoms: ase.Atoms, options: NeighborListOptions ) -> TensorBlock: # Computes a single neighbor list for an ASE atoms object - # (as in metatensor.torch.atomistic) + # (as in metatomic.torch) if np.all(atoms.pbc) or np.all(~atoms.pbc): nl_i, nl_j, nl_S, nl_D = vesin.ase_neighbor_list( diff --git a/src/metatrain/utils/per_atom.py b/src/metatrain/utils/per_atom.py index e2867519d..d64efb922 100644 --- a/src/metatrain/utils/per_atom.py +++ b/src/metatrain/utils/per_atom.py @@ -2,7 +2,7 @@ import torch from metatensor.torch import TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System def average_by_num_atoms( diff --git a/src/metatrain/utils/scaler.py b/src/metatrain/utils/scaler.py index c5fccaed3..47ec585f1 100644 --- a/src/metatrain/utils/scaler.py +++ b/src/metatrain/utils/scaler.py @@ -4,7 +4,7 @@ import numpy as np import torch from metatensor.torch import TensorMap -from metatensor.torch.atomistic import ModelOutput +from metatomic.torch import ModelOutput from .additive import remove_additive from .data import Dataset, DatasetInfo, TargetInfo, get_all_targets diff --git a/src/metatrain/utils/testing/equivariance.py b/src/metatrain/utils/testing/equivariance.py index 511082c2d..02ea198ad 100644 --- a/src/metatrain/utils/testing/equivariance.py +++ b/src/metatrain/utils/testing/equivariance.py @@ -1,7 +1,7 @@ import numpy as np import spherical import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from scipy.spatial.transform import Rotation diff --git a/src/metatrain/utils/transfer.py b/src/metatrain/utils/transfer.py index be27b1f63..0f3024baa 100644 --- a/src/metatrain/utils/transfer.py +++ b/src/metatrain/utils/transfer.py @@ -2,7 +2,7 @@ import torch from metatensor.torch import TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System @torch.jit.script diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 19efab40b..d98b48dac 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -11,7 +11,7 @@ import torch from jsonschema.exceptions import ValidationError from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch +from metatomic.torch import NeighborListOptions, systems_to_torch from omegaconf import OmegaConf from metatrain import RANDOM_SEED diff --git a/tests/utils/data/test_system_to_ase.py b/tests/utils/data/test_system_to_ase.py index 3d66f7e7c..823ca4927 100644 --- a/tests/utils/data/test_system_to_ase.py +++ b/tests/utils/data/test_system_to_ase.py @@ -1,5 +1,5 @@ import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.utils.data import system_to_ase diff --git a/tests/utils/data/test_writers.py b/tests/utils/data/test_writers.py index c2b8642de..d0aac5dda 100644 --- a/tests/utils/data/test_writers.py +++ b/tests/utils/data/test_writers.py @@ -4,7 +4,7 @@ import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, System +from metatomic.torch import ModelCapabilities, ModelOutput, System from metatrain.utils.data.readers.ase import read from metatrain.utils.data.writers import write_predictions, write_xyz diff --git a/tests/utils/test_additive.py b/tests/utils/test_additive.py index 1a3c614e5..1fbe5a885 100644 --- a/tests/utils/test_additive.py +++ b/tests/utils/test_additive.py @@ -5,7 +5,7 @@ import pytest import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ModelOutput, System +from metatomic.torch import ModelOutput, System from omegaconf import OmegaConf from metatrain.utils.additive import ZBL, CompositionModel, remove_additive diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index c489e8f99..ff619d4bc 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest -from metatensor.torch.atomistic import MetatensorAtomisticModel +from metatomic.torch import AtomisticModel from metatrain.soap_bpnn.model import SoapBpnn from metatrain.utils.io import check_file_extension, is_exported_file, load_model @@ -64,7 +64,7 @@ def test_load_model_checkpoint(path): ) def test_load_model_exported(path): model = load_model(path) - assert type(model) is MetatensorAtomisticModel + assert type(model) is AtomisticModel @pytest.mark.parametrize("suffix", [".yml", ".yaml"]) diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index 00c722e2b..1dc2a56d3 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -1,6 +1,6 @@ import torch -from metatensor.torch.atomistic import ( - MetatensorAtomisticModel, +from metatomic.torch import ( + AtomisticModel, ModelEvaluationOptions, ModelMetadata, ModelOutput, @@ -60,7 +60,7 @@ def test_llpr(tmpdir): llpr_model.compute_covariance(dataloader) llpr_model.compute_inverse_covariance() - exported_model = MetatensorAtomisticModel( + exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, @@ -106,7 +106,7 @@ def test_llpr(tmpdir): llpr_model.generate_ensemble({"energy": weights}, n_ensemble_members) assert "energy_ensemble" in llpr_model.capabilities.outputs - exported_model = MetatensorAtomisticModel( + exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, @@ -196,7 +196,7 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): ) llpr_model.compute_inverse_covariance() - exported_model = MetatensorAtomisticModel( + exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, @@ -242,7 +242,7 @@ def test_llpr_covariance_as_pseudo_hessian(tmpdir): llpr_model.generate_ensemble({"energy": weights}, n_ensemble_members) assert "energy_ensemble" in llpr_model.capabilities.outputs - exported_model = MetatensorAtomisticModel( + exported_model = AtomisticModel( llpr_model.eval(), ModelMetadata(), llpr_model.capabilities, diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 318cfdd62..f883cefe8 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -8,7 +8,7 @@ import pytest import wandb -from metatensor.torch.atomistic import ModelCapabilities, ModelOutput +from metatomic.torch import ModelCapabilities, ModelOutput from metatrain import PACKAGE_ROOT from metatrain.utils.logging import ( diff --git a/tests/utils/test_long_range.py b/tests/utils/test_long_range.py index ab5f64b43..634dd0362 100644 --- a/tests/utils/test_long_range.py +++ b/tests/utils/test_long_range.py @@ -1,6 +1,6 @@ import pytest import torch -from metatensor.torch.atomistic import systems_to_torch +from metatomic.torch import systems_to_torch from metatrain.experimental.nanopet import NanoPET from metatrain.soap_bpnn import SoapBpnn diff --git a/tests/utils/test_metadata.py b/tests/utils/test_metadata.py index a8ca3d9c8..3fb4c1d8d 100644 --- a/tests/utils/test_metadata.py +++ b/tests/utils/test_metadata.py @@ -1,6 +1,6 @@ import json -from metatensor.torch.atomistic import ModelMetadata +from metatomic.torch import ModelMetadata from metatrain.utils.metadata import append_metadata_references diff --git a/tests/utils/test_neighbor_list.py b/tests/utils/test_neighbor_list.py index d346dca00..20020c552 100644 --- a/tests/utils/test_neighbor_list.py +++ b/tests/utils/test_neighbor_list.py @@ -1,6 +1,6 @@ from pathlib import Path -from metatensor.torch.atomistic import NeighborListOptions +from metatomic.torch import NeighborListOptions from metatrain.utils.data.readers.ase import read_systems from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index d68ab8464..b573ef599 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -1,7 +1,7 @@ import metatensor.torch import pytest import torch -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.soap_bpnn import __model__ from metatrain.utils.data import DatasetInfo, read_systems @@ -106,7 +106,7 @@ def test_virial(is_training): for system in systems ] systems = [ - metatensor.torch.atomistic.System( + metatomic.torch.System( positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, @@ -137,7 +137,7 @@ def test_virial(is_training): for system in systems ] systems = [ - metatensor.torch.atomistic.System( + metatomic.torch.System( positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, @@ -190,7 +190,7 @@ def test_both(is_training): for system in systems ] systems = [ - metatensor.torch.atomistic.System( + metatomic.torch.System( positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, @@ -219,7 +219,7 @@ def test_both(is_training): for system in systems ] systems = [ - metatensor.torch.atomistic.System( + metatomic.torch.System( positions=system.positions @ strain, cell=system.cell @ strain, types=system.types, diff --git a/tests/utils/test_per_atom.py b/tests/utils/test_per_atom.py index 0f84180e7..46f20d295 100644 --- a/tests/utils/test_per_atom.py +++ b/tests/utils/test_per_atom.py @@ -1,6 +1,6 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.utils.per_atom import average_by_num_atoms, divide_by_num_atoms diff --git a/tests/utils/test_scaler.py b/tests/utils/test_scaler.py index c76b9a3df..c74a8f87c 100644 --- a/tests/utils/test_scaler.py +++ b/tests/utils/test_scaler.py @@ -2,7 +2,7 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from omegaconf import OmegaConf from metatrain.utils.data import Dataset, DatasetInfo diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py index 283842779..8b7d0c67e 100644 --- a/tests/utils/test_transfer.py +++ b/tests/utils/test_transfer.py @@ -1,7 +1,7 @@ import metatensor.torch import torch from metatensor.torch import Labels, TensorMap -from metatensor.torch.atomistic import System +from metatomic.torch import System from metatrain.utils.transfer import ( systems_and_targets_to_device,