Update modeling_mmMamba.py
Browse files- modeling_mmMamba.py +4 -2
modeling_mmMamba.py
CHANGED
|
@@ -421,7 +421,7 @@ class MHA_LM(nn.Module):
|
|
| 421 |
):
|
| 422 |
if self.rotary_emb_dim > 0:
|
| 423 |
q, kv = self.rotary_emb(
|
| 424 |
-
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 425 |
)
|
| 426 |
if inference_params is None:
|
| 427 |
k, v = kv.unbind(dim=-3)
|
|
@@ -550,7 +550,9 @@ class Mamba2_LM(nn.Module):
|
|
| 550 |
conv_state, ssm_state = None, None
|
| 551 |
if inference_params is not None:
|
| 552 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 553 |
-
|
|
|
|
|
|
|
| 554 |
if use_cache and inference_params.seqlen_offset==0:
|
| 555 |
vkq, new_conv_states = causal_conv1d_fn(
|
| 556 |
vkq.transpose(1, 2),
|
|
|
|
| 421 |
):
|
| 422 |
if self.rotary_emb_dim > 0:
|
| 423 |
q, kv = self.rotary_emb(
|
| 424 |
+
q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
|
| 425 |
)
|
| 426 |
if inference_params is None:
|
| 427 |
k, v = kv.unbind(dim=-3)
|
|
|
|
| 550 |
conv_state, ssm_state = None, None
|
| 551 |
if inference_params is not None:
|
| 552 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 553 |
+
conv_state = conv_state[:batch, ...]
|
| 554 |
+
ssm_state = ssm_state[:batch, ...]
|
| 555 |
+
|
| 556 |
if use_cache and inference_params.seqlen_offset==0:
|
| 557 |
vkq, new_conv_states = causal_conv1d_fn(
|
| 558 |
vkq.transpose(1, 2),
|