geovit-david-beans / trainer_v2_wormhole_routing.py
AbstractPhil's picture
Create trainer_v2_wormhole_routing.py
6640107 verified
"""
Train DavidBeans V2: Wormhole Routing Architecture
===================================================
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ BEANS V2 β”‚ "I learn where to look..."
β”‚ (Wormhole ViT)β”‚
β”‚ πŸŒ€ β†’ πŸŒ€ β†’ πŸŒ€ β”‚ Learned sparse routing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DAVID β”‚ "I know the crystals..."
β”‚ (Classifier) β”‚
β”‚ πŸ’Ž β†’ πŸ’Ž β†’ πŸ’Ž β”‚ Multi-scale projection
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
[Prediction]
Key findings from wormhole experiments:
1. When routing IS the task, routing learns structure
2. Auxiliary losses can be gamed - removed in V2
3. Gradient flow through router is critical - verified
4. Cross-contrastive aligns patch↔scale features
Author: AbstractPhil
Date: November 29, 2025
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from tqdm.auto import tqdm
import time
import math
from pathlib import Path
from typing import Dict, Optional, Tuple, List, Union
from dataclasses import dataclass, field
import json
from datetime import datetime
import os
import shutil
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
HF_TOKEN = userdata.get('HF_TOKEN')
try:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = HF_TOKEN
except:
pass
# Import both model versions
from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config
# HuggingFace Hub integration
try:
from huggingface_hub import HfApi, create_repo, upload_folder
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
# Safetensors support
try:
from safetensors.torch import save_file as save_safetensors
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
# TensorBoard support
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
print(" [!] tensorboard not installed. Run: pip install tensorboard")
import numpy as np
# ============================================================================
# TRAINING CONFIGURATION V2
# ============================================================================
@dataclass
class TrainingConfigV2:
"""Training configuration for DavidBeans V2 with wormhole routing."""
# Run identification
run_name: str = "default"
run_number: Optional[int] = None
# Model version
model_version: int = 2 # 1 = original, 2 = wormhole
# Data
dataset: str = "cifar100"
image_size: int = 32
batch_size: int = 128
num_workers: int = 4
# Training schedule
epochs: int = 200
warmup_epochs: int = 10
# Optimizer
learning_rate: float = 3e-4
weight_decay: float = 0.05
betas: Tuple[float, float] = (0.9, 0.999)
# Learning rate schedule
scheduler: str = "cosine"
min_lr: float = 1e-6
# Loss weights (based on experimental findings)
ce_weight: float = 1.0
contrast_weight: float = 0.5
# NOTE: No auxiliary routing loss - routing learns from task pressure
# Regularization
gradient_clip: float = 1.0
label_smoothing: float = 0.1
# Augmentation
use_augmentation: bool = True
mixup_alpha: float = 0.2
cutmix_alpha: float = 1.0
# Checkpointing
save_interval: int = 10
output_dir: str = "./checkpoints"
resume_from: Optional[str] = None
# TensorBoard
use_tensorboard: bool = True
log_interval: int = 50
log_routing: bool = True # Log routing patterns
# HuggingFace Hub
push_to_hub: bool = False
hub_repo_id: str = "AbstractPhil/geovit-david-beans"
hub_private: bool = False
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def to_dict(self) -> Dict:
return {k: v for k, v in self.__dict__.items()}
# ============================================================================
# ROUTING METRICS
# ============================================================================
class RoutingMetrics:
"""Track and analyze wormhole routing patterns."""
def __init__(self):
self.reset()
def reset(self):
self.route_entropies = []
self.route_diversities = []
self.grad_norms = {'query': [], 'key': []}
@torch.no_grad()
def compute_route_entropy(self, soft_routes: torch.Tensor) -> float:
"""Compute average entropy of routing distributions."""
# soft_routes: [B, P, K] or [B, T, K]
# Higher entropy = more diverse routing
eps = 1e-8
entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1)
return entropy.mean().item()
@torch.no_grad()
def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float:
"""Compute how many unique destinations are used."""
# routes: [B, P, K] indices
unique_per_sample = []
for b in range(routes.shape[0]):
unique = routes[b].unique().numel()
unique_per_sample.append(unique / num_positions)
return sum(unique_per_sample) / len(unique_per_sample)
def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module):
"""Extract metrics from routing info returned by V2 model."""
if not routing_info:
return
for layer_info in routing_info:
# Attention routing
if layer_info.get('attention'):
attn = layer_info['attention']
if attn.get('weights') is not None:
entropy = self.compute_route_entropy(attn['weights'])
self.route_entropies.append(entropy)
if attn.get('routes') is not None:
P = attn['routes'].shape[1]
diversity = self.compute_route_diversity(attn['routes'], P)
self.route_diversities.append(diversity)
# Expert routing
if layer_info.get('expert'):
exp = layer_info['expert']
if exp.get('weights') is not None:
entropy = self.compute_route_entropy(exp['weights'])
self.route_entropies.append(entropy)
def update_grad_norms(self, model: nn.Module):
"""Track gradient norms through router projections."""
for name, param in model.named_parameters():
if param.grad is not None:
if 'query_proj' in name and 'weight' in name:
self.grad_norms['query'].append(param.grad.norm().item())
elif 'key_proj' in name and 'weight' in name:
self.grad_norms['key'].append(param.grad.norm().item())
def get_summary(self) -> Dict[str, float]:
"""Get summary statistics."""
summary = {}
if self.route_entropies:
summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies)
if self.route_diversities:
summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities)
if self.grad_norms['query']:
summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query'])
if self.grad_norms['key']:
summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key'])
return summary
# ============================================================================
# DATA LOADING (unchanged from V1)
# ============================================================================
def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
"""Get train and test dataloaders."""
try:
import torchvision
import torchvision.transforms as T
if config.dataset == "cifar10":
mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
num_classes = 10
DatasetClass = torchvision.datasets.CIFAR10
elif config.dataset == "cifar100":
mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
num_classes = 100
DatasetClass = torchvision.datasets.CIFAR100
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
if config.use_augmentation:
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize(mean, std)
])
else:
train_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean, std)
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize(mean, std)
])
train_dataset = DatasetClass(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = DatasetClass(
root='./data', train=False, download=True, transform=test_transform
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0
)
return train_loader, test_loader, num_classes
except ImportError:
print(" [!] torchvision not available, using synthetic data")
return get_synthetic_dataloaders(config)
def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
"""Fallback synthetic data for testing."""
class SyntheticDataset(torch.utils.data.Dataset):
def __init__(self, size: int, image_size: int, num_classes: int):
self.size = size
self.image_size = image_size
self.num_classes = num_classes
def __len__(self):
return self.size
def __getitem__(self, idx):
x = torch.randn(3, self.image_size, self.image_size)
y = idx % self.num_classes
return x, y
num_classes = 100 if config.dataset == "cifar100" else 10
train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
return train_loader, test_loader, num_classes
# ============================================================================
# MIXUP / CUTMIX AUGMENTATION
# ============================================================================
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
"""Mixup augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
"""CutMix augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
_, _, H, W = x.shape
cut_ratio = math.sqrt(1 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
cx = torch.randint(0, H, (1,)).item()
cy = torch.randint(0, W, (1,)).item()
x1 = max(0, cx - cut_h // 2)
x2 = min(H, cx + cut_h // 2)
y1 = max(0, cy - cut_w // 2)
y2 = min(W, cy + cut_w // 2)
mixed_x = x.clone()
mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# ============================================================================
# METRICS TRACKER
# ============================================================================
class MetricsTracker:
"""Track training metrics with EMA smoothing."""
def __init__(self, ema_decay: float = 0.9):
self.ema_decay = ema_decay
self.metrics = {}
self.ema_metrics = {}
self.history = {}
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if k not in self.metrics:
self.metrics[k] = []
self.ema_metrics[k] = v
self.history[k] = []
self.metrics[k].append(v)
self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
def get_ema(self, key: str) -> float:
return self.ema_metrics.get(key, 0.0)
def get_epoch_mean(self, key: str) -> float:
values = self.metrics.get(key, [])
return sum(values) / len(values) if values else 0.0
def end_epoch(self):
for k, v in self.metrics.items():
if v:
self.history[k].append(sum(v) / len(v))
self.metrics = {k: [] for k in self.metrics}
def get_history(self) -> Dict:
return self.history
# ============================================================================
# CHECKPOINT UTILITIES
# ============================================================================
def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
"""Find the most recent checkpoint in output directory."""
checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
if not checkpoints:
best_model = output_dir / "best_model.pt"
if best_model.exists():
return best_model
return None
def get_epoch(p):
try:
return int(p.stem.split("_")[-1])
except:
return 0
checkpoints.sort(key=get_epoch, reverse=True)
return checkpoints[0]
def get_next_run_number(base_dir: Path) -> int:
"""Get the next run number by scanning existing run directories."""
if not base_dir.exists():
return 1
max_num = 0
for d in base_dir.iterdir():
if d.is_dir() and d.name.startswith("run_"):
try:
num = int(d.name.split("_")[1])
max_num = max(max_num, num)
except (IndexError, ValueError):
continue
return max_num + 1
def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str:
"""Generate a run directory name."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
safe_name = "_".join(filter(None, safe_name.split("_")))
return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}"
def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
"""Find the most recent run directory."""
if not base_dir.exists():
return None
run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
if not run_dirs:
return None
run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
return run_dirs[0]
def load_checkpoint(
checkpoint_path: Path,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
device: str = "cuda"
) -> Tuple[int, float]:
"""Load checkpoint and return (start_epoch, best_acc)."""
print(f"\nπŸ“‚ Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f" βœ“ Loaded model weights")
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f" βœ“ Loaded optimizer state")
epoch = checkpoint.get('epoch', 0)
best_acc = checkpoint.get('best_acc', 0.0)
print(f" βœ“ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
return epoch + 1, best_acc
# ============================================================================
# HUGGINGFACE HUB INTEGRATION
# ============================================================================
def generate_run_readme(
model_config: Union[DavidBeansConfig, DavidBeansV2Config],
train_config: TrainingConfigV2,
best_acc: float,
run_dir_name: str
) -> str:
"""Generate README for a specific run."""
scales_str = ", ".join([str(s) for s in model_config.scales])
# V2 specific info
if isinstance(model_config, DavidBeansV2Config):
routing_info = f"""
## Wormhole Routing (V2)
| Parameter | Value |
|-----------|-------|
| Mode | {model_config.wormhole_mode} |
| Wormholes/Position | {model_config.num_wormholes} |
| Temperature | {model_config.wormhole_temperature} |
| Tiles | {model_config.num_tiles} |
| Tile Wormholes | {model_config.tile_wormholes} |
"""
else:
routing_info = """
## Routing (V1)
| Parameter | Value |
|-----------|-------|
| k_neighbors | {model_config.k_neighbors} |
| Cantor Weight | {model_config.cantor_weight} |
"""
return f"""# Run: {run_dir_name}
## Results
- **Best Accuracy**: {best_acc:.2f}%
- **Dataset**: {train_config.dataset}
- **Epochs**: {train_config.epochs}
- **Model Version**: V{train_config.model_version}
## Model Config
| Parameter | Value |
|-----------|-------|
| Dim | {model_config.dim} |
| Layers | {model_config.num_layers} |
| Heads | {model_config.num_heads} |
| Scales | [{scales_str}] |
{routing_info}
## Training Config
| Parameter | Value |
|-----------|-------|
| Learning Rate | {train_config.learning_rate} |
| Weight Decay | {train_config.weight_decay} |
| Batch Size | {train_config.batch_size} |
| CE Weight | {train_config.ce_weight} |
| Contrast Weight | {train_config.contrast_weight} |
## Key Findings Applied
- Routing learns from task pressure (no auxiliary routing losses)
- Gradients verified to flow through router
- Cross-contrastive aligns patch↔scale features
"""
def prepare_run_for_hub(
model: nn.Module,
model_config: Union[DavidBeansConfig, DavidBeansV2Config],
train_config: TrainingConfigV2,
best_acc: float,
output_dir: Path,
run_dir_name: str,
training_history: Optional[Dict] = None
) -> Path:
"""Prepare run files for upload to HuggingFace Hub."""
hub_dir = output_dir / "hub_upload"
run_hub_dir = hub_dir / "weights" / run_dir_name
run_hub_dir.mkdir(parents=True, exist_ok=True)
# Save best model weights
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
if SAFETENSORS_AVAILABLE:
try:
save_safetensors(state_dict, run_hub_dir / "best.safetensors")
print(f" βœ“ Saved best.safetensors")
except Exception as e:
print(f" [!] Safetensors failed ({e}), using pytorch format")
torch.save(state_dict, run_hub_dir / "best.pt")
else:
torch.save(state_dict, run_hub_dir / "best.pt")
# Save model config
config_dict = {
"architecture": f"DavidBeans_V{train_config.model_version}",
"model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans",
**model_config.__dict__
}
with open(run_hub_dir / "config.json", "w") as f:
json.dump(config_dict, f, indent=2, default=str)
# Save training config
with open(run_hub_dir / "training_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
# Generate README
run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name)
with open(run_hub_dir / "README.md", "w") as f:
f.write(run_readme)
# Save training history
if training_history:
with open(run_hub_dir / "training_history.json", "w") as f:
json.dump(training_history, f, indent=2)
# Copy TensorBoard logs
tb_dir = output_dir / "tensorboard"
if tb_dir.exists():
hub_tb_dir = run_hub_dir / "tensorboard"
if hub_tb_dir.exists():
shutil.rmtree(hub_tb_dir)
shutil.copytree(tb_dir, hub_tb_dir)
return hub_dir
def push_run_to_hub(
hub_dir: Path,
repo_id: str,
run_dir_name: str,
private: bool = False,
commit_message: Optional[str] = None
) -> str:
"""Push run files to HuggingFace Hub."""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub not installed")
api = HfApi()
try:
create_repo(repo_id, private=private, exist_ok=True)
except Exception as e:
print(f" [!] Repo creation note: {e}")
run_upload_dir = hub_dir / "weights" / run_dir_name
if commit_message is None:
commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
url = upload_folder(
folder_path=str(run_upload_dir),
repo_id=repo_id,
path_in_repo=f"weights/{run_dir_name}",
commit_message=commit_message
)
return url
# ============================================================================
# TRAINING LOOP V2
# ============================================================================
def train_epoch_v2(
model: nn.Module,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
config: TrainingConfigV2,
epoch: int,
tracker: MetricsTracker,
routing_metrics: RoutingMetrics,
writer: Optional['SummaryWriter'] = None
) -> Dict[str, float]:
"""Train for one epoch with V2 routing metrics."""
model.train()
device = config.device
is_v2 = config.model_version == 2
total_loss = 0.0
total_correct = 0
total_samples = 0
global_step = epoch * len(train_loader)
routing_metrics.reset()
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
for batch_idx, (images, targets) in enumerate(pbar):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# Apply mixup/cutmix
use_mixup = config.use_augmentation and config.mixup_alpha > 0
use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
mixed = False
if use_mixup or use_cutmix:
r = torch.rand(1).item()
if r < 0.5:
pass
elif r < 0.75 and use_mixup:
images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
mixed = True
elif use_cutmix:
images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
mixed = True
# Forward pass
if is_v2:
result = model(
images,
targets=targets,
return_loss=True,
return_routing=(batch_idx % 10 == 0) # Sample routing every 10 batches
)
else:
result = model(images, targets=targets, return_loss=True)
losses = result['losses']
# Handle mixup CE loss
if mixed:
logits = result['logits']
ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
(1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
losses['ce'] = ce_loss
# Compute total loss (NO auxiliary routing loss - key finding!)
loss = (
config.ce_weight * losses['ce'] +
config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
)
# Add scale CE losses
for scale in model.config.scales:
scale_ce = losses.get(f'ce_{scale}', 0.0)
if isinstance(scale_ce, torch.Tensor):
loss = loss + 0.1 * scale_ce
# Backward pass
optimizer.zero_grad()
loss.backward()
# Track routing gradient norms (verify gradients flow!)
if is_v2:
routing_metrics.update_grad_norms(model)
if config.gradient_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
else:
grad_norm = 0.0
optimizer.step()
if scheduler is not None and config.scheduler == "onecycle":
scheduler.step()
# Update routing metrics from forward pass
if is_v2 and result.get('routing'):
routing_metrics.update_from_routing_info(result['routing'], model)
# Compute accuracy
with torch.no_grad():
logits = result['logits']
preds = logits.argmax(dim=-1)
if mixed:
correct = (lam * (preds == targets_a).float() +
(1 - lam) * (preds == targets_b).float()).sum()
else:
correct = (preds == targets).sum()
total_correct += correct.item()
total_samples += targets.size(0)
total_loss += loss.item()
# Track metrics
def to_float(v):
return v.item() if isinstance(v, torch.Tensor) else float(v)
contrast_loss = to_float(losses.get('contrast', 0.0))
current_lr = optimizer.param_groups[0]['lr']
tracker.update(
loss=loss.item(),
ce=losses['ce'].item(),
contrast=contrast_loss,
lr=current_lr
)
# TensorBoard logging
if writer is not None and (batch_idx + 1) % config.log_interval == 0:
step = global_step + batch_idx
writer.add_scalar('train/loss_total', loss.item(), step)
writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
writer.add_scalar('train/loss_contrast', contrast_loss, step)
writer.add_scalar('train/learning_rate', current_lr, step)
writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
# Log routing metrics for V2
if is_v2 and config.log_routing:
routing_summary = routing_metrics.get_summary()
for k, v in routing_summary.items():
writer.add_scalar(f'routing/{k}', v, step)
# Progress bar
routing_summary = routing_metrics.get_summary()
postfix = {
'loss': f"{tracker.get_ema('loss'):.3f}",
'acc': f"{100.0 * total_correct / total_samples:.1f}%",
}
if is_v2 and 'grad_query' in routing_summary:
postfix['βˆ‡q'] = f"{routing_summary['grad_query']:.2f}"
if 'route_entropy' in routing_summary:
postfix['H'] = f"{routing_summary['route_entropy']:.2f}"
pbar.set_postfix(postfix)
if scheduler is not None and config.scheduler == "cosine":
scheduler.step()
return {
'loss': total_loss / len(train_loader),
'acc': 100.0 * total_correct / total_samples,
**routing_metrics.get_summary()
}
@torch.no_grad()
def evaluate_v2(
model: nn.Module,
test_loader: DataLoader,
config: TrainingConfigV2
) -> Dict[str, float]:
"""Evaluate on test set."""
model.eval()
device = config.device
total_loss = 0.0
total_correct = 0
total_samples = 0
scale_correct = {s: 0 for s in model.config.scales}
for images, targets in test_loader:
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
result = model(images, targets=targets, return_loss=True)
logits = result['logits']
losses = result['losses']
loss = losses['total']
preds = logits.argmax(dim=-1)
total_loss += loss.item() * targets.size(0)
total_correct += (preds == targets).sum().item()
total_samples += targets.size(0)
for i, scale in enumerate(model.config.scales):
scale_logits = result['scale_logits'][i]
scale_preds = scale_logits.argmax(dim=-1)
scale_correct[scale] += (scale_preds == targets).sum().item()
metrics = {
'loss': total_loss / total_samples,
'acc': 100.0 * total_correct / total_samples
}
for scale, correct in scale_correct.items():
metrics[f'acc_{scale}'] = 100.0 * correct / total_samples
return metrics
# ============================================================================
# MAIN TRAINING FUNCTION V2
# ============================================================================
def train_david_beans_v2(
model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None,
train_config: Optional[TrainingConfigV2] = None
):
"""Main training function for DavidBeans V1 or V2."""
print("=" * 70)
print(" DAVID-BEANS V2 TRAINING: Wormhole Routing")
print("=" * 70)
print()
print(" πŸŒ€ WORMHOLES: Learned sparse routing")
print(" πŸ’Ž CRYSTALS: Multi-scale projection")
print()
print(" Key insight: When routing IS the task, routing learns structure")
print()
print("=" * 70)
if train_config is None:
train_config = TrainingConfigV2()
base_output_dir = Path(train_config.output_dir)
base_output_dir.mkdir(parents=True, exist_ok=True)
# =========================================================================
# FIXED: Proper checkpoint resolution
# =========================================================================
checkpoint_path = None
run_dir = None
run_dir_name = None
if train_config.resume_from:
resume_path = Path(train_config.resume_from)
# Case 1: Direct absolute/relative file path
if resume_path.is_file():
checkpoint_path = resume_path
run_dir = checkpoint_path.parent
run_dir_name = run_dir.name
print(f"\nπŸ“‚ Found checkpoint file: {checkpoint_path.name}")
# Case 2: Directory path - find best/latest checkpoint inside
elif resume_path.is_dir():
checkpoint_path = find_latest_checkpoint(resume_path)
if checkpoint_path:
run_dir = resume_path
run_dir_name = resume_path.name
print(f"\nπŸ“‚ Found checkpoint in dir: {checkpoint_path.name}")
# Case 3: Try as path relative to base_output_dir
else:
# Try as subdirectory name
possible_dir = base_output_dir / train_config.resume_from
if possible_dir.is_dir():
checkpoint_path = find_latest_checkpoint(possible_dir)
if checkpoint_path:
run_dir = possible_dir
run_dir_name = possible_dir.name
print(f"\nπŸ“‚ Found checkpoint in: {run_dir_name}")
# Try as relative file path
if checkpoint_path is None:
possible_file = base_output_dir / train_config.resume_from
if possible_file.is_file():
checkpoint_path = possible_file
run_dir = checkpoint_path.parent
run_dir_name = run_dir.name
print(f"\nπŸ“‚ Found checkpoint: {checkpoint_path.name}")
# Report if not found
if checkpoint_path is None:
print(f"\n [!] Could not find checkpoint: {train_config.resume_from}")
print(f" [!] Checked:")
print(f" - As file: {resume_path}")
print(f" - As dir: {resume_path}")
print(f" - Under {base_output_dir}")
print(f" [!] Starting fresh run instead")
else:
print(f" βœ“ Will resume from: {checkpoint_path}")
# Create new run directory if not resuming
if run_dir is None:
run_number = train_config.run_number or get_next_run_number(base_output_dir)
run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version)
run_dir = base_output_dir / run_dir_name
run_dir.mkdir(parents=True, exist_ok=True)
print(f"\nπŸ“ New run: {run_dir_name}")
else:
print(f"\nπŸ“ Resuming run: {run_dir_name}")
output_dir = run_dir
# =========================================================================
# Model config - load from checkpoint if resuming, else use provided/default
# =========================================================================
if checkpoint_path and checkpoint_path.exists() and model_config is None:
# Try to load config from checkpoint
try:
ckpt = torch.load(checkpoint_path, map_location='cpu')
if 'model_config' in ckpt:
saved_config = ckpt['model_config']
print(f" βœ“ Loading model config from checkpoint")
if train_config.model_version == 2:
model_config = DavidBeansV2Config(**saved_config)
else:
model_config = DavidBeansConfig(**saved_config)
except Exception as e:
print(f" [!] Could not load config from checkpoint: {e}")
# Create default config if still None
if model_config is None:
if train_config.model_version == 2:
model_config = DavidBeansV2Config(
image_size=train_config.image_size,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_wormholes=8,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
num_tiles=16,
tile_wormholes=4,
scales=[64, 128, 256, 384, 512],
num_classes=100,
contrast_weight=train_config.contrast_weight,
dropout=0.1
)
else:
model_config = DavidBeansConfig(
image_size=train_config.image_size,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_experts=5,
k_neighbors=16,
cantor_weight=0.3,
scales=[64, 128, 256, 384, 512],
num_classes=100,
dropout=0.1
)
device = train_config.device
print(f"\nDevice: {device}")
print(f"Model version: V{train_config.model_version}")
# Data
print("\nLoading data...")
train_loader, test_loader, num_classes = get_dataloaders(train_config)
print(f" Dataset: {train_config.dataset}")
print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
print(f" Classes: {num_classes}")
model_config.num_classes = num_classes
# Model
print("\nBuilding model...")
if train_config.model_version == 2:
model = DavidBeansV2(model_config)
else:
model = DavidBeans(model_config)
model = model.to(device)
print(f"\n{model}")
num_params = sum(p.numel() for p in model.parameters())
print(f"\nParameters: {num_params:,}")
# Optimizer
print("\nSetting up optimizer...")
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'bias' in name or 'norm' in name or 'embedding' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = AdamW([
{'params': decay_params, 'weight_decay': train_config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=train_config.learning_rate, betas=train_config.betas)
if train_config.scheduler == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=train_config.epochs - train_config.warmup_epochs,
eta_min=train_config.min_lr
)
elif train_config.scheduler == "onecycle":
scheduler = OneCycleLR(
optimizer,
max_lr=train_config.learning_rate,
epochs=train_config.epochs,
steps_per_epoch=len(train_loader),
pct_start=train_config.warmup_epochs / train_config.epochs
)
else:
scheduler = None
print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
print(f" Scheduler: {train_config.scheduler}")
tracker = MetricsTracker()
routing_metrics = RoutingMetrics()
best_acc = 0.0
start_epoch = 0
# =========================================================================
# Load checkpoint weights and optimizer state
# =========================================================================
if checkpoint_path and checkpoint_path.exists():
start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device)
# Advance scheduler to correct position
if scheduler is not None and train_config.scheduler == "cosine":
for _ in range(start_epoch):
scheduler.step()
print(f" βœ“ Advanced scheduler to epoch {start_epoch}")
# TensorBoard
writer = None
if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
tb_dir = output_dir / "tensorboard"
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir))
print(f" TensorBoard: {tb_dir}")
# Save configs
with open(output_dir / "config.json", "w") as f:
json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"},
f, indent=2, default=str)
with open(output_dir / "training_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
# Training loop
print("\n" + "=" * 70)
print(" TRAINING")
print("=" * 70)
for epoch in range(start_epoch, train_config.epochs):
epoch_start = time.time()
# Warmup
if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
train_metrics = train_epoch_v2(
model, train_loader, optimizer, scheduler,
train_config, epoch, tracker, routing_metrics, writer
)
test_metrics = evaluate_v2(model, test_loader, train_config)
epoch_time = time.time() - epoch_start
# TensorBoard
if writer is not None:
writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
for scale in model.config.scales:
writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch)
# Print summary - show ALL scales
scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales])
star = "β˜…" if test_metrics['acc'] > best_acc else ""
routing_info = ""
if train_config.model_version == 2 and 'grad_query' in train_metrics:
routing_info = f" | βˆ‡q:{train_metrics.get('grad_query', 0):.2f}"
print(f" β†’ Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}")
# Save best model
if test_metrics['acc'] > best_acc:
best_acc = test_metrics['acc']
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
'model_config': model_config.__dict__,
'train_config': train_config.to_dict()
}, output_dir / "best_model.pt")
# Periodic checkpoint
if (epoch + 1) % train_config.save_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
'model_config': model_config.__dict__,
'train_config': train_config.to_dict()
}, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
# Upload to hub
if train_config.push_to_hub and HF_HUB_AVAILABLE:
try:
hub_dir = prepare_run_for_hub(
model=model,
model_config=model_config,
train_config=train_config,
best_acc=best_acc,
output_dir=output_dir,
run_dir_name=run_dir_name,
training_history=tracker.get_history()
)
push_run_to_hub(
hub_dir=hub_dir,
repo_id=train_config.hub_repo_id,
run_dir_name=run_dir_name,
commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc"
)
print(f" πŸ“€ Uploaded to hub")
except Exception as e:
print(f" [!] Hub upload failed: {e}")
tracker.end_epoch()
# Final summary
print("\n" + "=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
print(f"\n Best Test Accuracy: {best_acc:.2f}%")
print(f" Model saved to: {output_dir / 'best_model.pt'}")
if writer is not None:
writer.close()
return model, best_acc
# ============================================================================
# PRESETS
# ============================================================================
def train_cifar100_v2_wormhole(
run_name: str = "wormhole_base",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with V2 wormhole routing."""
model_config = DavidBeansV2Config(
image_size=32,
patch_size=2,
dim=512,
num_layers=4,
num_heads=16,
# Wormhole routing parameters
num_wormholes=16,
wormhole_temperature=0.1,
wormhole_mode="hybrid",
# Tessellation parameters
num_tiles=16,
tile_wormholes=4,
# Crystal head
scales=[64, 128, 256, 512, 1024],
num_classes=100,
contrast_temperature=0.07,
contrast_weight=0.5,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=2,
dataset="cifar100",
epochs=300,
batch_size=512,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=15,
# Loss weights (no auxiliary routing loss!)
ce_weight=1.0,
contrast_weight=0.5,
# Augmentation
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
# Output
output_dir="./checkpoints/cifar100_v2",
resume_from=None, #"./checkpoints/cifar100_v2/run_002_v2_16patch_4tilewormholes_d768_4layer_20251130_045437/best_model.pt",
# Hub
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
# Routing logging
log_routing=True
)
return train_david_beans_v2(model_config, train_config)
def train_cifar100_v1_baseline(
run_name: str = "v1_baseline",
push_to_hub: bool = False,
resume: bool = False
):
"""CIFAR-100 with V1 (fixed Cantor routing) for comparison."""
model_config = DavidBeansConfig(
image_size=32,
patch_size=4,
dim=512,
num_layers=4,
num_heads=8,
num_experts=5,
k_neighbors=16,
cantor_weight=0.3,
scales=[64, 128, 256, 384, 512],
num_classes=100,
dropout=0.1
)
train_config = TrainingConfigV2(
run_name=run_name,
model_version=1,
dataset="cifar100",
epochs=200,
batch_size=128,
learning_rate=3e-4,
weight_decay=0.05,
warmup_epochs=10,
ce_weight=1.0,
contrast_weight=0.5,
label_smoothing=0.1,
mixup_alpha=0.2,
cutmix_alpha=1.0,
output_dir="./checkpoints/cifar100_v1",
resume_from="latest" if resume else None,
push_to_hub=push_to_hub,
hub_repo_id="AbstractPhil/geovit-david-beans",
log_routing=False
)
return train_david_beans_v2(model_config, train_config)
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
# =====================================================
# CONFIGURATION
# =====================================================
PRESET = "v2_wormhole" # "v1_baseline", "v2_wormhole", "test"
RESUME = False
RUN_NAME = "5scale_2x2patch_4tilewormholes_d512_4layer"
PUSH_TO_HUB = True
# =====================================================
# RUN
# =====================================================
if PRESET == "test":
print("πŸ§ͺ Quick test...")
model_config = DavidBeansV2Config(
image_size=32, patch_size=4, dim=128, num_layers=2,
num_heads=4, num_wormholes=4, num_tiles=8,
scales=[32, 64, 128], num_classes=10
)
train_config = TrainingConfigV2(
run_name="test", model_version=2,
epochs=2, batch_size=32,
use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0
)
model, acc = train_david_beans_v2(model_config, train_config)
elif PRESET == "v1_baseline":
print("πŸ«˜πŸ’Ž Training DavidBeans V1 (Cantor routing)...")
model, acc = train_cifar100_v1_baseline(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
elif PRESET == "v2_wormhole":
print("πŸ’Ž Training DavidBeans V2 (Wormhole routing)...")
model, acc = train_cifar100_v2_wormhole(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
resume=RESUME
)
else:
raise ValueError(f"Unknown preset: {PRESET}")
print(f"\nπŸŽ‰ Done! Best accuracy: {acc:.2f}%")