Update modeling_motif.py
Browse files- modeling_motif.py +2 -47
modeling_motif.py
CHANGED
|
@@ -409,35 +409,6 @@ class MotifAttention(nn.Module):
|
|
| 409 |
self.num_key_value_heads //= 2
|
| 410 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
| 411 |
|
| 412 |
-
##mix attn
|
| 413 |
-
|
| 414 |
-
self.mix_attn = config.mix_attn
|
| 415 |
-
|
| 416 |
-
if self.mix_attn:
|
| 417 |
-
|
| 418 |
-
self.cq, self.ck = 6, 11
|
| 419 |
-
self.ch = 2
|
| 420 |
-
|
| 421 |
-
self.key_query_conv = nn.Conv2d(
|
| 422 |
-
in_channels=self.num_heads*2,
|
| 423 |
-
out_channels=self.num_heads*2,
|
| 424 |
-
kernel_size=(self.cq, self.ck),
|
| 425 |
-
padding="same",
|
| 426 |
-
groups=self.num_heads*2
|
| 427 |
-
)
|
| 428 |
-
|
| 429 |
-
self.head_conv = nn.Conv1d(
|
| 430 |
-
in_channels=self.num_heads,
|
| 431 |
-
out_channels=self.num_heads,
|
| 432 |
-
kernel_size=1,
|
| 433 |
-
padding=0,
|
| 434 |
-
groups=self.num_heads // self.ch
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
self.group_norm = nn.GroupNorm(num_groups=self.num_heads, num_channels=self.num_heads)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
# re-init projections
|
| 442 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 443 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
|
@@ -516,12 +487,6 @@ class MotifAttention(nn.Module):
|
|
| 516 |
attention_mask = torch.triu(
|
| 517 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
| 518 |
1 + offset)
|
| 519 |
-
##attn weights conv2d, softmax and add attention_mask
|
| 520 |
-
if self.mix_attn:
|
| 521 |
-
## condition mask==0, value : 0
|
| 522 |
-
attn_weights = attn_weights.masked_fill( attention_mask == 0, 0)
|
| 523 |
-
attn_weights = self.key_query_conv(attn_weights)
|
| 524 |
-
attn_weights = attn_weights[:, :, :kv_seq_len, :kv_seq_len]
|
| 525 |
|
| 526 |
###add attn
|
| 527 |
attn_weights = attn_weights + attention_mask
|
|
@@ -536,11 +501,6 @@ class MotifAttention(nn.Module):
|
|
| 536 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
| 537 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
| 538 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
| 539 |
-
##head_conv
|
| 540 |
-
if self.mix_attn:
|
| 541 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, -1).contiguous()
|
| 542 |
-
attn_weights = self.head_conv(attn_weights)
|
| 543 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, q_len, -1).contiguous()
|
| 544 |
|
| 545 |
##shape : bsz, #heads, seq, head_dim
|
| 546 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
@@ -552,9 +512,7 @@ class MotifAttention(nn.Module):
|
|
| 552 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
|
| 553 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 554 |
f" {attn_output.size()}")
|
| 555 |
-
|
| 556 |
-
attn_output = self.group_norm(attn_output)
|
| 557 |
-
|
| 558 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 559 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 560 |
|
|
@@ -889,10 +847,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 889 |
logger.warning_once(
|
| 890 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
| 891 |
"unexpected results may be encountered.")
|
| 892 |
-
|
| 893 |
-
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 894 |
-
else:
|
| 895 |
-
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
| 896 |
self.mlp = MotifMLP(config)
|
| 897 |
|
| 898 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
|
|
|
| 409 |
self.num_key_value_heads //= 2
|
| 410 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
# re-init projections
|
| 413 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 414 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
|
|
|
| 487 |
attention_mask = torch.triu(
|
| 488 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
| 489 |
1 + offset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
###add attn
|
| 492 |
attn_weights = attn_weights + attention_mask
|
|
|
|
| 501 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
| 502 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
| 503 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
##shape : bsz, #heads, seq, head_dim
|
| 506 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
| 512 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
|
| 513 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 514 |
f" {attn_output.size()}")
|
| 515 |
+
|
|
|
|
|
|
|
| 516 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 517 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 518 |
|
|
|
|
| 847 |
logger.warning_once(
|
| 848 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
| 849 |
"unexpected results may be encountered.")
|
| 850 |
+
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
|
|
|
|
|
|
|
|
| 851 |
self.mlp = MotifMLP(config)
|
| 852 |
|
| 853 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|