Update modeling_motif.py
Browse files- modeling_motif.py +0 -257
modeling_motif.py
CHANGED
|
@@ -464,263 +464,6 @@ class MorehMoeFusedMLP(nn.Module):
|
|
| 464 |
return output
|
| 465 |
|
| 466 |
|
| 467 |
-
class MoEGate(nn.Module):
|
| 468 |
-
|
| 469 |
-
def __init__(self, config):
|
| 470 |
-
super().__init__()
|
| 471 |
-
self.config = config
|
| 472 |
-
self.top_k = config.num_experts_per_tok
|
| 473 |
-
self.n_routed_experts = config.n_routed_experts
|
| 474 |
-
self.routed_scaling_factor = config.routed_scaling_factor
|
| 475 |
-
self.scoring_func = config.scoring_func
|
| 476 |
-
self.seq_aux = config.seq_aux
|
| 477 |
-
self.topk_method = config.topk_method
|
| 478 |
-
self.n_group = config.n_group
|
| 479 |
-
self.topk_group = config.topk_group
|
| 480 |
-
|
| 481 |
-
# topk selection algorithm
|
| 482 |
-
self.norm_topk_prob = config.norm_topk_prob
|
| 483 |
-
self.gating_dim = config.hidden_size
|
| 484 |
-
self.weight = nn.Parameter(
|
| 485 |
-
torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 486 |
-
if self.topk_method == "noaux_tc":
|
| 487 |
-
self.e_score_correction_bias = nn.Parameter(
|
| 488 |
-
torch.empty((self.n_routed_experts)))
|
| 489 |
-
self.reset_parameters()
|
| 490 |
-
|
| 491 |
-
def reset_parameters(self) -> None:
|
| 492 |
-
import torch.nn.init as init
|
| 493 |
-
|
| 494 |
-
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 495 |
-
|
| 496 |
-
def forward(self, hidden_states):
|
| 497 |
-
bsz, seq_len, h = hidden_states.shape
|
| 498 |
-
### compute gating score
|
| 499 |
-
hidden_states = hidden_states.view(-1, h)
|
| 500 |
-
logits = F.linear(hidden_states.type(torch.float32),
|
| 501 |
-
self.weight.type(torch.float32), None)
|
| 502 |
-
if self.scoring_func == "sigmoid":
|
| 503 |
-
scores = logits.sigmoid()
|
| 504 |
-
else:
|
| 505 |
-
raise NotImplementedError(
|
| 506 |
-
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
| 507 |
-
)
|
| 508 |
-
|
| 509 |
-
### select top-k experts
|
| 510 |
-
if self.topk_method == "greedy":
|
| 511 |
-
topk_weight, topk_idx = torch.topk(scores,
|
| 512 |
-
k=self.top_k,
|
| 513 |
-
dim=-1,
|
| 514 |
-
sorted=False)
|
| 515 |
-
elif self.topk_method == "group_limited_greedy":
|
| 516 |
-
group_scores = (scores.view(bsz * seq_len, self.n_group,
|
| 517 |
-
-1).max(dim=-1).values) # [n, n_group]
|
| 518 |
-
group_idx = torch.topk(group_scores,
|
| 519 |
-
k=self.topk_group,
|
| 520 |
-
dim=-1,
|
| 521 |
-
sorted=False)[1] # [n, top_k_group]
|
| 522 |
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 523 |
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 524 |
-
score_mask = (group_mask.unsqueeze(-1).expand(
|
| 525 |
-
bsz * seq_len, self.n_group,
|
| 526 |
-
self.n_routed_experts // self.n_group).reshape(
|
| 527 |
-
bsz * seq_len, -1)) # [n, e]
|
| 528 |
-
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
| 529 |
-
topk_weight, topk_idx = torch.topk(tmp_scores,
|
| 530 |
-
k=self.top_k,
|
| 531 |
-
dim=-1,
|
| 532 |
-
sorted=False)
|
| 533 |
-
elif self.topk_method == "noaux_tc":
|
| 534 |
-
### will be used. ###
|
| 535 |
-
scores_for_choice = scores.view(
|
| 536 |
-
bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
| 537 |
-
group_scores = (scores_for_choice.view(
|
| 538 |
-
bsz * seq_len, self.n_group,
|
| 539 |
-
-1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
|
| 540 |
-
group_idx = torch.topk(group_scores,
|
| 541 |
-
k=self.topk_group,
|
| 542 |
-
dim=-1,
|
| 543 |
-
sorted=False)[1] # [n, top_k_group]
|
| 544 |
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
| 545 |
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
| 546 |
-
score_mask = (group_mask.unsqueeze(-1).expand(
|
| 547 |
-
bsz * seq_len, self.n_group,
|
| 548 |
-
self.n_routed_experts // self.n_group).reshape(
|
| 549 |
-
bsz * seq_len, -1)) # [n, e]
|
| 550 |
-
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
|
| 551 |
-
0.0) # [n, e]
|
| 552 |
-
_, topk_idx = torch.topk(tmp_scores,
|
| 553 |
-
k=self.top_k,
|
| 554 |
-
dim=-1,
|
| 555 |
-
sorted=False)
|
| 556 |
-
topk_weight = scores.gather(1, topk_idx)
|
| 557 |
-
else:
|
| 558 |
-
raise NotImplementedError(
|
| 559 |
-
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
### norm gate to sum 1
|
| 563 |
-
if self.top_k > 1 and self.norm_topk_prob:
|
| 564 |
-
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 565 |
-
topk_weight = topk_weight / denominator
|
| 566 |
-
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
| 567 |
-
|
| 568 |
-
return topk_idx, topk_weight
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
class MotifMoE(nn.Module):
|
| 572 |
-
"""
|
| 573 |
-
A mixed expert module containing shared experts.
|
| 574 |
-
"""
|
| 575 |
-
def __init__(self, config):
|
| 576 |
-
super().__init__()
|
| 577 |
-
self.config = config
|
| 578 |
-
self.num_experts_per_tok = config.num_experts_per_tok
|
| 579 |
-
self.use_moreh_moe = config.use_moreh_moe
|
| 580 |
-
self.use_fused_mlp = config.use_fused_mlp
|
| 581 |
-
|
| 582 |
-
if hasattr(config, "ep_size") and config.ep_size > 1:
|
| 583 |
-
assert config.ep_size == dist.get_world_size()
|
| 584 |
-
assert not config.use_moreh_moe
|
| 585 |
-
self.ep_size = config.ep_size
|
| 586 |
-
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
| 587 |
-
self.ep_rank = dist.get_rank()
|
| 588 |
-
self.experts = nn.ModuleList([
|
| 589 |
-
(DeepseekV3MLP(config,
|
| 590 |
-
intermediate_size=config.moe_intermediate_size)
|
| 591 |
-
if i >= self.ep_rank * self.experts_per_rank and i <
|
| 592 |
-
(self.ep_rank + 1) * self.experts_per_rank else None)
|
| 593 |
-
for i in range(config.n_routed_experts)
|
| 594 |
-
])
|
| 595 |
-
else:
|
| 596 |
-
self.ep_size = 1
|
| 597 |
-
self.experts_per_rank = config.n_routed_experts
|
| 598 |
-
self.ep_rank = 0
|
| 599 |
-
if self.use_moreh_moe:
|
| 600 |
-
if not self.use_fused_mlp:
|
| 601 |
-
self.experts = MorehMoeMLP(
|
| 602 |
-
ffn_dim=config.moe_intermediate_size,
|
| 603 |
-
hidden_dim=config.hidden_size,
|
| 604 |
-
hidden_act_moe=config.hidden_act_moe,
|
| 605 |
-
num_experts=config.n_routed_experts,
|
| 606 |
-
device=None)
|
| 607 |
-
else:
|
| 608 |
-
## group expert.
|
| 609 |
-
self.experts = MorehMoeFusedMLP(
|
| 610 |
-
ffn_dim=config.moe_intermediate_size,
|
| 611 |
-
hidden_dim=config.hidden_size,
|
| 612 |
-
hidden_act_moe=config.hidden_act_moe,
|
| 613 |
-
num_experts=config.n_routed_experts,
|
| 614 |
-
num_groups=config.n_group,
|
| 615 |
-
device=None,
|
| 616 |
-
continual_training=config.continual_training,
|
| 617 |
-
)
|
| 618 |
-
else:
|
| 619 |
-
self.experts = nn.ModuleList([
|
| 620 |
-
DeepseekV3MLP(
|
| 621 |
-
config, intermediate_size=config.moe_intermediate_size)
|
| 622 |
-
for i in range(config.n_routed_experts)
|
| 623 |
-
])
|
| 624 |
-
|
| 625 |
-
self.gate = MoEGate(config)
|
| 626 |
-
|
| 627 |
-
def forward(self, hidden_states):
|
| 628 |
-
identity = hidden_states
|
| 629 |
-
orig_shape = hidden_states.shape
|
| 630 |
-
topk_idx, topk_weight = self.gate(hidden_states)
|
| 631 |
-
if self.use_moreh_moe:
|
| 632 |
-
y = self.experts(hidden_states, topk_idx.view(*orig_shape[:-1], -1),
|
| 633 |
-
topk_weight.view(*orig_shape[:-1], -1))
|
| 634 |
-
y = y.type(hidden_states.dtype)
|
| 635 |
-
else:
|
| 636 |
-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 637 |
-
flat_topk_idx = topk_idx.view(-1)
|
| 638 |
-
if self.training:
|
| 639 |
-
hidden_states = hidden_states.repeat_interleave(
|
| 640 |
-
self.num_experts_per_tok, dim=0)
|
| 641 |
-
y = torch.empty_like(hidden_states)
|
| 642 |
-
for i, expert in enumerate(self.experts):
|
| 643 |
-
y[flat_topk_idx == i] = expert(
|
| 644 |
-
hidden_states[flat_topk_idx == i])
|
| 645 |
-
y = (y.view(*topk_weight.shape, -1) *
|
| 646 |
-
topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 647 |
-
y = y.type(hidden_states.dtype)
|
| 648 |
-
y = y.view(*orig_shape)
|
| 649 |
-
# y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 650 |
-
else:
|
| 651 |
-
y = self.moe_infer(hidden_states, topk_idx,
|
| 652 |
-
topk_weight).view(*orig_shape)
|
| 653 |
-
return y, identity
|
| 654 |
-
|
| 655 |
-
@torch.no_grad()
|
| 656 |
-
def moe_infer(self, x, topk_ids, topk_weight):
|
| 657 |
-
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
| 658 |
-
cnts.scatter_(1, topk_ids, 1)
|
| 659 |
-
tokens_per_expert = cnts.sum(dim=0)
|
| 660 |
-
idxs = topk_ids.view(-1).argsort()
|
| 661 |
-
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
| 662 |
-
sorted_tokens_shape = sorted_tokens.shape
|
| 663 |
-
if self.ep_size > 1:
|
| 664 |
-
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
|
| 665 |
-
-1).sum(dim=1)
|
| 666 |
-
tokens_per_expert_group = tokens_per_expert.new_empty(
|
| 667 |
-
tokens_per_expert.shape[0])
|
| 668 |
-
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
| 669 |
-
output_splits = (tokens_per_expert_group.view(
|
| 670 |
-
self.ep_size, -1).sum(1).cpu().numpy().tolist())
|
| 671 |
-
gathered_tokens = sorted_tokens.new_empty(
|
| 672 |
-
tokens_per_expert_group.sum(dim=0).cpu().item(),
|
| 673 |
-
sorted_tokens.shape[1])
|
| 674 |
-
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
| 675 |
-
dist.all_to_all(
|
| 676 |
-
list(gathered_tokens.split(output_splits)),
|
| 677 |
-
list(sorted_tokens.split(input_split_sizes)),
|
| 678 |
-
)
|
| 679 |
-
tokens_per_expert_post_gather = tokens_per_expert_group.view(
|
| 680 |
-
self.ep_size, self.experts_per_rank).sum(dim=0)
|
| 681 |
-
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],),
|
| 682 |
-
dtype=np.int32)
|
| 683 |
-
s = 0
|
| 684 |
-
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
| 685 |
-
gatherd_idxs[s:s + k] = i % self.experts_per_rank
|
| 686 |
-
s += k
|
| 687 |
-
gatherd_idxs = gatherd_idxs.argsort()
|
| 688 |
-
sorted_tokens = gathered_tokens[gatherd_idxs]
|
| 689 |
-
tokens_per_expert = tokens_per_expert_post_gather
|
| 690 |
-
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
| 691 |
-
|
| 692 |
-
outputs = []
|
| 693 |
-
start_idx = 0
|
| 694 |
-
for i, num_tokens in enumerate(tokens_per_expert):
|
| 695 |
-
end_idx = start_idx + num_tokens
|
| 696 |
-
if num_tokens == 0:
|
| 697 |
-
continue
|
| 698 |
-
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
| 699 |
-
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
| 700 |
-
expert_out = expert(tokens_for_this_expert)
|
| 701 |
-
outputs.append(expert_out)
|
| 702 |
-
start_idx = end_idx
|
| 703 |
-
|
| 704 |
-
outs = torch.cat(outputs,
|
| 705 |
-
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
| 706 |
-
if self.ep_size > 1:
|
| 707 |
-
new_x = torch.empty_like(outs)
|
| 708 |
-
new_x[gatherd_idxs] = outs
|
| 709 |
-
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
| 710 |
-
dist.all_to_all(
|
| 711 |
-
list(gathered_tokens.split(input_split_sizes)),
|
| 712 |
-
list(new_x.split(output_splits)),
|
| 713 |
-
)
|
| 714 |
-
outs = gathered_tokens
|
| 715 |
-
|
| 716 |
-
new_x = torch.empty_like(outs)
|
| 717 |
-
new_x[idxs] = outs
|
| 718 |
-
final_out = (new_x.view(
|
| 719 |
-
*topk_ids.shape, -1).type(topk_weight.dtype).mul_(
|
| 720 |
-
topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
|
| 721 |
-
return final_out
|
| 722 |
-
|
| 723 |
-
|
| 724 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 725 |
|
| 726 |
|
|
|
|
| 464 |
return output
|
| 465 |
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 468 |
|
| 469 |
|