Skip to content

Refactor GroupNorm and log unmatched state_dict keys #989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion physicsnemo/launch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ def load_checkpoint(
model.load(file_name)
else:
file_to_load = _cache_if_needed(file_name)
model.load_state_dict(torch.load(file_to_load, map_location=device))
missing_keys, unexpected_keys = model.load_state_dict(
torch.load(file_to_load, map_location=device)
)
if missing_keys:
checkpoint_logging.warning(
f"Missing keys when loading {name}: {missing_keys}"
)
if unexpected_keys:
checkpoint_logging.warning(
f"Unexpected keys when loading {name}: {unexpected_keys}"
)

checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Conv2d,
FourierEmbedding,
GroupNorm,
get_group_norm,
Linear,
PositionalEmbedding,
UNetBlock,
Expand Down
4 changes: 2 additions & 2 deletions physicsnemo/models/diffusion/dhariwal_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

from physicsnemo.models.diffusion import (
Conv2d,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.diffusion.utils import _recursive_property
from physicsnemo.models.meta import ModelMetaData
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
attention=(res in attn_resolutions),
**block_kwargs,
)
self.out_norm = GroupNorm(num_channels=cout)
self.out_norm = get_group_norm(num_channels=cout)
self.out_conv = Conv2d(
in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
)
Expand Down
149 changes: 83 additions & 66 deletions physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,72 @@ def forward(self, x):
return x


def get_group_norm(
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
act: str = None,
amp_mode: bool = False,
):
"""
Utility function to get the GroupNorm layer, either from apex or from torch.

Parameters
----------
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False.
Notes
-----
If `num_channels` is not divisible by `num_groups`, the actual number of groups
might be adjusted to satisfy the `min_channels_per_group` condition.
"""

num_groups = min(
num_groups,
(num_channels + min_channels_per_group - 1) // min_channels_per_group,
)
if num_channels % num_groups != 0:
raise ValueError(
"num_channels must be divisible by num_groups or min_channels_per_group"
)

if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")

act = act.lower() if act else act
if use_apex_gn:
return ApexGroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
affine=True,
act=act,
)
else:
return GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
act=act,
amp_mode=amp_mode,
)
Comment on lines +374 to +388
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliusberner if I am not mistaken, the ApexGroupNorm and GroupNorm do not have the same functionality, because of the amp_mode argument.

Problem

For GroupNorm:

  1. when amp_mode=False, the parameters of the GroupNorm layer are casted to the dtype of the input tensor during the forward pass.
  2. when amp_mode=True, no casting is manually performed, as torch.autocast is supposed to be enabled

However, since ApexGroupNorm does not have an amp_mode argument, it will not be able to do casting as in case 1 above (case 2 is not a concern). From the user perspective that means that the code below:

x_bf16  # bf16 input
model_fp32 = MyModel(..., amp_mode=False)  # fp32 model
y = model_fp32(x_bf16)

will work if model_fp32 uses GroupNorm but will fail if it uses ApexGroupNorm.
IMO, this is unexpected behavior and should be avoided.

Solution

An easy solution could be to subclass ApexGroupNorm with something like:

[...]
_ApexGroupNorm = getattr(apex_gn_module, "GroupNorm")

class ApexGroupNorm(_ApexGroupNorm):
    def __init__(..., amp_mode=False):
        super().__init__(...)
        self.amp_mode = amp_mode
    
    def forward(self, x):

        # Do the casting the same way as in `GroupNorm`
        weight, bias = self.weight, self.bias
        _validate_amp(self.amp_mode)
        if not self.amp_mode:
            if weight.dtype != x.dtype:
                weight = self.weight.to(x.dtype)
            if bias.dtype != x.dtype:
                bias = self.bias.to(x.dtype)

       # Call forward from parent class
       super().forward(x)        

Maybe @jialusui1102 can chime in if this ApexGN subclass is a viable solution or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think this check would have also failed before this MR. One question is whether we want to support the example that you provided? Typically, one requires the user to call model_fp32.to(dtype) before feeding inputs with a different dtype.

Copy link
Collaborator

@CharlelieLrt CharlelieLrt Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think this check would have also failed before this MR

Yes, very possible. But let's try to fix it here.

Typically, one requires the user to call model_fp32.to(dtype) before feeding inputs with a different dtype

