-
Notifications
You must be signed in to change notification settings - Fork 22
Efficient Block-Diagonal Matrix Operations for Wigner D Computation #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors Wigner D matrix computation to use batched block-diagonal operations instead of per-l
loops, significantly improving performance for large l_max
or batch sizes.
- Generate all small rotation matrices in one go and assemble into block-diagonal tensors
- Perform a single batched matrix multiplication sequence for the full Wigner D
- Group irreps by quantum number and rotate features in bulk rather than per-
l
Comments suppressed due to low confidence (3)
dptb/nn/tensor_product.py:17
- The docstring for
build_z_rot_multi
mentions anl_max
parameter which is not present; update the parameter list and descriptions to match the actual signature.
l_max: int
dptb/nn/tensor_product.py:55
- Add unit tests comparing
batch_wigner_D
against the originalwigner_D
for smalll_max
and random angles to ensure the batched implementation matches legacy outputs.
def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):
dptb/nn/tensor_product.py:257
- This loop is rotating data in
out
, which is still zero; it should be reshaping and rotatingx_
(the input features) rather thanout
to preserve the computed values.
x_slice = out[:, slice_in].reshape(n, mul, -1)
# Assign values to the diagonal | ||
M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1) | ||
# Assign values to non-overlapping anti-diagonals | ||
overlap_mask = (global_row == global_col_anti) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it require a mask here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'overlap_mask' is used to avoid assigning sine values to matrix entries that are already filled with cosine values on the diagonal. It ensures that sine terms are only written to true off-diagonal (anti-diagonal) positions to prevent overwriting or conflict.
|
||
# Load static data | ||
sizes = idx_data["sizes"][:l_max+1] | ||
offsets = idx_data["offsets"][:l_max+1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is the offset defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is the code for generating the 'z_rot_indices_lmax12.pt' file:
import torch
l_max_all = 12
l_list = torch.arange(0, l_max_all + 1)
sizes = 2 * l_list + 1
offsets = torch.cat([torch.tensor([0]), torch.cumsum(sizes, 0)[:-1]])
L = len(l_list)
Mmax = sizes.max().item()
mask = torch.zeros(L, Mmax, dtype=torch.bool)
freq = torch.zeros(L, Mmax)
inds = torch.zeros(L, Mmax, dtype=torch.long)
reversed_inds = torch.zeros(L, Mmax, dtype=torch.long)
for i, l in enumerate(l_list):
sz = sizes[i]
mask[i, :sz] = True
freq[i, :sz] = torch.arange(l, -l-1, -1)
inds[i, :sz] = torch.arange(0, sz)
reversed_inds[i, :sz] = torch.arange(2 * l, -1, -1)
torch.save({
"mask": mask,
"freq": freq,
"inds": inds,
"reversed_inds": reversed_inds,
"l_list": l_list,
"sizes": sizes,
"offsets": offsets,
"Mmax": Mmax,
"l_max_all": l_max_all,
}, "/root/DeePTB/dptb/nn/z_rot_indices_lmax12.pt")
better handle m=0 and m_index from e3nn.o3 import xyz_to_angles, Irreps
import math
import torch
import torch.nn as nn
from torch.nn import Linear
import os
from collections import defaultdict
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False)
def wigner_D(l, alpha, beta, gamma):
if not l < len(_Jd):
raise NotImplementedError(
f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more"
)
alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
J = _Jd[l].to(dtype=alpha.dtype, device=alpha.device)
Xa = _z_rot_mat(alpha, l)
Xb = _z_rot_mat(beta, l)
Xc = _z_rot_mat(gamma, l)
return Xa @ J @ Xb @ J @ Xc
def _z_rot_mat(angle, l):
shape, device, dtype = angle.shape, angle.device, angle.dtype
M = angle.new_zeros((*shape, 2 * l + 1, 2 * l + 1))
inds = torch.arange(0, 2 * l + 1, 1, device=device)
reversed_inds = torch.arange(2 * l, -1, -1, device=device)
frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)
M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
return M
class SO2_Linear(torch.nn.Module):
"""
SO(2) Conv: Perform an SO(2) convolution on features corresponding to +- m
Args:
m (int): Order of the spherical harmonic coefficients
sphere_channels (int): Number of spherical channels
m_output_channels (int): Number of output channels used during the SO(2) conv
lmax_list (list:int): List of degrees (l) for each resolution
mmax_list (list:int): List of orders (m) for each resolution
"""
def __init__(
self,
irreps_in,
irreps_out,
radial_emb: bool = False,
latent_dim: int = None,
radial_channels: list = None,
extra_m0_outsize: int = 0,
):
super(SO2_Linear, self).__init__()
self.irreps_in = irreps_in.simplify()
self.irreps_out = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify()
self.radial_emb = radial_emb
self.latent_dim = latent_dim
self.m_linear = nn.ModuleList()
self.fc_m0 = Linear(self.irreps_in.num_irreps, self.irreps_out.num_irreps, bias=True)
for m in range(1, self.irreps_out.lmax + 1):
self.m_linear.append(SO2_m_Linear(m, self.irreps_in, self.irreps_out))
# generate m mask
self.m_in_mask = torch.zeros(self.irreps_in.lmax+1, self.irreps_in.dim, dtype=torch.bool)
self.m_out_mask = torch.zeros(self.irreps_in.lmax+1, self.irreps_out.dim, dtype=torch.bool)
if self.irreps_in.dim <= self.irreps_out.dim:
front = True
self.m_in_num = [0] * (self.irreps_in.lmax+1)
else:
front = False
self.m_in_num = [0] * (self.irreps_out.lmax+1)
offset = 0
for mul, (l, p) in self.irreps_in:
start_id = offset + torch.LongTensor(list(range(mul))) * (2 * l + 1)
for m in range(l+1):
self.m_in_mask[m, start_id+l+m] = True
self.m_in_mask[m, start_id+l-m] = True
if front:
self.m_in_num[m] += mul
offset += mul * (2 * l + 1)
# assert sum(self.m_in_num) == self.irreps_in.dim
offset = 0
for mul, (l, p) in self.irreps_out:
start_id = offset + torch.LongTensor(list(range(mul))) * (2 * l + 1)
for m in range(l+1):
if m <= self.irreps_in.lmax:
self.m_out_mask[m, start_id+l+m] = True
self.m_out_mask[m, start_id+l-m] = True
if not front:
self.m_in_num[m] += mul
offset += mul * (2 * l + 1)
self.m_in_index = [0] + list(torch.cumsum(torch.tensor(self.m_in_num), dim=0))
if radial_emb:
self.radial_emb = RadialFunction([latent_dim]+radial_channels+[self.m_in_index[-1]])
self.front = front
def forward(self, x, R, latents=None):
n, _ = x.shape
if self.radial_emb:
weights = self.radial_emb(latents)
# ======================================================================
# ====== Improved: group input irreps by angular quantum number ========
# ======================================================================
x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device)
R_transformed = R[:, [1, 2, 0]]
alpha, beta = xyz_to_angles(R_transformed)
gamma = torch.zeros_like(alpha)
# Group irreducible representations by angular quantum number
groups = defaultdict(list)
for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()):
groups[l].append((mul, slice_info))
if l == 0:
x_[:, slice_info] = x[:, slice_info]
# Process each l group
for l, group in groups.items():
if l == 0 or not group:
continue
muls, slices = zip(*group)
x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group]
x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1]
rot_mat = wigner_D(l, alpha, beta, gamma) # [n, 2l+1, 2l+1]
transformed = torch.bmm(x_combined, rot_mat) # [n, total_mul, 2l+1]
for part, slice_info in zip(transformed.split(muls, dim=1), slices):
x_[:, slice_info] = part.reshape(n, -1)
# ======================================================================
out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device)
for m in range(self.irreps_out.lmax+1):
radial_weight = weights[:, self.m_in_index[m]:self.m_in_index[m+1]].unsqueeze(1) if self.radial_emb else 1.
if m == 0:
if self.front and self.radial_emb:
out[:, self.m_out_mask[m]] += self.fc_m0(x_[:, self.m_in_mask[m]] * radial_weight.squeeze(1))
elif self.radial_emb:
out[:, self.m_out_mask[m]] += self.fc_m0(x_[:, self.m_in_mask[m]]) * radial_weight.squeeze(1)
else:
out[:, self.m_out_mask[m]] += self.fc_m0(x_[:, self.m_in_mask[m]])
else:
# 1. Prepare input data. .contiguous() is for in-place ops and performance.
# Shape becomes (n, 2, C_in_m / 2)
x_m_in = x_[:, self.m_in_mask[m]].reshape(n, -1, 2).transpose(1, 2).contiguous()
if self.front and self.radial_emb:
# Apply weight before linear layer. Use in-place mul_ to save memory.
x_m_in.mul_(radial_weight)
linear_output = self.m_linear[m - 1](x_m_in)
elif self.radial_emb:
# Apply weight after linear layer. Use in-place mul_ to save memory.
linear_output = self.m_linear[m - 1](x_m_in)
linear_output.mul_(radial_weight)
else:
# No radial embedding, just pass through the linear layer.
linear_output = self.m_linear[m - 1](x_m_in)
# 2. Reshape output and add to the result tensor.
# .contiguous() is necessary before .reshape() after a .transpose().
final_addition = linear_output.transpose(1, 2).contiguous().reshape(n, -1)
out[:, self.m_out_mask[m]] += final_addition
# ======================================================================
# ====== Improved: group input irreps by angular quantum number ========
# ======================================================================
out_groups = defaultdict(list)
for (mul, (l, p)), slice_info in zip(self.irreps_out, self.irreps_out.slices()):
out_groups[l].append((mul, slice_info))
for l, group in out_groups.items():
if l == 0 or not group:
continue
muls, slices = zip(*group)
out_parts = [out[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group]
out_combined = torch.cat(out_parts, dim=1)
rot_mat = wigner_D(l, alpha, beta, gamma)
transformed = torch.bmm(rot_mat, out_combined.transpose(1,2)).transpose(1,2) # [n, total_mul, 2l+1]
for part, slice_info in zip(transformed.split(muls, dim=1), slices):
out[:, slice_info] = part.reshape(n, -1)
out.contiguous()
return out
class SO2_m_Linear(torch.nn.Module):
"""
SO(2) Conv: Perform an SO(2) convolution on features corresponding to +- m
Args:
m (int): Order of the spherical harmonic coefficients
sphere_channels (int): Number of spherical channels
m_output_channels (int): Number of output channels used during the SO(2) conv
lmax_list (list:int): List of degrees (l) for each resolution
mmax_list (list:int): List of orders (m) for each resolution
"""
def __init__(
self,
m,
irreps_in,
irreps_out,
):
super(SO2_m_Linear, self).__init__()
self.m = m
self.irreps_in = irreps_in
self.irreps_out = irreps_out
assert self.irreps_in.lmax >= m
assert self.irreps_out.lmax >= m
self.num_in_channel = 0
for mul, (l, p) in self.irreps_in:
if l >= m:
self.num_in_channel += mul
self.num_out_channel = 0
for mul, (l, p) in self.irreps_out:
if l >= m:
self.num_out_channel += mul
self.fc = Linear(self.num_in_channel,
2 * self.num_out_channel,
bias=False)
self.fc.weight.data.mul_(1 / math.sqrt(2))
def forward(self, x_m):
# x_m ~ [N, 2, n_irreps_m]
x_m = self.fc(x_m)
x_r = x_m.narrow(2, 0, self.fc.out_features // 2)
x_i = x_m.narrow(2, self.fc.out_features // 2, self.fc.out_features // 2)
x_m_r = x_r.narrow(1, 0, 1) - x_i.narrow(1, 1, 1) #x_r[:, 0] - x_i[:, 1]
x_m_i = x_r.narrow(1, 1, 1) + x_i.narrow(1, 0, 1) #x_r[:, 1] + x_i[:, 0]
x_out = torch.cat((x_m_r, x_m_i), dim=1)
return x_out
class RadialFunction(nn.Module):
'''
Contruct a radial function (linear layers + layer normalization + SiLU) given a list of channels
'''
def __init__(self, channels_list):
super().__init__()
modules = []
input_channels = channels_list[0]
for i in range(len(channels_list)):
if i == 0:
continue
modules.append(nn.Linear(input_channels, channels_list[i], bias=True))
input_channels = channels_list[i]
if i == len(channels_list) - 1:
break
modules.append(nn.LayerNorm(channels_list[i]))
modules.append(torch.nn.SiLU())
self.net = nn.Sequential(*modules)
def forward(self, inputs):
return self.net(inputs) |
Optimize Wigner D Computation Using Block-Diagonal Batch Operations
Summary
This PR improves the efficiency of Wigner D matrix computation by eliminating per-
l
for-loops and replacing them with batch block-diagonal matrix operations. All rotation matrices and related terms are assembled and multiplied in a single step. This approach reduces redundant computation and significantly speeds up the process, especially for large batches or highl_max
.Key Changes
l
values are generated at once.self.irreps_in
andself.irreps_out
by quantum numberl
.torch.cat
andtorch.bmm
to batch multiple irreps of the samel
together for rotation.