eunhwanpark-motiftech commited on
Commit
bd7180c
·
verified ·
1 Parent(s): 6d0fba5

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +0 -257
modeling_motif.py CHANGED
@@ -464,263 +464,6 @@ class MorehMoeFusedMLP(nn.Module):
464
  return output
465
 
466
 
467
- class MoEGate(nn.Module):
468
-
469
- def __init__(self, config):
470
- super().__init__()
471
- self.config = config
472
- self.top_k = config.num_experts_per_tok
473
- self.n_routed_experts = config.n_routed_experts
474
- self.routed_scaling_factor = config.routed_scaling_factor
475
- self.scoring_func = config.scoring_func
476
- self.seq_aux = config.seq_aux
477
- self.topk_method = config.topk_method
478
- self.n_group = config.n_group
479
- self.topk_group = config.topk_group
480
-
481
- # topk selection algorithm
482
- self.norm_topk_prob = config.norm_topk_prob
483
- self.gating_dim = config.hidden_size
484
- self.weight = nn.Parameter(
485
- torch.empty((self.n_routed_experts, self.gating_dim)))
486
- if self.topk_method == "noaux_tc":
487
- self.e_score_correction_bias = nn.Parameter(
488
- torch.empty((self.n_routed_experts)))
489
- self.reset_parameters()
490
-
491
- def reset_parameters(self) -> None:
492
- import torch.nn.init as init
493
-
494
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
495
-
496
- def forward(self, hidden_states):
497
- bsz, seq_len, h = hidden_states.shape
498
- ### compute gating score
499
- hidden_states = hidden_states.view(-1, h)
500
- logits = F.linear(hidden_states.type(torch.float32),
501
- self.weight.type(torch.float32), None)
502
- if self.scoring_func == "sigmoid":
503
- scores = logits.sigmoid()
504
- else:
505
- raise NotImplementedError(
506
- f"insupportable scoring function for MoE gating: {self.scoring_func}"
507
- )
508
-
509
- ### select top-k experts
510
- if self.topk_method == "greedy":
511
- topk_weight, topk_idx = torch.topk(scores,
512
- k=self.top_k,
513
- dim=-1,
514
- sorted=False)
515
- elif self.topk_method == "group_limited_greedy":
516
- group_scores = (scores.view(bsz * seq_len, self.n_group,
517
- -1).max(dim=-1).values) # [n, n_group]
518
- group_idx = torch.topk(group_scores,
519
- k=self.topk_group,
520
- dim=-1,
521
- sorted=False)[1] # [n, top_k_group]
522
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
523
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
524
- score_mask = (group_mask.unsqueeze(-1).expand(
525
- bsz * seq_len, self.n_group,
526
- self.n_routed_experts // self.n_group).reshape(
527
- bsz * seq_len, -1)) # [n, e]
528
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
529
- topk_weight, topk_idx = torch.topk(tmp_scores,
530
- k=self.top_k,
531
- dim=-1,
532
- sorted=False)
533
- elif self.topk_method == "noaux_tc":
534
- ### will be used. ###
535
- scores_for_choice = scores.view(
536
- bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
537
- group_scores = (scores_for_choice.view(
538
- bsz * seq_len, self.n_group,
539
- -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
540
- group_idx = torch.topk(group_scores,
541
- k=self.topk_group,
542
- dim=-1,
543
- sorted=False)[1] # [n, top_k_group]
544
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
545
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
546
- score_mask = (group_mask.unsqueeze(-1).expand(
547
- bsz * seq_len, self.n_group,
548
- self.n_routed_experts // self.n_group).reshape(
549
- bsz * seq_len, -1)) # [n, e]
550
- tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
551
- 0.0) # [n, e]
552
- _, topk_idx = torch.topk(tmp_scores,
553
- k=self.top_k,
554
- dim=-1,
555
- sorted=False)
556
- topk_weight = scores.gather(1, topk_idx)
557
- else:
558
- raise NotImplementedError(
559
- f"insupportable TopK function for MoE gating: {self.topk_method}"
560
- )
561
-
562
- ### norm gate to sum 1
563
- if self.top_k > 1 and self.norm_topk_prob:
564
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
565
- topk_weight = topk_weight / denominator
566
- topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
567
-
568
- return topk_idx, topk_weight
569
-
570
-
571
- class MotifMoE(nn.Module):
572
- """
573
- A mixed expert module containing shared experts.
574
- """
575
- def __init__(self, config):
576
- super().__init__()
577
- self.config = config
578
- self.num_experts_per_tok = config.num_experts_per_tok
579
- self.use_moreh_moe = config.use_moreh_moe
580
- self.use_fused_mlp = config.use_fused_mlp
581
-
582
- if hasattr(config, "ep_size") and config.ep_size > 1:
583
- assert config.ep_size == dist.get_world_size()
584
- assert not config.use_moreh_moe
585
- self.ep_size = config.ep_size
586
- self.experts_per_rank = config.n_routed_experts // config.ep_size
587
- self.ep_rank = dist.get_rank()
588
- self.experts = nn.ModuleList([
589
- (DeepseekV3MLP(config,
590
- intermediate_size=config.moe_intermediate_size)
591
- if i >= self.ep_rank * self.experts_per_rank and i <
592
- (self.ep_rank + 1) * self.experts_per_rank else None)
593
- for i in range(config.n_routed_experts)
594
- ])
595
- else:
596
- self.ep_size = 1
597
- self.experts_per_rank = config.n_routed_experts
598
- self.ep_rank = 0
599
- if self.use_moreh_moe:
600
- if not self.use_fused_mlp:
601
- self.experts = MorehMoeMLP(
602
- ffn_dim=config.moe_intermediate_size,
603
- hidden_dim=config.hidden_size,
604
- hidden_act_moe=config.hidden_act_moe,
605
- num_experts=config.n_routed_experts,
606
- device=None)
607
- else:
608
- ## group expert.
609
- self.experts = MorehMoeFusedMLP(
610
- ffn_dim=config.moe_intermediate_size,
611
- hidden_dim=config.hidden_size,
612
- hidden_act_moe=config.hidden_act_moe,
613
- num_experts=config.n_routed_experts,
614
- num_groups=config.n_group,
615
- device=None,
616
- continual_training=config.continual_training,
617
- )
618
- else:
619
- self.experts = nn.ModuleList([
620
- DeepseekV3MLP(
621
- config, intermediate_size=config.moe_intermediate_size)
622
- for i in range(config.n_routed_experts)
623
- ])
624
-
625
- self.gate = MoEGate(config)
626
-
627
- def forward(self, hidden_states):
628
- identity = hidden_states
629
- orig_shape = hidden_states.shape
630
- topk_idx, topk_weight = self.gate(hidden_states)
631
- if self.use_moreh_moe:
632
- y = self.experts(hidden_states, topk_idx.view(*orig_shape[:-1], -1),
633
- topk_weight.view(*orig_shape[:-1], -1))
634
- y = y.type(hidden_states.dtype)
635
- else:
636
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
637
- flat_topk_idx = topk_idx.view(-1)
638
- if self.training:
639
- hidden_states = hidden_states.repeat_interleave(
640
- self.num_experts_per_tok, dim=0)
641
- y = torch.empty_like(hidden_states)
642
- for i, expert in enumerate(self.experts):
643
- y[flat_topk_idx == i] = expert(
644
- hidden_states[flat_topk_idx == i])
645
- y = (y.view(*topk_weight.shape, -1) *
646
- topk_weight.unsqueeze(-1)).sum(dim=1)
647
- y = y.type(hidden_states.dtype)
648
- y = y.view(*orig_shape)
649
- # y = AddAuxiliaryLoss.apply(y, aux_loss)
650
- else:
651
- y = self.moe_infer(hidden_states, topk_idx,
652
- topk_weight).view(*orig_shape)
653
- return y, identity
654
-
655
- @torch.no_grad()
656
- def moe_infer(self, x, topk_ids, topk_weight):
657
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
658
- cnts.scatter_(1, topk_ids, 1)
659
- tokens_per_expert = cnts.sum(dim=0)
660
- idxs = topk_ids.view(-1).argsort()
661
- sorted_tokens = x[idxs // topk_ids.shape[1]]
662
- sorted_tokens_shape = sorted_tokens.shape
663
- if self.ep_size > 1:
664
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
665
- -1).sum(dim=1)
666
- tokens_per_expert_group = tokens_per_expert.new_empty(
667
- tokens_per_expert.shape[0])
668
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
669
- output_splits = (tokens_per_expert_group.view(
670
- self.ep_size, -1).sum(1).cpu().numpy().tolist())
671
- gathered_tokens = sorted_tokens.new_empty(
672
- tokens_per_expert_group.sum(dim=0).cpu().item(),
673
- sorted_tokens.shape[1])
674
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
675
- dist.all_to_all(
676
- list(gathered_tokens.split(output_splits)),
677
- list(sorted_tokens.split(input_split_sizes)),
678
- )
679
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
680
- self.ep_size, self.experts_per_rank).sum(dim=0)
681
- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],),
682
- dtype=np.int32)
683
- s = 0
684
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
685
- gatherd_idxs[s:s + k] = i % self.experts_per_rank
686
- s += k
687
- gatherd_idxs = gatherd_idxs.argsort()
688
- sorted_tokens = gathered_tokens[gatherd_idxs]
689
- tokens_per_expert = tokens_per_expert_post_gather
690
- tokens_per_expert = tokens_per_expert.cpu().numpy()
691
-
692
- outputs = []
693
- start_idx = 0
694
- for i, num_tokens in enumerate(tokens_per_expert):
695
- end_idx = start_idx + num_tokens
696
- if num_tokens == 0:
697
- continue
698
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
699
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
700
- expert_out = expert(tokens_for_this_expert)
701
- outputs.append(expert_out)
702
- start_idx = end_idx
703
-
704
- outs = torch.cat(outputs,
705
- dim=0) if len(outputs) else sorted_tokens.new_empty(0)
706
- if self.ep_size > 1:
707
- new_x = torch.empty_like(outs)
708
- new_x[gatherd_idxs] = outs
709
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
710
- dist.all_to_all(
711
- list(gathered_tokens.split(input_split_sizes)),
712
- list(new_x.split(output_splits)),
713
- )
714
- outs = gathered_tokens
715
-
716
- new_x = torch.empty_like(outs)
717
- new_x[idxs] = outs
718
- final_out = (new_x.view(
719
- *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
720
- topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
721
- return final_out
722
-
723
-
724
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
725
 
726
 
 
464
  return output
465
 
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
468
 
469