|
|
""" |
|
|
Helion-OSC Sharded Model Loader |
|
|
Efficiently loads 116 safetensors shards (2.8GB each) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, List |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
import psutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ShardedModelLoader: |
|
|
""" |
|
|
Loader for sharded safetensors model files |
|
|
Optimized for 116 shards of 2.8GB each |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str): |
|
|
""" |
|
|
Initialize the sharded model loader |
|
|
|
|
|
Args: |
|
|
model_path: Path to the inference directory containing shards |
|
|
""" |
|
|
self.model_path = Path(model_path) |
|
|
self.config_path = self.model_path / "model_config.json" |
|
|
self.index_path = self.model_path / "model.safetensors.index.json" |
|
|
|
|
|
|
|
|
logger.info(f"Loading configuration from {self.config_path}") |
|
|
with open(self.config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
logger.info(f"Loading weight index from {self.index_path}") |
|
|
with open(self.index_path, 'r') as f: |
|
|
self.index = json.load(f) |
|
|
|
|
|
self.metadata = self.index.get("metadata", {}) |
|
|
self.weight_map = self.index.get("weight_map", {}) |
|
|
|
|
|
logger.info(f"Model: {self.metadata.get('model_type', 'unknown')}") |
|
|
logger.info(f"Total shards: {self.metadata.get('total_shards', 0)}") |
|
|
logger.info(f"Total size: {self.metadata.get('total_size', 0) / 1e9:.2f} GB") |
|
|
logger.info(f"Total parameters: {self.config['architectures_info']['total_parameters']}") |
|
|
logger.info(f"Active parameters: {self.config['architectures_info']['active_parameters']}") |
|
|
|
|
|
def get_shard_path(self, shard_name: str) -> Path: |
|
|
"""Get full path to a shard file""" |
|
|
return self.model_path / shard_name |
|
|
|
|
|
def get_available_memory(self) -> Dict[str, float]: |
|
|
"""Get available system memory""" |
|
|
memory = psutil.virtual_memory() |
|
|
result = { |
|
|
"ram_total_gb": memory.total / 1e9, |
|
|
"ram_available_gb": memory.available / 1e9, |
|
|
"ram_percent_used": memory.percent |
|
|
} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
for i in range(torch.cuda.device_count()): |
|
|
gpu_mem = torch.cuda.get_device_properties(i).total_memory |
|
|
gpu_allocated = torch.cuda.memory_allocated(i) |
|
|
result[f"gpu_{i}_total_gb"] = gpu_mem / 1e9 |
|
|
result[f"gpu_{i}_available_gb"] = (gpu_mem - gpu_allocated) / 1e9 |
|
|
|
|
|
return result |
|
|
|
|
|
def load_shard(self, shard_name: str, device: str = "cpu") -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Load a single shard file |
|
|
|
|
|
Args: |
|
|
shard_name: Name of the shard file |
|
|
device: Device to load tensors to |
|
|
|
|
|
Returns: |
|
|
Dictionary of weight tensors |
|
|
""" |
|
|
shard_path = self.get_shard_path(shard_name) |
|
|
|
|
|
if not shard_path.exists(): |
|
|
raise FileNotFoundError(f"Shard not found: {shard_path}") |
|
|
|
|
|
logger.debug(f"Loading shard: {shard_name}") |
|
|
return load_file(str(shard_path), device=device) |
|
|
|
|
|
def load_sharded_weights( |
|
|
self, |
|
|
device: str = "cpu", |
|
|
low_memory: bool = False, |
|
|
show_progress: bool = True |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Load all sharded weights |
|
|
|
|
|
Args: |
|
|
device: Device to load weights to |
|
|
low_memory: Use memory-efficient loading |
|
|
show_progress: Show progress bar |
|
|
|
|
|
Returns: |
|
|
Dictionary of all model weights |
|
|
""" |
|
|
logger.info("Loading sharded model weights...") |
|
|
|
|
|
|
|
|
mem_info = self.get_available_memory() |
|
|
logger.info(f"Available RAM: {mem_info['ram_available_gb']:.2f} GB") |
|
|
if "gpu_0_available_gb" in mem_info: |
|
|
logger.info(f"Available GPU 0: {mem_info['gpu_0_available_gb']:.2f} GB") |
|
|
|
|
|
|
|
|
shard_files = sorted(set(self.weight_map.values())) |
|
|
total_shards = len(shard_files) |
|
|
|
|
|
logger.info(f"Loading {total_shards} shard files...") |
|
|
|
|
|
all_weights = {} |
|
|
|
|
|
|
|
|
pbar = tqdm(shard_files, disable=not show_progress, desc="Loading shards") |
|
|
|
|
|
for shard_name in pbar: |
|
|
pbar.set_description(f"Loading {shard_name}") |
|
|
|
|
|
|
|
|
shard_weights = self.load_shard(shard_name, device=device) |
|
|
|
|
|
|
|
|
all_weights.update(shard_weights) |
|
|
|
|
|
|
|
|
if low_memory: |
|
|
del shard_weights |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info(f"Loaded {len(all_weights)} weight tensors") |
|
|
return all_weights |
|
|
|
|
|
def get_layer_weights(self, layer_idx: int) -> List[str]: |
|
|
""" |
|
|
Get all weight keys for a specific layer |
|
|
|
|
|
Args: |
|
|
layer_idx: Layer index |
|
|
|
|
|
Returns: |
|
|
List of weight keys for that layer |
|
|
""" |
|
|
prefix = f"model.layers.{layer_idx}." |
|
|
return [k for k in self.weight_map.keys() if k.startswith(prefix)] |
|
|
|
|
|
def get_shard_for_weight(self, weight_key: str) -> Optional[str]: |
|
|
""" |
|
|
Get shard file name for a specific weight |
|
|
|
|
|
Args: |
|
|
weight_key: Weight key/name |
|
|
|
|
|
Returns: |
|
|
Shard file name or None |
|
|
""" |
|
|
return self.weight_map.get(weight_key) |
|
|
|
|
|
def verify_shards(self) -> Dict[str, bool]: |
|
|
""" |
|
|
Verify all shard files exist |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping shard names to existence status |
|
|
""" |
|
|
logger.info("Verifying shard files...") |
|
|
|
|
|
shard_files = set(self.weight_map.values()) |
|
|
verification = {} |
|
|
|
|
|
for shard_name in tqdm(sorted(shard_files), desc="Verifying"): |
|
|
shard_path = self.get_shard_path(shard_name) |
|
|
verification[shard_name] = shard_path.exists() |
|
|
|
|
|
missing = [s for s, exists in verification.items() if not exists] |
|
|
|
|
|
if missing: |
|
|
logger.warning(f"Missing {len(missing)} shard files:") |
|
|
for shard in missing[:10]: |
|
|
logger.warning(f" - {shard}") |
|
|
if len(missing) > 10: |
|
|
logger.warning(f" ... and {len(missing) - 10} more") |
|
|
else: |
|
|
logger.info("✓ All shard files present") |
|
|
|
|
|
return verification |
|
|
|
|
|
def load_metadata(self) -> Dict: |
|
|
"""Load model metadata""" |
|
|
return { |
|
|
"config": self.config, |
|
|
"index": self.index, |
|
|
"total_shards": self.metadata.get("total_shards", 0), |
|
|
"total_size_gb": self.metadata.get("total_size", 0) / 1e9, |
|
|
"architecture": self.config.get("architectures_info", {}), |
|
|
"num_layers": self.config.get("num_hidden_layers", 0), |
|
|
"hidden_size": self.config.get("hidden_size", 0), |
|
|
"vocab_size": self.config.get("vocab_size", 0) |
|
|
} |
|
|
|
|
|
|
|
|
def load_full_model( |
|
|
model_path: str, |
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
|
low_memory: bool = False |
|
|
): |
|
|
""" |
|
|
Convenience function to load the full model |
|
|
|
|
|
Args: |
|
|
model_path: Path to inference directory |
|
|
device: Device to load model to |
|
|
low_memory: Use low memory loading |
|
|
|
|
|
Returns: |
|
|
Loaded model weights and metadata |
|
|
""" |
|
|
loader = ShardedModelLoader(model_path) |
|
|
|
|
|
|
|
|
verification = loader.verify_shards() |
|
|
missing = sum(1 for exists in verification.values() if not exists) |
|
|
|
|
|
if missing > 0: |
|
|
raise FileNotFoundError( |
|
|
f"Cannot load model: {missing} shard files are missing. " |
|
|
f"Please download all 116 shard files." |
|
|
) |
|
|
|
|
|
|
|
|
weights = loader.load_sharded_weights( |
|
|
device=device, |
|
|
low_memory=low_memory, |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
metadata = loader.load_metadata() |
|
|
|
|
|
return weights, metadata |
|
|
|
|
|
|
|
|
def inspect_model(model_path: str): |
|
|
""" |
|
|
Inspect model structure without loading weights |
|
|
|
|
|
Args: |
|
|
model_path: Path to inference directory |
|
|
""" |
|
|
loader = ShardedModelLoader(model_path) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("HELION-OSC MODEL INSPECTION") |
|
|
print("="*80) |
|
|
|
|
|
metadata = loader.load_metadata() |
|
|
|
|
|
print(f"\nModel Type: {metadata['architecture'].get('model_description', 'N/A')}") |
|
|
print(f"Architecture: {metadata['architecture'].get('architecture_type', 'N/A')}") |
|
|
print(f"Total Parameters: {metadata['architecture'].get('total_parameters', 'N/A')}") |
|
|
print(f"Active Parameters: {metadata['architecture'].get('active_parameters', 'N/A')}") |
|
|
|
|
|
print(f"\nModel Configuration:") |
|
|
print(f" Layers: {metadata['num_layers']}") |
|
|
print(f" Hidden Size: {metadata['hidden_size']}") |
|
|
print(f" Vocabulary Size: {metadata['vocab_size']}") |
|
|
print(f" Attention Heads: {metadata['config'].get('num_attention_heads', 'N/A')}") |
|
|
print(f" KV Heads: {metadata['config'].get('num_key_value_heads', 'N/A')}") |
|
|
|
|
|
print(f"\nMoE Configuration:") |
|
|
arch = metadata['architecture'] |
|
|
print(f" Number of Experts: {arch.get('num_experts', 'N/A')}") |
|
|
print(f" Experts per Token: {arch.get('experts_per_token', 'N/A')}") |
|
|
print(f" Shared Experts: {arch.get('num_shared_experts', 'N/A')}") |
|
|
|
|
|
print(f"\nStorage Information:") |
|
|
print(f" Total Shards: {metadata['total_shards']}") |
|
|
print(f" Total Size: {metadata['total_size_gb']:.2f} GB") |
|
|
print(f" Shard Size: ~2.8 GB each") |
|
|
print(f" Format: safetensors") |
|
|
print(f" Precision: bfloat16") |
|
|
|
|
|
print(f"\nContext Length:") |
|
|
print(f" Max Position Embeddings: {metadata['config'].get('max_position_embeddings', 'N/A')}") |
|
|
print(f" RoPE Theta: {metadata['config'].get('rope_theta', 'N/A')}") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
|
|
|
|
|
|
print("\nVerifying shard files...") |
|
|
verification = loader.verify_shards() |
|
|
present = sum(1 for exists in verification.values() if exists) |
|
|
total = len(verification) |
|
|
|
|
|
print(f"\nShard Status: {present}/{total} files present") |
|
|
|
|
|
if present == total: |
|
|
print("✓ All shard files are available") |
|
|
else: |
|
|
print(f"✗ Missing {total - present} shard files") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main CLI interface""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Helion-OSC Sharded Model Loader") |
|
|
parser.add_argument( |
|
|
"model_path", |
|
|
type=str, |
|
|
help="Path to inference directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--action", |
|
|
choices=["inspect", "verify", "load"], |
|
|
default="inspect", |
|
|
help="Action to perform" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
|
help="Device to load model to" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--low-memory", |
|
|
action="store_true", |
|
|
help="Use low memory mode" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.action == "inspect": |
|
|
inspect_model(args.model_path) |
|
|
|
|
|
elif args.action == "verify": |
|
|
loader = ShardedModelLoader(args.model_path) |
|
|
loader.verify_shards() |
|
|
|
|
|
elif args.action == "load": |
|
|
logger.info("Loading full model...") |
|
|
weights, metadata = load_full_model( |
|
|
args.model_path, |
|
|
device=args.device, |
|
|
low_memory=args.low_memory |
|
|
) |
|
|
logger.info(f"Successfully loaded {len(weights)} weight tensors") |
|
|
logger.info(f"Model ready on {args.device}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |