-
Notifications
You must be signed in to change notification settings - Fork 22
Efficient Block-Diagonal Matrix Operations for Wigner D Computation #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,87 @@ | |
import torch.nn as nn | ||
from torch.nn import Linear | ||
import os | ||
import torch.nn.functional as F | ||
from collections import defaultdict | ||
|
||
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) | ||
_idx_data = torch.load(os.path.join(os.path.dirname(__file__), "z_rot_indices_lmax12.pt"), weights_only=False) | ||
|
||
|
||
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=False) | ||
def build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes): | ||
""" | ||
angle_stack: (3*N, ) # Input with alpha, beta, gamma stacked together | ||
l_max: int | ||
|
||
Returns: (Xa, Xb, Xc) # Each is of shape (N, D_total, D_total) | ||
""" | ||
N_all = angle_stack.shape[0] | ||
N = N_all // 3 | ||
|
||
D_total = sizes.sum().item() | ||
|
||
# Step 1: Vectorized computation of sine and cosine values | ||
angle_expand = angle_stack[None, :, None] # (1, 3N, 1) | ||
freq_expand = freq[:, None, :] # (L, 1, Mmax) | ||
sin_val = torch.sin(freq_expand * angle_expand) # (L, 3N, Mmax) | ||
cos_val = torch.cos(freq_expand * angle_expand) # (L, 3N, Mmax) | ||
|
||
# Step 2: Construct the block-diagonal matrix | ||
M_total = angle_stack.new_zeros((N_all, D_total, D_total)) | ||
idx_l, idx_row = torch.where(mask) # (K,), (K,) | ||
idx_col_diag = idx_row | ||
idx_col_anti = reversed_inds[idx_l, idx_row] | ||
global_row = offsets[idx_l] + idx_row # (K,) | ||
global_col_diag = offsets[idx_l] + idx_col_diag | ||
global_col_anti = offsets[idx_l] + idx_col_anti | ||
|
||
# Assign values to the diagonal | ||
M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1) | ||
# Assign values to non-overlapping anti-diagonals | ||
overlap_mask = (global_row == global_col_anti) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it require a mask here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 'overlap_mask' is used to avoid assigning sine values to matrix entries that are already filled with cosine values on the diagonal. It ensures that sine terms are only written to true off-diagonal (anti-diagonal) positions to prevent overwriting or conflict. |
||
M_total[:, global_row[~overlap_mask], global_col_anti[~overlap_mask]] = sin_val[idx_l[~overlap_mask], :, idx_row[~overlap_mask]].transpose(0,1) | ||
|
||
# Step 3: Split into three components corresponding to alpha, beta, gamma | ||
Xa = M_total[:N] | ||
Xb = M_total[N:2*N] | ||
Xc = M_total[2*N:] | ||
|
||
return Xa, Xb, Xc | ||
|
||
|
||
def batch_wigner_D(l_max, alpha, beta, gamma, _Jd): | ||
""" | ||
Compute Wigner D matrices for all L (from 0 to l_max) in a single batch. | ||
Returns a tensor of shape [N, D, D], where D = sum(2l+1 for l in 0..l_max). | ||
""" | ||
device = alpha.device | ||
N = alpha.shape[0] | ||
idx_data = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in _idx_data.items()} | ||
|
||
# Load static data | ||
sizes = idx_data["sizes"][:l_max+1] | ||
offsets = idx_data["offsets"][:l_max+1] | ||
Kai-Qi marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is the offset defined? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following is the code for generating the 'z_rot_indices_lmax12.pt' file: import torch l_max_all = 12 mask = torch.zeros(L, Mmax, dtype=torch.bool) for i, l in enumerate(l_list): torch.save({ |
||
mask = idx_data["mask"][:l_max+1] | ||
freq = idx_data["freq"][:l_max+1] | ||
reversed_inds = idx_data["reversed_inds"][:l_max+1] | ||
|
||
# Precompute block structure information | ||
dims = [2*l + 1 for l in range(l_max + 1)] | ||
offsets = torch.cumsum(torch.tensor([0] + dims[:-1], device='cuda'), 0) | ||
Kai-Qi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
D_total = sum(dims) | ||
|
||
# Construct block-diagonal J matrix | ||
J_full_small = torch.zeros(D_total, D_total, device='cuda') | ||
Kai-Qi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for l in range(l_max + 1): | ||
start = offsets[l] | ||
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l] | ||
|
||
J_full = J_full_small.unsqueeze(0).expand(N, -1, -1) | ||
angle_stack = torch.cat([alpha, beta, gamma], dim=0) | ||
Xa, Xb, Xc = build_z_rot_multi(angle_stack, mask, freq, reversed_inds, offsets, sizes) | ||
|
||
return Xa @ J_full @ Xb @ J_full @ Xc | ||
|
||
|
||
def wigner_D(l, alpha, beta, gamma): | ||
if not l < len(_Jd): | ||
|
@@ -54,7 +131,7 @@ def __init__( | |
extra_m0_outsize: int = 0, | ||
): | ||
super(SO2_Linear, self).__init__() | ||
|
||
|
||
self.irreps_in = irreps_in.simplify() | ||
self.irreps_out = (Irreps(f"{extra_m0_outsize}x0e") + irreps_out).simplify() | ||
|
@@ -105,19 +182,52 @@ def __init__( | |
self.radial_emb = RadialFunction([latent_dim]+radial_channels+[self.m_in_index[-1]]) | ||
self.front = front | ||
|
||
self.l_max = max(l for (_, (l, _)), _ in zip(self.irreps_in, self.irreps_in.slices()) if l > 0) | ||
self.dims = {l: 2*l + 1 for l in range(self.l_max + 1)} | ||
self.offsets = {} | ||
offset = 0 | ||
for l in range(self.l_max + 1): | ||
self.offsets[l] = offset | ||
offset += self.dims[l] | ||
|
||
|
||
def forward(self, x, R, latents=None): | ||
n, _ = x.shape | ||
|
||
if self.radial_emb: | ||
weights = self.radial_emb(latents) | ||
|
||
x_ = torch.zeros(n, self.irreps_in.dim, dtype=x.dtype, device=x.device) | ||
for (mul, (l,p)), slice in zip(self.irreps_in, self.irreps_in.slices()): | ||
if l > 0: | ||
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N)) | ||
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. | ||
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) | ||
x_[:, slice] = torch.einsum('nji,nmj->nmi', rot_mat_L, x[:, slice].reshape(n,mul,-1)).reshape(n,-1) | ||
x_ = torch.zeros_like(x) | ||
angle = xyz_to_angles(R[:, [1,2,0]]) | ||
|
||
# Compute Wigner D matrices for all l at once | ||
wigner_D_all = batch_wigner_D(self.l_max, angle[0], angle[1], torch.zeros_like(angle[0]), _Jd) | ||
|
||
# 1. group irreps by l | ||
groups = defaultdict(list) | ||
for (mul, (l, p)), slice_info in zip(self.irreps_in, self.irreps_in.slices()): | ||
groups[l].append((mul, slice_info)) | ||
|
||
# 2. Batch process all mul for each l | ||
for l, group in groups.items(): | ||
if l == 0 or not group: | ||
continue | ||
|
||
# Batch combination | ||
muls, slices = zip(*group) | ||
x_parts = [x[:, sl].reshape(n, mul, 2*l+1) for mul, sl in group] | ||
x_combined = torch.cat(x_parts, dim=1) # [n, total_mul, 2l+1] | ||
|
||
start = self.offsets[l] | ||
rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] | ||
|
||
# Batch feature rotation (n, total_mul, 2l+1) | ||
transformed = torch.bmm(x_combined, rot_mat) # (n, total_mul, 2l+1) | ||
|
||
# Split back into each slice in the original order | ||
for part, slice_info, mul in zip(transformed.split(muls, dim=1), slices, muls): | ||
x_[:, slice_info] = part.reshape(n, -1) | ||
|
||
|
||
out = torch.zeros(n, self.irreps_out.dim, dtype=x.dtype, device=x.device) | ||
for m in range(self.irreps_out.lmax+1): | ||
|
@@ -139,12 +249,14 @@ def forward(self, x, R, latents=None): | |
|
||
out.contiguous() | ||
|
||
for (mul, (l,p)), slice in zip(self.irreps_out, self.irreps_out.slices()): | ||
|
||
for (mul, (l, p)), slice_in in zip(self.irreps_out, self.irreps_out.slices()): | ||
if l > 0: | ||
angle = xyz_to_angles(R[:,[1,2,0]]) # (tensor(N), tensor(N)) | ||
# The roataion matrix is SO3 rotation, therefore Irreps(l,1), is used here. | ||
rot_mat_L = wigner_D(l, angle[0], angle[1], torch.zeros_like(angle[0])) | ||
out[:, slice] = torch.einsum('nij,nmj->nmi', rot_mat_L, out[:, slice].reshape(n,mul,-1)).reshape(n,-1) | ||
start = self.offsets[l] | ||
rot_mat = wigner_D_all[:, start:start+self.dims[l], start:start+self.dims[l]] | ||
x_slice = out[:, slice_in].reshape(n, mul, -1) | ||
rotated = torch.einsum('nij,nmj->nmi', rot_mat, x_slice) | ||
out[:, slice_in] = rotated.reshape(n, -1) | ||
|
||
return out | ||
|
||
|
@@ -225,4 +337,4 @@ def __init__(self, channels_list): | |
|
||
|
||
def forward(self, inputs): | ||
return self.net(inputs) | ||
return self.net(inputs) |
Uh oh!
There was an error while loading. Please reload this page.