Update modeling_motif.py
Browse files- modeling_motif.py +12 -48
modeling_motif.py
CHANGED
|
@@ -328,23 +328,10 @@ class MotifMLP(nn.Module):
|
|
| 328 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 329 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 330 |
|
| 331 |
-
if config.wesar_weights:
|
| 332 |
-
self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
|
| 333 |
-
self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
|
| 334 |
-
else:
|
| 335 |
-
self.gate_up_proj_alpha=1
|
| 336 |
-
self.down_proj_alpha=1
|
| 337 |
-
if config.muP:
|
| 338 |
-
self.down_proj.__do_scale_tager__ = True
|
| 339 |
-
self.gate_proj.__do_scale_tager_mu_dim_model__ = True
|
| 340 |
-
self.up_proj.__do_scale_tager_mu_dim_model__ = True
|
| 341 |
-
self.down_proj.__do_scale_tager_mu_ffn__ = True
|
| 342 |
-
|
| 343 |
-
|
| 344 |
def forward(self, hidden_state):
|
| 345 |
-
hidden_state = hidden_state
|
| 346 |
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
| 347 |
-
return self.
|
| 348 |
|
| 349 |
|
| 350 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -470,13 +457,6 @@ class MotifAttention(nn.Module):
|
|
| 470 |
max_position_embeddings=self.max_position_embeddings,
|
| 471 |
base=self.rope_theta)
|
| 472 |
|
| 473 |
-
for param in ["q_proj_alpha", "k_proj_alpha", "v_proj_alpha", "o_proj_alpha"]:
|
| 474 |
-
setattr(
|
| 475 |
-
self, param,
|
| 476 |
-
nn.Parameter(torch.tensor(getattr(config, param, 1.0), dtype=torch.float))
|
| 477 |
-
if config.wesar_weights else 1.0)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
def forward(
|
| 481 |
self,
|
| 482 |
hidden_states: torch.Tensor,
|
|
@@ -490,9 +470,9 @@ class MotifAttention(nn.Module):
|
|
| 490 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 491 |
bsz, q_len, _ = hidden_states.size()
|
| 492 |
|
| 493 |
-
query_states = self.q_proj(hidden_states)
|
| 494 |
-
key_states = self.k_proj(hidden_states)
|
| 495 |
-
value_states = self.v_proj(hidden_states)
|
| 496 |
|
| 497 |
## bsz, seq, n_heads, head_dim
|
| 498 |
|
|
@@ -685,9 +665,9 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 685 |
):
|
| 686 |
bsz, q_len, _ = hidden_states.size()
|
| 687 |
|
| 688 |
-
query_states = self.q_proj(hidden_states)
|
| 689 |
-
key_states = self.k_proj(hidden_states)
|
| 690 |
-
value_states = self.v_proj(hidden_states)
|
| 691 |
|
| 692 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 693 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
@@ -798,7 +778,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 798 |
f" {attn_output.size()}")
|
| 799 |
|
| 800 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 801 |
-
attn_output = self.o_proj(attn_output)
|
| 802 |
|
| 803 |
return attn_output, None, past_key_value
|
| 804 |
|
|
@@ -919,15 +899,6 @@ class MotifDecoderLayer(nn.Module):
|
|
| 919 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 920 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 921 |
|
| 922 |
-
if config.wesar_weights and config.use_norm_alpha:
|
| 923 |
-
self.input_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
| 924 |
-
else:
|
| 925 |
-
self.input_layernorm_alpha = 1
|
| 926 |
-
|
| 927 |
-
if config.wesar_weights and config.use_norm_alpha :
|
| 928 |
-
self.post_attention_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
| 929 |
-
else:
|
| 930 |
-
self.post_attention_layernorm_alpha = 1
|
| 931 |
|
| 932 |
def forward(
|
| 933 |
self,
|
|
@@ -965,7 +936,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 965 |
|
| 966 |
residual = hidden_states
|
| 967 |
|
| 968 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 969 |
|
| 970 |
# Self Attention
|
| 971 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
@@ -982,7 +953,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 982 |
|
| 983 |
# Fully Connected
|
| 984 |
residual = hidden_states
|
| 985 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 986 |
hidden_states = self.mlp(hidden_states)
|
| 987 |
hidden_states = residual + hidden_states
|
| 988 |
|
|
@@ -1199,14 +1170,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1199 |
self.post_init()
|
| 1200 |
|
| 1201 |
self.scale_emb = 1
|
| 1202 |
-
|
| 1203 |
-
# Reparameterization <|_1_|>
|
| 1204 |
-
if config.wesar_weights :
|
| 1205 |
-
logger.info(f'config.wesar_weights {config.wesar_weights}')
|
| 1206 |
-
self.norm_alpha = nn.Parameter(torch.tensor(1).float())
|
| 1207 |
-
self.scale_emb = 10
|
| 1208 |
-
else:
|
| 1209 |
-
self.norm_alpha = 1
|
| 1210 |
|
| 1211 |
def get_input_embeddings(self):
|
| 1212 |
return self.embed_tokens
|
|
|
|
| 328 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 329 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
def forward(self, hidden_state):
|
| 332 |
+
hidden_state = hidden_state
|
| 333 |
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
| 334 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 335 |
|
| 336 |
|
| 337 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 457 |
max_position_embeddings=self.max_position_embeddings,
|
| 458 |
base=self.rope_theta)
|
| 459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
def forward(
|
| 461 |
self,
|
| 462 |
hidden_states: torch.Tensor,
|
|
|
|
| 470 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 471 |
bsz, q_len, _ = hidden_states.size()
|
| 472 |
|
| 473 |
+
query_states = self.q_proj(hidden_states)
|
| 474 |
+
key_states = self.k_proj(hidden_states)
|
| 475 |
+
value_states = self.v_proj(hidden_states)
|
| 476 |
|
| 477 |
## bsz, seq, n_heads, head_dim
|
| 478 |
|
|
|
|
| 665 |
):
|
| 666 |
bsz, q_len, _ = hidden_states.size()
|
| 667 |
|
| 668 |
+
query_states = self.q_proj(hidden_states)
|
| 669 |
+
key_states = self.k_proj(hidden_states)
|
| 670 |
+
value_states = self.v_proj(hidden_states)
|
| 671 |
|
| 672 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 673 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
| 778 |
f" {attn_output.size()}")
|
| 779 |
|
| 780 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 781 |
+
attn_output = self.o_proj(attn_output)
|
| 782 |
|
| 783 |
return attn_output, None, past_key_value
|
| 784 |
|
|
|
|
| 899 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 900 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 901 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
|
| 903 |
def forward(
|
| 904 |
self,
|
|
|
|
| 936 |
|
| 937 |
residual = hidden_states
|
| 938 |
|
| 939 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 940 |
|
| 941 |
# Self Attention
|
| 942 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
|
| 953 |
|
| 954 |
# Fully Connected
|
| 955 |
residual = hidden_states
|
| 956 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 957 |
hidden_states = self.mlp(hidden_states)
|
| 958 |
hidden_states = residual + hidden_states
|
| 959 |
|
|
|
|
| 1170 |
self.post_init()
|
| 1171 |
|
| 1172 |
self.scale_emb = 1
|
| 1173 |
+
self.norm_alpha = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
|
| 1175 |
def get_input_embeddings(self):
|
| 1176 |
return self.embed_tokens
|