Update modeling_motif.py
Browse files- modeling_motif.py +6 -5
modeling_motif.py
CHANGED
|
@@ -615,11 +615,12 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 615 |
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
| 616 |
if use_cache else position_embeddings)
|
| 617 |
|
| 618 |
-
query_states, key_states = apply_rotary_pos_emb(
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
| 623 |
|
| 624 |
if past_key_value is not None:
|
| 625 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
|
| 615 |
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
|
| 616 |
if use_cache else position_embeddings)
|
| 617 |
|
| 618 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 619 |
+
query_states,
|
| 620 |
+
key_states,
|
| 621 |
+
cos,
|
| 622 |
+
sin
|
| 623 |
+
)
|
| 624 |
|
| 625 |
if past_key_value is not None:
|
| 626 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|