Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,21 @@ def forward(self, x):
return x


class SimpleLoRALayer(nn.Module):
"""Simple LoRA layer for individual linear projections in the decoder."""
def __init__(self, original_layer: nn.Module, rank: int):
super().__init__()
self.original_layer = original_layer
self.lora_A = nn.Linear(original_layer.in_features, rank, bias=False)
self.lora_B = nn.Linear(rank, original_layer.out_features, bias=False)

nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)

def forward(self, x):
return self.original_layer(x) + self.lora_B(self.lora_A(x))


class PEFT_Sam(nn.Module):
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Expand All @@ -401,6 +416,7 @@ class PEFT_Sam(nn.Module):
For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32.
By default, applies the PEFT method to all attention layers.
quantize: Whether to quantize the model for lower precision training. By default, does not quantize the model.
decoder_lora: Whether to apply LoRA to the mask decoder. By default, does not apply LoRA to decoder.
module_kwargs: The additional arguments for the respective PEFT modules.
"""

Expand All @@ -411,16 +427,16 @@ def __init__(
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Optional[List[int]] = None,
quantize: bool = False,
decoder_lora: bool = False,
**module_kwargs
):
super().__init__()

if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")

assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
"Invalid PEFT module"
)
if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
else: # Applies PEFT to the image encoder by default
Expand Down Expand Up @@ -475,6 +491,9 @@ def __init__(
):
raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.")

# Filter out decoder-specific kwargs from module_kwargs
encoder_kwargs = {k: v for k, v in module_kwargs.items() if k != 'decoder_lora'}

for t_layer_i, blk in enumerate(model.image_encoder.blocks):

# If we only want specific layers with PEFT instead of all
Expand All @@ -484,10 +503,67 @@ def __init__(
if issubclass(self.peft_module, SelectiveSurgery):
self.peft_blocks.append(self.peft_module(block=blk))
else:
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **encoder_kwargs))

self.peft_blocks = nn.ModuleList(self.peft_blocks)

# Apply LoRA to the mask decoder if requested
if decoder_lora:
if not rank or rank <= 0:
raise RuntimeError("The decoder LoRA method cannot run without a valid rank choice.")

for param in model.mask_decoder.parameters():
param.requires_grad = False

self._apply_lora_to_decoder(model.mask_decoder, rank, module_kwargs.get('update_matrices', ["q", "v"]))

self.sam = model

def _apply_lora_to_decoder(self, decoder, rank, update_matrices):
"""Apply LoRA to the mask decoder transformer layers."""
for layer in decoder.transformer.layers:
# Self-attention
if "q" in update_matrices:
layer.self_attn.q_proj = SimpleLoRALayer(layer.self_attn.q_proj, rank)
if "v" in update_matrices:
layer.self_attn.v_proj = SimpleLoRALayer(layer.self_attn.v_proj, rank)
if "k" in update_matrices:
layer.self_attn.k_proj = SimpleLoRALayer(layer.self_attn.k_proj, rank)

# Cross-attention token to image
if "q" in update_matrices:
layer.cross_attn_token_to_image.q_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.q_proj, rank)
if "v" in update_matrices:
layer.cross_attn_token_to_image.v_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.v_proj, rank)
if "k" in update_matrices:
layer.cross_attn_token_to_image.k_proj = SimpleLoRALayer(layer.cross_attn_token_to_image.k_proj, rank)

# Cross-attention image to token
if "q" in update_matrices:
layer.cross_attn_image_to_token.q_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.q_proj, rank)
if "v" in update_matrices:
layer.cross_attn_image_to_token.v_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.v_proj, rank)
if "k" in update_matrices:
layer.cross_attn_image_to_token.k_proj = SimpleLoRALayer(layer.cross_attn_image_to_token.k_proj, rank)

# MLP layers
if "mlp" in update_matrices:
layer.mlp.lin1 = SimpleLoRALayer(layer.mlp.lin1, rank)
layer.mlp.lin2 = SimpleLoRALayer(layer.mlp.lin2, rank)

# Final attention layer
if "q" in update_matrices:
decoder.transformer.final_attn_token_to_image.q_proj = SimpleLoRALayer(
decoder.transformer.final_attn_token_to_image.q_proj, rank
)
if "v" in update_matrices:
decoder.transformer.final_attn_token_to_image.v_proj = SimpleLoRALayer(
decoder.transformer.final_attn_token_to_image.v_proj, rank
)
if "k" in update_matrices:
decoder.transformer.final_attn_token_to_image.k_proj = SimpleLoRALayer(
decoder.transformer.final_attn_token_to_image.k_proj, rank
)

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)
5 changes: 3 additions & 2 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def get_trainable_sam_model(

# we would want to "freeze" all the components in the model if passed a list of parts
for l_item in freeze:
# in case PEFT is switched on, we cannot freeze the image encoder
# in case PEFT is switched on, we cannot freeze the image encoder unless decoder_lora is enabled
if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")
if not peft_kwargs.get('decoder_lora', False):
raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")

if name.startswith(f"{l_item}"):
param.requires_grad = False
Expand Down
Loading