Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 127 additions & 15 deletions dptb/nn/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,87 @@
import torch.nn as nn
from torch.nn import Linear
import os
import torch.nn.functional as F
from collections import defaultdict

_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False)
_idx_data = torch.load(os.path.join(os.path.dirname(__file__), "z_rot_indices_lmax12.pt"), weights_only=False)


_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False)
def build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes):
"""
angle_stack: (3*N, ) # Input with alpha, beta, gamma stacked together
l_max: int

Returns: (Xa, Xb, Xc) # Each is of shape (N, D_total, D_total)
"""
N_all = angle_stack.shape[0]
N = N_all // 3

D_total = sizes.sum().item()

# Step 1: Vectorized computation of sine and cosine values
angle_expand = angle_stack[None, :, None] # (1, 3N, 1)
freq_expand = freq[:, None, :] # (L, 1, Mmax)
sin_val = torch.sin(freq_expand * angle_expand) # (L, 3N, Mmax)
cos_val = torch.cos(freq_expand * angle_expand) # (L, 3N, Mmax)

# Step 2: Construct the block-diagonal matrix
M_total = angle_stack.new_zeros((N_all, D_total, D_total))
idx_l, idx_row = torch.where(mask) # (K,), (K,)
idx_col_diag = idx_row
idx_col_anti = reversed_inds[idx_l, idx_row]
global_row = offsets[idx_l] + idx_row # (K,)
global_col_diag = offsets[idx_l] + idx_col_diag
global_col_anti = offsets[idx_l] + idx_col_anti

# Assign values to the diagonal
M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1)
# Assign values to non-overlapping anti-diagonals
overlap_mask = (global_row == global_col_anti)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it require a mask here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'overlap_mask' is used to avoid assigning sine values to matrix entries that are already filled with cosine values on the diagonal. It ensures that sine terms are only written to true off-diagonal (anti-diagonal) positions to prevent overwriting or conflict.

M_total[:, global_row[~overlap_mask], global_col_anti[~overlap_mask]] = sin_val[idx_l[~overlap_mask], :, idx_row[~overlap_mask]].transpose(0,1)

# Step 3: Split into three components corresponding to alpha, beta, gamma
Xa = M_total[:N]
Xb = M_total[N:2*N]
Xc = M_total[2*N:]

return Xa, Xb, Xc


def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):
"""
Compute Wigner D matrices for all L (from 0 to l_max) in a single batch.
Returns a tensor of shape [N, D, D], where D = sum(2l+1 for l in 0..l_max).
"""
device = alpha.device
N = alpha.shape[0]
idx_data = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in _idx_data.items()}

# Load static data
sizes = idx_data["sizes"][:l_max+1]
offsets = idx_data["offsets"][:l_max+1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the offset defined?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is the code for generating the 'z_rot_indices_lmax12.pt' file:

import torch

l_max_all = 12
l_list = torch.arange(0, l_max_all + 1)
sizes = 2 * l_list + 1
offsets = torch.cat([torch.tensor([0]), torch.cumsum(sizes, 0)[:-1]])
L = len(l_list)
Mmax = sizes.max().item()

mask = torch.zeros(L, Mmax, dtype=torch.bool)
freq = torch.zeros(L, Mmax)
inds = torch.zeros(L, Mmax, dtype=torch.long)
reversed_inds = torch.zeros(L, Mmax, dtype=torch.long)

for i, l in enumerate(l_list):
sz = sizes[i]
mask[i, :sz] = True
freq[i, :sz] = torch.arange(l, -l-1, -1)
inds[i, :sz] = torch.arange(0, sz)
reversed_inds[i, :sz] = torch.arange(2 * l, -1, -1)

torch.save({
"mask": mask,
"freq": freq,
"inds": inds,
"reversed_inds": reversed_inds,
"l_list": l_list,
"sizes": sizes,
"offsets": offsets,
"Mmax": Mmax,
"l_max_all": l_max_all,
}, "/root/DeePTB/dptb/nn/z_rot_indices_lmax12.pt")

mask = idx_data["mask"][:l_max+1]
freq = idx_data["freq"][:l_max+1]
reversed_inds = idx_data["reversed_inds"][:l_max+1]

# Precompute block structure information
dims = [2*l + 1 for l in range(l_max + 1)]
offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device='cuda'), 0)
D_total = sum(dims)

