eunhwanpark-motiftech commited on
Commit
cfa11bc
·
verified ·
1 Parent(s): 1d63261

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +2 -47
modeling_motif.py CHANGED
@@ -409,35 +409,6 @@ class MotifAttention(nn.Module):
409
  self.num_key_value_heads //= 2
410
  self.n_rep = self.num_heads // self.num_key_value_heads
411
 
412
- ##mix attn
413
-
414
- self.mix_attn = config.mix_attn
415
-
416
- if self.mix_attn:
417
-
418
- self.cq, self.ck = 6, 11
419
- self.ch = 2
420
-
421
- self.key_query_conv = nn.Conv2d(
422
- in_channels=self.num_heads*2,
423
- out_channels=self.num_heads*2,
424
- kernel_size=(self.cq, self.ck),
425
- padding="same",
426
- groups=self.num_heads*2
427
- )
428
-
429
- self.head_conv = nn.Conv1d(
430
- in_channels=self.num_heads,
431
- out_channels=self.num_heads,
432
- kernel_size=1,
433
- padding=0,
434
- groups=self.num_heads // self.ch
435
- )
436
-
437
- self.group_norm = nn.GroupNorm(num_groups=self.num_heads, num_channels=self.num_heads)
438
-
439
-
440
-
441
  # re-init projections
442
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
443
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
@@ -516,12 +487,6 @@ class MotifAttention(nn.Module):
516
  attention_mask = torch.triu(
517
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
518
  1 + offset)
519
- ##attn weights conv2d, softmax and add attention_mask
520
- if self.mix_attn:
521
- ## condition mask==0, value : 0
522
- attn_weights = attn_weights.masked_fill( attention_mask == 0, 0)
523
- attn_weights = self.key_query_conv(attn_weights)
524
- attn_weights = attn_weights[:, :, :kv_seq_len, :kv_seq_len]
525
 
526
  ###add attn
527
  attn_weights = attn_weights + attention_mask
@@ -536,11 +501,6 @@ class MotifAttention(nn.Module):
536
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
537
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
538
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
539
- ##head_conv
540
- if self.mix_attn:
541
- attn_weights = attn_weights.view(bsz, self.num_heads, -1).contiguous()
542
- attn_weights = self.head_conv(attn_weights)
543
- attn_weights = attn_weights.view(bsz, self.num_heads, q_len, -1).contiguous()
544
 
545
  ##shape : bsz, #heads, seq, head_dim
546
  attn_output = torch.matmul(attn_weights, value_states)
@@ -552,9 +512,7 @@ class MotifAttention(nn.Module):
552
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
553
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
554
  f" {attn_output.size()}")
555
- if self.mix_attn:
556
- attn_output = self.group_norm(attn_output)
557
-
558
  attn_output = attn_output.transpose(1, 2).contiguous()
559
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
560
 
@@ -889,10 +847,7 @@ class MotifDecoderLayer(nn.Module):
889
  logger.warning_once(
890
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
891
  "unexpected results may be encountered.")
892
- if not config.mix_attn:
893
- self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
894
- else:
895
- self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
896
  self.mlp = MotifMLP(config)
897
 
898
  RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
 
409
  self.num_key_value_heads //= 2
410
  self.n_rep = self.num_heads // self.num_key_value_heads
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  # re-init projections
413
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
414
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
 
487
  attention_mask = torch.triu(
488
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
489
  1 + offset)
 
 
 
 
 
 
490
 
491
  ###add attn
492
  attn_weights = attn_weights + attention_mask
 
501
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
502
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
503
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
 
 
 
 
 
504
 
505
  ##shape : bsz, #heads, seq, head_dim
506
  attn_output = torch.matmul(attn_weights, value_states)
 
512
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
513
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
514
  f" {attn_output.size()}")
515
+
 
 
516
  attn_output = attn_output.transpose(1, 2).contiguous()
517
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
518
 
 
847
  logger.warning_once(
848
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
849
  "unexpected results may be encountered.")
850
+ self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
 
851
  self.mlp = MotifMLP(config)
852
 
853
  RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm