Update modeling_mmMamba_embedding.py
Browse files
modeling_mmMamba_embedding.py
CHANGED
|
@@ -410,7 +410,7 @@ class MHA_LM(nn.Module):
|
|
| 410 |
):
|
| 411 |
if self.rotary_emb_dim > 0:
|
| 412 |
q, kv = self.rotary_emb(
|
| 413 |
-
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
| 414 |
)
|
| 415 |
if inference_params is None:
|
| 416 |
k, v = kv.unbind(dim=-3)
|
|
@@ -538,7 +538,9 @@ class Mamba2_LM(nn.Module):
|
|
| 538 |
conv_state, ssm_state = None, None
|
| 539 |
if inference_params is not None:
|
| 540 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 541 |
-
|
|
|
|
|
|
|
| 542 |
if use_cache and inference_params.seqlen_offset==0:
|
| 543 |
vkq, new_conv_states = causal_conv1d_fn(
|
| 544 |
vkq.transpose(1, 2),
|
|
|
|
| 410 |
):
|
| 411 |
if self.rotary_emb_dim > 0:
|
| 412 |
q, kv = self.rotary_emb(
|
| 413 |
+
q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
|
| 414 |
)
|
| 415 |
if inference_params is None:
|
| 416 |
k, v = kv.unbind(dim=-3)
|
|
|
|
| 538 |
conv_state, ssm_state = None, None
|
| 539 |
if inference_params is not None:
|
| 540 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
| 541 |
+
conv_state = conv_state[:batch, ...]
|
| 542 |
+
ssm_state = ssm_state[:batch, ...]
|
| 543 |
+
|
| 544 |
if use_cache and inference_params.seqlen_offset==0:
|
| 545 |
vkq, new_conv_states = causal_conv1d_fn(
|
| 546 |
vkq.transpose(1, 2),
|