Update modeling_motif.py
Browse files- modeling_motif.py +2 -13
modeling_motif.py
CHANGED
|
@@ -839,7 +839,7 @@ MOTIF_ATTENTION_CLASSES = {
|
|
| 839 |
|
| 840 |
class MotifDecoderLayer(nn.Module):
|
| 841 |
|
| 842 |
-
def __init__(self, config: MotifConfig,
|
| 843 |
super().__init__()
|
| 844 |
self.hidden_size = config.hidden_size
|
| 845 |
if config.use_moreh_attention:
|
|
@@ -853,10 +853,6 @@ class MotifDecoderLayer(nn.Module):
|
|
| 853 |
else:
|
| 854 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
| 855 |
self.mlp = MotifMLP(config)
|
| 856 |
-
### moe
|
| 857 |
-
self.moe = None
|
| 858 |
-
if moe_layer:
|
| 859 |
-
self.moe = MotifMoE(config)
|
| 860 |
|
| 861 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
| 862 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -927,13 +923,7 @@ class MotifDecoderLayer(nn.Module):
|
|
| 927 |
residual = hidden_states
|
| 928 |
hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
|
| 929 |
|
| 930 |
-
|
| 931 |
-
hidden_states, identity = self.moe(hidden_states)
|
| 932 |
-
## add output of shared expert and output of small moe experts.
|
| 933 |
-
## hidden state must be zero tensor (for first forward)
|
| 934 |
-
hidden_states += self.mlp(identity)
|
| 935 |
-
else:
|
| 936 |
-
hidden_states = self.mlp(hidden_states)
|
| 937 |
|
| 938 |
hidden_states = residual + hidden_states
|
| 939 |
|
|
@@ -1114,7 +1104,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 1114 |
|
| 1115 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
| 1116 |
|
| 1117 |
-
logger.info(f'current_moe layer { moe_layer }')
|
| 1118 |
self.layers = nn.ModuleList([
|
| 1119 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
| 1120 |
])
|
|
|
|
| 839 |
|
| 840 |
class MotifDecoderLayer(nn.Module):
|
| 841 |
|
| 842 |
+
def __init__(self, config: MotifConfig, layer_idx: int):
|
| 843 |
super().__init__()
|
| 844 |
self.hidden_size = config.hidden_size
|
| 845 |
if config.use_moreh_attention:
|
|
|
|
| 853 |
else:
|
| 854 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
| 855 |
self.mlp = MotifMLP(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
| 858 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 923 |
residual = hidden_states
|
| 924 |
hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
|
| 925 |
|
| 926 |
+
hidden_states = self.mlp(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
hidden_states = residual + hidden_states
|
| 929 |
|
|
|
|
| 1104 |
|
| 1105 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
| 1106 |
|
|
|
|
| 1107 |
self.layers = nn.ModuleList([
|
| 1108 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
| 1109 |
])
|