diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index bf181ff5d..cbe9fde1a 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -390,6 +390,21 @@ def forward(self, x): return x +class SimpleLoRALayer(nn.Module): + """Simple LoRA layer for individual linear projections in the decoder.""" + def __init__(self, original_layer: nn.Module, rank: int): + super().__init__() + self.original_layer = original_layer + self.lora_A = nn.Linear(original_layer.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, original_layer.out_features, bias=False) + + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x): + return self.original_layer(x) + self.lora_B(self.lora_A(x)) + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. @@ -401,6 +416,7 @@ class PEFT_Sam(nn.Module): For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32. By default, applies the PEFT method to all attention layers. quantize: Whether to quantize the model for lower precision training. By default, does not quantize the model. + decoder_lora: Whether to apply LoRA to the mask decoder. By default, does not apply LoRA to decoder. module_kwargs: The additional arguments for the respective PEFT modules. """ @@ -411,16 +427,16 @@ def __init__( peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Optional[List[int]] = None, quantize: bool = False, + decoder_lora: bool = False, **module_kwargs ): super().__init__() - if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): - raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( "Invalid PEFT module" ) + if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): + raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") if attention_layers_to_update: self.peft_layers = attention_layers_to_update else: # Applies PEFT to the image encoder by default @@ -475,6 +491,9 @@ def __init__( ): raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.") + # Filter out decoder-specific kwargs from module_kwargs + encoder_kwargs = {k: v for k, v in module_kwargs.items() if k != 'decoder_lora'} + for t_layer_i, blk in enumerate(model.image_encoder.blocks): # If we only want specific layers with PEFT instead of all @@ -484,10 +503,67 @@ def __init__( if issubclass(self.peft_module, SelectiveSurgery): self.peft_blocks.append(self.peft_module(block=blk)) else: - self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) + self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **encoder_kwargs)) self.peft_blocks = nn.ModuleList(self.peft_blocks) + + # Apply LoRA to the mask decoder if requested + if decoder_lora: + if not rank or rank <= 0: + raise RuntimeError("The decoder LoRA method cannot run without a valid rank choice.") + + for param in model.mask_decoder.parameters(): + param.requires_grad = False + + self._apply_lora_to_decoder(model.mask_decoder, rank, module_kwargs.get('update_matrices', ["q", "v"])) + self.sam = model + def _apply_lora_to_decoder(self, decoder, rank, update_matrices): + """Apply LoRA to the mask decoder transformer layers.""" + for layer in decoder.transformer.layers: + # Self-attention + if "q" in update_matrices: + layer.self_attn.q_proj = SimpleLoRALayer(layer.self_attn.q_proj, rank) + if "v" in update_matrices: + layer.self_attn.v_proj = SimpleLoRALayer(layer.self_attn.v_proj, rank) + if "k" in update_matrices: + layer.self_attn.k_proj = SimpleLoRALayer(layer.self_attn.k_proj, rank) + + # Cross-attention token to image + if "q" in update_matrices: + layer.cross_attn_token_to_image.q_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.q_proj, rank) + if "v" in update_matrices: + layer.cross_attn_token_to_image.v_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.v_proj, rank) + if "k" in update_matrices: + layer.cross_attn_token_to_image.k_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.k_proj, rank) + + # Cross-attention image to token + if "q" in update_matrices: + layer.cross_attn_image_to_token.q_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.q_proj, rank) + if "v" in update_matrices: + layer.cross_attn_image_to_token.v_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.v_proj, rank) + if "k" in update_matrices: + layer.cross_attn_image_to_token.k_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.k_proj, rank) + + # MLP layers + if "mlp" in update_matrices: + layer.mlp.lin1 = SimpleLoRALayer(layer.mlp.lin1, rank) + layer.mlp.lin2 = SimpleLoRALayer(layer.mlp.lin2, rank) + + # Final attention layer + if "q" in update_matrices: + decoder.transformer.final_attn_token_to_image.q_proj = SimpleLoRALayer( + decoder.transformer.final_attn_token_to_image.q_proj, rank + ) + if "v" in update_matrices: + decoder.transformer.final_attn_token_to_image.v_proj = SimpleLoRALayer( + decoder.transformer.final_attn_token_to_image.v_proj, rank + ) + if "k" in update_matrices: + decoder.transformer.final_attn_token_to_image.k_proj = SimpleLoRALayer( + decoder.transformer.final_attn_token_to_image.k_proj, rank + ) + def forward(self, batched_input, multimask_output): return self.sam(batched_input, multimask_output) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 01f503eb1..cfadb5387 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -135,9 +135,10 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: - # in case PEFT is switched on, we cannot freeze the image encoder + # in case PEFT is switched on, we cannot freeze the image encoder unless decoder_lora is enabled if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): - raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") + if not peft_kwargs.get('decoder_lora', False): + raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): param.requires_grad = False