|
24 | 24 | except ImportError:
|
25 | 25 | causal_conv1d_fn, causal_conv1d_cuda = None, None
|
26 | 26 |
|
27 |
| -from src.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd |
28 |
| -from src.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd |
29 |
| -from src.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db |
30 |
| -from src.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable |
31 |
| -from src.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref |
32 |
| -from src.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd |
33 |
| -from src.ops.triton.ssd_state_passing import state_passing, state_passing_ref |
34 |
| -from src.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates |
35 |
| -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb |
36 |
| -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable |
37 |
| -from src.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref |
38 |
| -from src.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev |
39 |
| -from src.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd |
40 |
| -from src.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd |
| 27 | +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd |
| 28 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd |
| 29 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db |
| 30 | +from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable |
| 31 | +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref |
| 32 | +from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd |
| 33 | +from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref |
| 34 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates |
| 35 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb |
| 36 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable |
| 37 | +from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref |
| 38 | +from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev |
| 39 | +from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd |
| 40 | +from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd |
41 | 41 |
|
42 | 42 | TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
43 | 43 |
|
@@ -651,7 +651,7 @@ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus
|
651 | 651 | Return:
|
652 | 652 | out: (batch, seqlen, nheads, headdim)
|
653 | 653 | """
|
654 |
| - from src.ops.selective_scan_interface import selective_scan_fn |
| 654 | + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn |
655 | 655 |
|
656 | 656 | batch, seqlen, nheads, headdim = x.shape
|
657 | 657 | _, _, ngroups, dstate = B.shape
|
|
0 commit comments