""" Train DavidBeans V2: Wormhole Routing Architecture =================================================== ┌─────────────────┐ │ BEANS V2 │ "I learn where to look..." │ (Wormhole ViT)│ │ 🌀 → 🌀 → 🌀 │ Learned sparse routing └────────┬────────┘ │ ▼ ┌─────────────────┐ │ DAVID │ "I know the crystals..." │ (Classifier) │ │ 💎 → 💎 → 💎 │ Multi-scale projection └────────┬────────┘ │ ▼ [Prediction] Key findings from wormhole experiments: 1. When routing IS the task, routing learns structure 2. Auxiliary losses can be gamed - removed in V2 3. Gradient flow through router is critical - verified 4. Cross-contrastive aligns patch↔scale features Author: AbstractPhil Date: November 29, 2025 """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR from tqdm.auto import tqdm import time import math from pathlib import Path from typing import Dict, Optional, Tuple, List, Union from dataclasses import dataclass, field import json from datetime import datetime import os import shutil from google.colab import userdata os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN') HF_TOKEN = userdata.get('HF_TOKEN') try: from google.colab import userdata HF_TOKEN = userdata.get('HF_TOKEN') os.environ['HF_TOKEN'] = HF_TOKEN except: pass # Import both model versions from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config # HuggingFace Hub integration try: from huggingface_hub import HfApi, create_repo, upload_folder HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub") # Safetensors support try: from safetensors.torch import save_file as save_safetensors SAFETENSORS_AVAILABLE = True except ImportError: SAFETENSORS_AVAILABLE = False # TensorBoard support try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: TENSORBOARD_AVAILABLE = False print(" [!] tensorboard not installed. Run: pip install tensorboard") import numpy as np # ============================================================================ # TRAINING CONFIGURATION V2 # ============================================================================ @dataclass class TrainingConfigV2: """Training configuration for DavidBeans V2 with wormhole routing.""" # Run identification run_name: str = "default" run_number: Optional[int] = None # Model version model_version: int = 2 # 1 = original, 2 = wormhole # Data dataset: str = "cifar100" image_size: int = 32 batch_size: int = 128 num_workers: int = 4 # Training schedule epochs: int = 200 warmup_epochs: int = 10 # Optimizer learning_rate: float = 3e-4 weight_decay: float = 0.05 betas: Tuple[float, float] = (0.9, 0.999) # Learning rate schedule scheduler: str = "cosine" min_lr: float = 1e-6 # Loss weights (based on experimental findings) ce_weight: float = 1.0 contrast_weight: float = 0.5 # NOTE: No auxiliary routing loss - routing learns from task pressure # Regularization gradient_clip: float = 1.0 label_smoothing: float = 0.1 # Augmentation use_augmentation: bool = True mixup_alpha: float = 0.2 cutmix_alpha: float = 1.0 # Checkpointing save_interval: int = 10 output_dir: str = "./checkpoints" resume_from: Optional[str] = None # TensorBoard use_tensorboard: bool = True log_interval: int = 50 log_routing: bool = True # Log routing patterns # HuggingFace Hub push_to_hub: bool = False hub_repo_id: str = "AbstractPhil/geovit-david-beans" hub_private: bool = False # Device device: str = "cuda" if torch.cuda.is_available() else "cpu" def to_dict(self) -> Dict: return {k: v for k, v in self.__dict__.items()} # ============================================================================ # ROUTING METRICS # ============================================================================ class RoutingMetrics: """Track and analyze wormhole routing patterns.""" def __init__(self): self.reset() def reset(self): self.route_entropies = [] self.route_diversities = [] self.grad_norms = {'query': [], 'key': []} @torch.no_grad() def compute_route_entropy(self, soft_routes: torch.Tensor) -> float: """Compute average entropy of routing distributions.""" # soft_routes: [B, P, K] or [B, T, K] # Higher entropy = more diverse routing eps = 1e-8 entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1) return entropy.mean().item() @torch.no_grad() def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float: """Compute how many unique destinations are used.""" # routes: [B, P, K] indices unique_per_sample = [] for b in range(routes.shape[0]): unique = routes[b].unique().numel() unique_per_sample.append(unique / num_positions) return sum(unique_per_sample) / len(unique_per_sample) def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module): """Extract metrics from routing info returned by V2 model.""" if not routing_info: return for layer_info in routing_info: # Attention routing if layer_info.get('attention'): attn = layer_info['attention'] if attn.get('weights') is not None: entropy = self.compute_route_entropy(attn['weights']) self.route_entropies.append(entropy) if attn.get('routes') is not None: P = attn['routes'].shape[1] diversity = self.compute_route_diversity(attn['routes'], P) self.route_diversities.append(diversity) # Expert routing if layer_info.get('expert'): exp = layer_info['expert'] if exp.get('weights') is not None: entropy = self.compute_route_entropy(exp['weights']) self.route_entropies.append(entropy) def update_grad_norms(self, model: nn.Module): """Track gradient norms through router projections.""" for name, param in model.named_parameters(): if param.grad is not None: if 'query_proj' in name and 'weight' in name: self.grad_norms['query'].append(param.grad.norm().item()) elif 'key_proj' in name and 'weight' in name: self.grad_norms['key'].append(param.grad.norm().item()) def get_summary(self) -> Dict[str, float]: """Get summary statistics.""" summary = {} if self.route_entropies: summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies) if self.route_diversities: summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities) if self.grad_norms['query']: summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query']) if self.grad_norms['key']: summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key']) return summary # ============================================================================ # DATA LOADING (unchanged from V1) # ============================================================================ def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: """Get train and test dataloaders.""" try: import torchvision import torchvision.transforms as T if config.dataset == "cifar10": mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) num_classes = 10 DatasetClass = torchvision.datasets.CIFAR10 elif config.dataset == "cifar100": mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) num_classes = 100 DatasetClass = torchvision.datasets.CIFAR100 else: raise ValueError(f"Unknown dataset: {config.dataset}") if config.use_augmentation: train_transform = T.Compose([ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.AutoAugment(T.AutoAugmentPolicy.CIFAR10), T.ToTensor(), T.Normalize(mean, std) ]) else: train_transform = T.Compose([ T.ToTensor(), T.Normalize(mean, std) ]) test_transform = T.Compose([ T.ToTensor(), T.Normalize(mean, std) ]) train_dataset = DatasetClass( root='./data', train=True, download=True, transform=train_transform ) test_dataset = DatasetClass( root='./data', train=False, download=True, transform=test_transform ) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, persistent_workers=config.num_workers > 0, drop_last=True ) test_loader = DataLoader( test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True, persistent_workers=config.num_workers > 0 ) return train_loader, test_loader, num_classes except ImportError: print(" [!] torchvision not available, using synthetic data") return get_synthetic_dataloaders(config) def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]: """Fallback synthetic data for testing.""" class SyntheticDataset(torch.utils.data.Dataset): def __init__(self, size: int, image_size: int, num_classes: int): self.size = size self.image_size = image_size self.num_classes = num_classes def __len__(self): return self.size def __getitem__(self, idx): x = torch.randn(3, self.image_size, self.image_size) y = idx % self.num_classes return x, y num_classes = 100 if config.dataset == "cifar100" else 10 train_dataset = SyntheticDataset(5000, config.image_size, num_classes) test_dataset = SyntheticDataset(1000, config.image_size, num_classes) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) return train_loader, test_loader, num_classes # ============================================================================ # MIXUP / CUTMIX AUGMENTATION # ============================================================================ def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2): """Mixup augmentation.""" if alpha > 0: lam = torch.distributions.Beta(alpha, alpha).sample().item() else: lam = 1.0 batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) mixed_x = lam * x + (1 - lam) * x[index] y_a, y_b = y, y[index] return mixed_x, y_a, y_b, lam def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): """CutMix augmentation.""" if alpha > 0: lam = torch.distributions.Beta(alpha, alpha).sample().item() else: lam = 1.0 batch_size = x.size(0) index = torch.randperm(batch_size, device=x.device) _, _, H, W = x.shape cut_ratio = math.sqrt(1 - lam) cut_h = int(H * cut_ratio) cut_w = int(W * cut_ratio) cx = torch.randint(0, H, (1,)).item() cy = torch.randint(0, W, (1,)).item() x1 = max(0, cx - cut_h // 2) x2 = min(H, cx + cut_h // 2) y1 = max(0, cy - cut_w // 2) y2 = min(W, cy + cut_w // 2) mixed_x = x.clone() mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2] lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W) y_a, y_b = y, y[index] return mixed_x, y_a, y_b, lam # ============================================================================ # METRICS TRACKER # ============================================================================ class MetricsTracker: """Track training metrics with EMA smoothing.""" def __init__(self, ema_decay: float = 0.9): self.ema_decay = ema_decay self.metrics = {} self.ema_metrics = {} self.history = {} def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() if k not in self.metrics: self.metrics[k] = [] self.ema_metrics[k] = v self.history[k] = [] self.metrics[k].append(v) self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v def get_ema(self, key: str) -> float: return self.ema_metrics.get(key, 0.0) def get_epoch_mean(self, key: str) -> float: values = self.metrics.get(key, []) return sum(values) / len(values) if values else 0.0 def end_epoch(self): for k, v in self.metrics.items(): if v: self.history[k].append(sum(v) / len(v)) self.metrics = {k: [] for k in self.metrics} def get_history(self) -> Dict: return self.history # ============================================================================ # CHECKPOINT UTILITIES # ============================================================================ def find_latest_checkpoint(output_dir: Path) -> Optional[Path]: """Find the most recent checkpoint in output directory.""" checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt")) if not checkpoints: best_model = output_dir / "best_model.pt" if best_model.exists(): return best_model return None def get_epoch(p): try: return int(p.stem.split("_")[-1]) except: return 0 checkpoints.sort(key=get_epoch, reverse=True) return checkpoints[0] def get_next_run_number(base_dir: Path) -> int: """Get the next run number by scanning existing run directories.""" if not base_dir.exists(): return 1 max_num = 0 for d in base_dir.iterdir(): if d.is_dir() and d.name.startswith("run_"): try: num = int(d.name.split("_")[1]) max_num = max(max_num, num) except (IndexError, ValueError): continue return max_num + 1 def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str: """Generate a run directory name.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower()) safe_name = "_".join(filter(None, safe_name.split("_"))) return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}" def find_latest_run_dir(base_dir: Path) -> Optional[Path]: """Find the most recent run directory.""" if not base_dir.exists(): return None run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")] if not run_dirs: return None run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True) return run_dirs[0] def load_checkpoint( checkpoint_path: Path, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, device: str = "cuda" ) -> Tuple[int, float]: """Load checkpoint and return (start_epoch, best_acc).""" print(f"\n📂 Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f" ✓ Loaded model weights") if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f" ✓ Loaded optimizer state") epoch = checkpoint.get('epoch', 0) best_acc = checkpoint.get('best_acc', 0.0) print(f" ✓ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%") return epoch + 1, best_acc # ============================================================================ # HUGGINGFACE HUB INTEGRATION # ============================================================================ def generate_run_readme( model_config: Union[DavidBeansConfig, DavidBeansV2Config], train_config: TrainingConfigV2, best_acc: float, run_dir_name: str ) -> str: """Generate README for a specific run.""" scales_str = ", ".join([str(s) for s in model_config.scales]) # V2 specific info if isinstance(model_config, DavidBeansV2Config): routing_info = f""" ## Wormhole Routing (V2) | Parameter | Value | |-----------|-------| | Mode | {model_config.wormhole_mode} | | Wormholes/Position | {model_config.num_wormholes} | | Temperature | {model_config.wormhole_temperature} | | Tiles | {model_config.num_tiles} | | Tile Wormholes | {model_config.tile_wormholes} | """ else: routing_info = """ ## Routing (V1) | Parameter | Value | |-----------|-------| | k_neighbors | {model_config.k_neighbors} | | Cantor Weight | {model_config.cantor_weight} | """ return f"""# Run: {run_dir_name} ## Results - **Best Accuracy**: {best_acc:.2f}% - **Dataset**: {train_config.dataset} - **Epochs**: {train_config.epochs} - **Model Version**: V{train_config.model_version} ## Model Config | Parameter | Value | |-----------|-------| | Dim | {model_config.dim} | | Layers | {model_config.num_layers} | | Heads | {model_config.num_heads} | | Scales | [{scales_str}] | {routing_info} ## Training Config | Parameter | Value | |-----------|-------| | Learning Rate | {train_config.learning_rate} | | Weight Decay | {train_config.weight_decay} | | Batch Size | {train_config.batch_size} | | CE Weight | {train_config.ce_weight} | | Contrast Weight | {train_config.contrast_weight} | ## Key Findings Applied - Routing learns from task pressure (no auxiliary routing losses) - Gradients verified to flow through router - Cross-contrastive aligns patch↔scale features """ def prepare_run_for_hub( model: nn.Module, model_config: Union[DavidBeansConfig, DavidBeansV2Config], train_config: TrainingConfigV2, best_acc: float, output_dir: Path, run_dir_name: str, training_history: Optional[Dict] = None ) -> Path: """Prepare run files for upload to HuggingFace Hub.""" hub_dir = output_dir / "hub_upload" run_hub_dir = hub_dir / "weights" / run_dir_name run_hub_dir.mkdir(parents=True, exist_ok=True) # Save best model weights state_dict = {k: v.clone() for k, v in model.state_dict().items()} if SAFETENSORS_AVAILABLE: try: save_safetensors(state_dict, run_hub_dir / "best.safetensors") print(f" ✓ Saved best.safetensors") except Exception as e: print(f" [!] Safetensors failed ({e}), using pytorch format") torch.save(state_dict, run_hub_dir / "best.pt") else: torch.save(state_dict, run_hub_dir / "best.pt") # Save model config config_dict = { "architecture": f"DavidBeans_V{train_config.model_version}", "model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans", **model_config.__dict__ } with open(run_hub_dir / "config.json", "w") as f: json.dump(config_dict, f, indent=2, default=str) # Save training config with open(run_hub_dir / "training_config.json", "w") as f: json.dump(train_config.to_dict(), f, indent=2, default=str) # Generate README run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name) with open(run_hub_dir / "README.md", "w") as f: f.write(run_readme) # Save training history if training_history: with open(run_hub_dir / "training_history.json", "w") as f: json.dump(training_history, f, indent=2) # Copy TensorBoard logs tb_dir = output_dir / "tensorboard" if tb_dir.exists(): hub_tb_dir = run_hub_dir / "tensorboard" if hub_tb_dir.exists(): shutil.rmtree(hub_tb_dir) shutil.copytree(tb_dir, hub_tb_dir) return hub_dir def push_run_to_hub( hub_dir: Path, repo_id: str, run_dir_name: str, private: bool = False, commit_message: Optional[str] = None ) -> str: """Push run files to HuggingFace Hub.""" if not HF_HUB_AVAILABLE: raise RuntimeError("huggingface_hub not installed") api = HfApi() try: create_repo(repo_id, private=private, exist_ok=True) except Exception as e: print(f" [!] Repo creation note: {e}") run_upload_dir = hub_dir / "weights" / run_dir_name if commit_message is None: commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}" url = upload_folder( folder_path=str(run_upload_dir), repo_id=repo_id, path_in_repo=f"weights/{run_dir_name}", commit_message=commit_message ) return url # ============================================================================ # TRAINING LOOP V2 # ============================================================================ def train_epoch_v2( model: nn.Module, train_loader: DataLoader, optimizer: torch.optim.Optimizer, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], config: TrainingConfigV2, epoch: int, tracker: MetricsTracker, routing_metrics: RoutingMetrics, writer: Optional['SummaryWriter'] = None ) -> Dict[str, float]: """Train for one epoch with V2 routing metrics.""" model.train() device = config.device is_v2 = config.model_version == 2 total_loss = 0.0 total_correct = 0 total_samples = 0 global_step = epoch * len(train_loader) routing_metrics.reset() pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True) for batch_idx, (images, targets) in enumerate(pbar): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) # Apply mixup/cutmix use_mixup = config.use_augmentation and config.mixup_alpha > 0 use_cutmix = config.use_augmentation and config.cutmix_alpha > 0 mixed = False if use_mixup or use_cutmix: r = torch.rand(1).item() if r < 0.5: pass elif r < 0.75 and use_mixup: images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha) mixed = True elif use_cutmix: images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha) mixed = True # Forward pass if is_v2: result = model( images, targets=targets, return_loss=True, return_routing=(batch_idx % 10 == 0) # Sample routing every 10 batches ) else: result = model(images, targets=targets, return_loss=True) losses = result['losses'] # Handle mixup CE loss if mixed: logits = result['logits'] ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \ (1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing) losses['ce'] = ce_loss # Compute total loss (NO auxiliary routing loss - key finding!) loss = ( config.ce_weight * losses['ce'] + config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device)) ) # Add scale CE losses for scale in model.config.scales: scale_ce = losses.get(f'ce_{scale}', 0.0) if isinstance(scale_ce, torch.Tensor): loss = loss + 0.1 * scale_ce # Backward pass optimizer.zero_grad() loss.backward() # Track routing gradient norms (verify gradients flow!) if is_v2: routing_metrics.update_grad_norms(model) if config.gradient_clip > 0: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip) else: grad_norm = 0.0 optimizer.step() if scheduler is not None and config.scheduler == "onecycle": scheduler.step() # Update routing metrics from forward pass if is_v2 and result.get('routing'): routing_metrics.update_from_routing_info(result['routing'], model) # Compute accuracy with torch.no_grad(): logits = result['logits'] preds = logits.argmax(dim=-1) if mixed: correct = (lam * (preds == targets_a).float() + (1 - lam) * (preds == targets_b).float()).sum() else: correct = (preds == targets).sum() total_correct += correct.item() total_samples += targets.size(0) total_loss += loss.item() # Track metrics def to_float(v): return v.item() if isinstance(v, torch.Tensor) else float(v) contrast_loss = to_float(losses.get('contrast', 0.0)) current_lr = optimizer.param_groups[0]['lr'] tracker.update( loss=loss.item(), ce=losses['ce'].item(), contrast=contrast_loss, lr=current_lr ) # TensorBoard logging if writer is not None and (batch_idx + 1) % config.log_interval == 0: step = global_step + batch_idx writer.add_scalar('train/loss_total', loss.item(), step) writer.add_scalar('train/loss_ce', losses['ce'].item(), step) writer.add_scalar('train/loss_contrast', contrast_loss, step) writer.add_scalar('train/learning_rate', current_lr, step) writer.add_scalar('train/grad_norm', to_float(grad_norm), step) # Log routing metrics for V2 if is_v2 and config.log_routing: routing_summary = routing_metrics.get_summary() for k, v in routing_summary.items(): writer.add_scalar(f'routing/{k}', v, step) # Progress bar routing_summary = routing_metrics.get_summary() postfix = { 'loss': f"{tracker.get_ema('loss'):.3f}", 'acc': f"{100.0 * total_correct / total_samples:.1f}%", } if is_v2 and 'grad_query' in routing_summary: postfix['∇q'] = f"{routing_summary['grad_query']:.2f}" if 'route_entropy' in routing_summary: postfix['H'] = f"{routing_summary['route_entropy']:.2f}" pbar.set_postfix(postfix) if scheduler is not None and config.scheduler == "cosine": scheduler.step() return { 'loss': total_loss / len(train_loader), 'acc': 100.0 * total_correct / total_samples, **routing_metrics.get_summary() } @torch.no_grad() def evaluate_v2( model: nn.Module, test_loader: DataLoader, config: TrainingConfigV2 ) -> Dict[str, float]: """Evaluate on test set.""" model.eval() device = config.device total_loss = 0.0 total_correct = 0 total_samples = 0 scale_correct = {s: 0 for s in model.config.scales} for images, targets in test_loader: images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) result = model(images, targets=targets, return_loss=True) logits = result['logits'] losses = result['losses'] loss = losses['total'] preds = logits.argmax(dim=-1) total_loss += loss.item() * targets.size(0) total_correct += (preds == targets).sum().item() total_samples += targets.size(0) for i, scale in enumerate(model.config.scales): scale_logits = result['scale_logits'][i] scale_preds = scale_logits.argmax(dim=-1) scale_correct[scale] += (scale_preds == targets).sum().item() metrics = { 'loss': total_loss / total_samples, 'acc': 100.0 * total_correct / total_samples } for scale, correct in scale_correct.items(): metrics[f'acc_{scale}'] = 100.0 * correct / total_samples return metrics # ============================================================================ # MAIN TRAINING FUNCTION V2 # ============================================================================ def train_david_beans_v2( model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None, train_config: Optional[TrainingConfigV2] = None ): """Main training function for DavidBeans V1 or V2.""" print("=" * 70) print(" DAVID-BEANS V2 TRAINING: Wormhole Routing") print("=" * 70) print() print(" 🌀 WORMHOLES: Learned sparse routing") print(" 💎 CRYSTALS: Multi-scale projection") print() print(" Key insight: When routing IS the task, routing learns structure") print() print("=" * 70) if train_config is None: train_config = TrainingConfigV2() base_output_dir = Path(train_config.output_dir) base_output_dir.mkdir(parents=True, exist_ok=True) # ========================================================================= # FIXED: Proper checkpoint resolution # ========================================================================= checkpoint_path = None run_dir = None run_dir_name = None if train_config.resume_from: resume_path = Path(train_config.resume_from) # Case 1: Direct absolute/relative file path if resume_path.is_file(): checkpoint_path = resume_path run_dir = checkpoint_path.parent run_dir_name = run_dir.name print(f"\n📂 Found checkpoint file: {checkpoint_path.name}") # Case 2: Directory path - find best/latest checkpoint inside elif resume_path.is_dir(): checkpoint_path = find_latest_checkpoint(resume_path) if checkpoint_path: run_dir = resume_path run_dir_name = resume_path.name print(f"\n📂 Found checkpoint in dir: {checkpoint_path.name}") # Case 3: Try as path relative to base_output_dir else: # Try as subdirectory name possible_dir = base_output_dir / train_config.resume_from if possible_dir.is_dir(): checkpoint_path = find_latest_checkpoint(possible_dir) if checkpoint_path: run_dir = possible_dir run_dir_name = possible_dir.name print(f"\n📂 Found checkpoint in: {run_dir_name}") # Try as relative file path if checkpoint_path is None: possible_file = base_output_dir / train_config.resume_from if possible_file.is_file(): checkpoint_path = possible_file run_dir = checkpoint_path.parent run_dir_name = run_dir.name print(f"\n📂 Found checkpoint: {checkpoint_path.name}") # Report if not found if checkpoint_path is None: print(f"\n [!] Could not find checkpoint: {train_config.resume_from}") print(f" [!] Checked:") print(f" - As file: {resume_path}") print(f" - As dir: {resume_path}") print(f" - Under {base_output_dir}") print(f" [!] Starting fresh run instead") else: print(f" ✓ Will resume from: {checkpoint_path}") # Create new run directory if not resuming if run_dir is None: run_number = train_config.run_number or get_next_run_number(base_output_dir) run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version) run_dir = base_output_dir / run_dir_name run_dir.mkdir(parents=True, exist_ok=True) print(f"\n📁 New run: {run_dir_name}") else: print(f"\n📁 Resuming run: {run_dir_name}") output_dir = run_dir # ========================================================================= # Model config - load from checkpoint if resuming, else use provided/default # ========================================================================= if checkpoint_path and checkpoint_path.exists() and model_config is None: # Try to load config from checkpoint try: ckpt = torch.load(checkpoint_path, map_location='cpu') if 'model_config' in ckpt: saved_config = ckpt['model_config'] print(f" ✓ Loading model config from checkpoint") if train_config.model_version == 2: model_config = DavidBeansV2Config(**saved_config) else: model_config = DavidBeansConfig(**saved_config) except Exception as e: print(f" [!] Could not load config from checkpoint: {e}") # Create default config if still None if model_config is None: if train_config.model_version == 2: model_config = DavidBeansV2Config( image_size=train_config.image_size, patch_size=4, dim=512, num_layers=4, num_heads=8, num_wormholes=8, wormhole_temperature=0.1, wormhole_mode="hybrid", num_tiles=16, tile_wormholes=4, scales=[64, 128, 256, 384, 512], num_classes=100, contrast_weight=train_config.contrast_weight, dropout=0.1 ) else: model_config = DavidBeansConfig( image_size=train_config.image_size, patch_size=4, dim=512, num_layers=4, num_heads=8, num_experts=5, k_neighbors=16, cantor_weight=0.3, scales=[64, 128, 256, 384, 512], num_classes=100, dropout=0.1 ) device = train_config.device print(f"\nDevice: {device}") print(f"Model version: V{train_config.model_version}") # Data print("\nLoading data...") train_loader, test_loader, num_classes = get_dataloaders(train_config) print(f" Dataset: {train_config.dataset}") print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") print(f" Classes: {num_classes}") model_config.num_classes = num_classes # Model print("\nBuilding model...") if train_config.model_version == 2: model = DavidBeansV2(model_config) else: model = DavidBeans(model_config) model = model.to(device) print(f"\n{model}") num_params = sum(p.numel() for p in model.parameters()) print(f"\nParameters: {num_params:,}") # Optimizer print("\nSetting up optimizer...") decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if 'bias' in name or 'norm' in name or 'embedding' in name: no_decay_params.append(param) else: decay_params.append(param) optimizer = AdamW([ {'params': decay_params, 'weight_decay': train_config.weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ], lr=train_config.learning_rate, betas=train_config.betas) if train_config.scheduler == "cosine": scheduler = CosineAnnealingLR( optimizer, T_max=train_config.epochs - train_config.warmup_epochs, eta_min=train_config.min_lr ) elif train_config.scheduler == "onecycle": scheduler = OneCycleLR( optimizer, max_lr=train_config.learning_rate, epochs=train_config.epochs, steps_per_epoch=len(train_loader), pct_start=train_config.warmup_epochs / train_config.epochs ) else: scheduler = None print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})") print(f" Scheduler: {train_config.scheduler}") tracker = MetricsTracker() routing_metrics = RoutingMetrics() best_acc = 0.0 start_epoch = 0 # ========================================================================= # Load checkpoint weights and optimizer state # ========================================================================= if checkpoint_path and checkpoint_path.exists(): start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device) # Advance scheduler to correct position if scheduler is not None and train_config.scheduler == "cosine": for _ in range(start_epoch): scheduler.step() print(f" ✓ Advanced scheduler to epoch {start_epoch}") # TensorBoard writer = None if train_config.use_tensorboard and TENSORBOARD_AVAILABLE: tb_dir = output_dir / "tensorboard" tb_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=str(tb_dir)) print(f" TensorBoard: {tb_dir}") # Save configs with open(output_dir / "config.json", "w") as f: json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"}, f, indent=2, default=str) with open(output_dir / "training_config.json", "w") as f: json.dump(train_config.to_dict(), f, indent=2, default=str) # Training loop print("\n" + "=" * 70) print(" TRAINING") print("=" * 70) for epoch in range(start_epoch, train_config.epochs): epoch_start = time.time() # Warmup if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine": warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs for param_group in optimizer.param_groups: param_group['lr'] = warmup_lr train_metrics = train_epoch_v2( model, train_loader, optimizer, scheduler, train_config, epoch, tracker, routing_metrics, writer ) test_metrics = evaluate_v2(model, test_loader, train_config) epoch_time = time.time() - epoch_start # TensorBoard if writer is not None: writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch) writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch) writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch) for scale in model.config.scales: writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch) # Print summary - show ALL scales scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales]) star = "★" if test_metrics['acc'] > best_acc else "" routing_info = "" if train_config.model_version == 2 and 'grad_query' in train_metrics: routing_info = f" | ∇q:{train_metrics.get('grad_query', 0):.2f}" print(f" → Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | " f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}") # Save best model if test_metrics['acc'] > best_acc: best_acc = test_metrics['acc'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_acc': best_acc, 'model_config': model_config.__dict__, 'train_config': train_config.to_dict() }, output_dir / "best_model.pt") # Periodic checkpoint if (epoch + 1) % train_config.save_interval == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_acc': best_acc, 'model_config': model_config.__dict__, 'train_config': train_config.to_dict() }, output_dir / f"checkpoint_epoch_{epoch + 1}.pt") # Upload to hub if train_config.push_to_hub and HF_HUB_AVAILABLE: try: hub_dir = prepare_run_for_hub( model=model, model_config=model_config, train_config=train_config, best_acc=best_acc, output_dir=output_dir, run_dir_name=run_dir_name, training_history=tracker.get_history() ) push_run_to_hub( hub_dir=hub_dir, repo_id=train_config.hub_repo_id, run_dir_name=run_dir_name, commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc" ) print(f" 📤 Uploaded to hub") except Exception as e: print(f" [!] Hub upload failed: {e}") tracker.end_epoch() # Final summary print("\n" + "=" * 70) print(" TRAINING COMPLETE") print("=" * 70) print(f"\n Best Test Accuracy: {best_acc:.2f}%") print(f" Model saved to: {output_dir / 'best_model.pt'}") if writer is not None: writer.close() return model, best_acc # ============================================================================ # PRESETS # ============================================================================ def train_cifar100_v2_wormhole( run_name: str = "wormhole_base", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with V2 wormhole routing.""" model_config = DavidBeansV2Config( image_size=32, patch_size=2, dim=512, num_layers=4, num_heads=16, # Wormhole routing parameters num_wormholes=16, wormhole_temperature=0.1, wormhole_mode="hybrid", # Tessellation parameters num_tiles=16, tile_wormholes=4, # Crystal head scales=[64, 128, 256, 512, 1024], num_classes=100, contrast_temperature=0.07, contrast_weight=0.5, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=2, dataset="cifar100", epochs=300, batch_size=512, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=15, # Loss weights (no auxiliary routing loss!) ce_weight=1.0, contrast_weight=0.5, # Augmentation label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, # Output output_dir="./checkpoints/cifar100_v2", resume_from=None, #"./checkpoints/cifar100_v2/run_002_v2_16patch_4tilewormholes_d768_4layer_20251130_045437/best_model.pt", # Hub push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", # Routing logging log_routing=True ) return train_david_beans_v2(model_config, train_config) def train_cifar100_v1_baseline( run_name: str = "v1_baseline", push_to_hub: bool = False, resume: bool = False ): """CIFAR-100 with V1 (fixed Cantor routing) for comparison.""" model_config = DavidBeansConfig( image_size=32, patch_size=4, dim=512, num_layers=4, num_heads=8, num_experts=5, k_neighbors=16, cantor_weight=0.3, scales=[64, 128, 256, 384, 512], num_classes=100, dropout=0.1 ) train_config = TrainingConfigV2( run_name=run_name, model_version=1, dataset="cifar100", epochs=200, batch_size=128, learning_rate=3e-4, weight_decay=0.05, warmup_epochs=10, ce_weight=1.0, contrast_weight=0.5, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=1.0, output_dir="./checkpoints/cifar100_v1", resume_from="latest" if resume else None, push_to_hub=push_to_hub, hub_repo_id="AbstractPhil/geovit-david-beans", log_routing=False ) return train_david_beans_v2(model_config, train_config) # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": # ===================================================== # CONFIGURATION # ===================================================== PRESET = "v2_wormhole" # "v1_baseline", "v2_wormhole", "test" RESUME = False RUN_NAME = "5scale_2x2patch_4tilewormholes_d512_4layer" PUSH_TO_HUB = True # ===================================================== # RUN # ===================================================== if PRESET == "test": print("🧪 Quick test...") model_config = DavidBeansV2Config( image_size=32, patch_size=4, dim=128, num_layers=2, num_heads=4, num_wormholes=4, num_tiles=8, scales=[32, 64, 128], num_classes=10 ) train_config = TrainingConfigV2( run_name="test", model_version=2, epochs=2, batch_size=32, use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0 ) model, acc = train_david_beans_v2(model_config, train_config) elif PRESET == "v1_baseline": print("🫘💎 Training DavidBeans V1 (Cantor routing)...") model, acc = train_cifar100_v1_baseline( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) elif PRESET == "v2_wormhole": print("💎 Training DavidBeans V2 (Wormhole routing)...") model, acc = train_cifar100_v2_wormhole( run_name=RUN_NAME, push_to_hub=PUSH_TO_HUB, resume=RESUME ) else: raise ValueError(f"Unknown preset: {PRESET}") print(f"\n🎉 Done! Best accuracy: {acc:.2f}%")