AMPLIFY_350M / amplify_te.py
pstjohn's picture
Upload folder using huggingface_hub
ae26e36 verified
raw
history blame
9.28 kB
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://huggingface.co/chandar-lab/AMPLIFY_120M/blob/main/amplify.py
import torch
from torch import nn
import transformer_engine.pytorch
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput
class AMPLIFYConfig(PretrainedConfig):
model_type = "AMPLIFY"
# All config parameters must have a default value.
def __init__(
self,
hidden_size: int = 960,
num_hidden_layers: int = 32,
num_attention_heads: int = 15,
intermediate_size: int = 3840,
dropout_prob: float = 0,
embedding_init_range: float = 0.02,
decoder_init_range: float = 0.02,
rms_norm: bool = True,
norm_eps: float = 1e-05,
hidden_act: str = "SwiGLU",
layer_norm_after_embedding: bool = False,
layer_norm_before_last_layer: bool = True,
vocab_size: int = 27,
ffn_bias: bool = False,
att_bias: bool = False,
pad_token_id: int = 0,
max_length: int = 2048,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout_prob = dropout_prob
self.embedding_init_range = embedding_init_range
self.decoder_init_range = decoder_init_range
self.rms_norm = rms_norm
self.norm_eps = norm_eps
self.hidden_act = hidden_act
self.layer_norm_after_embedding = layer_norm_after_embedding
self.layer_norm_before_last_layer = layer_norm_before_last_layer
self.vocab_size = vocab_size
self.ffn_bias = ffn_bias
self.att_bias = att_bias
self.pad_token_id = pad_token_id
self.max_length = max_length
class AMPLIFYPreTrainedModel(PreTrainedModel):
config_class = AMPLIFYConfig
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.uniform_(
-self.config.decoder_init_range, self.config.decoder_init_range
)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.uniform_(
-self.config.embedding_init_range, self.config.embedding_init_range
)
class AMPLIFY(AMPLIFYPreTrainedModel):
"""The main model class.
Args:
config (amplify.model.amplify.AMPLIFYConfig): model configuration.
"""
def __init__(self, config: AMPLIFYConfig, **kwargs):
super().__init__(config)
self.config = config
self.encoder = nn.Embedding(
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id,
dtype=config.torch_dtype,
)
if config.layer_norm_after_embedding:
self.layer_norm_1 = (
transformer_engine.pytorch.RMSNorm(
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
)
if config.rms_norm
else transformer_engine.pytorch.LayerNorm(
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
)
)
if config.hidden_act.lower() == "swiglu":
# To keep the number of parameters and the amount of computation constant, we reduce the
# number of hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and
# make it a multiple of 8 to avoid RuntimeError due to misaligned operand
multiple_of = 8
intermediate_size = int(2 * config.intermediate_size / 3)
intermediate_size = multiple_of * (
(intermediate_size + multiple_of - 1) // multiple_of
)
self.transformer_encoder = nn.ModuleList()
for layer_num in range(config.num_hidden_layers):
self.transformer_encoder.append(
transformer_engine.pytorch.TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=intermediate_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.norm_eps,
hidden_dropout=config.dropout_prob,
attention_dropout=config.dropout_prob,
apply_residual_connection_post_layernorm=False,
layer_type="encoder",
self_attn_mask_type="padding",
normalization="RMSNorm" if config.rms_norm else "LayerNorm",
fuse_qkv_params=True,
qkv_weight_interleaved=True,
output_layernorm=False,
bias=False,
activation=config.hidden_act.lower(),
attn_input_format="bshd",
layer_number=layer_num + 1,
name="encoder_block",
window_size=(-1, -1),
rotary_pos_interleaved=True,
seq_length=config.max_length,
params_dtype=config.torch_dtype,
)
)
self.freqs_cis = RotaryPositionEmbedding(
config.hidden_size // config.num_attention_heads, interleaved=True
)(config.max_length)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids,
attention_mask=None,
output_hidden_states=False,
output_attentions=False,
labels=None,
**kwargs,
):
# Initialize
hidden_states = []
# Attention mask
if attention_mask is not None and attention_mask.dtype is torch.int64:
# TE expects a boolean attention mask, where "True" indicates a token to be masked.
attention_mask = ~attention_mask.to(bool)
# RoPE
self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
freqs_cis = self.freqs_cis[: input_ids.shape[1]]
# Embedding
x = self.encoder(input_ids)
if self.config.layer_norm_after_embedding:
x = self.layer_norm_1(x)
# Transformer encoder
for layer in self.transformer_encoder:
x = layer(x, attention_mask, rotary_pos_emb=freqs_cis)
if output_hidden_states:
hidden_states.append(x)
if output_attentions:
raise ValueError("output_attentions is not supported for TE")
return BaseModelOutput(
last_hidden_state=x,
hidden_states=hidden_states,
attentions=None,
)
class AMPLIFYForMaskedLM(AMPLIFYPreTrainedModel):
def __init__(self, config: AMPLIFYConfig, **kwargs):
super().__init__(config)
self.amplify = AMPLIFY(config, **kwargs)
if config.layer_norm_before_last_layer:
self.decoder = transformer_engine.pytorch.LayerNormLinear(
config.hidden_size,
config.vocab_size,
config.norm_eps,
params_dtype=config.torch_dtype,
normalization="RMSNorm" if config.rms_norm else "LayerNorm",
)
else:
self.decoder = transformer_engine.pytorch.Linear(
config.hidden_size, config.vocab_size, params_dtype=config.torch_dtype
)
def forward(
self,
input_ids,
attention_mask=None,
output_hidden_states=False,
output_attentions=False,
labels=None,
**kwargs,
):
outputs = self.amplify(
input_ids,
attention_mask,
output_hidden_states,
output_attentions,
labels,
**kwargs,
)
# Classification head with layer norm
logits = self.decoder(outputs.last_hidden_state)
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1)
)
else:
loss = None
# Return logits or the output of the last hidden layer
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)