Skip to content

Commit 33aeaa0

Browse files
committed
Fix imports
1 parent 60dadf2 commit 33aeaa0

File tree

5 files changed

+20
-20
lines changed

5 files changed

+20
-20
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
![Mamba](assets/selection.png "Selective State Space")
44
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
55
> Albert Gu*, Tri Dao*\
6-
> Paper: https://arxiv.org/abs/2312.00752
7-
> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
6+
> Paper: https://arxiv.org/abs/2312.00752\
7+
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
88
> Tri Dao*, Albert Gu*\
99
> Paper: https://arxiv.org/abs/2405.21060
1010

mamba_ssm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.0.0"
1+
__version__ = "2.0.1"
22

33
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
44
from mamba_ssm.modules.mamba_simple import Mamba

mamba_ssm/distributed/tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from einops import rearrange
1313

14-
from src.distributed.distributed_utils import (
14+
from mamba_ssm.distributed.distributed_utils import (
1515
all_gather_raw,
1616
all_reduce,
1717
all_reduce_raw,

mamba_ssm/ops/triton/ssd_chunk_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from einops import rearrange, repeat
1616

17-
from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
17+
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
1818

1919
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
2020

mamba_ssm/ops/triton/ssd_combined.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@
2424
except ImportError:
2525
causal_conv1d_fn, causal_conv1d_cuda = None, None
2626

27-
from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28-
from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29-
from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30-
from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31-
from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
32-
from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
33-
from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref
34-
from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
35-
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
36-
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
37-
from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
38-
from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
39-
from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
40-
from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
27+
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31+
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
32+
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
33+
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
34+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
35+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
36+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
37+
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
38+
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
39+
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
40+
from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
4141

4242
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
4343

@@ -651,7 +651,7 @@ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus
651651
Return:
652652
out: (batch, seqlen, nheads, headdim)
653653
"""
654-
from src.ops.selective_scan_interface import selective_scan_fn
654+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
655655

656656
batch, seqlen, nheads, headdim = x.shape
657657
_, _, ngroups, dstate = B.shape

0 commit comments

Comments
 (0)