From 72e22e86c358b6e831b55fd4d37c63b305c71d9b Mon Sep 17 00:00:00 2001 From: chenxiao Date: Sat, 21 Jun 2025 11:53:58 +0800 Subject: [PATCH 1/2] Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) --- src/diffusers/models/transformers/transformer_cosmos.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 6c312b7a5a3f..2ffb4ae41b33 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -186,9 +186,9 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) # 4. Prepare for GQA - query_idx = torch.tensor(query.size(3), device=query.device) - key_idx = torch.tensor(key.size(3), device=key.device) - value_idx = torch.tensor(value.size(3), device=value.device) + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) key = key.repeat_interleave(query_idx // key_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3) From fff9e31eb7a389f70e078c3efdcfdeb3ad609966 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 10 Jul 2025 22:23:04 +0200 Subject: [PATCH 2/2] up --- .../models/transformers/transformer_cosmos.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 2ffb4ae41b33..59569a7a14d0 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -186,9 +186,15 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) # 4. Prepare for GQA - query_idx = query.size(3) - key_idx = key.size(3) - value_idx = value.size(3) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) key = key.repeat_interleave(query_idx // key_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)