AbstractPhil's picture
Update trainer.py
e67f59a verified
"""
BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder)
===================================================
SD 1.5 VAE + Flan-T5-Large text encoder
Dual tower collectives: vision towers + text towers
Text prompts for CIFAR-10 classes:
"a photo of an airplane"
"a photo of an automobile"
etc.
Requirements:
pip install transformers diffusers torchvision tqdm
pip install git+https://github.com/AbstractEyes/geofractal
Currently running like a turtle, will optimize tomorrow.
apache 2.0 license
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, List
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from huggingface_hub import HfApi, upload_file, create_repo
import json
from tqdm import tqdm
# =============================================================================
# GEOFRACTAL IMPORTS
# =============================================================================
from geofractal.router.wide_router import WideRouter
from geofractal.router.prefab.agatha.beatrix_tension_oscillator import (
BeatrixOscillator,
ScheduleType,
)
from geofractal.router.prefab.geometric_tower_builder import (
TowerConfig,
FusionType,
ConfigurableCollective,
build_tower_collective,
preset_pos_neg_pairs,
)
from geofractal.router.prefab.geometric_conv_tower_builder import (
ConvTowerConfig,
ConvTowerCollective,
build_conv_collective,
preset_conv_pos_neg,
)
# =============================================================================
# CIFAR-10 CLASS PROMPTS
# =============================================================================
CIFAR10_PROMPTS = [
"a photo of an airplane",
"a photo of an automobile",
"a photo of a bird",
"a photo of a cat",
"a photo of a deer",
"a photo of a dog",
"a photo of a frog",
"a photo of a horse",
"a photo of a ship",
"a photo of a truck",
]
# =============================================================================
# SD 1.5 VAE
# =============================================================================
class SD15VAE(nn.Module):
def __init__(self, freeze: bool = True):
super().__init__()
from diffusers import AutoencoderKL
self.vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=torch.float32,
)
if freeze:
self.vae.eval()
for p in self.vae.parameters():
p.requires_grad = False
self.scale_factor = 0.18215
@torch.no_grad()
def encode(self, x: Tensor) -> Tensor:
return self.vae.encode(x).latent_dist.sample() * self.scale_factor
@torch.no_grad()
def decode(self, z: Tensor) -> Tensor:
return self.vae.decode(z / self.scale_factor).sample
# =============================================================================
# FLAN-T5-LARGE TEXT ENCODER
# =============================================================================
class T5TextEncoder(nn.Module):
"""Flan-T5 encoder with bottleneck projection."""
def __init__(
self,
model_name: str = "google/flan-t5-xl",
freeze: bool = True,
max_length: int = 77,
bottleneck_dim: int = 256,
):
super().__init__()
from transformers import T5EncoderModel, T5Tokenizer
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.encoder = T5EncoderModel.from_pretrained(model_name)
self.max_length = max_length
self.raw_dim = self.encoder.config.d_model # 2048 for XL
self.output_dim = bottleneck_dim
# Bottleneck projection
self.bottleneck = nn.Sequential(
nn.Linear(self.raw_dim, bottleneck_dim),
nn.GELU(),
nn.Linear(bottleneck_dim, bottleneck_dim),
)
if freeze:
self.encoder.eval()
for p in self.encoder.parameters():
p.requires_grad = False
# Note: bottleneck stays trainable during cache build, but we detach outputs
@torch.no_grad()
def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]:
"""
Encode text prompts with bottleneck.
Returns:
sequence: [B, L, bottleneck_dim] - compressed sequence embeddings
pooled: [B, bottleneck_dim] - compressed mean pooled embedding
"""
tokens = self.tokenizer(
texts,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokens.input_ids.to(device)
attention_mask = tokens.attention_mask.to(device)
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
sequence_raw = outputs.last_hidden_state # [B, L, raw_dim]
# Apply bottleneck
sequence = self.bottleneck(sequence_raw) # [B, L, bottleneck_dim]
# Mean pool over non-padding tokens
mask_expanded = attention_mask.unsqueeze(-1).float()
pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
return sequence, pooled
@torch.no_grad()
def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]:
"""
Encode text prompts WITHOUT bottleneck (for caching raw embeddings).
Returns:
sequence: [B, L, raw_dim] - raw T5 embeddings
pooled: [B, raw_dim] - raw mean pooled embedding
"""
tokens = self.tokenizer(
texts,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokens.input_ids.to(device)
attention_mask = tokens.attention_mask.to(device)
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
sequence = outputs.last_hidden_state # [B, L, raw_dim]
# Mean pool over non-padding tokens
mask_expanded = attention_mask.unsqueeze(-1).float()
pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
return sequence, pooled
# =============================================================================
# CACHED DATASET (VAE latents + T5 text embeddings per class)
# =============================================================================
class CachedCIFAR10T5(Dataset):
"""
Pre-cached CIFAR-10 with VAE latents.
T5 embeddings are computed per-class (not per-image).
"""
T5_MODEL = "google/flan-t5-xl" # Change this to use different T5 variant
def __init__(
self,
train: bool = True,
image_size: int = 256,
cache_dir: str = "./cache",
device: str = "cuda",
):
self.train = train
# Include T5 model name in cache path
t5_suffix = self.T5_MODEL.replace("/", "_")
self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt"
if self.cache_path.exists():
print(f"Loading cache: {self.cache_path}")
cache = torch.load(self.cache_path, weights_only=False)
self.latents = cache['latents']
self.labels = cache['labels']
self.text_sequence = cache['text_sequence'] # [10, L, dim]
self.text_pooled = cache['text_pooled'] # [10, dim]
self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1])
else:
print(f"Building cache for {'train' if train else 'val'} set...")
self._build_cache(image_size, device)
def _build_cache(self, image_size: int, device: str):
# Load encoders
print(" Loading VAE...")
vae = SD15VAE(freeze=True).to(device)
print(f" Loading T5 ({self.T5_MODEL})...")
t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device)
# Encode class prompts - save RAW embeddings (bottleneck is in model)
print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...")
text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device)
self.text_sequence = text_seq.cpu() # [10, L, raw_dim]
self.text_pooled = text_pool.cpu() # [10, raw_dim]
self.text_dim = t5.raw_dim # Store raw dim for bottleneck sizing
# Encode images
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
all_latents, all_labels = [], []
print(" Encoding images...")
with torch.no_grad():
for images, labels in tqdm(loader, desc=" Caching", leave=False):
images = images.to(device)
all_latents.append(vae.encode(images).cpu())
all_labels.append(labels)
self.latents = torch.cat(all_latents, dim=0)
self.labels = torch.cat(all_labels, dim=0)
del vae, t5
torch.cuda.empty_cache()
# Save
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
'latents': self.latents,
'labels': self.labels,
'text_sequence': self.text_sequence,
'text_pooled': self.text_pooled,
'text_dim': self.text_dim,
}, self.cache_path)
print(f" Saved: {self.cache_path}")
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
return (
self.latents[idx],
self.text_sequence[label], # [L, raw_dim]
self.text_pooled[label], # [raw_dim]
label,
)
# =============================================================================
# SINUSOIDAL EMBEDDING
# =============================================================================
class SinusoidalEmbed(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: Tensor) -> Tensor:
half = self.dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
args = t.unsqueeze(-1) * freqs
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# =============================================================================
# CONFIG
# =============================================================================
@dataclass
class FlowConfig:
image_size: int = 256
num_classes: int = 10
latent_channels: int = 4
latent_size: int = 32
# T5 dimensions
text_raw_dim: int = 2048 # Raw T5-XL output, overridden by dataset
text_seq_len: int = 77
bottleneck_dim: int = 256 # Compressed text dim
# Tower collective (transformer-based)
tower_dim: int = 256
tower_depth: int = 2
num_heads: int = 8
geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex')
# Conv tower types (convolutional)
conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite')
conv_spatial_size: int = 8 # Spatial size for conv towers
# Oscillator
manifold_dim: int = 1024 # Projected manifold (smaller than latent)
num_tower_pairs: int = 16 # 32 towers / 2
osc_steps: int = 50 # For sampling only
fingerprint_dim: int = 64
# Flow
num_flow_steps: int = 50
sigma_min: float = 0.001
# Training
batch_size: int = 64
lr: float = 1e-4
weight_decay: float = 0.01
num_epochs: int = 100
cache_dir: str = "./cache"
device: str = "cuda"
output_dir: str = "./beatrix_cifar_t5"
@property
def latent_flat_dim(self) -> int:
"""Full flattened latent size: 4 Γ— 32 Γ— 32 = 4096"""
return self.latent_channels * self.latent_size * self.latent_size
# =============================================================================
# BEATRIX FLOW MODEL (Vision + Text Towers)
# =============================================================================
class BeatrixFlowT5(WideRouter):
"""
Flow model with dual tower collectives per modality:
Vision side:
- Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg)
- Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg)
Text side (mirrored):
- Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg)
- Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg)
All towers output opinions that combine for velocity prediction.
"""
def __init__(self, cfg: FlowConfig):
super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False)
self.objects['cfg'] = cfg
# =================================================================
# TEXT BOTTLENECK (trainable)
# =================================================================
self.attach('text_bottleneck_seq', nn.Sequential(
nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim),
nn.GELU(),
nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim),
))
self.attach('text_bottleneck_pool', nn.Sequential(
nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim),
nn.GELU(),
nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim),
))
# =================================================================
# VISION GEOMETRIC TOWERS (pos/neg pairs)
# =================================================================
vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types))
vision_geo_collective = build_tower_collective(
configs=vision_geo_configs,
dim=cfg.tower_dim,
default_depth=cfg.tower_depth,
num_heads=cfg.num_heads,
ffn_mult=4.0,
dropout=0.1,
fingerprint_dim=cfg.fingerprint_dim,
fusion_type='adaptive',
name='vision_geo',
)
self.attach('vision_geo', vision_geo_collective)
# =================================================================
# VISION CONV TOWERS (pos/neg pairs)
# =================================================================
vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types))
vision_conv_collective = build_conv_collective(
configs=vision_conv_configs,
dim=cfg.tower_dim,
default_depth=cfg.tower_depth,
fingerprint_dim=cfg.fingerprint_dim,
spatial_size=cfg.conv_spatial_size,
name='vision_conv',
)
self.attach('vision_conv', vision_conv_collective)
# =================================================================
# TEXT GEOMETRIC TOWERS (pos/neg pairs) - MIRRORED
# =================================================================
text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types))
text_geo_collective = build_tower_collective(
configs=text_geo_configs,
dim=cfg.tower_dim,
default_depth=cfg.tower_depth,
num_heads=cfg.num_heads,
ffn_mult=4.0,
dropout=0.1,
fingerprint_dim=cfg.fingerprint_dim,
fusion_type='adaptive',
name='text_geo',
)
self.attach('text_geo', text_geo_collective)
# =================================================================
# TEXT CONV TOWERS (pos/neg pairs) - MIRRORED
# =================================================================
text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types))
text_conv_collective = build_conv_collective(
configs=text_conv_configs,
dim=cfg.tower_dim,
default_depth=cfg.tower_depth,
fingerprint_dim=cfg.fingerprint_dim,
spatial_size=cfg.conv_spatial_size,
name='text_conv',
)
self.attach('text_conv', text_conv_collective)
# =================================================================
# PROJECTIONS
# =================================================================
# Latent patchifier
patch_size = 4
num_patches = (cfg.latent_size // patch_size) ** 2
patch_dim = cfg.latent_channels * patch_size * patch_size
self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim))
self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02)
self.objects['patch_size'] = patch_size
self.objects['num_patches'] = num_patches
# Text already at bottleneck_dim (256) = tower_dim, no extra projection needed
# =================================================================
# OSCILLATOR (for sampling)
# =================================================================
# Total towers: (4 geo + 4 conv) Γ— pos/neg Γ— 2 modalities = 32 towers
num_geo_towers = len(vision_geo_configs)
num_conv_towers = len(vision_conv_configs)
total_towers = (num_geo_towers + num_conv_towers) * 2 # Γ— 2 for vision + text
oscillator = BeatrixOscillator(
name='oscillator',
manifold_dim=cfg.manifold_dim,
tower_dim=cfg.tower_dim,
num_tower_pairs=total_towers // 2,
num_theta_probes=4,
fingerprint_dim=cfg.fingerprint_dim,
kappa_schedule=ScheduleType.TESLA_369,
use_intrinsic_tension=True,
)
self.attach('oscillator', oscillator)
# =================================================================
# CONDITIONING
# =================================================================
# Time embedding
time_embed = nn.Sequential(
SinusoidalEmbed(256),
nn.Linear(256, cfg.tower_dim),
nn.GELU(),
nn.Linear(cfg.tower_dim, cfg.tower_dim),
)
self.attach('time_embed', time_embed)
# Bottlenecked text -> reference anchor
self.attach('text_to_ref', nn.Sequential(
nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim),
nn.GELU(),
nn.Linear(cfg.manifold_dim, cfg.manifold_dim),
))
# Time modulation for reference
self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim))
# =================================================================
# LATENT PROJECTION (4096 <-> manifold_dim)
# =================================================================
self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim))
self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim))
# Learnable velocity mixing
self.velocity_mix = nn.Parameter(torch.tensor(0.5))
def patchify(self, z: Tensor) -> Tensor:
"""[B, 4, 32, 32] -> [B, num_patches, tower_dim]"""
B, C, H, W = z.shape
p = self.objects['patch_size']
z = z.unfold(2, p, p).unfold(3, p, p)
z = z.permute(0, 2, 3, 1, 4, 5).contiguous()
z = z.view(B, -1, C * p * p)
return self['patch_proj'](z) + self.patch_pos_embed
def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]:
"""
Run all four tower collectives.
Returns list of tower opinions [B, tower_dim] (32 total).
"""
patches = self.patchify(z)
text_bottlenecked = self['text_bottleneck_seq'](text_seq)
# Run all collectives
vision_geo = self['vision_geo'](patches)
vision_conv_fused, vision_conv_ops = self['vision_conv'](patches)
text_geo = self['text_geo'](text_bottlenecked)
text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked)
# Collect opinions - use list comprehension (faster than append loop)
return (
[op.opinion for op in vision_geo.opinions.values()] +
list(vision_conv_ops.values()) +
[op.opinion for op in text_geo.opinions.values()] +
list(text_conv_ops.values())
)
def forward(
self,
z_0: Tensor,
text_seq: Tensor,
text_pooled: Tensor,
labels: Tensor,
t: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
"""Training forward - single step velocity prediction."""
cfg = self.objects['cfg']
B = z_0.shape[0]
device = z_0.device
if t is None:
t = torch.rand(B, device=device)
# Flatten latent [B, 4, 32, 32] -> [B, 4096]
z_0_flat = z_0.flatten(1)
# Noise + interpolate in full latent space
eps = torch.randn_like(z_0)
eps_flat = eps.flatten(1)
t_exp = t.view(B, 1, 1, 1)
z_t = (1 - t_exp) * z_0 + t_exp * eps
z_t_flat = z_t.flatten(1)
# Target velocity (in full latent space)
v_target = eps_flat - z_0_flat
# === PROJECT TO SMALLER MANIFOLD ===
z_t_proj = self['latent_down'](z_t_flat) # [B, 4096] -> [B, manifold_dim]
# Bottleneck pooled text for reference
text_pooled_bn = self['text_bottleneck_pool'](text_pooled)
# Reference from bottlenecked text + time (in manifold space)
time_emb = self['time_embed'](t)
x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb)
# Get all tower outputs (text_seq bottlenecked inside get_tower_outputs)
tower_outputs = self.get_tower_outputs(z_t, text_seq)
# Compute forces in manifold space
osc = self['oscillator']
tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None)
spring_force = x_ref - z_t_proj
# Velocity prediction in manifold space
tau = torch.sigmoid(self.velocity_mix)
v_pred_proj = (1 - tau) * spring_force + tau * tower_force
# === PROJECT BACK TO FULL LATENT ===
v_pred = self['latent_up'](v_pred_proj) # [B, manifold_dim] -> [B, 4096]
loss = F.mse_loss(v_pred, v_target)
return {'loss': loss, 'tau': tau.detach()}
@torch.no_grad()
def sample(
self,
text_seq: Tensor,
text_pooled: Tensor,
vae: SD15VAE,
num_steps: Optional[int] = None,
) -> Tensor:
"""Generate samples from text conditioning."""
cfg = self.objects['cfg']
B = text_seq.shape[0]
device = text_seq.device
num_steps = num_steps or cfg.num_flow_steps
# Bottleneck pooled text once
text_pooled_bn = self['text_bottleneck_pool'](text_pooled)
# Start from noise
z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device)
dt = 1.0 / num_steps
for step in range(num_steps):
t_val = 1.0 - step * dt
t = torch.full((B,), t_val, device=device)
time_emb = self['time_embed'](t)
x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb)
z_flat = z.flatten(1)
# Project to manifold
z_proj = self['latent_down'](z_flat)
tower_outputs = self.get_tower_outputs(z, text_seq)
osc = self['oscillator']
tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None)
spring_force = x_ref - z_proj
tau = torch.sigmoid(self.velocity_mix)
v_pred_proj = (1 - tau) * spring_force + tau * tower_force
# Project back and update
v_pred = self['latent_up'](v_pred_proj)
z_flat = z_flat - dt * v_pred
z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size)
return vae.decode(z)
# =============================================================================
# TRAINER
# =============================================================================
class Trainer:
def __init__(self, cfg: FlowConfig):
self.cfg = cfg
self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
self.scaler = torch.amp.GradScaler('cuda')
# Dataset
print("\n=== Building Cached Datasets ===")
self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device)
self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device)
# Update config with actual T5 raw dimension from cache
cfg.text_raw_dim = self.train_dataset.text_dim
print(f"T5 raw dimension: {cfg.text_raw_dim} β†’ bottleneck: {cfg.bottleneck_dim}")
self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True)
# Store raw text embeddings for sampling (bottleneck applied in model)
self.text_sequence = self.train_dataset.text_sequence.to(self.device) # [10, L, raw_dim]
self.text_pooled = self.train_dataset.text_pooled.to(self.device) # [10, raw_dim]
# Model
print("\n=== Building Model (Vision + Text Towers) ===")
self.model = BeatrixFlowT5(cfg).to(self.device)
# Compile
if hasattr(torch, 'compile'):
print("Compiling with WideRouter.prepare_and_compile()...")
self.model = self.model.prepare_and_compile(
mode="reduce-overhead",
fullgraph=False,
)
num_params = sum(p.numel() for p in self.model.parameters())
print(f"Trainable parameters: {num_params:,}")
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader))
# Load most recent checkpoint if exists
self.start_epoch = 0
self.hf_repo = "AbstractPhil/beatrix-diffusion-proto"
self._load_latest_checkpoint()
self._vae = None
# HuggingFace Hub setup
self._setup_hf_repo()
def _setup_hf_repo(self):
"""Create HF repo if needed and save initial config."""
try:
self.hf_api = HfApi()
create_repo(self.hf_repo, exist_ok=True, repo_type="model")
print(f"HF repo: {self.hf_repo}")
# Save config
config_dict = {
'image_size': self.cfg.image_size,
'num_classes': self.cfg.num_classes,
'latent_channels': self.cfg.latent_channels,
'latent_size': self.cfg.latent_size,
'text_raw_dim': self.cfg.text_raw_dim,
'bottleneck_dim': self.cfg.bottleneck_dim,
'tower_dim': self.cfg.tower_dim,
'tower_depth': self.cfg.tower_depth,
'num_heads': self.cfg.num_heads,
'geometric_types': self.cfg.geometric_types,
'conv_types': self.cfg.conv_types,
'conv_spatial_size': self.cfg.conv_spatial_size,
'manifold_dim': self.cfg.manifold_dim,
'fingerprint_dim': self.cfg.fingerprint_dim,
'num_flow_steps': self.cfg.num_flow_steps,
}
config_path = self.output_dir / "config.json"
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=2)
upload_file(
path_or_fileobj=str(config_path),
path_in_repo="config.json",
repo_id=self.hf_repo,
)
except Exception as e:
print(f"HF setup warning: {e}")
self.hf_api = None
def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None):
"""Upload checkpoint, samples, and metrics to HuggingFace."""
if self.hf_api is None:
return
try:
# Upload checkpoint
ckpt_path = self.output_dir / "ckpt_latest.pt"
if ckpt_path.exists():
upload_file(
path_or_fileobj=str(ckpt_path),
path_in_repo="ckpt_latest.pt",
repo_id=self.hf_repo,
)
# Upload samples
if sample_path.exists():
upload_file(
path_or_fileobj=str(sample_path),
path_in_repo=f"samples/epoch_{epoch:03d}.png",
repo_id=self.hf_repo,
)
# Also as latest
upload_file(
path_or_fileobj=str(sample_path),
path_in_repo="samples/latest.png",
repo_id=self.hf_repo,
)
# Upload metrics log
if metrics:
metrics_path = self.output_dir / "metrics.jsonl"
with open(metrics_path, 'a') as f:
f.write(json.dumps({'epoch': epoch, **metrics}) + '\n')
upload_file(
path_or_fileobj=str(metrics_path),
path_in_repo="metrics.jsonl",
repo_id=self.hf_repo,
)
print(f" β†’ Uploaded to HF")
except Exception as e:
print(f" β†’ HF upload failed: {e}")
def _load_latest_checkpoint(self):
"""Load most recent checkpoint if available (local or HF)."""
latest_path = self.output_dir / "ckpt_latest.pt"
# Try local first
if latest_path.exists():
print(f"Resuming from local ckpt_latest.pt...")
ckpt = torch.load(latest_path, weights_only=False)
else:
# Fall back to numbered checkpoints
ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt"))
if ckpts:
latest_path = ckpts[-1]
print(f"Resuming from {latest_path.name}...")
ckpt = torch.load(latest_path, weights_only=False)
else:
# Try downloading from HuggingFace
try:
from huggingface_hub import hf_hub_download
print(f"Checking HF for checkpoint...")
hf_path = hf_hub_download(
repo_id=self.hf_repo,
filename="ckpt_latest.pt",
local_dir=str(self.output_dir),
)
print(f"Downloaded checkpoint from HF")
ckpt = torch.load(hf_path, weights_only=False)
except Exception as e:
print(f"No checkpoint found (local or HF): {e}")
return
self.model.load_state_dict(ckpt['model'])
self.optimizer.load_state_dict(ckpt['optimizer'])
self.scheduler.load_state_dict(ckpt['scheduler'])
self.start_epoch = ckpt['epoch']
print(f" Resumed at epoch {self.start_epoch}")
def _load_vae(self):
"""Load VAE for sampling (temporary)."""
print("Loading VAE for sampling...")
return SD15VAE(freeze=True).to(self.device)
def _unload_vae(self, vae):
"""Unload VAE after sampling."""
del vae
torch.cuda.empty_cache()
def train_epoch(self, epoch: int) -> Dict[str, float]:
self.model.train()
total_loss, total_tau, n = 0.0, 0.0, 0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False)
for latents, text_seq, text_pooled, labels in pbar:
latents = latents.to(self.device)
text_seq = text_seq.to(self.device)
text_pooled = text_pooled.to(self.device)
labels = labels.to(self.device)
with torch.amp.autocast('cuda'):
out = self.model(latents, text_seq, text_pooled, labels)
loss = out['loss']
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
total_loss += loss.item()
total_tau += out['tau'].item()
n += 1
pbar.set_postfix(loss=f"{loss.item():.4f}", Ο„=f"{out['tau'].item():.2f}")
return {'loss': total_loss / n, 'tau': total_tau / n}
@torch.no_grad()
def validate(self) -> Dict[str, float]:
self.model.eval()
total_loss, n = 0.0, 0
for latents, text_seq, text_pooled, labels in self.val_loader:
latents = latents.to(self.device)
text_seq = text_seq.to(self.device)
text_pooled = text_pooled.to(self.device)
labels = labels.to(self.device)
with torch.amp.autocast('cuda'):
out = self.model(latents, text_seq, text_pooled, labels)
total_loss += out['loss'].item()
n += 1
return {'val_loss': total_loss / n}
@torch.no_grad()
def sample_images(self, n_per_class: int = 10) -> Tensor:
"""Generate samples for each class (memory-efficient batched)."""
self.model.eval()
torch.cuda.empty_cache()
# Load VAE temporarily
vae = self._load_vae()
all_samples = []
batch_size = 10 # Generate 10 images at a time
for class_idx in range(10):
# Generate n_per_class images for this class
for batch_start in range(0, n_per_class, batch_size):
batch_n = min(batch_size, n_per_class - batch_start)
text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1)
text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1)
with torch.amp.autocast('cuda'):
samples = self.model.sample(text_seq, text_pooled, vae)
all_samples.append(samples.cpu())
# Unload VAE
self._unload_vae(vae)
samples = torch.cat(all_samples, dim=0).to(self.device)
return ((samples + 1) / 2).clamp(0, 1)
def save_checkpoint(self, epoch: int, milestone: bool = False):
ckpt = {
'epoch': epoch,
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
}
# Always save latest (for resume)
torch.save(ckpt, self.output_dir / "ckpt_latest.pt")
# Save milestone checkpoints
if milestone:
torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt")
def train(self):
num_geo = len(self.cfg.geometric_types) * 2 # pos/neg
num_conv = len(self.cfg.conv_types) * 2
total_towers = (num_geo + num_conv) * 2 # Γ— 2 modalities
print(f"\n{'='*60}")
print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)")
print(f"{'='*60}")
print(f"Device: {self.device}")
print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)")
print(f"Conv towers: {self.cfg.conv_types} (pos/neg)")
print(f"Tower dim: {self.cfg.tower_dim}")
print(f"T5 raw β†’ bottleneck: {self.cfg.text_raw_dim} β†’ {self.cfg.bottleneck_dim}")
print(f"Latent β†’ manifold: {self.cfg.latent_flat_dim} β†’ {self.cfg.manifold_dim}")
print(f"Total towers: {total_towers}")
print(f"Batch size: {self.cfg.batch_size}")
print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}")
print(f"{'='*60}\n")
for epoch in range(self.start_epoch, self.cfg.num_epochs):
train_metrics = self.train_epoch(epoch)
val_metrics = self.validate()
lr = self.scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1:3d} β”‚ loss={train_metrics['loss']:.4f} β”‚ val={val_metrics['val_loss']:.4f} β”‚ Ο„={train_metrics['tau']:.2f} β”‚ lr={lr:.2e}")
# Sample every epoch to track progress
samples = self.sample_images(10)
grid = make_grid(samples, nrow=10, padding=2)
sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png"
save_image(grid, sample_path)
print(f" β†’ Saved samples")
# Checkpoint every epoch (latest), milestone every 10
self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0))
# Upload to HuggingFace
metrics = {
'loss': train_metrics['loss'],
'val_loss': val_metrics['val_loss'],
'tau': train_metrics['tau'],
'lr': lr,
}
self._upload_to_hf(epoch + 1, sample_path, metrics)
samples = self.sample_images(10)
grid = make_grid(samples, nrow=10, padding=2)
final_path = self.output_dir / "samples_final.png"
save_image(grid, final_path)
self.save_checkpoint(self.cfg.num_epochs, milestone=True)
self._upload_to_hf(self.cfg.num_epochs, final_path)
print(f"\nTraining complete!")
# =============================================================================
# MAIN
# =============================================================================
def main():
# Lightweight config - 16 towers instead of 32
cfg = FlowConfig(
image_size=256,
tower_dim=256,
tower_depth=2,
num_heads=8,
geometric_types=('cantor', 'beatrix'), # 2 types Γ— pos/neg = 4 per modality
conv_types=('wide_resnet', 'squeeze_excite'), # 2 types Γ— pos/neg = 4 per modality
conv_spatial_size=8,
bottleneck_dim=256,
manifold_dim=512, # Smaller manifold
batch_size=64,
num_epochs=100,
cache_dir="./cache",
output_dir="./beatrix_cifar_t5",
)
trainer = Trainer(cfg)
trainer.train()
def main_full():
"""Full 32-tower configuration."""
cfg = FlowConfig(
image_size=256,
tower_dim=256,
tower_depth=2,
num_heads=8,
geometric_types=('cantor', 'beatrix', 'helix', 'simplex'),
conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'),
conv_spatial_size=8,
bottleneck_dim=256,
manifold_dim=1024,
batch_size=64,
num_epochs=100,
cache_dir="./cache",
output_dir="./beatrix_cifar_t5",
)
trainer = Trainer(cfg)
trainer.train()
if __name__ == "__main__":
main()