# Construct block-diagonal J matrix
J_full_small = torch.zeros(D_total, D_total, device='cuda')
for l in range(l_max + 1):
start = offsets[l]
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l]

J_full = J_full_small.unsqueeze(0).expand(N, -1, -1)
angle_stack = torch.cat([alpha, beta, gamma], dim=0)
Xa, Xb, Xc = build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes)

return Xa @ J_full @ Xb @ J_full @ Xc


def wigner_D(l, alpha, beta, gamma):
if not l < len(_Jd):
Expand Down Expand Up @@ -54,7 +131,7 @@ def __init__(
extra_m0_outsize: int = 0,
):
super(SO2_Linear, self).__init__()


self.irreps_in = irreps_in.simplify()
self.irreps_out = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify()
Expand Down Expand Up @@ -105,19 +182,52 @@ def __init__(
self.radial_emb = RadialFunction([latent_dim]+radial_channels+[self.m_in_index[-1]])
self.front = front

self.l_max = max(l for (_, (l, _)), _ in zip(self.irreps_in, self.irreps_in.slices()) if l > 0)
self.dims = {l: 2*l + 1 for l in range(self.l_max + 1)}
self.offsets = {}
offset = 0
for l in range(self.l_max + 1):
self.offsets[l] = offset
offset += self.dims[l]


def forward(self, x, R, latents=None):
n, _ = x.shape

if self.radial_emb:
weights = self.radial_emb(latents)

x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device)
for (mul, (l,p)), slice in zip(self.irreps_in, self.irreps_in.slices()):
if l > 0:
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N))
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here.
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0]))
x_[:, slice] = torch.einsum('nji,nmj->nmi', rot_mat_L, x[:, slice].reshape(n,mul,-1)).reshape(n,-1)
x_ = torch.zeros_like(x)
angle = xyz_to_angles(R[:, [1,2,0]])

# Compute Wigner D matrices for all l at once
wigner_D_all = batch_wigner_D(self.l_max, angle[0], angle[1], torch.zeros_like(angle[0]), _Jd)

# 1. group irreps by l
groups = defaultdict(list)
for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()):
groups[l].append((mul, slice_info))

# 2. Batch process all mul for each l
for l, group in groups.items():
if l == 0 or not group:
continue

# Batch combination
muls, slices = zip(*group)
x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group]
x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1]

start = self.offsets[l]
rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]]

# Batch feature rotation (n, total_mul, 2l+1)
transformed = torch.bmm(x_combined, rot_mat) # (n, total_mul, 2l+1)

# Split back into each slice in the original order
for part, slice_info, mul in zip(transformed.split(muls, dim=1), slices, muls):
x_[:, slice_info] = part.reshape(n, -1)


out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device)
for m in range(self.irreps_out.lmax+1):
Expand All @@ -139,12 +249,14 @@ def forward(self, x, R, latents=None):

out.contiguous()

for (mul, (l,p)), slice in zip(self.irreps_out, self.irreps_out.slices()):

for (mul, (l, p)), slice_in in zip(self.irreps_out, self.irreps_out.slices()):
if l > 0:
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N))
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here.
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0]))
out[:, slice] = torch.einsum('nij,nmj->nmi', rot_mat_L, out[:, slice].reshape(n,mul,-1)).reshape(n,-1)
start = self.offsets[l]
rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]]
x_slice = out[:, slice_in].reshape(n, mul, -1)
rotated = torch.einsum('nij,nmj->nmi', rot_mat, x_slice)
out[:, slice_in] = rotated.reshape(n, -1)

return out

Expand Down Expand Up @@ -225,4 +337,4 @@ def __init__(self, channels_list):


def forward(self, inputs):
return self.net(inputs)
return self.net(inputs)
Binary file added dptb/nn/z_rot_indices_lmax12.pt
Binary file not shown.