Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ This codebase is built based on MosaicML's amazing [Composer package](https://gi
## Install Requirements
**Step 1**: To get started with this repository, you'll need to follow these installation steps. Before proceeding, make sure you have [Pytorch](https://pytorch.org/get-started/previous-versions/) and [Flash Attention](https://github.com/Dao-AILab/flash-attention) installed. You can do this via pip using the following commands:
```
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn==1.0.3.post
# pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
# pip install flash-attn==1.0.3.post
pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121
pip install "flash-attn==2.3.2"
```
Please note that Flash Attention version 2 is not currently supported and may require manual modifications to the model file.
Update: the `flash-attn` version and corresponding interface in model are updated. Now it's compatible with Flash Attention 2.

**Step 2**: Then install the rest of the required packages:
```
cd llmshearing
pip install -r requirement.txt
```

Expand Down
36 changes: 26 additions & 10 deletions llmshearing/models/composer_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def flash_attn_fn(
try:
from flash_attn import bert_padding # type: ignore
from flash_attn import flash_attn_interface # type: ignore
from flash_attn import flash_attn_func, flash_attn_varlen_func # for flash-attn-2
except ImportError as e:
raise e

Expand Down Expand Up @@ -815,18 +816,33 @@ def flash_attn_fn(

dropout_p = dropout_p if training else 0.0

output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
# output_unpad = flash_attn_interface.flash_attn_unpadded_func(
# query_unpad,
# key_unpad,
# value_unpad,
# cu_seqlens_q,
# cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
# dropout_p,
# softmax_scale=softmax_scale,
# causal=is_causal,
# return_attn_probs=needs_weights)

# the flash-attn-2 interface
output_unpad = flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=is_causal,
return_attn_probs=needs_weights)
return_attn_probs=needs_weights,
)

if head_z is not None:
output_unpad = output_unpad * head_z # 1 * h * 1
Expand Down
12 changes: 11 additions & 1 deletion llmshearing/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,23 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None:

target = target.view(-1)
logits = logits.view(target.shape[0], -1)
losses = self.loss_fn(logits, target)
# losses = self.loss_fn(logits, target)

total_items = (target != self.ignore_index).sum()
if total_items.item() == 0:
return # 👈 skip update to avoid NaN

losses = self.loss_fn(logits, target)
self.total_items += total_items #type: ignore (third-party)

# accumulate loss over all batches
self.sum_loss += losses.to(torch.float32)

# override base class, to avoid zero division
def compute(self) -> torch.Tensor:
if self.total_items == 0:
return torch.tensor(0.0, dtype=torch.float32, device=self.sum_loss.device)
return self.sum_loss / self.total_items


class DomainCount(Metric):
Expand Down
7 changes: 6 additions & 1 deletion llmshearing/utils/test_composer_hf_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def construct_example_cfg(model_size, path=None, add_l0_module=False):
cfg = om.create({"name": "mosaic_llama_65b", "path": path,"init_device": "cpu", "d_model": 8192, "n_heads": 64, "n_layers": 80, "intermediate_size": 22016})

# add default values
cfg = om.merge(cfg, om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5}))
# cfg = om.merge(cfg, om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5}))
# order reversed. Use cfg to override default, instead of the opposite
cfg = om.merge(
om.create({"max_seq_len": 4096, "vocab_size": 32000, "init_std": 0.02, "attn_pdrop": 0.0, "resid_pdrop": 0.0, "emb_pdrop": 0.0, "attn_impl": "flash", "rms_norm_eps": 1e-5}),
cfg
)
if add_l0_module:
cfg["l0_module"] = {"start_sparsity": 0, "target_sparsity": 0.6, "pruning_modules": ["head", "head_layer", "mlp", "intermediate", "hidden"], "lagrangian_warmup_steps": "320ba"}
return cfg
Expand Down