Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/metatrain/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def forward(
The input systems are first converted into a batched representation containing:

- `element_indices_nodes` [n_atoms]: Atomic species of the central atoms
- `element_indices_neighbors` [n_edges]: Atomic species of neighboring atoms
- `element_indices_neighbors` [n_atoms, max_num_neighbors]: Atomic species of
neighboring atoms
- `edge_vectors` [n_atoms, max_num_neighbors, 3]: Cartesian edge vectors
between central atoms and their neighbors
- `padding_mask` [n_atoms, max_num_neighbors]: Mask indicating real vs padded
Expand Down Expand Up @@ -421,13 +422,25 @@ def forward(
cutoff_factors = cutoff_func(edge_distances, self.cutoff, self.cutoff_width)
cutoff_factors[~padding_mask] = 0.0

neighbors_index = (
neighbors_index * neighbors_index.shape[1] + reversed_neighbor_list
)

# At this point, we have `neighbors_index[~padding_mask] = 0`, which however
# creates too many of the same index which slows down backward enormously.
# (See see https://github.com/pytorch/pytorch/issues/41162)
# We therefore replace the padded indices with a sequence of unique indices.
neighbors_index[~padding_mask] = torch.arange(
int(torch.sum(~padding_mask)), device=device
)

# **Stage 1: Feature Computation via GNN Layers**
featurizer_inputs: Dict[str, torch.Tensor] = dict(
element_indices_nodes=element_indices_nodes,
element_indices_neighbors=element_indices_neighbors,
edge_vectors=edge_vectors,
neighbors_index=neighbors_index,
reversed_neighbor_list=reversed_neighbor_list,
# reversed_neighbor_list=reversed_neighbor_list,
padding_mask=padding_mask,
edge_distances=edge_distances,
cutoff_factors=cutoff_factors,
Expand Down Expand Up @@ -612,9 +625,14 @@ def _feedforward_featurization_impl(
# from atom `j` to atom `i` in on the GNN layer N+1 is a
# reversed message from atom `i` to atom `j` on the GNN layer N.
input_node_embeddings = output_node_embeddings
new_input_edge_embeddings = output_edge_embeddings[
inputs["neighbors_index"], inputs["reversed_neighbor_list"]
]
new_input_edge_embeddings = output_edge_embeddings.reshape(
output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1],
output_edge_embeddings.shape[2],
)[inputs["neighbors_index"]].reshape(
output_edge_embeddings.shape[0],
output_edge_embeddings.shape[1],
output_edge_embeddings.shape[2],
)
# input_messages = 0.5 * (output_edge_embeddings + new_input_messages)
concatenated = torch.cat(
[output_edge_embeddings, new_input_edge_embeddings], dim=-1
Expand Down Expand Up @@ -668,9 +686,14 @@ def _residual_featurization_impl(
# using a reversed neighbor list, so the new input message
# from atom `j` to atom `i` in on the GNN layer N+1 is a
# reversed message from atom `i` to atom `j` on the GNN layer N.
new_input_messages = output_edge_embeddings[
inputs["neighbors_index"], inputs["reversed_neighbor_list"]
]
new_input_messages = output_edge_embeddings.reshape(
output_edge_embeddings.shape[0] * output_edge_embeddings.shape[1],
output_edge_embeddings.shape[2],
)[inputs["neighbors_index"]].reshape(
output_edge_embeddings.shape[0],
output_edge_embeddings.shape[1],
output_edge_embeddings.shape[2],
)
input_edge_embeddings = 0.5 * (input_edge_embeddings + new_input_messages)
return node_features_list, edge_features_list

Expand Down