kgrabko commited on
Commit
0d647c2
·
verified ·
1 Parent(s): 66e364e

Create merge_70b_shards_v2.py

Browse files
prepared_sft_data/merge_70b_shards_v2.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025-2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ # ==============================================================================
5
+ import os
6
+ import torch
7
+ import glob
8
+ from safetensors.torch import load_file, save_file
9
+ from tqdm import tqdm
10
+
11
+ # --- ПУТИ ---
12
+ # Твой обученный чекпоинт (результат Full SFT)
13
+ SFT_CHECKPOINT_PATH = "/content/full_checkpoints_70b/jirack_70b_full_step_200.safetensors"
14
+ # Папка с твоими оригинальными 30 шардами
15
+ ORIGINAL_SHARDS_DIR = "/content/JiRack_BitNet_70B_Packed/checkpoints/checkpoint-220000"
16
+ # Куда сохранить результат
17
+ OUTPUT_DIR = "/content/JiRack_70B_SFT_Merged"
18
+
19
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
20
+
21
+ def merge_shards():
22
+ print(f"🚀 Загрузка SFT чекпоинта: {SFT_CHECKPOINT_PATH}")
23
+ sft_weights = load_file(SFT_CHECKPOINT_PATH, device="cpu")
24
+
25
+ # Получаем список всех шардов оригинальной модели
26
+ shard_files = sorted(glob.glob(f"{ORIGINAL_SHARDS_DIR}/*.safetensors"))
27
+
28
+ print(f"📦 Найдено шардов для обработки: {len(shard_files)}")
29
+
30
+ for shard_path in tqdm(shard_files, desc="Merging Shards"):
31
+ shard_name = os.path.basename(shard_path)
32
+ # Загружаем оригинальный шард
33
+ current_shard = load_file(shard_path, device="cpu")
34
+ updated_shard = {}
35
+ merge_count = 0
36
+
37
+ for key, weight in current_shard.items():
38
+ # Убираем префикс 'model.', если он есть в ключах чекпоинта, но нет в шардах (или наоборот)
39
+ # Мы ищем точное совпадение ключа в sft_weights
40
+
41
+ # Проверяем ключ как есть
42
+ if key in sft_weights:
43
+ updated_shard[key] = sft_weights[key]
44
+ merge_count += 1
45
+ # Проверяем с учетом возможной разницы в префиксах (model.layers... vs layers...)
46
+ elif key.replace("model.", "") in sft_weights:
47
+ updated_shard[key] = sft_weights[key.replace("model.", "")]
48
+ merge_count += 1
49
+ else:
50
+ # Если веса не обучались (не попали в SFT чекпоинт), оставляем оригинал
51
+ updated_shard[key] = weight
52
+
53
+ # Сохраняем обновленный шард
54
+ save_path = os.path.join(OUTPUT_DIR, shard_name)
55
+ save_file(updated_shard, save_path)
56
+ # print(f"✅ {shard_name}: обновлено {merge_count} тензоров")
57
+
58
+ print(f"\n✨ Мердж завершен! Готовая модель здесь: {OUTPUT_DIR}")
59
+
60
+ if __name__ == "__main__":
61
+ merge_shards()