|
|
""" |
|
|
Train DavidBeans V2: Wormhole Routing Architecture |
|
|
=================================================== |
|
|
|
|
|
βββββββββββββββββββ |
|
|
β BEANS V2 β "I learn where to look..." |
|
|
β (Wormhole ViT)β |
|
|
β π β π β π β Learned sparse routing |
|
|
ββββββββββ¬βββββββββ |
|
|
β |
|
|
βΌ |
|
|
βββββββββββββββββββ |
|
|
β DAVID β "I know the crystals..." |
|
|
β (Classifier) β |
|
|
β π β π β π β Multi-scale projection |
|
|
ββββββββββ¬βββββββββ |
|
|
β |
|
|
βΌ |
|
|
[Prediction] |
|
|
|
|
|
Key findings from wormhole experiments: |
|
|
1. When routing IS the task, routing learns structure |
|
|
2. Auxiliary losses can be gamed - removed in V2 |
|
|
3. Gradient flow through router is critical - verified |
|
|
4. Cross-contrastive aligns patchβscale features |
|
|
|
|
|
Author: AbstractPhil |
|
|
Date: November 29, 2025 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR |
|
|
from tqdm.auto import tqdm |
|
|
import time |
|
|
import math |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Tuple, List, Union |
|
|
from dataclasses import dataclass, field |
|
|
import json |
|
|
from datetime import datetime |
|
|
import os |
|
|
import shutil |
|
|
|
|
|
from google.colab import userdata |
|
|
|
|
|
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN') |
|
|
HF_TOKEN = userdata.get('HF_TOKEN') |
|
|
|
|
|
try: |
|
|
from google.colab import userdata |
|
|
HF_TOKEN = userdata.get('HF_TOKEN') |
|
|
os.environ['HF_TOKEN'] = HF_TOKEN |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig |
|
|
from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo, upload_folder |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub") |
|
|
|
|
|
|
|
|
try: |
|
|
from safetensors.torch import save_file as save_safetensors |
|
|
SAFETENSORS_AVAILABLE = True |
|
|
except ImportError: |
|
|
SAFETENSORS_AVAILABLE = False |
|
|
|
|
|
|
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
TENSORBOARD_AVAILABLE = True |
|
|
except ImportError: |
|
|
TENSORBOARD_AVAILABLE = False |
|
|
print(" [!] tensorboard not installed. Run: pip install tensorboard") |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfigV2: |
|
|
"""Training configuration for DavidBeans V2 with wormhole routing.""" |
|
|
|
|
|
|
|
|
run_name: str = "default" |
|
|
run_number: Optional[int] = None |
|
|
|
|
|
|
|
|
model_version: int = 2 |
|
|
|
|
|
|
|
|
dataset: str = "cifar100" |
|
|
image_size: int = 32 |
|
|
batch_size: int = 128 |
|
|
num_workers: int = 4 |
|
|
|
|
|
|
|
|
epochs: int = 200 |
|
|
warmup_epochs: int = 10 |
|
|
|
|
|
|
|
|
learning_rate: float = 3e-4 |
|
|
weight_decay: float = 0.05 |
|
|
betas: Tuple[float, float] = (0.9, 0.999) |
|
|
|
|
|
|
|
|
scheduler: str = "cosine" |
|
|
min_lr: float = 1e-6 |
|
|
|
|
|
|
|
|
ce_weight: float = 1.0 |
|
|
contrast_weight: float = 0.5 |
|
|
|
|
|
|
|
|
|
|
|
gradient_clip: float = 1.0 |
|
|
label_smoothing: float = 0.1 |
|
|
|
|
|
|
|
|
use_augmentation: bool = True |
|
|
mixup_alpha: float = 0.2 |
|
|
cutmix_alpha: float = 1.0 |
|
|
|
|
|
|
|
|
save_interval: int = 10 |
|
|
output_dir: str = "./checkpoints" |
|
|
resume_from: Optional[str] = None |
|
|
|
|
|
|
|
|
use_tensorboard: bool = True |
|
|
log_interval: int = 50 |
|
|
log_routing: bool = True |
|
|
|
|
|
|
|
|
push_to_hub: bool = False |
|
|
hub_repo_id: str = "AbstractPhil/geovit-david-beans" |
|
|
hub_private: bool = False |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return {k: v for k, v in self.__dict__.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RoutingMetrics: |
|
|
"""Track and analyze wormhole routing patterns.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.route_entropies = [] |
|
|
self.route_diversities = [] |
|
|
self.grad_norms = {'query': [], 'key': []} |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_route_entropy(self, soft_routes: torch.Tensor) -> float: |
|
|
"""Compute average entropy of routing distributions.""" |
|
|
|
|
|
|
|
|
eps = 1e-8 |
|
|
entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1) |
|
|
return entropy.mean().item() |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float: |
|
|
"""Compute how many unique destinations are used.""" |
|
|
|
|
|
unique_per_sample = [] |
|
|
for b in range(routes.shape[0]): |
|
|
unique = routes[b].unique().numel() |
|
|
unique_per_sample.append(unique / num_positions) |
|
|
return sum(unique_per_sample) / len(unique_per_sample) |
|
|
|
|
|
def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module): |
|
|
"""Extract metrics from routing info returned by V2 model.""" |
|
|
if not routing_info: |
|
|
return |
|
|
|
|
|
for layer_info in routing_info: |
|
|
|
|
|
if layer_info.get('attention'): |
|
|
attn = layer_info['attention'] |
|
|
if attn.get('weights') is not None: |
|
|
entropy = self.compute_route_entropy(attn['weights']) |
|
|
self.route_entropies.append(entropy) |
|
|
if attn.get('routes') is not None: |
|
|
P = attn['routes'].shape[1] |
|
|
diversity = self.compute_route_diversity(attn['routes'], P) |
|
|
self.route_diversities.append(diversity) |
|
|
|
|
|
|
|
|
if layer_info.get('expert'): |
|
|
exp = layer_info['expert'] |
|
|
if exp.get('weights') is not None: |
|
|
entropy = self.compute_route_entropy(exp['weights']) |
|
|
self.route_entropies.append(entropy) |
|
|
|
|
|
def update_grad_norms(self, model: nn.Module): |
|
|
"""Track gradient norms through router projections.""" |
|
|
for name, param in model.named_parameters(): |
|
|
if param.grad is not None: |
|
|
if 'query_proj' in name and 'weight' in name: |
|
|
self.grad_norms['query'].append(param.grad.norm().item()) |
|
|
elif 'key_proj' in name and 'weight' in name: |
|
|
self.grad_norms['key'].append(param.grad.norm().item()) |
|
|
|
|
|
def get_summary(self) -> Dict[str, float]: |
|
|
"""Get summary statistics.""" |
|
|
summary = {} |
|
|
|
|
|
if self.route_entropies: |
|
|
summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies) |
|
|
if self.route_diversities: |
|
|
summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities) |
|
|
if self.grad_norms['query']: |
|
|
summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query']) |
|
|
if self.grad_norms['key']: |
|
|
summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key']) |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: |
|
|
"""Get train and test dataloaders.""" |
|
|
|
|
|
try: |
|
|
import torchvision |
|
|
import torchvision.transforms as T |
|
|
|
|
|
if config.dataset == "cifar10": |
|
|
mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) |
|
|
num_classes = 10 |
|
|
DatasetClass = torchvision.datasets.CIFAR10 |
|
|
elif config.dataset == "cifar100": |
|
|
mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) |
|
|
num_classes = 100 |
|
|
DatasetClass = torchvision.datasets.CIFAR100 |
|
|
else: |
|
|
raise ValueError(f"Unknown dataset: {config.dataset}") |
|
|
|
|
|
if config.use_augmentation: |
|
|
train_transform = T.Compose([ |
|
|
T.RandomCrop(32, padding=4), |
|
|
T.RandomHorizontalFlip(), |
|
|
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean, std) |
|
|
]) |
|
|
else: |
|
|
train_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean, std) |
|
|
]) |
|
|
|
|
|
test_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean, std) |
|
|
]) |
|
|
|
|
|
train_dataset = DatasetClass( |
|
|
root='./data', train=True, download=True, transform=train_transform |
|
|
) |
|
|
test_dataset = DatasetClass( |
|
|
root='./data', train=False, download=True, transform=test_transform |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=config.num_workers > 0, |
|
|
drop_last=True |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=config.num_workers > 0 |
|
|
) |
|
|
|
|
|
return train_loader, test_loader, num_classes |
|
|
|
|
|
except ImportError: |
|
|
print(" [!] torchvision not available, using synthetic data") |
|
|
return get_synthetic_dataloaders(config) |
|
|
|
|
|
|
|
|
def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: |
|
|
"""Fallback synthetic data for testing.""" |
|
|
|
|
|
class SyntheticDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, size: int, image_size: int, num_classes: int): |
|
|
self.size = size |
|
|
self.image_size = image_size |
|
|
self.num_classes = num_classes |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
x = torch.randn(3, self.image_size, self.image_size) |
|
|
y = idx % self.num_classes |
|
|
return x, y |
|
|
|
|
|
num_classes = 100 if config.dataset == "cifar100" else 10 |
|
|
train_dataset = SyntheticDataset(5000, config.image_size, num_classes) |
|
|
test_dataset = SyntheticDataset(1000, config.image_size, num_classes) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) |
|
|
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) |
|
|
|
|
|
return train_loader, test_loader, num_classes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2): |
|
|
"""Mixup augmentation.""" |
|
|
if alpha > 0: |
|
|
lam = torch.distributions.Beta(alpha, alpha).sample().item() |
|
|
else: |
|
|
lam = 1.0 |
|
|
|
|
|
batch_size = x.size(0) |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
mixed_x = lam * x + (1 - lam) * x[index] |
|
|
y_a, y_b = y, y[index] |
|
|
|
|
|
return mixed_x, y_a, y_b, lam |
|
|
|
|
|
|
|
|
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): |
|
|
"""CutMix augmentation.""" |
|
|
if alpha > 0: |
|
|
lam = torch.distributions.Beta(alpha, alpha).sample().item() |
|
|
else: |
|
|
lam = 1.0 |
|
|
|
|
|
batch_size = x.size(0) |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
_, _, H, W = x.shape |
|
|
|
|
|
cut_ratio = math.sqrt(1 - lam) |
|
|
cut_h = int(H * cut_ratio) |
|
|
cut_w = int(W * cut_ratio) |
|
|
|
|
|
cx = torch.randint(0, H, (1,)).item() |
|
|
cy = torch.randint(0, W, (1,)).item() |
|
|
|
|
|
x1 = max(0, cx - cut_h // 2) |
|
|
x2 = min(H, cx + cut_h // 2) |
|
|
y1 = max(0, cy - cut_w // 2) |
|
|
y2 = min(W, cy + cut_w // 2) |
|
|
|
|
|
mixed_x = x.clone() |
|
|
mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2] |
|
|
|
|
|
lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W) |
|
|
|
|
|
y_a, y_b = y, y[index] |
|
|
|
|
|
return mixed_x, y_a, y_b, lam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetricsTracker: |
|
|
"""Track training metrics with EMA smoothing.""" |
|
|
|
|
|
def __init__(self, ema_decay: float = 0.9): |
|
|
self.ema_decay = ema_decay |
|
|
self.metrics = {} |
|
|
self.ema_metrics = {} |
|
|
self.history = {} |
|
|
|
|
|
def update(self, **kwargs): |
|
|
for k, v in kwargs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
v = v.item() |
|
|
|
|
|
if k not in self.metrics: |
|
|
self.metrics[k] = [] |
|
|
self.ema_metrics[k] = v |
|
|
self.history[k] = [] |
|
|
|
|
|
self.metrics[k].append(v) |
|
|
self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v |
|
|
|
|
|
def get_ema(self, key: str) -> float: |
|
|
return self.ema_metrics.get(key, 0.0) |
|
|
|
|
|
def get_epoch_mean(self, key: str) -> float: |
|
|
values = self.metrics.get(key, []) |
|
|
return sum(values) / len(values) if values else 0.0 |
|
|
|
|
|
def end_epoch(self): |
|
|
for k, v in self.metrics.items(): |
|
|
if v: |
|
|
self.history[k].append(sum(v) / len(v)) |
|
|
self.metrics = {k: [] for k in self.metrics} |
|
|
|
|
|
def get_history(self) -> Dict: |
|
|
return self.history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_latest_checkpoint(output_dir: Path) -> Optional[Path]: |
|
|
"""Find the most recent checkpoint in output directory.""" |
|
|
checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt")) |
|
|
|
|
|
if not checkpoints: |
|
|
best_model = output_dir / "best_model.pt" |
|
|
if best_model.exists(): |
|
|
return best_model |
|
|
return None |
|
|
|
|
|
def get_epoch(p): |
|
|
try: |
|
|
return int(p.stem.split("_")[-1]) |
|
|
except: |
|
|
return 0 |
|
|
|
|
|
checkpoints.sort(key=get_epoch, reverse=True) |
|
|
return checkpoints[0] |
|
|
|
|
|
|
|
|
def get_next_run_number(base_dir: Path) -> int: |
|
|
"""Get the next run number by scanning existing run directories.""" |
|
|
if not base_dir.exists(): |
|
|
return 1 |
|
|
|
|
|
max_num = 0 |
|
|
for d in base_dir.iterdir(): |
|
|
if d.is_dir() and d.name.startswith("run_"): |
|
|
try: |
|
|
num = int(d.name.split("_")[1]) |
|
|
max_num = max(max_num, num) |
|
|
except (IndexError, ValueError): |
|
|
continue |
|
|
|
|
|
return max_num + 1 |
|
|
|
|
|
|
|
|
def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str: |
|
|
"""Generate a run directory name.""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower()) |
|
|
safe_name = "_".join(filter(None, safe_name.split("_"))) |
|
|
return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}" |
|
|
|
|
|
|
|
|
def find_latest_run_dir(base_dir: Path) -> Optional[Path]: |
|
|
"""Find the most recent run directory.""" |
|
|
if not base_dir.exists(): |
|
|
return None |
|
|
|
|
|
run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")] |
|
|
|
|
|
if not run_dirs: |
|
|
return None |
|
|
|
|
|
run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True) |
|
|
return run_dirs[0] |
|
|
|
|
|
|
|
|
def load_checkpoint( |
|
|
checkpoint_path: Path, |
|
|
model: nn.Module, |
|
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
device: str = "cuda" |
|
|
) -> Tuple[int, float]: |
|
|
"""Load checkpoint and return (start_epoch, best_acc).""" |
|
|
print(f"\nπ Loading checkpoint: {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f" β Loaded model weights") |
|
|
|
|
|
if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
print(f" β Loaded optimizer state") |
|
|
|
|
|
epoch = checkpoint.get('epoch', 0) |
|
|
best_acc = checkpoint.get('best_acc', 0.0) |
|
|
|
|
|
print(f" β Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%") |
|
|
|
|
|
return epoch + 1, best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_run_readme( |
|
|
model_config: Union[DavidBeansConfig, DavidBeansV2Config], |
|
|
train_config: TrainingConfigV2, |
|
|
best_acc: float, |
|
|
run_dir_name: str |
|
|
) -> str: |
|
|
"""Generate README for a specific run.""" |
|
|
|
|
|
scales_str = ", ".join([str(s) for s in model_config.scales]) |
|
|
|
|
|
|
|
|
if isinstance(model_config, DavidBeansV2Config): |
|
|
routing_info = f""" |
|
|
## Wormhole Routing (V2) |
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Mode | {model_config.wormhole_mode} | |
|
|
| Wormholes/Position | {model_config.num_wormholes} | |
|
|
| Temperature | {model_config.wormhole_temperature} | |
|
|
| Tiles | {model_config.num_tiles} | |
|
|
| Tile Wormholes | {model_config.tile_wormholes} | |
|
|
""" |
|
|
else: |
|
|
routing_info = """ |
|
|
## Routing (V1) |
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| k_neighbors | {model_config.k_neighbors} | |
|
|
| Cantor Weight | {model_config.cantor_weight} | |
|
|
""" |
|
|
|
|
|
return f"""# Run: {run_dir_name} |
|
|
|
|
|
## Results |
|
|
- **Best Accuracy**: {best_acc:.2f}% |
|
|
- **Dataset**: {train_config.dataset} |
|
|
- **Epochs**: {train_config.epochs} |
|
|
- **Model Version**: V{train_config.model_version} |
|
|
|
|
|
## Model Config |
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Dim | {model_config.dim} | |
|
|
| Layers | {model_config.num_layers} | |
|
|
| Heads | {model_config.num_heads} | |
|
|
| Scales | [{scales_str}] | |
|
|
{routing_info} |
|
|
|
|
|
## Training Config |
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Learning Rate | {train_config.learning_rate} | |
|
|
| Weight Decay | {train_config.weight_decay} | |
|
|
| Batch Size | {train_config.batch_size} | |
|
|
| CE Weight | {train_config.ce_weight} | |
|
|
| Contrast Weight | {train_config.contrast_weight} | |
|
|
|
|
|
## Key Findings Applied |
|
|
- Routing learns from task pressure (no auxiliary routing losses) |
|
|
- Gradients verified to flow through router |
|
|
- Cross-contrastive aligns patchβscale features |
|
|
""" |
|
|
|
|
|
|
|
|
def prepare_run_for_hub( |
|
|
model: nn.Module, |
|
|
model_config: Union[DavidBeansConfig, DavidBeansV2Config], |
|
|
train_config: TrainingConfigV2, |
|
|
best_acc: float, |
|
|
output_dir: Path, |
|
|
run_dir_name: str, |
|
|
training_history: Optional[Dict] = None |
|
|
) -> Path: |
|
|
"""Prepare run files for upload to HuggingFace Hub.""" |
|
|
|
|
|
hub_dir = output_dir / "hub_upload" |
|
|
run_hub_dir = hub_dir / "weights" / run_dir_name |
|
|
run_hub_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
state_dict = {k: v.clone() for k, v in model.state_dict().items()} |
|
|
|
|
|
if SAFETENSORS_AVAILABLE: |
|
|
try: |
|
|
save_safetensors(state_dict, run_hub_dir / "best.safetensors") |
|
|
print(f" β Saved best.safetensors") |
|
|
except Exception as e: |
|
|
print(f" [!] Safetensors failed ({e}), using pytorch format") |
|
|
torch.save(state_dict, run_hub_dir / "best.pt") |
|
|
else: |
|
|
torch.save(state_dict, run_hub_dir / "best.pt") |
|
|
|
|
|
|
|
|
config_dict = { |
|
|
"architecture": f"DavidBeans_V{train_config.model_version}", |
|
|
"model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans", |
|
|
**model_config.__dict__ |
|
|
} |
|
|
with open(run_hub_dir / "config.json", "w") as f: |
|
|
json.dump(config_dict, f, indent=2, default=str) |
|
|
|
|
|
|
|
|
with open(run_hub_dir / "training_config.json", "w") as f: |
|
|
json.dump(train_config.to_dict(), f, indent=2, default=str) |
|
|
|
|
|
|
|
|
run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name) |
|
|
with open(run_hub_dir / "README.md", "w") as f: |
|
|
f.write(run_readme) |
|
|
|
|
|
|
|
|
if training_history: |
|
|
with open(run_hub_dir / "training_history.json", "w") as f: |
|
|
json.dump(training_history, f, indent=2) |
|
|
|
|
|
|
|
|
tb_dir = output_dir / "tensorboard" |
|
|
if tb_dir.exists(): |
|
|
hub_tb_dir = run_hub_dir / "tensorboard" |
|
|
if hub_tb_dir.exists(): |
|
|
shutil.rmtree(hub_tb_dir) |
|
|
shutil.copytree(tb_dir, hub_tb_dir) |
|
|
|
|
|
return hub_dir |
|
|
|
|
|
|
|
|
def push_run_to_hub( |
|
|
hub_dir: Path, |
|
|
repo_id: str, |
|
|
run_dir_name: str, |
|
|
private: bool = False, |
|
|
commit_message: Optional[str] = None |
|
|
) -> str: |
|
|
"""Push run files to HuggingFace Hub.""" |
|
|
|
|
|
if not HF_HUB_AVAILABLE: |
|
|
raise RuntimeError("huggingface_hub not installed") |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
try: |
|
|
create_repo(repo_id, private=private, exist_ok=True) |
|
|
except Exception as e: |
|
|
print(f" [!] Repo creation note: {e}") |
|
|
|
|
|
run_upload_dir = hub_dir / "weights" / run_dir_name |
|
|
|
|
|
if commit_message is None: |
|
|
commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}" |
|
|
|
|
|
url = upload_folder( |
|
|
folder_path=str(run_upload_dir), |
|
|
repo_id=repo_id, |
|
|
path_in_repo=f"weights/{run_dir_name}", |
|
|
commit_message=commit_message |
|
|
) |
|
|
|
|
|
return url |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch_v2( |
|
|
model: nn.Module, |
|
|
train_loader: DataLoader, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], |
|
|
config: TrainingConfigV2, |
|
|
epoch: int, |
|
|
tracker: MetricsTracker, |
|
|
routing_metrics: RoutingMetrics, |
|
|
writer: Optional['SummaryWriter'] = None |
|
|
) -> Dict[str, float]: |
|
|
"""Train for one epoch with V2 routing metrics.""" |
|
|
|
|
|
model.train() |
|
|
device = config.device |
|
|
is_v2 = config.model_version == 2 |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
global_step = epoch * len(train_loader) |
|
|
|
|
|
routing_metrics.reset() |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True) |
|
|
|
|
|
for batch_idx, (images, targets) in enumerate(pbar): |
|
|
images = images.to(device, non_blocking=True) |
|
|
targets = targets.to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
use_mixup = config.use_augmentation and config.mixup_alpha > 0 |
|
|
use_cutmix = config.use_augmentation and config.cutmix_alpha > 0 |
|
|
|
|
|
mixed = False |
|
|
if use_mixup or use_cutmix: |
|
|
r = torch.rand(1).item() |
|
|
if r < 0.5: |
|
|
pass |
|
|
elif r < 0.75 and use_mixup: |
|
|
images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha) |
|
|
mixed = True |
|
|
elif use_cutmix: |
|
|
images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha) |
|
|
mixed = True |
|
|
|
|
|
|
|
|
if is_v2: |
|
|
result = model( |
|
|
images, |
|
|
targets=targets, |
|
|
return_loss=True, |
|
|
return_routing=(batch_idx % 10 == 0) |
|
|
) |
|
|
else: |
|
|
result = model(images, targets=targets, return_loss=True) |
|
|
|
|
|
losses = result['losses'] |
|
|
|
|
|
|
|
|
if mixed: |
|
|
logits = result['logits'] |
|
|
ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \ |
|
|
(1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing) |
|
|
losses['ce'] = ce_loss |
|
|
|
|
|
|
|
|
loss = ( |
|
|
config.ce_weight * losses['ce'] + |
|
|
config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device)) |
|
|
) |
|
|
|
|
|
|
|
|
for scale in model.config.scales: |
|
|
scale_ce = losses.get(f'ce_{scale}', 0.0) |
|
|
if isinstance(scale_ce, torch.Tensor): |
|
|
loss = loss + 0.1 * scale_ce |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if is_v2: |
|
|
routing_metrics.update_grad_norms(model) |
|
|
|
|
|
if config.gradient_clip > 0: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip) |
|
|
else: |
|
|
grad_norm = 0.0 |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
if scheduler is not None and config.scheduler == "onecycle": |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
if is_v2 and result.get('routing'): |
|
|
routing_metrics.update_from_routing_info(result['routing'], model) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = result['logits'] |
|
|
preds = logits.argmax(dim=-1) |
|
|
|
|
|
if mixed: |
|
|
correct = (lam * (preds == targets_a).float() + |
|
|
(1 - lam) * (preds == targets_b).float()).sum() |
|
|
else: |
|
|
correct = (preds == targets).sum() |
|
|
|
|
|
total_correct += correct.item() |
|
|
total_samples += targets.size(0) |
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
def to_float(v): |
|
|
return v.item() if isinstance(v, torch.Tensor) else float(v) |
|
|
|
|
|
contrast_loss = to_float(losses.get('contrast', 0.0)) |
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
tracker.update( |
|
|
loss=loss.item(), |
|
|
ce=losses['ce'].item(), |
|
|
contrast=contrast_loss, |
|
|
lr=current_lr |
|
|
) |
|
|
|
|
|
|
|
|
if writer is not None and (batch_idx + 1) % config.log_interval == 0: |
|
|
step = global_step + batch_idx |
|
|
writer.add_scalar('train/loss_total', loss.item(), step) |
|
|
writer.add_scalar('train/loss_ce', losses['ce'].item(), step) |
|
|
writer.add_scalar('train/loss_contrast', contrast_loss, step) |
|
|
writer.add_scalar('train/learning_rate', current_lr, step) |
|
|
writer.add_scalar('train/grad_norm', to_float(grad_norm), step) |
|
|
|
|
|
|
|
|
if is_v2 and config.log_routing: |
|
|
routing_summary = routing_metrics.get_summary() |
|
|
for k, v in routing_summary.items(): |
|
|
writer.add_scalar(f'routing/{k}', v, step) |
|
|
|
|
|
|
|
|
routing_summary = routing_metrics.get_summary() |
|
|
postfix = { |
|
|
'loss': f"{tracker.get_ema('loss'):.3f}", |
|
|
'acc': f"{100.0 * total_correct / total_samples:.1f}%", |
|
|
} |
|
|
if is_v2 and 'grad_query' in routing_summary: |
|
|
postfix['βq'] = f"{routing_summary['grad_query']:.2f}" |
|
|
if 'route_entropy' in routing_summary: |
|
|
postfix['H'] = f"{routing_summary['route_entropy']:.2f}" |
|
|
|
|
|
pbar.set_postfix(postfix) |
|
|
|
|
|
if scheduler is not None and config.scheduler == "cosine": |
|
|
scheduler.step() |
|
|
|
|
|
return { |
|
|
'loss': total_loss / len(train_loader), |
|
|
'acc': 100.0 * total_correct / total_samples, |
|
|
**routing_metrics.get_summary() |
|
|
} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate_v2( |
|
|
model: nn.Module, |
|
|
test_loader: DataLoader, |
|
|
config: TrainingConfigV2 |
|
|
) -> Dict[str, float]: |
|
|
"""Evaluate on test set.""" |
|
|
|
|
|
model.eval() |
|
|
device = config.device |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
scale_correct = {s: 0 for s in model.config.scales} |
|
|
|
|
|
for images, targets in test_loader: |
|
|
images = images.to(device, non_blocking=True) |
|
|
targets = targets.to(device, non_blocking=True) |
|
|
|
|
|
result = model(images, targets=targets, return_loss=True) |
|
|
|
|
|
logits = result['logits'] |
|
|
losses = result['losses'] |
|
|
|
|
|
loss = losses['total'] |
|
|
preds = logits.argmax(dim=-1) |
|
|
|
|
|
total_loss += loss.item() * targets.size(0) |
|
|
total_correct += (preds == targets).sum().item() |
|
|
total_samples += targets.size(0) |
|
|
|
|
|
for i, scale in enumerate(model.config.scales): |
|
|
scale_logits = result['scale_logits'][i] |
|
|
scale_preds = scale_logits.argmax(dim=-1) |
|
|
scale_correct[scale] += (scale_preds == targets).sum().item() |
|
|
|
|
|
metrics = { |
|
|
'loss': total_loss / total_samples, |
|
|
'acc': 100.0 * total_correct / total_samples |
|
|
} |
|
|
|
|
|
for scale, correct in scale_correct.items(): |
|
|
metrics[f'acc_{scale}'] = 100.0 * correct / total_samples |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_david_beans_v2( |
|
|
model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None, |
|
|
train_config: Optional[TrainingConfigV2] = None |
|
|
): |
|
|
"""Main training function for DavidBeans V1 or V2.""" |
|
|
|
|
|
print("=" * 70) |
|
|
print(" DAVID-BEANS V2 TRAINING: Wormhole Routing") |
|
|
print("=" * 70) |
|
|
print() |
|
|
print(" π WORMHOLES: Learned sparse routing") |
|
|
print(" π CRYSTALS: Multi-scale projection") |
|
|
print() |
|
|
print(" Key insight: When routing IS the task, routing learns structure") |
|
|
print() |
|
|
print("=" * 70) |
|
|
|
|
|
if train_config is None: |
|
|
train_config = TrainingConfigV2() |
|
|
|
|
|
base_output_dir = Path(train_config.output_dir) |
|
|
base_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_path = None |
|
|
run_dir = None |
|
|
run_dir_name = None |
|
|
|
|
|
if train_config.resume_from: |
|
|
resume_path = Path(train_config.resume_from) |
|
|
|
|
|
|
|
|
if resume_path.is_file(): |
|
|
checkpoint_path = resume_path |
|
|
run_dir = checkpoint_path.parent |
|
|
run_dir_name = run_dir.name |
|
|
print(f"\nπ Found checkpoint file: {checkpoint_path.name}") |
|
|
|
|
|
|
|
|
elif resume_path.is_dir(): |
|
|
checkpoint_path = find_latest_checkpoint(resume_path) |
|
|
if checkpoint_path: |
|
|
run_dir = resume_path |
|
|
run_dir_name = resume_path.name |
|
|
print(f"\nπ Found checkpoint in dir: {checkpoint_path.name}") |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
possible_dir = base_output_dir / train_config.resume_from |
|
|
if possible_dir.is_dir(): |
|
|
checkpoint_path = find_latest_checkpoint(possible_dir) |
|
|
if checkpoint_path: |
|
|
run_dir = possible_dir |
|
|
run_dir_name = possible_dir.name |
|
|
print(f"\nπ Found checkpoint in: {run_dir_name}") |
|
|
|
|
|
|
|
|
if checkpoint_path is None: |
|
|
possible_file = base_output_dir / train_config.resume_from |
|
|
if possible_file.is_file(): |
|
|
checkpoint_path = possible_file |
|
|
run_dir = checkpoint_path.parent |
|
|
run_dir_name = run_dir.name |
|
|
print(f"\nπ Found checkpoint: {checkpoint_path.name}") |
|
|
|
|
|
|
|
|
if checkpoint_path is None: |
|
|
print(f"\n [!] Could not find checkpoint: {train_config.resume_from}") |
|
|
print(f" [!] Checked:") |
|
|
print(f" - As file: {resume_path}") |
|
|
print(f" - As dir: {resume_path}") |
|
|
print(f" - Under {base_output_dir}") |
|
|
print(f" [!] Starting fresh run instead") |
|
|
else: |
|
|
print(f" β Will resume from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
if run_dir is None: |
|
|
run_number = train_config.run_number or get_next_run_number(base_output_dir) |
|
|
run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version) |
|
|
run_dir = base_output_dir / run_dir_name |
|
|
run_dir.mkdir(parents=True, exist_ok=True) |
|
|
print(f"\nπ New run: {run_dir_name}") |
|
|
else: |
|
|
print(f"\nπ Resuming run: {run_dir_name}") |
|
|
|
|
|
output_dir = run_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if checkpoint_path and checkpoint_path.exists() and model_config is None: |
|
|
|
|
|
try: |
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
|
if 'model_config' in ckpt: |
|
|
saved_config = ckpt['model_config'] |
|
|
print(f" β Loading model config from checkpoint") |
|
|
if train_config.model_version == 2: |
|
|
model_config = DavidBeansV2Config(**saved_config) |
|
|
else: |
|
|
model_config = DavidBeansConfig(**saved_config) |
|
|
except Exception as e: |
|
|
print(f" [!] Could not load config from checkpoint: {e}") |
|
|
|
|
|
|
|
|
if model_config is None: |
|
|
if train_config.model_version == 2: |
|
|
model_config = DavidBeansV2Config( |
|
|
image_size=train_config.image_size, |
|
|
patch_size=4, |
|
|
dim=512, |
|
|
num_layers=4, |
|
|
num_heads=8, |
|
|
num_wormholes=8, |
|
|
wormhole_temperature=0.1, |
|
|
wormhole_mode="hybrid", |
|
|
num_tiles=16, |
|
|
tile_wormholes=4, |
|
|
scales=[64, 128, 256, 384, 512], |
|
|
num_classes=100, |
|
|
contrast_weight=train_config.contrast_weight, |
|
|
dropout=0.1 |
|
|
) |
|
|
else: |
|
|
model_config = DavidBeansConfig( |
|
|
image_size=train_config.image_size, |
|
|
patch_size=4, |
|
|
dim=512, |
|
|
num_layers=4, |
|
|
num_heads=8, |
|
|
num_experts=5, |
|
|
k_neighbors=16, |
|
|
cantor_weight=0.3, |
|
|
scales=[64, 128, 256, 384, 512], |
|
|
num_classes=100, |
|
|
dropout=0.1 |
|
|
) |
|
|
|
|
|
device = train_config.device |
|
|
print(f"\nDevice: {device}") |
|
|
print(f"Model version: V{train_config.model_version}") |
|
|
|
|
|
|
|
|
print("\nLoading data...") |
|
|
train_loader, test_loader, num_classes = get_dataloaders(train_config) |
|
|
print(f" Dataset: {train_config.dataset}") |
|
|
print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") |
|
|
print(f" Classes: {num_classes}") |
|
|
|
|
|
model_config.num_classes = num_classes |
|
|
|
|
|
|
|
|
print("\nBuilding model...") |
|
|
if train_config.model_version == 2: |
|
|
model = DavidBeansV2(model_config) |
|
|
else: |
|
|
model = DavidBeans(model_config) |
|
|
|
|
|
model = model.to(device) |
|
|
print(f"\n{model}") |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"\nParameters: {num_params:,}") |
|
|
|
|
|
|
|
|
print("\nSetting up optimizer...") |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if 'bias' in name or 'norm' in name or 'embedding' in name: |
|
|
no_decay_params.append(param) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
|
|
|
optimizer = AdamW([ |
|
|
{'params': decay_params, 'weight_decay': train_config.weight_decay}, |
|
|
{'params': no_decay_params, 'weight_decay': 0.0} |
|
|
], lr=train_config.learning_rate, betas=train_config.betas) |
|
|
|
|
|
if train_config.scheduler == "cosine": |
|
|
scheduler = CosineAnnealingLR( |
|
|
optimizer, |
|
|
T_max=train_config.epochs - train_config.warmup_epochs, |
|
|
eta_min=train_config.min_lr |
|
|
) |
|
|
elif train_config.scheduler == "onecycle": |
|
|
scheduler = OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=train_config.learning_rate, |
|
|
epochs=train_config.epochs, |
|
|
steps_per_epoch=len(train_loader), |
|
|
pct_start=train_config.warmup_epochs / train_config.epochs |
|
|
) |
|
|
else: |
|
|
scheduler = None |
|
|
|
|
|
print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})") |
|
|
print(f" Scheduler: {train_config.scheduler}") |
|
|
|
|
|
tracker = MetricsTracker() |
|
|
routing_metrics = RoutingMetrics() |
|
|
best_acc = 0.0 |
|
|
start_epoch = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if checkpoint_path and checkpoint_path.exists(): |
|
|
start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device) |
|
|
|
|
|
|
|
|
if scheduler is not None and train_config.scheduler == "cosine": |
|
|
for _ in range(start_epoch): |
|
|
scheduler.step() |
|
|
print(f" β Advanced scheduler to epoch {start_epoch}") |
|
|
|
|
|
|
|
|
writer = None |
|
|
if train_config.use_tensorboard and TENSORBOARD_AVAILABLE: |
|
|
tb_dir = output_dir / "tensorboard" |
|
|
tb_dir.mkdir(parents=True, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir=str(tb_dir)) |
|
|
print(f" TensorBoard: {tb_dir}") |
|
|
|
|
|
|
|
|
with open(output_dir / "config.json", "w") as f: |
|
|
json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"}, |
|
|
f, indent=2, default=str) |
|
|
with open(output_dir / "training_config.json", "w") as f: |
|
|
json.dump(train_config.to_dict(), f, indent=2, default=str) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" TRAINING") |
|
|
print("=" * 70) |
|
|
|
|
|
for epoch in range(start_epoch, train_config.epochs): |
|
|
epoch_start = time.time() |
|
|
|
|
|
|
|
|
if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine": |
|
|
warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = warmup_lr |
|
|
|
|
|
train_metrics = train_epoch_v2( |
|
|
model, train_loader, optimizer, scheduler, |
|
|
train_config, epoch, tracker, routing_metrics, writer |
|
|
) |
|
|
|
|
|
test_metrics = evaluate_v2(model, test_loader, train_config) |
|
|
|
|
|
epoch_time = time.time() - epoch_start |
|
|
|
|
|
|
|
|
if writer is not None: |
|
|
writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) |
|
|
writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch) |
|
|
writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch) |
|
|
writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch) |
|
|
|
|
|
for scale in model.config.scales: |
|
|
writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch) |
|
|
|
|
|
|
|
|
scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales]) |
|
|
star = "β
" if test_metrics['acc'] > best_acc else "" |
|
|
|
|
|
routing_info = "" |
|
|
if train_config.model_version == 2 and 'grad_query' in train_metrics: |
|
|
routing_info = f" | βq:{train_metrics.get('grad_query', 0):.2f}" |
|
|
|
|
|
print(f" β Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | " |
|
|
f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}") |
|
|
|
|
|
|
|
|
if test_metrics['acc'] > best_acc: |
|
|
best_acc = test_metrics['acc'] |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'best_acc': best_acc, |
|
|
'model_config': model_config.__dict__, |
|
|
'train_config': train_config.to_dict() |
|
|
}, output_dir / "best_model.pt") |
|
|
|
|
|
|
|
|
if (epoch + 1) % train_config.save_interval == 0: |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'best_acc': best_acc, |
|
|
'model_config': model_config.__dict__, |
|
|
'train_config': train_config.to_dict() |
|
|
}, output_dir / f"checkpoint_epoch_{epoch + 1}.pt") |
|
|
|
|
|
|
|
|
if train_config.push_to_hub and HF_HUB_AVAILABLE: |
|
|
try: |
|
|
hub_dir = prepare_run_for_hub( |
|
|
model=model, |
|
|
model_config=model_config, |
|
|
train_config=train_config, |
|
|
best_acc=best_acc, |
|
|
output_dir=output_dir, |
|
|
run_dir_name=run_dir_name, |
|
|
training_history=tracker.get_history() |
|
|
) |
|
|
push_run_to_hub( |
|
|
hub_dir=hub_dir, |
|
|
repo_id=train_config.hub_repo_id, |
|
|
run_dir_name=run_dir_name, |
|
|
commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc" |
|
|
) |
|
|
print(f" π€ Uploaded to hub") |
|
|
except Exception as e: |
|
|
print(f" [!] Hub upload failed: {e}") |
|
|
|
|
|
tracker.end_epoch() |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" TRAINING COMPLETE") |
|
|
print("=" * 70) |
|
|
print(f"\n Best Test Accuracy: {best_acc:.2f}%") |
|
|
print(f" Model saved to: {output_dir / 'best_model.pt'}") |
|
|
|
|
|
if writer is not None: |
|
|
writer.close() |
|
|
|
|
|
return model, best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_cifar100_v2_wormhole( |
|
|
run_name: str = "wormhole_base", |
|
|
push_to_hub: bool = False, |
|
|
resume: bool = False |
|
|
): |
|
|
"""CIFAR-100 with V2 wormhole routing.""" |
|
|
|
|
|
model_config = DavidBeansV2Config( |
|
|
image_size=32, |
|
|
patch_size=2, |
|
|
dim=512, |
|
|
num_layers=4, |
|
|
num_heads=16, |
|
|
|
|
|
num_wormholes=16, |
|
|
wormhole_temperature=0.1, |
|
|
wormhole_mode="hybrid", |
|
|
|
|
|
num_tiles=16, |
|
|
tile_wormholes=4, |
|
|
|
|
|
scales=[64, 128, 256, 512, 1024], |
|
|
num_classes=100, |
|
|
contrast_temperature=0.07, |
|
|
contrast_weight=0.5, |
|
|
dropout=0.1 |
|
|
) |
|
|
|
|
|
train_config = TrainingConfigV2( |
|
|
run_name=run_name, |
|
|
model_version=2, |
|
|
dataset="cifar100", |
|
|
epochs=300, |
|
|
batch_size=512, |
|
|
learning_rate=3e-4, |
|
|
weight_decay=0.05, |
|
|
warmup_epochs=15, |
|
|
|
|
|
ce_weight=1.0, |
|
|
contrast_weight=0.5, |
|
|
|
|
|
label_smoothing=0.1, |
|
|
mixup_alpha=0.2, |
|
|
cutmix_alpha=1.0, |
|
|
|
|
|
output_dir="./checkpoints/cifar100_v2", |
|
|
resume_from=None, |
|
|
|
|
|
push_to_hub=push_to_hub, |
|
|
hub_repo_id="AbstractPhil/geovit-david-beans", |
|
|
|
|
|
log_routing=True |
|
|
) |
|
|
|
|
|
return train_david_beans_v2(model_config, train_config) |
|
|
|
|
|
|
|
|
def train_cifar100_v1_baseline( |
|
|
run_name: str = "v1_baseline", |
|
|
push_to_hub: bool = False, |
|
|
resume: bool = False |
|
|
): |
|
|
"""CIFAR-100 with V1 (fixed Cantor routing) for comparison.""" |
|
|
|
|
|
model_config = DavidBeansConfig( |
|
|
image_size=32, |
|
|
patch_size=4, |
|
|
dim=512, |
|
|
num_layers=4, |
|
|
num_heads=8, |
|
|
num_experts=5, |
|
|
k_neighbors=16, |
|
|
cantor_weight=0.3, |
|
|
scales=[64, 128, 256, 384, 512], |
|
|
num_classes=100, |
|
|
dropout=0.1 |
|
|
) |
|
|
|
|
|
train_config = TrainingConfigV2( |
|
|
run_name=run_name, |
|
|
model_version=1, |
|
|
dataset="cifar100", |
|
|
epochs=200, |
|
|
batch_size=128, |
|
|
learning_rate=3e-4, |
|
|
weight_decay=0.05, |
|
|
warmup_epochs=10, |
|
|
ce_weight=1.0, |
|
|
contrast_weight=0.5, |
|
|
label_smoothing=0.1, |
|
|
mixup_alpha=0.2, |
|
|
cutmix_alpha=1.0, |
|
|
output_dir="./checkpoints/cifar100_v1", |
|
|
resume_from="latest" if resume else None, |
|
|
push_to_hub=push_to_hub, |
|
|
hub_repo_id="AbstractPhil/geovit-david-beans", |
|
|
log_routing=False |
|
|
) |
|
|
|
|
|
return train_david_beans_v2(model_config, train_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRESET = "v2_wormhole" |
|
|
RESUME = False |
|
|
RUN_NAME = "5scale_2x2patch_4tilewormholes_d512_4layer" |
|
|
PUSH_TO_HUB = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if PRESET == "test": |
|
|
print("π§ͺ Quick test...") |
|
|
model_config = DavidBeansV2Config( |
|
|
image_size=32, patch_size=4, dim=128, num_layers=2, |
|
|
num_heads=4, num_wormholes=4, num_tiles=8, |
|
|
scales=[32, 64, 128], num_classes=10 |
|
|
) |
|
|
train_config = TrainingConfigV2( |
|
|
run_name="test", model_version=2, |
|
|
epochs=2, batch_size=32, |
|
|
use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0 |
|
|
) |
|
|
model, acc = train_david_beans_v2(model_config, train_config) |
|
|
|
|
|
elif PRESET == "v1_baseline": |
|
|
print("π«π Training DavidBeans V1 (Cantor routing)...") |
|
|
model, acc = train_cifar100_v1_baseline( |
|
|
run_name=RUN_NAME, |
|
|
push_to_hub=PUSH_TO_HUB, |
|
|
resume=RESUME |
|
|
) |
|
|
|
|
|
elif PRESET == "v2_wormhole": |
|
|
print("π Training DavidBeans V2 (Wormhole routing)...") |
|
|
model, acc = train_cifar100_v2_wormhole( |
|
|
run_name=RUN_NAME, |
|
|
push_to_hub=PUSH_TO_HUB, |
|
|
resume=RESUME |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown preset: {PRESET}") |
|
|
|
|
|
print(f"\nπ Done! Best accuracy: {acc:.2f}%") |