|
|
""" |
|
|
BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder) |
|
|
=================================================== |
|
|
|
|
|
SD 1.5 VAE + Flan-T5-Large text encoder |
|
|
Dual tower collectives: vision towers + text towers |
|
|
|
|
|
Text prompts for CIFAR-10 classes: |
|
|
"a photo of an airplane" |
|
|
"a photo of an automobile" |
|
|
etc. |
|
|
|
|
|
Requirements: |
|
|
pip install transformers diffusers torchvision tqdm |
|
|
pip install git+https://github.com/AbstractEyes/geofractal |
|
|
|
|
|
Currently running like a turtle, will optimize tomorrow. |
|
|
|
|
|
apache 2.0 license |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Tuple, Optional, List |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torchvision import datasets, transforms |
|
|
from torchvision.utils import make_grid, save_image |
|
|
from huggingface_hub import HfApi, upload_file, create_repo |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from geofractal.router.wide_router import WideRouter |
|
|
from geofractal.router.prefab.agatha.beatrix_tension_oscillator import ( |
|
|
BeatrixOscillator, |
|
|
ScheduleType, |
|
|
) |
|
|
from geofractal.router.prefab.geometric_tower_builder import ( |
|
|
TowerConfig, |
|
|
FusionType, |
|
|
ConfigurableCollective, |
|
|
build_tower_collective, |
|
|
preset_pos_neg_pairs, |
|
|
) |
|
|
from geofractal.router.prefab.geometric_conv_tower_builder import ( |
|
|
ConvTowerConfig, |
|
|
ConvTowerCollective, |
|
|
build_conv_collective, |
|
|
preset_conv_pos_neg, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CIFAR10_PROMPTS = [ |
|
|
"a photo of an airplane", |
|
|
"a photo of an automobile", |
|
|
"a photo of a bird", |
|
|
"a photo of a cat", |
|
|
"a photo of a deer", |
|
|
"a photo of a dog", |
|
|
"a photo of a frog", |
|
|
"a photo of a horse", |
|
|
"a photo of a ship", |
|
|
"a photo of a truck", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SD15VAE(nn.Module): |
|
|
def __init__(self, freeze: bool = True): |
|
|
super().__init__() |
|
|
from diffusers import AutoencoderKL |
|
|
|
|
|
self.vae = AutoencoderKL.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="vae", |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
if freeze: |
|
|
self.vae.eval() |
|
|
for p in self.vae.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
self.scale_factor = 0.18215 |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, x: Tensor) -> Tensor: |
|
|
return self.vae.encode(x).latent_dist.sample() * self.scale_factor |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode(self, z: Tensor) -> Tensor: |
|
|
return self.vae.decode(z / self.scale_factor).sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class T5TextEncoder(nn.Module): |
|
|
"""Flan-T5 encoder with bottleneck projection.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "google/flan-t5-xl", |
|
|
freeze: bool = True, |
|
|
max_length: int = 77, |
|
|
bottleneck_dim: int = 256, |
|
|
): |
|
|
super().__init__() |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
self.encoder = T5EncoderModel.from_pretrained(model_name) |
|
|
self.max_length = max_length |
|
|
self.raw_dim = self.encoder.config.d_model |
|
|
self.output_dim = bottleneck_dim |
|
|
|
|
|
|
|
|
self.bottleneck = nn.Sequential( |
|
|
nn.Linear(self.raw_dim, bottleneck_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(bottleneck_dim, bottleneck_dim), |
|
|
) |
|
|
|
|
|
if freeze: |
|
|
self.encoder.eval() |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
|
|
""" |
|
|
Encode text prompts with bottleneck. |
|
|
|
|
|
Returns: |
|
|
sequence: [B, L, bottleneck_dim] - compressed sequence embeddings |
|
|
pooled: [B, bottleneck_dim] - compressed mean pooled embedding |
|
|
""" |
|
|
tokens = self.tokenizer( |
|
|
texts, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
input_ids = tokens.input_ids.to(device) |
|
|
attention_mask = tokens.attention_mask.to(device) |
|
|
|
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
sequence_raw = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
sequence = self.bottleneck(sequence_raw) |
|
|
|
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).float() |
|
|
pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
|
|
|
|
|
return sequence, pooled |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
|
|
""" |
|
|
Encode text prompts WITHOUT bottleneck (for caching raw embeddings). |
|
|
|
|
|
Returns: |
|
|
sequence: [B, L, raw_dim] - raw T5 embeddings |
|
|
pooled: [B, raw_dim] - raw mean pooled embedding |
|
|
""" |
|
|
tokens = self.tokenizer( |
|
|
texts, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
input_ids = tokens.input_ids.to(device) |
|
|
attention_mask = tokens.attention_mask.to(device) |
|
|
|
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
sequence = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).float() |
|
|
pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
|
|
|
|
|
return sequence, pooled |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CachedCIFAR10T5(Dataset): |
|
|
""" |
|
|
Pre-cached CIFAR-10 with VAE latents. |
|
|
T5 embeddings are computed per-class (not per-image). |
|
|
""" |
|
|
|
|
|
T5_MODEL = "google/flan-t5-xl" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
train: bool = True, |
|
|
image_size: int = 256, |
|
|
cache_dir: str = "./cache", |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.train = train |
|
|
|
|
|
t5_suffix = self.T5_MODEL.replace("/", "_") |
|
|
self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt" |
|
|
|
|
|
if self.cache_path.exists(): |
|
|
print(f"Loading cache: {self.cache_path}") |
|
|
cache = torch.load(self.cache_path, weights_only=False) |
|
|
self.latents = cache['latents'] |
|
|
self.labels = cache['labels'] |
|
|
self.text_sequence = cache['text_sequence'] |
|
|
self.text_pooled = cache['text_pooled'] |
|
|
self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1]) |
|
|
else: |
|
|
print(f"Building cache for {'train' if train else 'val'} set...") |
|
|
self._build_cache(image_size, device) |
|
|
|
|
|
def _build_cache(self, image_size: int, device: str): |
|
|
|
|
|
print(" Loading VAE...") |
|
|
vae = SD15VAE(freeze=True).to(device) |
|
|
print(f" Loading T5 ({self.T5_MODEL})...") |
|
|
t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device) |
|
|
|
|
|
|
|
|
print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...") |
|
|
text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device) |
|
|
self.text_sequence = text_seq.cpu() |
|
|
self.text_pooled = text_pool.cpu() |
|
|
self.text_dim = t5.raw_dim |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
|
]) |
|
|
|
|
|
dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform) |
|
|
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) |
|
|
|
|
|
all_latents, all_labels = [], [] |
|
|
|
|
|
print(" Encoding images...") |
|
|
with torch.no_grad(): |
|
|
for images, labels in tqdm(loader, desc=" Caching", leave=False): |
|
|
images = images.to(device) |
|
|
all_latents.append(vae.encode(images).cpu()) |
|
|
all_labels.append(labels) |
|
|
|
|
|
self.latents = torch.cat(all_latents, dim=0) |
|
|
self.labels = torch.cat(all_labels, dim=0) |
|
|
|
|
|
del vae, t5 |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
torch.save({ |
|
|
'latents': self.latents, |
|
|
'labels': self.labels, |
|
|
'text_sequence': self.text_sequence, |
|
|
'text_pooled': self.text_pooled, |
|
|
'text_dim': self.text_dim, |
|
|
}, self.cache_path) |
|
|
print(f" Saved: {self.cache_path}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.labels) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
label = self.labels[idx] |
|
|
return ( |
|
|
self.latents[idx], |
|
|
self.text_sequence[label], |
|
|
self.text_pooled[label], |
|
|
label, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SinusoidalEmbed(nn.Module): |
|
|
def __init__(self, dim: int): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, t: Tensor) -> Tensor: |
|
|
half = self.dim // 2 |
|
|
freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half) |
|
|
args = t.unsqueeze(-1) * freqs |
|
|
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FlowConfig: |
|
|
image_size: int = 256 |
|
|
num_classes: int = 10 |
|
|
latent_channels: int = 4 |
|
|
latent_size: int = 32 |
|
|
|
|
|
|
|
|
text_raw_dim: int = 2048 |
|
|
text_seq_len: int = 77 |
|
|
bottleneck_dim: int = 256 |
|
|
|
|
|
|
|
|
tower_dim: int = 256 |
|
|
tower_depth: int = 2 |
|
|
num_heads: int = 8 |
|
|
geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex') |
|
|
|
|
|
|
|
|
conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite') |
|
|
conv_spatial_size: int = 8 |
|
|
|
|
|
|
|
|
manifold_dim: int = 1024 |
|
|
num_tower_pairs: int = 16 |
|
|
osc_steps: int = 50 |
|
|
fingerprint_dim: int = 64 |
|
|
|
|
|
|
|
|
num_flow_steps: int = 50 |
|
|
sigma_min: float = 0.001 |
|
|
|
|
|
|
|
|
batch_size: int = 64 |
|
|
lr: float = 1e-4 |
|
|
weight_decay: float = 0.01 |
|
|
num_epochs: int = 100 |
|
|
|
|
|
cache_dir: str = "./cache" |
|
|
device: str = "cuda" |
|
|
output_dir: str = "./beatrix_cifar_t5" |
|
|
|
|
|
@property |
|
|
def latent_flat_dim(self) -> int: |
|
|
"""Full flattened latent size: 4 Γ 32 Γ 32 = 4096""" |
|
|
return self.latent_channels * self.latent_size * self.latent_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeatrixFlowT5(WideRouter): |
|
|
""" |
|
|
Flow model with dual tower collectives per modality: |
|
|
|
|
|
Vision side: |
|
|
- Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
|
|
- Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
|
|
|
|
|
Text side (mirrored): |
|
|
- Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
|
|
- Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
|
|
|
|
|
All towers output opinions that combine for velocity prediction. |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: FlowConfig): |
|
|
super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False) |
|
|
self.objects['cfg'] = cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.attach('text_bottleneck_seq', nn.Sequential( |
|
|
nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
|
|
)) |
|
|
self.attach('text_bottleneck_pool', nn.Sequential( |
|
|
nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
|
|
|
|
|
vision_geo_collective = build_tower_collective( |
|
|
configs=vision_geo_configs, |
|
|
dim=cfg.tower_dim, |
|
|
default_depth=cfg.tower_depth, |
|
|
num_heads=cfg.num_heads, |
|
|
ffn_mult=4.0, |
|
|
dropout=0.1, |
|
|
fingerprint_dim=cfg.fingerprint_dim, |
|
|
fusion_type='adaptive', |
|
|
name='vision_geo', |
|
|
) |
|
|
self.attach('vision_geo', vision_geo_collective) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
|
|
|
|
|
vision_conv_collective = build_conv_collective( |
|
|
configs=vision_conv_configs, |
|
|
dim=cfg.tower_dim, |
|
|
default_depth=cfg.tower_depth, |
|
|
fingerprint_dim=cfg.fingerprint_dim, |
|
|
spatial_size=cfg.conv_spatial_size, |
|
|
name='vision_conv', |
|
|
) |
|
|
self.attach('vision_conv', vision_conv_collective) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
|
|
|
|
|
text_geo_collective = build_tower_collective( |
|
|
configs=text_geo_configs, |
|
|
dim=cfg.tower_dim, |
|
|
default_depth=cfg.tower_depth, |
|
|
num_heads=cfg.num_heads, |
|
|
ffn_mult=4.0, |
|
|
dropout=0.1, |
|
|
fingerprint_dim=cfg.fingerprint_dim, |
|
|
fusion_type='adaptive', |
|
|
name='text_geo', |
|
|
) |
|
|
self.attach('text_geo', text_geo_collective) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
|
|
|
|
|
text_conv_collective = build_conv_collective( |
|
|
configs=text_conv_configs, |
|
|
dim=cfg.tower_dim, |
|
|
default_depth=cfg.tower_depth, |
|
|
fingerprint_dim=cfg.fingerprint_dim, |
|
|
spatial_size=cfg.conv_spatial_size, |
|
|
name='text_conv', |
|
|
) |
|
|
self.attach('text_conv', text_conv_collective) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patch_size = 4 |
|
|
num_patches = (cfg.latent_size // patch_size) ** 2 |
|
|
patch_dim = cfg.latent_channels * patch_size * patch_size |
|
|
|
|
|
self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim)) |
|
|
self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02) |
|
|
self.objects['patch_size'] = patch_size |
|
|
self.objects['num_patches'] = num_patches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_geo_towers = len(vision_geo_configs) |
|
|
num_conv_towers = len(vision_conv_configs) |
|
|
total_towers = (num_geo_towers + num_conv_towers) * 2 |
|
|
|
|
|
oscillator = BeatrixOscillator( |
|
|
name='oscillator', |
|
|
manifold_dim=cfg.manifold_dim, |
|
|
tower_dim=cfg.tower_dim, |
|
|
num_tower_pairs=total_towers // 2, |
|
|
num_theta_probes=4, |
|
|
fingerprint_dim=cfg.fingerprint_dim, |
|
|
kappa_schedule=ScheduleType.TESLA_369, |
|
|
use_intrinsic_tension=True, |
|
|
) |
|
|
self.attach('oscillator', oscillator) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_embed = nn.Sequential( |
|
|
SinusoidalEmbed(256), |
|
|
nn.Linear(256, cfg.tower_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(cfg.tower_dim, cfg.tower_dim), |
|
|
) |
|
|
self.attach('time_embed', time_embed) |
|
|
|
|
|
|
|
|
self.attach('text_to_ref', nn.Sequential( |
|
|
nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(cfg.manifold_dim, cfg.manifold_dim), |
|
|
)) |
|
|
|
|
|
|
|
|
self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim)) |
|
|
self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim)) |
|
|
|
|
|
|
|
|
self.velocity_mix = nn.Parameter(torch.tensor(0.5)) |
|
|
|
|
|
def patchify(self, z: Tensor) -> Tensor: |
|
|
"""[B, 4, 32, 32] -> [B, num_patches, tower_dim]""" |
|
|
B, C, H, W = z.shape |
|
|
p = self.objects['patch_size'] |
|
|
|
|
|
z = z.unfold(2, p, p).unfold(3, p, p) |
|
|
z = z.permute(0, 2, 3, 1, 4, 5).contiguous() |
|
|
z = z.view(B, -1, C * p * p) |
|
|
|
|
|
return self['patch_proj'](z) + self.patch_pos_embed |
|
|
|
|
|
def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]: |
|
|
""" |
|
|
Run all four tower collectives. |
|
|
Returns list of tower opinions [B, tower_dim] (32 total). |
|
|
""" |
|
|
patches = self.patchify(z) |
|
|
text_bottlenecked = self['text_bottleneck_seq'](text_seq) |
|
|
|
|
|
|
|
|
vision_geo = self['vision_geo'](patches) |
|
|
vision_conv_fused, vision_conv_ops = self['vision_conv'](patches) |
|
|
text_geo = self['text_geo'](text_bottlenecked) |
|
|
text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked) |
|
|
|
|
|
|
|
|
return ( |
|
|
[op.opinion for op in vision_geo.opinions.values()] + |
|
|
list(vision_conv_ops.values()) + |
|
|
[op.opinion for op in text_geo.opinions.values()] + |
|
|
list(text_conv_ops.values()) |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
z_0: Tensor, |
|
|
text_seq: Tensor, |
|
|
text_pooled: Tensor, |
|
|
labels: Tensor, |
|
|
t: Optional[Tensor] = None, |
|
|
) -> Dict[str, Tensor]: |
|
|
"""Training forward - single step velocity prediction.""" |
|
|
cfg = self.objects['cfg'] |
|
|
B = z_0.shape[0] |
|
|
device = z_0.device |
|
|
|
|
|
if t is None: |
|
|
t = torch.rand(B, device=device) |
|
|
|
|
|
|
|
|
z_0_flat = z_0.flatten(1) |
|
|
|
|
|
|
|
|
eps = torch.randn_like(z_0) |
|
|
eps_flat = eps.flatten(1) |
|
|
t_exp = t.view(B, 1, 1, 1) |
|
|
z_t = (1 - t_exp) * z_0 + t_exp * eps |
|
|
z_t_flat = z_t.flatten(1) |
|
|
|
|
|
|
|
|
v_target = eps_flat - z_0_flat |
|
|
|
|
|
|
|
|
z_t_proj = self['latent_down'](z_t_flat) |
|
|
|
|
|
|
|
|
text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
|
|
|
|
|
|
|
|
time_emb = self['time_embed'](t) |
|
|
x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
|
|
|
|
|
|
|
|
tower_outputs = self.get_tower_outputs(z_t, text_seq) |
|
|
|
|
|
|
|
|
osc = self['oscillator'] |
|
|
tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None) |
|
|
spring_force = x_ref - z_t_proj |
|
|
|
|
|
|
|
|
tau = torch.sigmoid(self.velocity_mix) |
|
|
v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
|
|
|
|
|
|
|
|
v_pred = self['latent_up'](v_pred_proj) |
|
|
|
|
|
loss = F.mse_loss(v_pred, v_target) |
|
|
|
|
|
return {'loss': loss, 'tau': tau.detach()} |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample( |
|
|
self, |
|
|
text_seq: Tensor, |
|
|
text_pooled: Tensor, |
|
|
vae: SD15VAE, |
|
|
num_steps: Optional[int] = None, |
|
|
) -> Tensor: |
|
|
"""Generate samples from text conditioning.""" |
|
|
cfg = self.objects['cfg'] |
|
|
B = text_seq.shape[0] |
|
|
device = text_seq.device |
|
|
num_steps = num_steps or cfg.num_flow_steps |
|
|
|
|
|
|
|
|
text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
|
|
|
|
|
|
|
|
z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device) |
|
|
|
|
|
dt = 1.0 / num_steps |
|
|
|
|
|
for step in range(num_steps): |
|
|
t_val = 1.0 - step * dt |
|
|
t = torch.full((B,), t_val, device=device) |
|
|
|
|
|
time_emb = self['time_embed'](t) |
|
|
x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
|
|
|
|
|
z_flat = z.flatten(1) |
|
|
|
|
|
|
|
|
z_proj = self['latent_down'](z_flat) |
|
|
|
|
|
tower_outputs = self.get_tower_outputs(z, text_seq) |
|
|
|
|
|
osc = self['oscillator'] |
|
|
tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None) |
|
|
spring_force = x_ref - z_proj |
|
|
|
|
|
tau = torch.sigmoid(self.velocity_mix) |
|
|
v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
|
|
|
|
|
|
|
|
v_pred = self['latent_up'](v_pred_proj) |
|
|
z_flat = z_flat - dt * v_pred |
|
|
z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size) |
|
|
|
|
|
return vae.decode(z) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, cfg: FlowConfig): |
|
|
self.cfg = cfg |
|
|
self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") |
|
|
self.output_dir = Path(cfg.output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
self.scaler = torch.amp.GradScaler('cuda') |
|
|
|
|
|
|
|
|
print("\n=== Building Cached Datasets ===") |
|
|
self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
|
|
self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
|
|
|
|
|
|
|
|
cfg.text_raw_dim = self.train_dataset.text_dim |
|
|
print(f"T5 raw dimension: {cfg.text_raw_dim} β bottleneck: {cfg.bottleneck_dim}") |
|
|
|
|
|
self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) |
|
|
self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True) |
|
|
|
|
|
|
|
|
self.text_sequence = self.train_dataset.text_sequence.to(self.device) |
|
|
self.text_pooled = self.train_dataset.text_pooled.to(self.device) |
|
|
|
|
|
|
|
|
print("\n=== Building Model (Vision + Text Towers) ===") |
|
|
self.model = BeatrixFlowT5(cfg).to(self.device) |
|
|
|
|
|
|
|
|
if hasattr(torch, 'compile'): |
|
|
print("Compiling with WideRouter.prepare_and_compile()...") |
|
|
self.model = self.model.prepare_and_compile( |
|
|
mode="reduce-overhead", |
|
|
fullgraph=False, |
|
|
) |
|
|
|
|
|
num_params = sum(p.numel() for p in self.model.parameters()) |
|
|
print(f"Trainable parameters: {num_params:,}") |
|
|
|
|
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader)) |
|
|
|
|
|
|
|
|
self.start_epoch = 0 |
|
|
self.hf_repo = "AbstractPhil/beatrix-diffusion-proto" |
|
|
self._load_latest_checkpoint() |
|
|
|
|
|
self._vae = None |
|
|
|
|
|
|
|
|
self._setup_hf_repo() |
|
|
|
|
|
def _setup_hf_repo(self): |
|
|
"""Create HF repo if needed and save initial config.""" |
|
|
try: |
|
|
self.hf_api = HfApi() |
|
|
create_repo(self.hf_repo, exist_ok=True, repo_type="model") |
|
|
print(f"HF repo: {self.hf_repo}") |
|
|
|
|
|
|
|
|
config_dict = { |
|
|
'image_size': self.cfg.image_size, |
|
|
'num_classes': self.cfg.num_classes, |
|
|
'latent_channels': self.cfg.latent_channels, |
|
|
'latent_size': self.cfg.latent_size, |
|
|
'text_raw_dim': self.cfg.text_raw_dim, |
|
|
'bottleneck_dim': self.cfg.bottleneck_dim, |
|
|
'tower_dim': self.cfg.tower_dim, |
|
|
'tower_depth': self.cfg.tower_depth, |
|
|
'num_heads': self.cfg.num_heads, |
|
|
'geometric_types': self.cfg.geometric_types, |
|
|
'conv_types': self.cfg.conv_types, |
|
|
'conv_spatial_size': self.cfg.conv_spatial_size, |
|
|
'manifold_dim': self.cfg.manifold_dim, |
|
|
'fingerprint_dim': self.cfg.fingerprint_dim, |
|
|
'num_flow_steps': self.cfg.num_flow_steps, |
|
|
} |
|
|
config_path = self.output_dir / "config.json" |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
|
|
|
upload_file( |
|
|
path_or_fileobj=str(config_path), |
|
|
path_in_repo="config.json", |
|
|
repo_id=self.hf_repo, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"HF setup warning: {e}") |
|
|
self.hf_api = None |
|
|
|
|
|
def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None): |
|
|
"""Upload checkpoint, samples, and metrics to HuggingFace.""" |
|
|
if self.hf_api is None: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
ckpt_path = self.output_dir / "ckpt_latest.pt" |
|
|
if ckpt_path.exists(): |
|
|
upload_file( |
|
|
path_or_fileobj=str(ckpt_path), |
|
|
path_in_repo="ckpt_latest.pt", |
|
|
repo_id=self.hf_repo, |
|
|
) |
|
|
|
|
|
|
|
|
if sample_path.exists(): |
|
|
upload_file( |
|
|
path_or_fileobj=str(sample_path), |
|
|
path_in_repo=f"samples/epoch_{epoch:03d}.png", |
|
|
repo_id=self.hf_repo, |
|
|
) |
|
|
|
|
|
upload_file( |
|
|
path_or_fileobj=str(sample_path), |
|
|
path_in_repo="samples/latest.png", |
|
|
repo_id=self.hf_repo, |
|
|
) |
|
|
|
|
|
|
|
|
if metrics: |
|
|
metrics_path = self.output_dir / "metrics.jsonl" |
|
|
with open(metrics_path, 'a') as f: |
|
|
f.write(json.dumps({'epoch': epoch, **metrics}) + '\n') |
|
|
upload_file( |
|
|
path_or_fileobj=str(metrics_path), |
|
|
path_in_repo="metrics.jsonl", |
|
|
repo_id=self.hf_repo, |
|
|
) |
|
|
|
|
|
print(f" β Uploaded to HF") |
|
|
except Exception as e: |
|
|
print(f" β HF upload failed: {e}") |
|
|
|
|
|
def _load_latest_checkpoint(self): |
|
|
"""Load most recent checkpoint if available (local or HF).""" |
|
|
latest_path = self.output_dir / "ckpt_latest.pt" |
|
|
|
|
|
|
|
|
if latest_path.exists(): |
|
|
print(f"Resuming from local ckpt_latest.pt...") |
|
|
ckpt = torch.load(latest_path, weights_only=False) |
|
|
else: |
|
|
|
|
|
ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt")) |
|
|
if ckpts: |
|
|
latest_path = ckpts[-1] |
|
|
print(f"Resuming from {latest_path.name}...") |
|
|
ckpt = torch.load(latest_path, weights_only=False) |
|
|
else: |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
print(f"Checking HF for checkpoint...") |
|
|
hf_path = hf_hub_download( |
|
|
repo_id=self.hf_repo, |
|
|
filename="ckpt_latest.pt", |
|
|
local_dir=str(self.output_dir), |
|
|
) |
|
|
print(f"Downloaded checkpoint from HF") |
|
|
ckpt = torch.load(hf_path, weights_only=False) |
|
|
except Exception as e: |
|
|
print(f"No checkpoint found (local or HF): {e}") |
|
|
return |
|
|
|
|
|
self.model.load_state_dict(ckpt['model']) |
|
|
self.optimizer.load_state_dict(ckpt['optimizer']) |
|
|
self.scheduler.load_state_dict(ckpt['scheduler']) |
|
|
self.start_epoch = ckpt['epoch'] |
|
|
print(f" Resumed at epoch {self.start_epoch}") |
|
|
|
|
|
def _load_vae(self): |
|
|
"""Load VAE for sampling (temporary).""" |
|
|
print("Loading VAE for sampling...") |
|
|
return SD15VAE(freeze=True).to(self.device) |
|
|
|
|
|
def _unload_vae(self, vae): |
|
|
"""Unload VAE after sampling.""" |
|
|
del vae |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def train_epoch(self, epoch: int) -> Dict[str, float]: |
|
|
self.model.train() |
|
|
total_loss, total_tau, n = 0.0, 0.0, 0 |
|
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False) |
|
|
for latents, text_seq, text_pooled, labels in pbar: |
|
|
latents = latents.to(self.device) |
|
|
text_seq = text_seq.to(self.device) |
|
|
text_pooled = text_pooled.to(self.device) |
|
|
labels = labels.to(self.device) |
|
|
|
|
|
with torch.amp.autocast('cuda'): |
|
|
out = self.model(latents, text_seq, text_pooled, labels) |
|
|
loss = out['loss'] |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
self.scaler.scale(loss).backward() |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
self.scheduler.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
total_tau += out['tau'].item() |
|
|
n += 1 |
|
|
|
|
|
pbar.set_postfix(loss=f"{loss.item():.4f}", Ο=f"{out['tau'].item():.2f}") |
|
|
|
|
|
return {'loss': total_loss / n, 'tau': total_tau / n} |
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(self) -> Dict[str, float]: |
|
|
self.model.eval() |
|
|
total_loss, n = 0.0, 0 |
|
|
|
|
|
for latents, text_seq, text_pooled, labels in self.val_loader: |
|
|
latents = latents.to(self.device) |
|
|
text_seq = text_seq.to(self.device) |
|
|
text_pooled = text_pooled.to(self.device) |
|
|
labels = labels.to(self.device) |
|
|
|
|
|
with torch.amp.autocast('cuda'): |
|
|
out = self.model(latents, text_seq, text_pooled, labels) |
|
|
total_loss += out['loss'].item() |
|
|
n += 1 |
|
|
|
|
|
return {'val_loss': total_loss / n} |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_images(self, n_per_class: int = 10) -> Tensor: |
|
|
"""Generate samples for each class (memory-efficient batched).""" |
|
|
self.model.eval() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
vae = self._load_vae() |
|
|
|
|
|
all_samples = [] |
|
|
batch_size = 10 |
|
|
|
|
|
for class_idx in range(10): |
|
|
|
|
|
for batch_start in range(0, n_per_class, batch_size): |
|
|
batch_n = min(batch_size, n_per_class - batch_start) |
|
|
|
|
|
text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1) |
|
|
text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1) |
|
|
|
|
|
with torch.amp.autocast('cuda'): |
|
|
samples = self.model.sample(text_seq, text_pooled, vae) |
|
|
|
|
|
all_samples.append(samples.cpu()) |
|
|
|
|
|
|
|
|
self._unload_vae(vae) |
|
|
|
|
|
samples = torch.cat(all_samples, dim=0).to(self.device) |
|
|
return ((samples + 1) / 2).clamp(0, 1) |
|
|
|
|
|
def save_checkpoint(self, epoch: int, milestone: bool = False): |
|
|
ckpt = { |
|
|
'epoch': epoch, |
|
|
'model': self.model.state_dict(), |
|
|
'optimizer': self.optimizer.state_dict(), |
|
|
'scheduler': self.scheduler.state_dict(), |
|
|
} |
|
|
|
|
|
torch.save(ckpt, self.output_dir / "ckpt_latest.pt") |
|
|
|
|
|
if milestone: |
|
|
torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt") |
|
|
|
|
|
def train(self): |
|
|
num_geo = len(self.cfg.geometric_types) * 2 |
|
|
num_conv = len(self.cfg.conv_types) * 2 |
|
|
total_towers = (num_geo + num_conv) * 2 |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)") |
|
|
print(f"{'='*60}") |
|
|
print(f"Device: {self.device}") |
|
|
print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)") |
|
|
print(f"Conv towers: {self.cfg.conv_types} (pos/neg)") |
|
|
print(f"Tower dim: {self.cfg.tower_dim}") |
|
|
print(f"T5 raw β bottleneck: {self.cfg.text_raw_dim} β {self.cfg.bottleneck_dim}") |
|
|
print(f"Latent β manifold: {self.cfg.latent_flat_dim} β {self.cfg.manifold_dim}") |
|
|
print(f"Total towers: {total_towers}") |
|
|
print(f"Batch size: {self.cfg.batch_size}") |
|
|
print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
for epoch in range(self.start_epoch, self.cfg.num_epochs): |
|
|
train_metrics = self.train_epoch(epoch) |
|
|
val_metrics = self.validate() |
|
|
|
|
|
lr = self.scheduler.get_last_lr()[0] |
|
|
print(f"Epoch {epoch+1:3d} β loss={train_metrics['loss']:.4f} β val={val_metrics['val_loss']:.4f} β Ο={train_metrics['tau']:.2f} β lr={lr:.2e}") |
|
|
|
|
|
|
|
|
samples = self.sample_images(10) |
|
|
grid = make_grid(samples, nrow=10, padding=2) |
|
|
sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png" |
|
|
save_image(grid, sample_path) |
|
|
print(f" β Saved samples") |
|
|
|
|
|
|
|
|
self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0)) |
|
|
|
|
|
|
|
|
metrics = { |
|
|
'loss': train_metrics['loss'], |
|
|
'val_loss': val_metrics['val_loss'], |
|
|
'tau': train_metrics['tau'], |
|
|
'lr': lr, |
|
|
} |
|
|
self._upload_to_hf(epoch + 1, sample_path, metrics) |
|
|
|
|
|
samples = self.sample_images(10) |
|
|
grid = make_grid(samples, nrow=10, padding=2) |
|
|
final_path = self.output_dir / "samples_final.png" |
|
|
save_image(grid, final_path) |
|
|
self.save_checkpoint(self.cfg.num_epochs, milestone=True) |
|
|
self._upload_to_hf(self.cfg.num_epochs, final_path) |
|
|
print(f"\nTraining complete!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
cfg = FlowConfig( |
|
|
image_size=256, |
|
|
tower_dim=256, |
|
|
tower_depth=2, |
|
|
num_heads=8, |
|
|
geometric_types=('cantor', 'beatrix'), |
|
|
conv_types=('wide_resnet', 'squeeze_excite'), |
|
|
conv_spatial_size=8, |
|
|
bottleneck_dim=256, |
|
|
manifold_dim=512, |
|
|
batch_size=64, |
|
|
num_epochs=100, |
|
|
cache_dir="./cache", |
|
|
output_dir="./beatrix_cifar_t5", |
|
|
) |
|
|
|
|
|
trainer = Trainer(cfg) |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
def main_full(): |
|
|
"""Full 32-tower configuration.""" |
|
|
cfg = FlowConfig( |
|
|
image_size=256, |
|
|
tower_dim=256, |
|
|
tower_depth=2, |
|
|
num_heads=8, |
|
|
geometric_types=('cantor', 'beatrix', 'helix', 'simplex'), |
|
|
conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'), |
|
|
conv_spatial_size=8, |
|
|
bottleneck_dim=256, |
|
|
manifold_dim=1024, |
|
|
batch_size=64, |
|
|
num_epochs=100, |
|
|
cache_dir="./cache", |
|
|
output_dir="./beatrix_cifar_t5", |
|
|
) |
|
|
|
|
|
trainer = Trainer(cfg) |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |