""" BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder) =================================================== SD 1.5 VAE + Flan-T5-Large text encoder Dual tower collectives: vision towers + text towers Text prompts for CIFAR-10 classes: "a photo of an airplane" "a photo of an automobile" etc. Requirements: pip install transformers diffusers torchvision tqdm pip install git+https://github.com/AbstractEyes/geofractal Currently running like a turtle, will optimize tomorrow. apache 2.0 license """ from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, Tuple, Optional, List from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms from torchvision.utils import make_grid, save_image from huggingface_hub import HfApi, upload_file, create_repo import json from tqdm import tqdm # ============================================================================= # GEOFRACTAL IMPORTS # ============================================================================= from geofractal.router.wide_router import WideRouter from geofractal.router.prefab.agatha.beatrix_tension_oscillator import ( BeatrixOscillator, ScheduleType, ) from geofractal.router.prefab.geometric_tower_builder import ( TowerConfig, FusionType, ConfigurableCollective, build_tower_collective, preset_pos_neg_pairs, ) from geofractal.router.prefab.geometric_conv_tower_builder import ( ConvTowerConfig, ConvTowerCollective, build_conv_collective, preset_conv_pos_neg, ) # ============================================================================= # CIFAR-10 CLASS PROMPTS # ============================================================================= CIFAR10_PROMPTS = [ "a photo of an airplane", "a photo of an automobile", "a photo of a bird", "a photo of a cat", "a photo of a deer", "a photo of a dog", "a photo of a frog", "a photo of a horse", "a photo of a ship", "a photo of a truck", ] # ============================================================================= # SD 1.5 VAE # ============================================================================= class SD15VAE(nn.Module): def __init__(self, freeze: bool = True): super().__init__() from diffusers import AutoencoderKL self.vae = AutoencoderKL.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=torch.float32, ) if freeze: self.vae.eval() for p in self.vae.parameters(): p.requires_grad = False self.scale_factor = 0.18215 @torch.no_grad() def encode(self, x: Tensor) -> Tensor: return self.vae.encode(x).latent_dist.sample() * self.scale_factor @torch.no_grad() def decode(self, z: Tensor) -> Tensor: return self.vae.decode(z / self.scale_factor).sample # ============================================================================= # FLAN-T5-LARGE TEXT ENCODER # ============================================================================= class T5TextEncoder(nn.Module): """Flan-T5 encoder with bottleneck projection.""" def __init__( self, model_name: str = "google/flan-t5-xl", freeze: bool = True, max_length: int = 77, bottleneck_dim: int = 256, ): super().__init__() from transformers import T5EncoderModel, T5Tokenizer self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.encoder = T5EncoderModel.from_pretrained(model_name) self.max_length = max_length self.raw_dim = self.encoder.config.d_model # 2048 for XL self.output_dim = bottleneck_dim # Bottleneck projection self.bottleneck = nn.Sequential( nn.Linear(self.raw_dim, bottleneck_dim), nn.GELU(), nn.Linear(bottleneck_dim, bottleneck_dim), ) if freeze: self.encoder.eval() for p in self.encoder.parameters(): p.requires_grad = False # Note: bottleneck stays trainable during cache build, but we detach outputs @torch.no_grad() def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: """ Encode text prompts with bottleneck. Returns: sequence: [B, L, bottleneck_dim] - compressed sequence embeddings pooled: [B, bottleneck_dim] - compressed mean pooled embedding """ tokens = self.tokenizer( texts, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt", ) input_ids = tokens.input_ids.to(device) attention_mask = tokens.attention_mask.to(device) outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) sequence_raw = outputs.last_hidden_state # [B, L, raw_dim] # Apply bottleneck sequence = self.bottleneck(sequence_raw) # [B, L, bottleneck_dim] # Mean pool over non-padding tokens mask_expanded = attention_mask.unsqueeze(-1).float() pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) return sequence, pooled @torch.no_grad() def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: """ Encode text prompts WITHOUT bottleneck (for caching raw embeddings). Returns: sequence: [B, L, raw_dim] - raw T5 embeddings pooled: [B, raw_dim] - raw mean pooled embedding """ tokens = self.tokenizer( texts, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt", ) input_ids = tokens.input_ids.to(device) attention_mask = tokens.attention_mask.to(device) outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) sequence = outputs.last_hidden_state # [B, L, raw_dim] # Mean pool over non-padding tokens mask_expanded = attention_mask.unsqueeze(-1).float() pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) return sequence, pooled # ============================================================================= # CACHED DATASET (VAE latents + T5 text embeddings per class) # ============================================================================= class CachedCIFAR10T5(Dataset): """ Pre-cached CIFAR-10 with VAE latents. T5 embeddings are computed per-class (not per-image). """ T5_MODEL = "google/flan-t5-xl" # Change this to use different T5 variant def __init__( self, train: bool = True, image_size: int = 256, cache_dir: str = "./cache", device: str = "cuda", ): self.train = train # Include T5 model name in cache path t5_suffix = self.T5_MODEL.replace("/", "_") self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt" if self.cache_path.exists(): print(f"Loading cache: {self.cache_path}") cache = torch.load(self.cache_path, weights_only=False) self.latents = cache['latents'] self.labels = cache['labels'] self.text_sequence = cache['text_sequence'] # [10, L, dim] self.text_pooled = cache['text_pooled'] # [10, dim] self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1]) else: print(f"Building cache for {'train' if train else 'val'} set...") self._build_cache(image_size, device) def _build_cache(self, image_size: int, device: str): # Load encoders print(" Loading VAE...") vae = SD15VAE(freeze=True).to(device) print(f" Loading T5 ({self.T5_MODEL})...") t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device) # Encode class prompts - save RAW embeddings (bottleneck is in model) print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...") text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device) self.text_sequence = text_seq.cpu() # [10, L, raw_dim] self.text_pooled = text_pool.cpu() # [10, raw_dim] self.text_dim = t5.raw_dim # Store raw dim for bottleneck sizing # Encode images transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform) loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) all_latents, all_labels = [], [] print(" Encoding images...") with torch.no_grad(): for images, labels in tqdm(loader, desc=" Caching", leave=False): images = images.to(device) all_latents.append(vae.encode(images).cpu()) all_labels.append(labels) self.latents = torch.cat(all_latents, dim=0) self.labels = torch.cat(all_labels, dim=0) del vae, t5 torch.cuda.empty_cache() # Save self.cache_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ 'latents': self.latents, 'labels': self.labels, 'text_sequence': self.text_sequence, 'text_pooled': self.text_pooled, 'text_dim': self.text_dim, }, self.cache_path) print(f" Saved: {self.cache_path}") def __len__(self): return len(self.labels) def __getitem__(self, idx): label = self.labels[idx] return ( self.latents[idx], self.text_sequence[label], # [L, raw_dim] self.text_pooled[label], # [raw_dim] label, ) # ============================================================================= # SINUSOIDAL EMBEDDING # ============================================================================= class SinusoidalEmbed(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: Tensor) -> Tensor: half = self.dim // 2 freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half) args = t.unsqueeze(-1) * freqs return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # ============================================================================= # CONFIG # ============================================================================= @dataclass class FlowConfig: image_size: int = 256 num_classes: int = 10 latent_channels: int = 4 latent_size: int = 32 # T5 dimensions text_raw_dim: int = 2048 # Raw T5-XL output, overridden by dataset text_seq_len: int = 77 bottleneck_dim: int = 256 # Compressed text dim # Tower collective (transformer-based) tower_dim: int = 256 tower_depth: int = 2 num_heads: int = 8 geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex') # Conv tower types (convolutional) conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite') conv_spatial_size: int = 8 # Spatial size for conv towers # Oscillator manifold_dim: int = 1024 # Projected manifold (smaller than latent) num_tower_pairs: int = 16 # 32 towers / 2 osc_steps: int = 50 # For sampling only fingerprint_dim: int = 64 # Flow num_flow_steps: int = 50 sigma_min: float = 0.001 # Training batch_size: int = 64 lr: float = 1e-4 weight_decay: float = 0.01 num_epochs: int = 100 cache_dir: str = "./cache" device: str = "cuda" output_dir: str = "./beatrix_cifar_t5" @property def latent_flat_dim(self) -> int: """Full flattened latent size: 4 × 32 × 32 = 4096""" return self.latent_channels * self.latent_size * self.latent_size # ============================================================================= # BEATRIX FLOW MODEL (Vision + Text Towers) # ============================================================================= class BeatrixFlowT5(WideRouter): """ Flow model with dual tower collectives per modality: Vision side: - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) Text side (mirrored): - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) All towers output opinions that combine for velocity prediction. """ def __init__(self, cfg: FlowConfig): super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False) self.objects['cfg'] = cfg # ================================================================= # TEXT BOTTLENECK (trainable) # ================================================================= self.attach('text_bottleneck_seq', nn.Sequential( nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), nn.GELU(), nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), )) self.attach('text_bottleneck_pool', nn.Sequential( nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), nn.GELU(), nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), )) # ================================================================= # VISION GEOMETRIC TOWERS (pos/neg pairs) # ================================================================= vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) vision_geo_collective = build_tower_collective( configs=vision_geo_configs, dim=cfg.tower_dim, default_depth=cfg.tower_depth, num_heads=cfg.num_heads, ffn_mult=4.0, dropout=0.1, fingerprint_dim=cfg.fingerprint_dim, fusion_type='adaptive', name='vision_geo', ) self.attach('vision_geo', vision_geo_collective) # ================================================================= # VISION CONV TOWERS (pos/neg pairs) # ================================================================= vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) vision_conv_collective = build_conv_collective( configs=vision_conv_configs, dim=cfg.tower_dim, default_depth=cfg.tower_depth, fingerprint_dim=cfg.fingerprint_dim, spatial_size=cfg.conv_spatial_size, name='vision_conv', ) self.attach('vision_conv', vision_conv_collective) # ================================================================= # TEXT GEOMETRIC TOWERS (pos/neg pairs) - MIRRORED # ================================================================= text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) text_geo_collective = build_tower_collective( configs=text_geo_configs, dim=cfg.tower_dim, default_depth=cfg.tower_depth, num_heads=cfg.num_heads, ffn_mult=4.0, dropout=0.1, fingerprint_dim=cfg.fingerprint_dim, fusion_type='adaptive', name='text_geo', ) self.attach('text_geo', text_geo_collective) # ================================================================= # TEXT CONV TOWERS (pos/neg pairs) - MIRRORED # ================================================================= text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) text_conv_collective = build_conv_collective( configs=text_conv_configs, dim=cfg.tower_dim, default_depth=cfg.tower_depth, fingerprint_dim=cfg.fingerprint_dim, spatial_size=cfg.conv_spatial_size, name='text_conv', ) self.attach('text_conv', text_conv_collective) # ================================================================= # PROJECTIONS # ================================================================= # Latent patchifier patch_size = 4 num_patches = (cfg.latent_size // patch_size) ** 2 patch_dim = cfg.latent_channels * patch_size * patch_size self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim)) self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02) self.objects['patch_size'] = patch_size self.objects['num_patches'] = num_patches # Text already at bottleneck_dim (256) = tower_dim, no extra projection needed # ================================================================= # OSCILLATOR (for sampling) # ================================================================= # Total towers: (4 geo + 4 conv) × pos/neg × 2 modalities = 32 towers num_geo_towers = len(vision_geo_configs) num_conv_towers = len(vision_conv_configs) total_towers = (num_geo_towers + num_conv_towers) * 2 # × 2 for vision + text oscillator = BeatrixOscillator( name='oscillator', manifold_dim=cfg.manifold_dim, tower_dim=cfg.tower_dim, num_tower_pairs=total_towers // 2, num_theta_probes=4, fingerprint_dim=cfg.fingerprint_dim, kappa_schedule=ScheduleType.TESLA_369, use_intrinsic_tension=True, ) self.attach('oscillator', oscillator) # ================================================================= # CONDITIONING # ================================================================= # Time embedding time_embed = nn.Sequential( SinusoidalEmbed(256), nn.Linear(256, cfg.tower_dim), nn.GELU(), nn.Linear(cfg.tower_dim, cfg.tower_dim), ) self.attach('time_embed', time_embed) # Bottlenecked text -> reference anchor self.attach('text_to_ref', nn.Sequential( nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim), nn.GELU(), nn.Linear(cfg.manifold_dim, cfg.manifold_dim), )) # Time modulation for reference self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim)) # ================================================================= # LATENT PROJECTION (4096 <-> manifold_dim) # ================================================================= self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim)) self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim)) # Learnable velocity mixing self.velocity_mix = nn.Parameter(torch.tensor(0.5)) def patchify(self, z: Tensor) -> Tensor: """[B, 4, 32, 32] -> [B, num_patches, tower_dim]""" B, C, H, W = z.shape p = self.objects['patch_size'] z = z.unfold(2, p, p).unfold(3, p, p) z = z.permute(0, 2, 3, 1, 4, 5).contiguous() z = z.view(B, -1, C * p * p) return self['patch_proj'](z) + self.patch_pos_embed def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]: """ Run all four tower collectives. Returns list of tower opinions [B, tower_dim] (32 total). """ patches = self.patchify(z) text_bottlenecked = self['text_bottleneck_seq'](text_seq) # Run all collectives vision_geo = self['vision_geo'](patches) vision_conv_fused, vision_conv_ops = self['vision_conv'](patches) text_geo = self['text_geo'](text_bottlenecked) text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked) # Collect opinions - use list comprehension (faster than append loop) return ( [op.opinion for op in vision_geo.opinions.values()] + list(vision_conv_ops.values()) + [op.opinion for op in text_geo.opinions.values()] + list(text_conv_ops.values()) ) def forward( self, z_0: Tensor, text_seq: Tensor, text_pooled: Tensor, labels: Tensor, t: Optional[Tensor] = None, ) -> Dict[str, Tensor]: """Training forward - single step velocity prediction.""" cfg = self.objects['cfg'] B = z_0.shape[0] device = z_0.device if t is None: t = torch.rand(B, device=device) # Flatten latent [B, 4, 32, 32] -> [B, 4096] z_0_flat = z_0.flatten(1) # Noise + interpolate in full latent space eps = torch.randn_like(z_0) eps_flat = eps.flatten(1) t_exp = t.view(B, 1, 1, 1) z_t = (1 - t_exp) * z_0 + t_exp * eps z_t_flat = z_t.flatten(1) # Target velocity (in full latent space) v_target = eps_flat - z_0_flat # === PROJECT TO SMALLER MANIFOLD === z_t_proj = self['latent_down'](z_t_flat) # [B, 4096] -> [B, manifold_dim] # Bottleneck pooled text for reference text_pooled_bn = self['text_bottleneck_pool'](text_pooled) # Reference from bottlenecked text + time (in manifold space) time_emb = self['time_embed'](t) x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) # Get all tower outputs (text_seq bottlenecked inside get_tower_outputs) tower_outputs = self.get_tower_outputs(z_t, text_seq) # Compute forces in manifold space osc = self['oscillator'] tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None) spring_force = x_ref - z_t_proj # Velocity prediction in manifold space tau = torch.sigmoid(self.velocity_mix) v_pred_proj = (1 - tau) * spring_force + tau * tower_force # === PROJECT BACK TO FULL LATENT === v_pred = self['latent_up'](v_pred_proj) # [B, manifold_dim] -> [B, 4096] loss = F.mse_loss(v_pred, v_target) return {'loss': loss, 'tau': tau.detach()} @torch.no_grad() def sample( self, text_seq: Tensor, text_pooled: Tensor, vae: SD15VAE, num_steps: Optional[int] = None, ) -> Tensor: """Generate samples from text conditioning.""" cfg = self.objects['cfg'] B = text_seq.shape[0] device = text_seq.device num_steps = num_steps or cfg.num_flow_steps # Bottleneck pooled text once text_pooled_bn = self['text_bottleneck_pool'](text_pooled) # Start from noise z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device) dt = 1.0 / num_steps for step in range(num_steps): t_val = 1.0 - step * dt t = torch.full((B,), t_val, device=device) time_emb = self['time_embed'](t) x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) z_flat = z.flatten(1) # Project to manifold z_proj = self['latent_down'](z_flat) tower_outputs = self.get_tower_outputs(z, text_seq) osc = self['oscillator'] tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None) spring_force = x_ref - z_proj tau = torch.sigmoid(self.velocity_mix) v_pred_proj = (1 - tau) * spring_force + tau * tower_force # Project back and update v_pred = self['latent_up'](v_pred_proj) z_flat = z_flat - dt * v_pred z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size) return vae.decode(z) # ============================================================================= # TRAINER # ============================================================================= class Trainer: def __init__(self, cfg: FlowConfig): self.cfg = cfg self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") self.output_dir = Path(cfg.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True self.scaler = torch.amp.GradScaler('cuda') # Dataset print("\n=== Building Cached Datasets ===") self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) # Update config with actual T5 raw dimension from cache cfg.text_raw_dim = self.train_dataset.text_dim print(f"T5 raw dimension: {cfg.text_raw_dim} → bottleneck: {cfg.bottleneck_dim}") self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True) # Store raw text embeddings for sampling (bottleneck applied in model) self.text_sequence = self.train_dataset.text_sequence.to(self.device) # [10, L, raw_dim] self.text_pooled = self.train_dataset.text_pooled.to(self.device) # [10, raw_dim] # Model print("\n=== Building Model (Vision + Text Towers) ===") self.model = BeatrixFlowT5(cfg).to(self.device) # Compile if hasattr(torch, 'compile'): print("Compiling with WideRouter.prepare_and_compile()...") self.model = self.model.prepare_and_compile( mode="reduce-overhead", fullgraph=False, ) num_params = sum(p.numel() for p in self.model.parameters()) print(f"Trainable parameters: {num_params:,}") self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader)) # Load most recent checkpoint if exists self.start_epoch = 0 self.hf_repo = "AbstractPhil/beatrix-diffusion-proto" self._load_latest_checkpoint() self._vae = None # HuggingFace Hub setup self._setup_hf_repo() def _setup_hf_repo(self): """Create HF repo if needed and save initial config.""" try: self.hf_api = HfApi() create_repo(self.hf_repo, exist_ok=True, repo_type="model") print(f"HF repo: {self.hf_repo}") # Save config config_dict = { 'image_size': self.cfg.image_size, 'num_classes': self.cfg.num_classes, 'latent_channels': self.cfg.latent_channels, 'latent_size': self.cfg.latent_size, 'text_raw_dim': self.cfg.text_raw_dim, 'bottleneck_dim': self.cfg.bottleneck_dim, 'tower_dim': self.cfg.tower_dim, 'tower_depth': self.cfg.tower_depth, 'num_heads': self.cfg.num_heads, 'geometric_types': self.cfg.geometric_types, 'conv_types': self.cfg.conv_types, 'conv_spatial_size': self.cfg.conv_spatial_size, 'manifold_dim': self.cfg.manifold_dim, 'fingerprint_dim': self.cfg.fingerprint_dim, 'num_flow_steps': self.cfg.num_flow_steps, } config_path = self.output_dir / "config.json" with open(config_path, 'w') as f: json.dump(config_dict, f, indent=2) upload_file( path_or_fileobj=str(config_path), path_in_repo="config.json", repo_id=self.hf_repo, ) except Exception as e: print(f"HF setup warning: {e}") self.hf_api = None def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None): """Upload checkpoint, samples, and metrics to HuggingFace.""" if self.hf_api is None: return try: # Upload checkpoint ckpt_path = self.output_dir / "ckpt_latest.pt" if ckpt_path.exists(): upload_file( path_or_fileobj=str(ckpt_path), path_in_repo="ckpt_latest.pt", repo_id=self.hf_repo, ) # Upload samples if sample_path.exists(): upload_file( path_or_fileobj=str(sample_path), path_in_repo=f"samples/epoch_{epoch:03d}.png", repo_id=self.hf_repo, ) # Also as latest upload_file( path_or_fileobj=str(sample_path), path_in_repo="samples/latest.png", repo_id=self.hf_repo, ) # Upload metrics log if metrics: metrics_path = self.output_dir / "metrics.jsonl" with open(metrics_path, 'a') as f: f.write(json.dumps({'epoch': epoch, **metrics}) + '\n') upload_file( path_or_fileobj=str(metrics_path), path_in_repo="metrics.jsonl", repo_id=self.hf_repo, ) print(f" → Uploaded to HF") except Exception as e: print(f" → HF upload failed: {e}") def _load_latest_checkpoint(self): """Load most recent checkpoint if available (local or HF).""" latest_path = self.output_dir / "ckpt_latest.pt" # Try local first if latest_path.exists(): print(f"Resuming from local ckpt_latest.pt...") ckpt = torch.load(latest_path, weights_only=False) else: # Fall back to numbered checkpoints ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt")) if ckpts: latest_path = ckpts[-1] print(f"Resuming from {latest_path.name}...") ckpt = torch.load(latest_path, weights_only=False) else: # Try downloading from HuggingFace try: from huggingface_hub import hf_hub_download print(f"Checking HF for checkpoint...") hf_path = hf_hub_download( repo_id=self.hf_repo, filename="ckpt_latest.pt", local_dir=str(self.output_dir), ) print(f"Downloaded checkpoint from HF") ckpt = torch.load(hf_path, weights_only=False) except Exception as e: print(f"No checkpoint found (local or HF): {e}") return self.model.load_state_dict(ckpt['model']) self.optimizer.load_state_dict(ckpt['optimizer']) self.scheduler.load_state_dict(ckpt['scheduler']) self.start_epoch = ckpt['epoch'] print(f" Resumed at epoch {self.start_epoch}") def _load_vae(self): """Load VAE for sampling (temporary).""" print("Loading VAE for sampling...") return SD15VAE(freeze=True).to(self.device) def _unload_vae(self, vae): """Unload VAE after sampling.""" del vae torch.cuda.empty_cache() def train_epoch(self, epoch: int) -> Dict[str, float]: self.model.train() total_loss, total_tau, n = 0.0, 0.0, 0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False) for latents, text_seq, text_pooled, labels in pbar: latents = latents.to(self.device) text_seq = text_seq.to(self.device) text_pooled = text_pooled.to(self.device) labels = labels.to(self.device) with torch.amp.autocast('cuda'): out = self.model(latents, text_seq, text_pooled, labels) loss = out['loss'] self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() total_loss += loss.item() total_tau += out['tau'].item() n += 1 pbar.set_postfix(loss=f"{loss.item():.4f}", τ=f"{out['tau'].item():.2f}") return {'loss': total_loss / n, 'tau': total_tau / n} @torch.no_grad() def validate(self) -> Dict[str, float]: self.model.eval() total_loss, n = 0.0, 0 for latents, text_seq, text_pooled, labels in self.val_loader: latents = latents.to(self.device) text_seq = text_seq.to(self.device) text_pooled = text_pooled.to(self.device) labels = labels.to(self.device) with torch.amp.autocast('cuda'): out = self.model(latents, text_seq, text_pooled, labels) total_loss += out['loss'].item() n += 1 return {'val_loss': total_loss / n} @torch.no_grad() def sample_images(self, n_per_class: int = 10) -> Tensor: """Generate samples for each class (memory-efficient batched).""" self.model.eval() torch.cuda.empty_cache() # Load VAE temporarily vae = self._load_vae() all_samples = [] batch_size = 10 # Generate 10 images at a time for class_idx in range(10): # Generate n_per_class images for this class for batch_start in range(0, n_per_class, batch_size): batch_n = min(batch_size, n_per_class - batch_start) text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1) text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1) with torch.amp.autocast('cuda'): samples = self.model.sample(text_seq, text_pooled, vae) all_samples.append(samples.cpu()) # Unload VAE self._unload_vae(vae) samples = torch.cat(all_samples, dim=0).to(self.device) return ((samples + 1) / 2).clamp(0, 1) def save_checkpoint(self, epoch: int, milestone: bool = False): ckpt = { 'epoch': epoch, 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), } # Always save latest (for resume) torch.save(ckpt, self.output_dir / "ckpt_latest.pt") # Save milestone checkpoints if milestone: torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt") def train(self): num_geo = len(self.cfg.geometric_types) * 2 # pos/neg num_conv = len(self.cfg.conv_types) * 2 total_towers = (num_geo + num_conv) * 2 # × 2 modalities print(f"\n{'='*60}") print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)") print(f"{'='*60}") print(f"Device: {self.device}") print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)") print(f"Conv towers: {self.cfg.conv_types} (pos/neg)") print(f"Tower dim: {self.cfg.tower_dim}") print(f"T5 raw → bottleneck: {self.cfg.text_raw_dim} → {self.cfg.bottleneck_dim}") print(f"Latent → manifold: {self.cfg.latent_flat_dim} → {self.cfg.manifold_dim}") print(f"Total towers: {total_towers}") print(f"Batch size: {self.cfg.batch_size}") print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}") print(f"{'='*60}\n") for epoch in range(self.start_epoch, self.cfg.num_epochs): train_metrics = self.train_epoch(epoch) val_metrics = self.validate() lr = self.scheduler.get_last_lr()[0] print(f"Epoch {epoch+1:3d} │ loss={train_metrics['loss']:.4f} │ val={val_metrics['val_loss']:.4f} │ τ={train_metrics['tau']:.2f} │ lr={lr:.2e}") # Sample every epoch to track progress samples = self.sample_images(10) grid = make_grid(samples, nrow=10, padding=2) sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png" save_image(grid, sample_path) print(f" → Saved samples") # Checkpoint every epoch (latest), milestone every 10 self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0)) # Upload to HuggingFace metrics = { 'loss': train_metrics['loss'], 'val_loss': val_metrics['val_loss'], 'tau': train_metrics['tau'], 'lr': lr, } self._upload_to_hf(epoch + 1, sample_path, metrics) samples = self.sample_images(10) grid = make_grid(samples, nrow=10, padding=2) final_path = self.output_dir / "samples_final.png" save_image(grid, final_path) self.save_checkpoint(self.cfg.num_epochs, milestone=True) self._upload_to_hf(self.cfg.num_epochs, final_path) print(f"\nTraining complete!") # ============================================================================= # MAIN # ============================================================================= def main(): # Lightweight config - 16 towers instead of 32 cfg = FlowConfig( image_size=256, tower_dim=256, tower_depth=2, num_heads=8, geometric_types=('cantor', 'beatrix'), # 2 types × pos/neg = 4 per modality conv_types=('wide_resnet', 'squeeze_excite'), # 2 types × pos/neg = 4 per modality conv_spatial_size=8, bottleneck_dim=256, manifold_dim=512, # Smaller manifold batch_size=64, num_epochs=100, cache_dir="./cache", output_dir="./beatrix_cifar_t5", ) trainer = Trainer(cfg) trainer.train() def main_full(): """Full 32-tower configuration.""" cfg = FlowConfig( image_size=256, tower_dim=256, tower_depth=2, num_heads=8, geometric_types=('cantor', 'beatrix', 'helix', 'simplex'), conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'), conv_spatial_size=8, bottleneck_dim=256, manifold_dim=1024, batch_size=64, num_epochs=100, cache_dir="./cache", output_dir="./beatrix_cifar_t5", ) trainer = Trainer(cfg) trainer.train() if __name__ == "__main__": main()