I thought so as well, but then what is the purpose of the amp_mode argument? The only thing it does are these manual casting... @jialusui1102 I think you introduced the amp_mode argument, any thought on this?



class GroupNorm(torch.nn.Module):
"""
A custom Group Normalization layer implementation.
Expand All @@ -333,22 +399,13 @@ class GroupNorm(torch.nn.Module):

Parameters
----------
num_groups : int
Desired number of groups to divide the input channels.
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
min_channels_per_group : int, optional
Minimum channels required per group. This ensures that no group has fewer
channels than this number. By default 4.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
fused_act : bool, optional
Whether to fuse the activation function with GroupNorm. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
Expand All @@ -361,56 +418,22 @@ class GroupNorm(torch.nn.Module):

def __init__(
self,
num_groups: int,
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
fused_act: bool = False,
act: str = None,
amp_mode: bool = False,
):
if fused_act and act is None:
raise ValueError("'act' must be specified when 'fused_act' is set to True.")

super().__init__()
self.num_groups = min(
num_groups,
(num_channels + min_channels_per_group - 1) // min_channels_per_group,
)
if num_channels % self.num_groups != 0:
raise ValueError(
"num_channels must be divisible by num_groups or min_channels_per_group"
)
self.num_groups = num_groups
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(num_channels))
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")
self.use_apex_gn = use_apex_gn
self.fused_act = fused_act
self.act = act.lower() if act else act
self.act_fn = None
self.amp_mode = amp_mode
if self.use_apex_gn:
if self.act:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
act=self.act,
)

else:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
)
if self.fused_act:
if self.act is not None:
self.act_fn = self.get_activation_function()
self.amp_mode = amp_mode

def forward(self, x):
if (not x.is_cuda) and self.use_apex_gn:
Expand All @@ -420,14 +443,12 @@ def forward(self, x):
weight, bias = self.weight, self.bias
_validate_amp(self.amp_mode)
if not self.amp_mode:
if not self.use_apex_gn:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)
if self.use_apex_gn:
x = self.gn(x)
elif self.training:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)

if self.training:
# Use default torch implementation of GroupNorm for training
# This does not support channels last memory format
x = torch.nn.functional.group_norm(
Expand All @@ -437,8 +458,6 @@ def forward(self, x):
bias=bias,
eps=self.eps,
)
if self.fused_act:
x = self.act_fn(x)
else:
# Use custom GroupNorm implementation that supports channels last
# memory layout for inference
Expand All @@ -454,8 +473,8 @@ def forward(self, x):
bias = rearrange(bias, "c -> 1 c 1 1")
x = x * weight + bias

if self.fused_act:
x = self.act_fn(x)
if self.act_fn is not None:
x = self.act_fn(x)
return x

def get_activation_function(self):
Expand Down Expand Up @@ -731,11 +750,10 @@ def __init__(
self.adaptive_scale = adaptive_scale
self.profile_mode = profile_mode
self.amp_mode = amp_mode
self.norm0 = GroupNorm(
self.norm0 = get_group_norm(
num_channels=in_channels,
eps=eps,
use_apex_gn=use_apex_gn,
fused_act=True,
act=act,
amp_mode=amp_mode,
)
Expand All @@ -757,19 +775,18 @@ def __init__(
**init,
)
if self.adaptive_scale:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
amp_mode=amp_mode,
)
else:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
act=act,
fused_act=True,
amp_mode=amp_mode,
)
self.conv1 = Conv2d(
Expand Down
8 changes: 5 additions & 3 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from physicsnemo.models.diffusion import (
Conv2d,
FourierEmbedding,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.diffusion.utils import _recursive_property
from physicsnemo.models.meta import ModelMetaData
Expand Down Expand Up @@ -487,7 +487,7 @@ def __init__(
resample_filter=resample_filter,
amp_mode=amp_mode,
)
self.dec[f"{res}x{res}_aux_norm"] = GroupNorm(
self.dec[f"{res}x{res}_aux_norm"] = get_group_norm(
num_channels=cout,
eps=1e-6,
use_apex_gn=use_apex_gn,
Expand Down Expand Up @@ -861,7 +861,9 @@ def __init__(
if self.gridtype == "learnable":
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
self.register_buffer(
"pos_embd", self._get_positional_embedding().float(), persistent=False
)
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
self.lead_time_channels = lead_time_channels
Expand Down
Loading