#!/usr/bin/env python3 # n.py — Joint AR+SAT Trainer with Expansion Ratio Testing # Enhanced inference: checkpoint name, tok/s, UK time from __future__ import annotations import argparse, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess from pathlib import Path from contextlib import nullcontext from typing import Dict, Any, List, Optional, Tuple from datetime import datetime, timezone STATUS_SCRIPT_PATH = Path(__file__).resolve() STATUS_DEFAULT_LOG = STATUS_SCRIPT_PATH.parent / "train.log" STATUS_DEFAULT_SAVE_DIR = STATUS_SCRIPT_PATH.parent / "ckpts_expansion" _STATUS_PROGRESS_RE = re.compile( r"^\[(?P\d+(?:\.\d+)?)%\]\s+" r"(?P[\d,]+)/(?P[\d,]+)\s+tok\s+\|\s+" r"(?P[\d.]+)\s+tok/s\s+\|\s+" r"loss=(?P-?[\d.]+)\s+B=(?P\d+)\s+L=(?P\d+)\s*$" ) _STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P\S+?\.pt)\s+\((?P[0-9a-f]+)\.\.\.\)") _STATUS_STEP_RE = re.compile(r"step(?P\d+)") def _status_iso(ts: Optional[float]) -> Optional[str]: if ts is None: return None return datetime.fromtimestamp(ts, tz=timezone.utc).astimezone().isoformat(timespec="seconds") def _status_human_duration(seconds: Optional[float]) -> Optional[str]: if seconds is None: return None total = max(0, int(seconds)) days, rem = divmod(total, 86400) hours, rem = divmod(rem, 3600) minutes, secs = divmod(rem, 60) parts = [] if days: parts.append(f"{days}d") if hours or parts: parts.append(f"{hours}h") if minutes or parts: parts.append(f"{minutes}m") parts.append(f"{secs}s") return " ".join(parts) def _status_format_int(value: Optional[int]) -> str: return "?" if value is None else f"{value:,}" def _status_parse_step(text: str) -> Optional[int]: match = _STATUS_STEP_RE.search(text) return int(match.group("step")) if match else None def _status_resolve_ckpt_path(raw_path: str, base_dir: Path) -> Path: ckpt_path = Path(raw_path) return ckpt_path if ckpt_path.is_absolute() else (base_dir / ckpt_path).resolve() def _status_read_cmdline(proc_dir: Path) -> Optional[List[str]]: try: data = (proc_dir / "cmdline").read_bytes().split(b"\0") return [item.decode("utf-8", errors="ignore") for item in data if item] except Exception: return None def _status_resolve_proc_arg(proc_dir: Path, raw_arg: str) -> Optional[Path]: try: arg_path = Path(raw_arg) if arg_path.is_absolute(): return arg_path.resolve() cwd = Path(os.readlink(proc_dir / "cwd")) return (cwd / arg_path).resolve() except Exception: return None def _status_proc_uptime(proc_dir: Path) -> Optional[float]: try: proc_uptime = float((Path("/proc") / "uptime").read_text().split()[0]) stat_text = (proc_dir / "stat").read_text() after = stat_text[stat_text.rfind(")") + 2:].split() start_ticks = float(after[19]) clock_ticks = os.sysconf(os.sysconf_names["SC_CLK_TCK"]) return max(0.0, proc_uptime - (start_ticks / clock_ticks)) except Exception: return None def _status_find_trainers(script_path: Path) -> List[Dict[str, Any]]: matches: List[Dict[str, Any]] = [] for proc_dir in Path("/proc").iterdir(): if not proc_dir.name.isdigit(): continue args = _status_read_cmdline(proc_dir) if not args or "train" not in args: continue resolved_script = None for arg in args: if Path(arg).name != script_path.name: continue candidate = _status_resolve_proc_arg(proc_dir, arg) if candidate == script_path: resolved_script = candidate break if resolved_script is None: continue uptime_seconds = _status_proc_uptime(proc_dir) try: cwd = str(Path(os.readlink(proc_dir / "cwd"))) except Exception: cwd = None matches.append({ "pid": int(proc_dir.name), "cmdline": " ".join(args), "args": args, "cwd": cwd, "uptime_seconds": round(uptime_seconds, 3) if uptime_seconds is not None else None, "uptime_human": _status_human_duration(uptime_seconds), }) return sorted(matches, key=lambda item: item["pid"]) def _status_parse_progress_line(line: str) -> Optional[Dict[str, Any]]: match = _STATUS_PROGRESS_RE.match(line.strip()) if not match: return None tok_per_sec = float(match.group("tok_s")) loss = float(match.group("loss")) return { "raw_line": line.strip(), "percent": float(match.group("percent")), "seen_tokens": int(match.group("seen").replace(",", "")), "target_tokens": int(match.group("target").replace(",", "")), "tok_per_sec": int(tok_per_sec) if tok_per_sec.is_integer() else tok_per_sec, "loss": loss, "batch": int(match.group("batch")), "block": int(match.group("block")), } def _status_parse_delta_line(line: str) -> Optional[Dict[str, Any]]: match = _STATUS_DELTA_RE.search(line) if not match: return None name = match.group("name") return { "raw_line": line.strip(), "name": name, "step": _status_parse_step(name), "sha_prefix": match.group("sha"), "source": "log", } def _status_scan_log(log_path: Path) -> tuple[Dict[str, Any], Optional[Dict[str, Any]], Optional[Dict[str, Any]], List[str]]: now = time.time() info: Dict[str, Any] = { "path": str(log_path), "exists": log_path.exists(), "mtime": None, "mtime_iso": None, "age_seconds": None, "age_human": None, "size_bytes": None, } warnings: List[str] = [] if not log_path.exists(): warnings.append(f"train log missing: {log_path}") return info, None, None, warnings try: st = log_path.stat() info["mtime"] = st.st_mtime info["mtime_iso"] = _status_iso(st.st_mtime) info["age_seconds"] = round(max(0.0, now - st.st_mtime), 3) info["age_human"] = _status_human_duration(info["age_seconds"]) info["size_bytes"] = st.st_size except Exception as exc: warnings.append(f"failed to stat train log: {exc}") last_progress = None last_delta = None try: with log_path.open("r", encoding="utf-8", errors="ignore") as handle: for raw_line in handle: line = raw_line.rstrip("\n") progress = _status_parse_progress_line(line) if progress is not None: last_progress = progress delta = _status_parse_delta_line(line) if delta is not None: last_delta = delta except Exception as exc: warnings.append(f"failed to read train log: {exc}") return info, last_progress, last_delta, warnings def _status_latest_full_checkpoint(save_dir: Path, base_dir: Path) -> tuple[Dict[str, Any], List[str]]: latest_path = save_dir / "latest.json" info: Dict[str, Any] = { "metadata_path": str(latest_path), "exists": latest_path.exists(), "raw_path": None, "checkpoint_path": None, "checkpoint_name": None, "checkpoint_exists": None, "step": None, "checkpoint_mtime": None, "checkpoint_mtime_iso": None, } warnings: List[str] = [] if not latest_path.exists(): warnings.append(f"latest.json missing: {latest_path}") return info, warnings try: payload = json.loads(latest_path.read_text(encoding="utf-8")) except Exception as exc: warnings.append(f"failed to parse latest.json: {exc}") return info, warnings raw_path = payload.get("path") info["raw_path"] = raw_path info["step"] = payload.get("step") if raw_path: ckpt_path = _status_resolve_ckpt_path(raw_path, base_dir) info["checkpoint_path"] = str(ckpt_path) info["checkpoint_name"] = ckpt_path.name info["checkpoint_exists"] = ckpt_path.exists() if ckpt_path.exists(): try: st = ckpt_path.stat() info["checkpoint_mtime"] = st.st_mtime info["checkpoint_mtime_iso"] = _status_iso(st.st_mtime) except Exception as exc: warnings.append(f"failed to stat full checkpoint: {exc}") else: warnings.append(f"latest.json points to missing checkpoint: {ckpt_path}") return info, warnings def _status_newest_delta(save_dir: Path) -> tuple[Optional[Dict[str, Any]], List[str]]: warnings: List[str] = [] if not save_dir.exists(): warnings.append(f"save dir missing: {save_dir}") return None, warnings try: candidates = [item for item in save_dir.glob("*_delta_step*.pt") if item.is_file()] except Exception as exc: warnings.append(f"failed to list delta checkpoints: {exc}") return None, warnings if not candidates: warnings.append(f"no delta checkpoints found in {save_dir}") return None, warnings newest = max(candidates, key=lambda item: item.stat().st_mtime) st = newest.stat() return { "path": str(newest), "name": newest.name, "step": _status_parse_step(newest.name), "mtime": st.st_mtime, "mtime_iso": _status_iso(st.st_mtime), "size_bytes": st.st_size, "source": "disk", }, warnings def _status_gpu_info() -> tuple[Optional[Dict[str, Any]], List[str]]: warnings: List[str] = [] try: result = subprocess.run( [ "nvidia-smi", "--query-gpu=name,utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw", "--format=csv,noheader,nounits", ], capture_output=True, text=True, timeout=5, check=False, ) except FileNotFoundError: return None, warnings except Exception as exc: warnings.append(f"failed to query GPU status: {exc}") return None, warnings if result.returncode != 0: warnings.append(result.stderr.strip() or "nvidia-smi returned non-zero exit status") return None, warnings lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] if not lines: return None, warnings if len(lines) > 1: warnings.append("multiple GPUs detected; reporting the first GPU only") parts = [part.strip() for part in lines[0].split(",")] if len(parts) != 6: warnings.append(f"unexpected nvidia-smi format: {lines[0]}") return None, warnings def _parse_int(raw: str) -> Optional[int]: try: return int(float(raw)) except Exception: return None def _parse_float(raw: str) -> Optional[float]: try: return float(raw) except Exception: return None return { "name": parts[0], "utilization_gpu": _parse_int(parts[1]), "memory_used_mib": _parse_int(parts[2]), "memory_total_mib": _parse_int(parts[3]), "temperature_c": _parse_int(parts[4]), "power_draw_w": _parse_float(parts[5]), }, warnings def _status_choose_delta(from_log: Optional[Dict[str, Any]], from_disk: Optional[Dict[str, Any]], warnings: List[str]) -> Optional[Dict[str, Any]]: if from_log and from_disk: log_step = from_log.get("step") disk_step = from_disk.get("step") if log_step is not None and disk_step is not None: if log_step != disk_step: warnings.append( f"log delta step {log_step} and newest on-disk delta step {disk_step} differ; using the newer step" ) if disk_step >= log_step: merged = dict(from_disk) merged["source"] = "disk+log" if disk_step == log_step else "disk" if disk_step == log_step: merged["sha_prefix"] = from_log.get("sha_prefix") return merged return dict(from_log) return dict(from_disk) if from_disk: return dict(from_disk) if from_log: return dict(from_log) return None def _collect_status(log_path: Path, save_dir: Path) -> tuple[Dict[str, Any], int]: checked_at = time.time() requested_save_dir = save_dir.expanduser() log_path = log_path.expanduser() status: Dict[str, Any] = { "checked_at": checked_at, "checked_at_iso": _status_iso(checked_at), "running": False, "process": None, "progress": None, "delta_checkpoint": None, "delta_from_log": None, "delta_on_disk": None, "latest_full_checkpoint": None, "log": None, "gpu": None, "save_dir": { "requested_path": str(requested_save_dir), "path": str(requested_save_dir), "exists": requested_save_dir.exists(), "source": "requested", }, "warnings": [], } warnings = status["warnings"] matches = _status_find_trainers(STATUS_SCRIPT_PATH) if len(matches) > 1: status["error"] = "multiple active n.py train processes found" status["processes"] = matches return status, 1 if matches: status["running"] = True status["process"] = matches[0] save_dir = requested_save_dir if status["process"] and status["process"].get("cwd"): proc_cwd = Path(status["process"]["cwd"]) alt_save_dir = (proc_cwd / requested_save_dir.name).resolve() if alt_save_dir != requested_save_dir and alt_save_dir.exists(): requested_delta, _ = _status_newest_delta(requested_save_dir) requested_full, _ = _status_latest_full_checkpoint(requested_save_dir, STATUS_SCRIPT_PATH.parent) alt_delta, _ = _status_newest_delta(alt_save_dir) alt_full, _ = _status_latest_full_checkpoint(alt_save_dir, proc_cwd) requested_score = int(requested_delta is not None) + int(bool(requested_full.get("checkpoint_exists"))) alt_score = int(alt_delta is not None) + int(bool(alt_full.get("checkpoint_exists"))) if alt_score > requested_score: save_dir = alt_save_dir status["save_dir"] = { "requested_path": str(requested_save_dir), "path": str(save_dir), "exists": save_dir.exists(), "source": "process_cwd_fallback", } warnings.append( f"using process cwd save dir fallback: {save_dir} (requested {requested_save_dir})" ) log_info, progress, delta_from_log, log_warnings = _status_scan_log(log_path) warnings.extend(log_warnings) status["log"] = log_info status["progress"] = progress status["delta_from_log"] = delta_from_log latest_base_dir = STATUS_SCRIPT_PATH.parent if status["save_dir"].get("source") == "process_cwd_fallback" and status["process"] and status["process"].get("cwd"): latest_base_dir = Path(status["process"]["cwd"]) latest_full, latest_warnings = _status_latest_full_checkpoint(save_dir, latest_base_dir) warnings.extend(latest_warnings) status["latest_full_checkpoint"] = latest_full delta_on_disk, delta_warnings = _status_newest_delta(save_dir) warnings.extend(delta_warnings) status["delta_on_disk"] = delta_on_disk status["delta_checkpoint"] = _status_choose_delta(delta_from_log, delta_on_disk, warnings) gpu, gpu_warnings = _status_gpu_info() warnings.extend(gpu_warnings) status["gpu"] = gpu if status["running"] and log_info.get("age_seconds") is not None and log_info["age_seconds"] > 600: warnings.append(f"train log appears stale while trainer is running ({log_info['age_human']} old)") if log_info.get("exists") and progress is None: warnings.append("no parseable progress line found in train log") latest_step = latest_full.get("step") if latest_full else None delta_step = status["delta_checkpoint"].get("step") if status["delta_checkpoint"] else None if latest_step is not None and delta_step is not None and latest_step < delta_step: warnings.append(f"latest.json step {latest_step} lags newest delta step {delta_step}") if not status["running"] and progress is None: warnings.append("no active trainer process found") return status, 0 def _format_status_text(status: Dict[str, Any]) -> str: lines = [f"AGILLM status @ {status.get('checked_at_iso')}"] if status.get("error"): lines.append(f"Error: {status['error']}") for proc in status.get("processes", []): lines.append(f"- pid {proc.get('pid')}: {proc.get('cmdline')}") return "\n".join(lines) process = status.get("process") if status.get("running") and process: lines.append(f"Process: RUNNING | pid {process.get('pid')} | uptime {process.get('uptime_human') or 'unknown'}") lines.append(f"Cmd: {process.get('cmdline')}") else: lines.append("Process: NOT RUNNING") progress = status.get("progress") if progress: lines.append( "Progress: " f"{progress['percent']:.1f}% | " f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | " f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | " f"B={progress['batch']} L={progress['block']}" ) else: lines.append("Progress: unavailable") log_info = status.get("log") or {} if log_info.get("exists"): lines.append( f"Log: {log_info.get('path')} | updated {log_info.get('age_human') or 'unknown'} ago | " f"mtime {log_info.get('mtime_iso')}" ) else: lines.append(f"Log: missing ({log_info.get('path')})") delta = status.get("delta_checkpoint") if delta: line = f"Delta: {delta.get('name')} | step {delta.get('step')} | source {delta.get('source')}" if delta.get("path"): line += f" | {delta['path']}" lines.append(line) else: lines.append("Delta: unavailable") latest_full = status.get("latest_full_checkpoint") or {} if latest_full.get("exists"): lines.append( f"Latest full: step {latest_full.get('step')} | {latest_full.get('checkpoint_path') or latest_full.get('raw_path')}" ) else: lines.append(f"Latest full: unavailable ({latest_full.get('metadata_path')})") gpu = status.get("gpu") if gpu: lines.append( f"GPU: {gpu.get('name')} | {gpu.get('utilization_gpu')}% | " f"{gpu.get('memory_used_mib')}/{gpu.get('memory_total_mib')} MiB | " f"{gpu.get('temperature_c')}C | {gpu.get('power_draw_w')} W" ) warnings = status.get("warnings") or [] if warnings: lines.append("Warnings:") lines.extend(f"- {warning}" for warning in warnings) return "\n".join(lines) def _emit_status(log_path: Path, save_dir: Path, as_json: bool) -> int: status, exit_code = _collect_status(log_path, save_dir) if as_json: print(json.dumps(status, indent=2, sort_keys=True)) else: print(_format_status_text(status)) return exit_code def _run_status_command(argv: List[str]) -> int: parser = argparse.ArgumentParser(prog=f"{STATUS_SCRIPT_PATH.name} status", description="Read-only training status") parser.add_argument("--json", dest="json_output", action="store_true", help="Emit machine-readable JSON") parser.add_argument("--log", type=Path, default=STATUS_DEFAULT_LOG, help="Path to the training log") parser.add_argument("--save_dir", type=Path, default=STATUS_DEFAULT_SAVE_DIR, help="Checkpoint directory") args = parser.parse_args(argv) return _emit_status(args.log, args.save_dir, args.json_output) def _maybe_handle_status_fastpath() -> None: if len(sys.argv) > 1 and sys.argv[1] == "status": raise SystemExit(_run_status_command(sys.argv[2:])) _maybe_handle_status_fastpath() import torch # SafeProgress - Claude-safe progress (discrete lines, not single growing line) class SafeProgress: def __init__(self, total, initial=0, unit="tok", print_every=500): self.total, self.n, self.unit = total, initial, unit self.initial = initial self.last_print, self.postfix = initial, {} self.start_time = __import__('time').time() def update(self, n=1): self.n += n if self.n - self.last_print >= 1000000: # print every ~1M tokens self._print(); self.last_print = self.n def set_postfix(self, **kwargs): self.postfix = kwargs def _print(self): elapsed = __import__('time').time() - self.start_time rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0 pct = 100 * self.n / self.total if self.total > 0 else 0 pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items()) print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}") def close(self): self._print(); print("Done.") import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset, DownloadConfig from transformers import AutoTokenizer, logging as hf_log # from tqdm.auto import tqdm # DISABLED - kills Claude context # ─────────────────────────────── HOT DATASET LOADING ─────────────────────────────── HOT_CONFIG_PATH = Path("/workspace/hot_config.json") _hot_config_cache = {"mtime": 0, "data": {}} def get_hot_config() -> dict: """Load hot_config.json with caching, return empty dict if missing""" try: if HOT_CONFIG_PATH.exists(): mtime = HOT_CONFIG_PATH.stat().st_mtime if mtime > _hot_config_cache["mtime"]: with open(HOT_CONFIG_PATH) as f: _hot_config_cache["data"] = json.load(f) _hot_config_cache["mtime"] = mtime return _hot_config_cache["data"] except Exception as e: print(f"[hot_config] Error loading: {e}") return {} def get_hot_datasets(default_sources: str) -> str: """Get datasets from hot_config if present, else use default""" cfg = get_hot_config() if "datasets" in cfg and cfg["datasets"]: hot_ds = cfg["datasets"] if isinstance(hot_ds, list): hot_ds = ",".join(hot_ds) print(f"[hot_config] Using hot datasets: {hot_ds}") return hot_ds return default_sources # DISABLED: # Auto-rotating log to prevent context-window suicide # DISABLED: try: # DISABLED: from rotating_log import install_rotating_log # DISABLED: install_rotating_log() # DISABLED: except ImportError: # pass # Running without rotation # ───────────────────────── ANSI Colors ───────────────────────── class Colors: RESET = "\033[0m" BOLD = "\033[1m" PROMPT = "\033[36m" GEN = "\033[0m" INFO = "\033[90m" WARN = "\033[93m" # ───────────────────────── Globals ───────────────────────── hf_log.set_verbosity_error() DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except Exception: pass TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2") tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True) if tok.pad_token is None: tok.add_special_tokens({"pad_token": "<|pad|>"}) # ─── Fix tokenizer Ġ/▁ mismatch ─── # The DeepSeek-V3.2 vocab uses Ġ (U+0120) for space-prefixed tokens, # but some transformers versions set the Metaspace pre-tokenizer to use # ▁ (U+2581) instead, causing encode/decode to lose all spaces. def _fix_tokenizer_space_mismatch(tokenizer): try: import json as _json from tokenizers import Tokenizer as _Tokenizer bt = tokenizer.backend_tokenizer tj = _json.loads(bt.to_str()) pre = tj.get("pre_tokenizer", {}) needs_fix = (pre.get("type") == "Metaspace" and pre.get("replacement") == "\u2581") if not needs_fix: return # Check if vocab actually uses Ġ (U+0120) for spaces vocab = tj.get("model", {}).get("vocab", {}) has_gpt2_space = any(k.startswith("\u0120") for k in list(vocab.keys())[:500]) if not has_gpt2_space: return # Patch pre_tokenizer: ▁ -> Ġ tj["pre_tokenizer"]["replacement"] = "\u0120" # Patch decoder: ▁ -> Ġ in Replace step for step in tj.get("decoder", {}).get("decoders", []): if step.get("type") == "Replace": pat = step.get("pattern", {}) if pat.get("String") == "\u2581": pat["String"] = "\u0120" # Rebuild backend tokenizer fixed = _Tokenizer.from_str(_json.dumps(tj)) tokenizer.backend_tokenizer = fixed # Verify fix test_ids = tokenizer.encode("hello world") test_dec = tokenizer.decode(test_ids, skip_special_tokens=True) if "hello world" in test_dec: print("[tokenizer] Fixed Ġ/▁ space mismatch") else: print(f"[tokenizer] WARNING: fix applied but decode test failed: {repr(test_dec)}") except Exception as e: print(f"[tokenizer] Could not fix space mismatch: {e}") _fix_tokenizer_space_mismatch(tok) # ─── Tokenizer startup health check ─── # Abort early if tokenizer can't roundtrip spaces — prevents silent data corruption def _tokenizer_health_check(tokenizer): import transformers as _tf ver = _tf.__version__ print(f"[tokenizer] transformers={ver}, tokenizers={__import__('tokenizers').__version__}") # Warn on known-bad versions try: from packaging.version import Version if Version(ver) >= Version('5.0.0'): print(f'[tokenizer] WARNING: transformers {ver} may have Metaspace bug — verify carefully') except ImportError: pass # Roundtrip tests — must preserve spaces tests = [ 'Water boils at one hundred degrees', 'The quick brown fox jumps over the lazy dog', 'Hello world! This is a test sentence with spaces.', ] for text in tests: ids = tokenizer.encode(text) decoded = tokenizer.decode(ids, skip_special_tokens=True) if ' ' not in decoded: print(f'[tokenizer] FATAL: Roundtrip lost all spaces!') print(f' Input: {repr(text)}') print(f' Encoded: {ids[:20]}...') print(f' Decoded: {repr(decoded)}') print(f'[tokenizer] ABORTING — fix tokenizer before training!') sys.exit(1) # Check decoded is reasonably close to input if text.lower().split()[:3] != decoded.lower().split()[:3]: print(f'[tokenizer] WARNING: Roundtrip diverged:') print(f' Input: {repr(text[:60])}') print(f' Decoded: {repr(decoded[:60])}') print(f'[tokenizer] Health check PASSED — spaces preserved in roundtrip') _tokenizer_health_check(tok) VOCAB, EOS = ( max(tok.get_vocab().values()) + 1, tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id ) # ───────────────────────── PRESETS ───────────────────────── PRESETS: Dict[str, Dict[str, int]] = { "femto_1x": dict(d=16, layers=1, heads=1, rank=16), "femto_12x": dict(d=16, layers=1, heads=1, rank=192), "femto_24x": dict(d=16, layers=1, heads=1, rank=384), "pico_1x": dict(d=32, layers=1, heads=2, rank=16), "pico_3x": dict(d=32, layers=1, heads=2, rank=48), "pico_6x": dict(d=32, layers=1, heads=2, rank=96), "pico_12x": dict(d=32, layers=1, heads=2, rank=192), "pico_24x": dict(d=32, layers=1, heads=2, rank=384), "pico_48x": dict(d=32, layers=1, heads=2, rank=768), "nano_1x": dict(d=64, layers=2, heads=4, rank=16), "nano_3x": dict(d=64, layers=2, heads=4, rank=48), "nano_6x": dict(d=64, layers=2, heads=4, rank=96), "nano_12x": dict(d=64, layers=2, heads=4, rank=192), "nano_24x": dict(d=64, layers=2, heads=4, rank=384), "nano_48x": dict(d=64, layers=2, heads=4, rank=768), "nano_96x": dict(d=64, layers=2, heads=4, rank=1536), "micro_3x": dict(d=128, layers=4, heads=8, rank=48), "micro_6x": dict(d=128, layers=4, heads=8, rank=96), "micro_12x": dict(d=128, layers=4, heads=8, rank=192), "micro_24x": dict(d=128, layers=4, heads=8, rank=384), "small": dict(d=512, layers=8, heads=16, rank=64), "smallx2": dict(d=512, layers=16, heads=16, rank=64), "base": dict(d=768, layers=12, heads=24, rank=96), "base18": dict(d=768, layers=18, heads=24, rank=96), "large": dict(d=1024, layers=24, heads=16, rank=128), } DEFAULT_BLOCK = 1122 DEFAULT_BATCH = 4 SAT_BLOCK = 2 LR_CORE, LR_HEAD = 5e-5, 2e-4 EMIT_LAMBDA = 0.1 DEFAULT_SAVE_SEC = 24 * 3600 DEFAULT_DELTA_STEPS = 500 # lightweight weight-only save every N steps DEFAULT_MAX_DELTAS = 5 # keep last N deltas (older pruned after full save) CKDIR = pathlib.Path("ckpts_expansion") DEFAULT_PRETRAIN_SOURCES = "OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1,HuggingFaceFW/fineweb,wikimedia/wikipedia:20231101.en,allenai/c4:en" DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k" DEFAULT_AFTER_SFT_BLOCK = 1122 # ───────────────────────── UK Time Helper ───────────────────────── def get_uk_time() -> str: utc_now = datetime.now(timezone.utc) year = utc_now.year march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc) while march_last.weekday() != 6: march_last = march_last.replace(day=march_last.day - 1) oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc) while oct_last.weekday() != 6: oct_last = oct_last.replace(day=oct_last.day - 1) if march_last <= utc_now < oct_last: uk_offset = 1 tz_name = "BST" else: uk_offset = 0 tz_name = "GMT" from datetime import timedelta uk_time = utc_now + timedelta(hours=uk_offset) return uk_time.strftime(f'%Y-%m-%d %H:%M:%S {tz_name}') # ───────────────────────── Utilities ───────────────────────── def rng_state(): if DEV.type == "cuda": try: return torch.cuda.get_rng_state(DEV) except TypeError: return torch.cuda.get_rng_state() return torch.get_rng_state() def _is_probably_ckpt(path: pathlib.Path) -> bool: try: return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20) except Exception: return False def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None: try: if path.is_dir(): cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime, reverse=True) return cands[0] if cands else None if path.suffix == ".tmp": solid = path.with_suffix("") return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent) return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent) except Exception: return None def _try_load(path: pathlib.Path, map_location="cpu"): try: return torch.load(path, map_location="cpu") except Exception as e: print(f"[ckpt-skip] {path} not usable: {e}") return None def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int): if max_ckpts is None or max_ckpts <= 0: return try: pattern = f"{phase_name}_step*.pt" ckpts = sorted( [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime ) excess = len(ckpts) - max_ckpts if excess > 0: for p in ckpts[:excess]: try: p.unlink() print(f" [prune] deleted old {p.name}") except Exception: pass except Exception as e: print(f"[ckpt-prune] error: {e}") def print_expansion_info(cfg: dict, tie_weights: bool = False): d_k = cfg["d"] // cfg["heads"] rank = cfg["rank"] ratio = rank / d_k regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION") tie_str = "YES" if tie_weights else "NO" print(f"┌─────────────────────────────────────────┐") print(f"│ TUNEABLE ATTENTION CONFIG │") print(f"├─────────────────────────────────────────┤") print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │") print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │") print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │") print(f"└─────────────────────────────────────────┘") # ───────────────────────── AMP helper ───────────────────────── try: from torch.amp import autocast as _ac, GradScaler except ImportError: from torch.cuda.amp import autocast as _ac, GradScaler def _auto_amp_dtype(): if DEV.type == "cuda": try: if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 except Exception: return torch.float16 return torch.float32 def amp(enabled: bool): return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype()) # ───────────────────────── Chat & Data Stream ───────────────────────── def _coerce_role(r: str) -> str: r = (r or "").lower() if r in {"user", "human", "customer"}: return "user" if r in {"assistant", "gpt", "bot"}: return "assistant" if r in {"system", "context"}: return "system" return r or "user" def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]: msgs = ex.get(messages_key) if msgs is None: for alt in ("conversations", "dialog", "turns"): if isinstance(ex.get(alt), list): msgs = ex[alt]; break if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict): try: norm = [] for m in msgs: role = _coerce_role(m.get("role", "")); content = m.get("content", m.get("text", "")) if not isinstance(content, str): continue norm.append({"role": role, "content": content}) if not norm: return None return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt) except Exception: return None for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")): if isinstance(ex.get(a), str) and isinstance(ex.get(b), str): return f"User: {ex[a]}\nAssistant: {ex[b]}" return None def _open_stream_one(ds_name: str, seed: int, streaming: bool = True): dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True) if ":" in ds_name: base, config = ds_name.split(":", 1) else: base, config = ds_name, None if not streaming: print(f"[download] Downloading {ds_name} (non-streaming)...") if base == "json": data_files = {"train": config} ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc) else: ds = load_dataset(base, config, split="train", streaming=streaming, download_config=dc) if config else \ load_dataset(base, split="train", streaming=streaming, download_config=dc) if streaming: return iter(ds.shuffle(buffer_size=1000, seed=seed)) else: print(f"[download] Got {len(ds):,} examples. Shuffling...") ds = ds.shuffle(seed=seed) return iter(ds) def token_stream(ds_names: str, target: int, seed: int = 42, chat: bool = False, chat_messages_key: str = "messages", sft_add_generation_prompt: bool = False, dataset_field_text: str = "text", streaming: bool = True): ds_names = get_hot_datasets(ds_names) # HOT LOAD sources = [s.strip() for s in ds_names.split(",") if s.strip()] if not sources: return src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0 while emitted < target: try: if it is None: it = _open_stream_one(sources[src_idx], seed, streaming=streaming) ex = next(it) text = None if isinstance(ex, dict): if chat: text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt) if text is None: if dataset_field_text and isinstance(ex.get(dataset_field_text), str): text = ex[dataset_field_text] elif isinstance(ex.get("text"), str): text = ex["text"] if not isinstance(text, str): attempts = 0; continue enc = tok.encode(text) if EOS is not None and (len(enc) == 0 or enc[-1] != EOS): enc = enc + [EOS] for t in enc: yield t emitted += 1 if emitted >= target: return attempts = 0 except StopIteration: it = None; src_idx = (src_idx + 1) % len(sources) except Exception as e: attempts += 1 sleep_s = min(60.0, backoff_base ** min(attempts, 6)) print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s") time.sleep(sleep_s); it = None if attempts % 5 == 0 and len(sources) > 1: src_idx = (src_idx + 1) % len(sources) # ───────────────────────── ALiBi ───────────────────────── def _alibi_slopes(n_heads: int): def pow2slopes(n): start = 2 ** (-2 ** -(math.log2(n) - 3)) ratio = start return [start * (ratio ** i) for i in range(n)] if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads) else: closest = 2 ** math.floor(math.log2(n_heads)) vals = pow2slopes(closest) extra = pow2slopes(2 * closest) vals += extra[0::2][: n_heads - closest] return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) def alibi_bias(n_heads: int, n_tokens: int): i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) dist = (j - i).clamp_min(0) return -_alibi_slopes(n_heads) * dist # ───────────────────────── Model components ───────────────────────── class TuneableAttentionMHA(nn.Module): def __init__(self, d: int, h: int, r: int, use_relpos: bool = True): super().__init__() assert d % h == 0 self.h, self.dk, self.r = h, d // h, r self.use_relpos = use_relpos self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, d, bias=False) self.v = nn.Linear(d, d, bias=False) self.U = nn.Parameter(torch.randn(self.dk, r)) nn.init.orthogonal_(self.U) self.proj = nn.Linear(h * self.dk, d, bias=False) self.drop = nn.Dropout(0.1) def _proj_qk(self, x): B, N, _ = x.shape return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) def _reshape_v(self, x): B, N, _ = x.shape return x.view(B, N, self.h, self.dk).transpose(1, 2) def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False): q = self._proj_qk(self.q(x)) k_new = self._proj_qk(self.k(x)) v_new = self._reshape_v(self.v(x)) if kv_cache is None: k, v = k_new, v_new else: k_cached, v_cached = kv_cache if use_cache: k = torch.cat([k_cached, k_new], dim=2) v = torch.cat([v_cached, v_new], dim=2) else: k, v = k_new, v_new att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if self.use_relpos and rel_bias_tokens is not None: att = att + alibi_bias(self.h, rel_bias_tokens)[:, :, -q.size(2):, :] if mask is not None: att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1) out = self.drop(self.proj(z)) return (out, (k, v)) if use_cache else out class Block(nn.Module): def __init__(self, d: int, h: int, r: int): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) self.mha = TuneableAttentionMHA(d, h, r) self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None): if use_cache: y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True) x = x + y + self.ff(self.ln2(x + y)) return x, new_kv else: n = x.size(1) x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n) return x + self.ff(self.ln2(x)) class Encoder(nn.Module): def __init__(self, cfg, tie_weights: bool = False): super().__init__() d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] self.emb = nn.Embedding(VOCAB, d) self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)]) self.ln = nn.LayerNorm(d) self.tie_weights = tie_weights def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None): x = self.emb(ids) if not use_cache: for blk in self.blocks: x = blk(x, mask) return self.ln(x) new_kvs = [] for i, blk in enumerate(self.blocks): kv = kv_caches[i] if kv_caches else None x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len) new_kvs.append(kv_out) return self.ln(x), new_kvs class ARHead(nn.Module): def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None): super().__init__() self.tie_weights = tie_weights if tie_weights and embedding_weight is not None: self.proj = nn.Linear(d, VOCAB, bias=False) self.proj.weight = embedding_weight else: self.proj = nn.Linear(d, VOCAB) def forward(self, h): return self.proj(h) class SATHead(nn.Module): def __init__(self, d, mode="var"): super().__init__() self.proj = nn.Linear(d, VOCAB) self.gate = nn.Linear(d, 2) if mode == "var" else None def forward(self, h_last): return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None) # ───────────────────────── Masks ───────────────────────── def causal_mask(n): return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) def sat_mask(n, block=SAT_BLOCK): idx = torch.arange(n, device=DEV) grp = idx.unsqueeze(0) // block allow = (grp.T == grp) | (grp.T > grp) return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0) def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK): total_len = cached_len + new_len mask = torch.zeros((1, 1, new_len, total_len), device=DEV) return mask # ───────────────────────── Checkpoint helpers ───────────────────────── # ───────────────────────── Delta Checkpoints (weight-only, async) ───────────────────────── _delta_lock = threading.Lock() _delta_thread: Optional[threading.Thread] = None def _sha256_file(path: pathlib.Path) -> str: """Compute SHA256 of a file for integrity verification.""" h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1 << 20), b""): h.update(chunk) return h.hexdigest() def _do_delta_save(tensors: dict, path: pathlib.Path, meta: dict): """Background worker: write weight-only checkpoint + checksum.""" try: path.parent.mkdir(exist_ok=True, parents=True) tmp = path.with_suffix(path.suffix + ".dtmp") torch.save({"weights": tensors, **meta}, tmp, _use_new_zipfile_serialization=False) digest = _sha256_file(tmp) tmp.replace(path) # Write sidecar checksum path.with_suffix(".sha256").write_text(f"{digest} {path.name}\n") print(f" [delta] saved {path.name} ({digest[:12]}...)") except Exception as e: print(f" [delta] FAILED {path.name}: {e}") def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str): """Save weight-only delta in background thread. Non-blocking.""" global _delta_thread # Wait for any previous delta write to finish if _delta_thread is not None and _delta_thread.is_alive(): _delta_thread.join(timeout=60) # Snapshot weights to CPU (detach from GPU graph) with _delta_lock: tensors = { "core": {k: v.detach().cpu() for k, v in core.state_dict().items()}, "ar": {k: v.detach().cpu() for k, v in ar_h.state_dict().items()}, "sat": {k: v.detach().cpu() for k, v in sat_h.state_dict().items()}, } meta = {"step": step, "seen_tok": seen_tok, "wall_time": time.time(), "delta": True} path = save_dir / f"{phase_name}_delta_step{step:08d}.pt" _delta_thread = threading.Thread(target=_do_delta_save, args=(tensors, path, meta), daemon=True) _delta_thread.start() def _prune_deltas(save_dir: pathlib.Path, phase_name: str, max_deltas: int): """Keep only the most recent max_deltas delta files.""" if max_deltas is None or max_deltas <= 0: return try: pattern = f"{phase_name}_delta_step*.pt" deltas = sorted( [p for p in save_dir.glob(pattern) if p.stat().st_size > 0], key=lambda p: p.stat().st_mtime ) excess = len(deltas) - max_deltas if excess > 0: for p in deltas[:excess]: try: p.unlink() sha = p.with_suffix(".sha256") if sha.exists(): sha.unlink() print(f" [delta-prune] deleted {p.name}") except Exception: pass except Exception as e: print(f" [delta-prune] error: {e}") def load_delta(path: pathlib.Path, core, ar_h, sat_h): """Load weight-only delta. Returns (step, seen_tok) or raises.""" # Verify checksum if sidecar exists sha_path = path.with_suffix(".sha256") if sha_path.exists(): expected = sha_path.read_text().split()[0] actual = _sha256_file(path) if expected != actual: raise ValueError(f"Checksum mismatch for {path.name}: expected {expected[:12]}... got {actual[:12]}...") print(f" [delta] checksum OK for {path.name}") ck = torch.load(path, map_location="cpu", weights_only=False) if not ck.get("delta"): raise ValueError(f"{path.name} is not a delta checkpoint") core.load_state_dict(ck["weights"]["core"]) ar_h.load_state_dict(ck["weights"]["ar"]) sat_h.load_state_dict(ck["weights"]["sat"]) return ck.get("step", 0), ck.get("seen_tok", 0) def _flush_delta(): """Wait for any in-flight delta save to complete.""" global _delta_thread if _delta_thread is not None and _delta_thread.is_alive(): print(" [delta] flushing in-flight write...") _delta_thread.join(timeout=120) def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta): path.parent.mkdir(exist_ok=True, parents=True) tmp = path.with_suffix(path.suffix + ".tmp") state = { "core": core.state_dict(), "ar": ar_h.state_dict(), "sat": sat_h.state_dict(), "opt": opt.state_dict(), "scaler": scaler.state_dict(), "cfg": meta.get("cfg"), "tokenizer_id": TOKENIZER_ID, "tokenizer_json": tok.backend_tokenizer.to_str(), "transformers_version": __import__("transformers").__version__, "tokenizers_version": __import__("tokenizers").__version__, "tie_weights": meta.get("tie_weights", False), **{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")} } torch.save(state, tmp, _use_new_zipfile_serialization=False) tmp.replace(path) (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]})) print(f"\n✓ saved checkpoint {path.name}") def load_ckpt(path, core, ar_h, sat_h, opt, scaler): p = _resolve_ckpt(path) or path ck = _try_load(p, map_location="cpu") if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}") core.load_state_dict(ck["core"]) ar_h.load_state_dict(ck["ar"]) sat_h.load_state_dict(ck["sat"]) opt.load_state_dict(ck["opt"]) scaler.load_state_dict(ck["scaler"]) # Restore tokenizer from checkpoint if available if "tokenizer_json" in ck: try: from tokenizers import Tokenizer as _Tokenizer tok.backend_tokenizer = _Tokenizer.from_str(ck["tokenizer_json"]) print("[tokenizer] Restored from checkpoint") except Exception as e: print(f"[tokenizer] WARNING: could not restore from checkpoint: {e}") # Warn if transformers version changed since checkpoint was saved if "transformers_version" in ck: import transformers as _tf if ck["transformers_version"] != _tf.__version__: print(f"[tokenizer] WARNING: checkpoint saved with transformers={ck['transformers_version']}, now running {_tf.__version__}") return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()) def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None): p = _resolve_ckpt(path) or path if not p.exists(): return 0 ck = _try_load(p, map_location="cpu") if ck is None: return 0 sd = ck.get(key, ck) if key else ck if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] tgt_sd = tgt.state_dict() filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape} if filt: tgt.load_state_dict(filt, strict=False) return len(filt) def infer_cfg_from_ckpt(path: pathlib.Path): p = _resolve_ckpt(path) or path if not p.exists(): return None sd = _try_load(p, map_location="cpu") if sd is None: return None if "cfg" in sd: return dict(sd["cfg"]) return None # ───────────────────────── Training Logic ───────────────────────── def _parse_grow_plan(s: str) -> List[int]: return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128])) def _count_enabled_params(*modules) -> int: seen_data_ptrs = set() total = 0 for m in modules: if m is None: continue for p in m.parameters(): if p.data_ptr() not in seen_data_ptrs: seen_data_ptrs.add(p.data_ptr()) total += p.numel() return total def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool): for p in core.parameters(): p.requires_grad = not freeze_core if freeze_core: if unfreeze_ln: for blk in core.blocks: for p in blk.ln1.parameters(): p.requires_grad = True for p in blk.ln2.parameters(): p.requires_grad = True for p in core.ln.parameters(): p.requires_grad = True if train_emb: for p in core.emb.parameters(): p.requires_grad = True def _train_phase( args, phase_name: str, core, ar_h, sat_h, opt, scaler, start_step, seen_tok, resume_wall_time, cfg, source, steps, block_size, batch_size, chat_cfg: dict, max_ckpts: int, target_tokens_override: Optional[int] = None, tie_weights: bool = False, streaming: bool = True ): BLOCK = block_size BATCH = batch_size if target_tokens_override is not None: target_tokens = target_tokens_override else: ratio = 51.2 if args.chilla_max_double else 25 param_count = _count_enabled_params(core, ar_h, sat_h) target_tokens = int(ratio * param_count) if steps: phase_target_tokens = steps * BLOCK * BATCH total_tokens_needed = seen_tok + phase_target_tokens else: total_tokens_needed = target_tokens if total_tokens_needed <= seen_tok: print(f"[{phase_name}] target {total_tokens_needed} already reached.") return start_step, seen_tok, resume_wall_time stream = token_stream( source, total_tokens_needed, seed=42, chat=chat_cfg.get("chat", False), chat_messages_key=chat_cfg.get("key", "messages"), sft_add_generation_prompt=chat_cfg.get("gen_prompt", False), dataset_field_text=chat_cfg.get("text_field", "text"), streaming=streaming ) ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1) ce_gate = nn.CrossEntropyLoss() pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok") grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else [] buf: list[int] = [] batch_accum: list[list[int]] = [] step = start_step steps_since_last_grow = 0 oom_retries = 0 MAX_OOM_RETRIES = 2 now_wall = time.time() last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall)) last_delta_step = start_step print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}") print(f"[{phase_name}] AR_ONLY={args.ar_only}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}") while seen_tok < total_tokens_needed: try: while len(buf) < BLOCK: buf.append(next(stream)) except StopIteration: break seq = buf[:BLOCK] buf = buf[BLOCK:] batch_accum.append(seq) if len(batch_accum) < BATCH: continue ids = torch.tensor(batch_accum, device=DEV) batch_accum = [] tgt_ar = ids.clone() try: with amp(args.amp): h_ar = core(ids, causal_mask(ids.size(1))) logits_ar = ar_h(h_ar)[:, :-1] loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1)) if args.ar_only: loss = loss_ar else: h_sat = core(ids, sat_mask(ids.size(1))) logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:]) tgt_sat = ids[:, 1:SAT_BLOCK+1] loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1)) if gate is not None: loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long)) loss = loss_ar + loss_sat scaler.scale(loss).backward() scaler.unscale_(opt) nn.utils.clip_grad_norm_(core.parameters(), 1.0) scaler.step(opt) scaler.update() opt.zero_grad(set_to_none=True) except RuntimeError as e: msg = str(e).lower() if "out of memory" in msg or "cuda error" in msg: batch_accum = [] opt.zero_grad(set_to_none=True) if DEV.type == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() oom_retries += 1 if oom_retries <= MAX_OOM_RETRIES: print(f"\n[{phase_name} OOM] Retry {oom_retries}/{MAX_OOM_RETRIES} at Batch={BATCH}, clearing VRAM...") time.sleep(2) continue oom_retries = 0 if BATCH > 1: print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1} (after {MAX_OOM_RETRIES} retries)") BATCH -= 1 time.sleep(2) else: new_block = max(128, BLOCK // 2) print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}") BLOCK = new_block time.sleep(2) steps_since_last_grow = 0 continue raise step += 1 # Periodic tokenizer spot-check: verify training data has spaces if step % 1000 == 0: try: sample_text = tok.decode(ids[0][:50].tolist(), skip_special_tokens=True) if len(sample_text) > 20 and " " not in sample_text: print(f"\n[tokenizer] ALERT step {step}: decoded batch has NO SPACES!") print(f" Sample: {repr(sample_text[:80])}") print(" Check transformers version!") except Exception: pass oom_retries = 0 toks_processed = BLOCK * BATCH seen_tok += toks_processed pbar.update(toks_processed) pbar.set_postfix(loss=f"{loss.item():.3f}", B=BATCH, L=BLOCK) if args.save_every_sec > 0: now_mono = time.monotonic() if now_mono - last_save_mono >= args.save_every_sec: ck_name = f"{phase_name}_step{step:08d}.pt" _flush_delta() # wait for any in-flight delta before full save _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts) save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) last_save_mono = now_mono # Prune old deltas after a full save (they're superseded) _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep) last_delta_step = step # reset delta counter after full save # ── Delta checkpoint (step-based, weight-only, async) ── if args.delta_every_steps > 0 and (step - last_delta_step) >= args.delta_every_steps: _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep) save_delta(core, ar_h, sat_h, step, seen_tok, pathlib.Path(args.save_dir), phase_name) last_delta_step = step if args.auto_grow: steps_since_last_grow += 1 if steps_since_last_grow >= args.grow_every_steps: steps_since_last_grow = 0 try: idx = grow_plan.index(BLOCK) if idx + 1 < len(grow_plan): BLOCK = grow_plan[idx + 1] print(f"[{phase_name} Grow] Block -> {BLOCK}") if DEV.type == "cuda": torch.cuda.empty_cache() except ValueError: grow_plan = sorted(set(grow_plan + [BLOCK])) pbar.close() _flush_delta() # ensure any in-flight delta completes before final save save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) return step, seen_tok, time.time() # ───────────────────────── Main Orchestrator ───────────────────────── def train(args): cfg = PRESETS[args.preset].copy() tie_weights = args.tie_weights print_expansion_info(cfg, tie_weights) if not args.fresh: src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt" prev_cfg = infer_cfg_from_ckpt(src_probe) else: prev_cfg = None if prev_cfg: cfg.update({k: v for k, v in prev_cfg.items() if k in cfg}) if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2) if args.rank: cfg["rank"] = args.rank if args.x2 and not prev_cfg: cfg["layers"] *= 2 print(f"Config: {cfg}") core = Encoder(cfg, tie_weights=tie_weights).to(DEV) ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) sat_h = SATHead(cfg["d"], mode="var").to(DEV) total_params = _count_enabled_params(core, ar_h, sat_h) print(f"Total parameters: {total_params:,}") if tie_weights: print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}") if not args.fresh: src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt" src = _resolve_ckpt(src) if src: loaded = _safe_load_any(src, core, key="core") _safe_load_any(src, ar_h, key="ar") _safe_load_any(src, sat_h, key="sat") if loaded: print(f"Warm-start loaded from {src}") _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb) opt = torch.optim.AdamW([ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core}, {"params": ar_h.parameters(), "lr": args.lr_head}, {"params": sat_h.parameters(), "lr": args.lr_head}, ]) scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda")) start_step, seen_tok, last_wall = 0, 0, None if args.resume_delta and not args.fresh: delta_step, delta_tok = load_delta(pathlib.Path(args.resume_delta), core, ar_h, sat_h) start_step, seen_tok, last_wall = delta_step, delta_tok, None print(f"Resumed from DELTA at step {start_step} (optimizer state reset — momentum rebuilds in ~100 steps)") elif args.resume and not args.fresh: start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler) print(f"Resumed from step {start_step}") # torch.compile AFTER loading checkpoint (key names differ) if args.compile: print("[torch.compile] Compiling model...") core = torch.compile(core, mode="reduce-overhead") ar_h = torch.compile(ar_h, mode="reduce-overhead") sat_h = torch.compile(sat_h, mode="reduce-overhead") print("[torch.compile] Done.") step, seen_tok, last_wall = _train_phase( args, "pretrain", core, ar_h, sat_h, opt, scaler, start_step, seen_tok, last_wall, cfg, args.source, args.steps, args.block or DEFAULT_BLOCK, args.batch_size or DEFAULT_BATCH, chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text}, max_ckpts=args.max_ckpts, target_tokens_override=args.target_tokens, tie_weights=tie_weights ) if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0): args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES args.after_sft_chat = True if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0: print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...") _phase_freeze(core, freeze_core=args.after_sft_freeze_core, unfreeze_ln=args.after_sft_unfreeze_ln, train_emb=args.after_sft_train_emb) opt = torch.optim.AdamW([ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core}, {"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head}, {"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head}, ]) step, seen_tok, last_wall = _train_phase( args, "sft", core, ar_h, sat_h, opt, scaler, step, seen_tok, last_wall, cfg, args.after_sft_source, args.after_sft_steps, args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK, args.batch_size or DEFAULT_BATCH, chat_cfg={ "chat": args.after_sft_chat, "key": args.after_sft_chat_messages_key, "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt, "text_field": args.after_sft_dataset_field_text }, max_ckpts=args.max_ckpts, target_tokens_override=None, tie_weights=tie_weights, streaming=False ) save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) print("🎉 All Training Complete") # ───────────────────────── Sampling ───────────────────────── def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p): if ids.numel() == 0: return logits hist = ids[0, -n:].long() if n > 0 else ids[0].long() uniq, counts = torch.unique(hist, return_counts=True) if pres_p or freq_p: logits[..., uniq] -= (pres_p + freq_p * counts.float()) if rep_p != 1.0: sel = logits[..., uniq] logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p) return logits def _sample(logits, T, top_k, top_p, min_p, greedy): if greedy: return logits.argmax(-1, keepdim=True) probs = (logits / max(T, 1e-8)).softmax(-1) if top_k: v, i = torch.topk(probs, min(top_k, probs.size(-1))) probs = torch.zeros_like(probs).scatter_(-1, i, v) if top_p < 1.0: s_probs, s_idx = torch.sort(probs, descending=True, dim=-1) probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float()) if min_p > 0: probs[probs < min_p] = 0 if probs.sum() == 0: return logits.argmax(-1, keepdim=True) return probs.div_(probs.sum()).multinomial(1) @torch.no_grad() def infer(args): if args.mode == "ar": if args.temperature is None: args.temperature = 0.7 if args.top_k is None: args.top_k = 0 if args.repetition_penalty is None: args.repetition_penalty = 1.3 if args.presence_penalty is None: args.presence_penalty = 0.0 if args.frequency_penalty is None: args.frequency_penalty = 0.3 if args.penalty_last_n is None: args.penalty_last_n = 128 if args.var is None: args.var = False else: if args.temperature is None: args.temperature = 0.5 if args.top_k is None: args.top_k = 30 if args.repetition_penalty is None: args.repetition_penalty = 2.0 if args.presence_penalty is None: args.presence_penalty = 0.6 if args.frequency_penalty is None: args.frequency_penalty = 1.0 if args.penalty_last_n is None: args.penalty_last_n = 200 if args.var is None: args.var = True path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt) sd = torch.load(path, map_location="cpu") # Restore tokenizer from checkpoint if available if "tokenizer_json" in sd: try: from tokenizers import Tokenizer as _Tokenizer tok.backend_tokenizer = _Tokenizer.from_str(sd["tokenizer_json"]) print("[tokenizer] Restored from checkpoint") except Exception as e: print(f"[tokenizer] WARNING: could not restore from checkpoint: {e}") # Warn if transformers version changed since checkpoint was saved if "transformers_version" in sd: import transformers as _tf if sd["transformers_version"] != _tf.__version__: print(f"[tokenizer] WARNING: checkpoint saved with transformers={sd['transformers_version']}, now running {_tf.__version__}") # Handle delta checkpoints (weight-only, no cfg) if sd.get("delta"): print("[infer] Delta checkpoint detected, using large preset cfg") cfg = PRESETS["large"].copy() tie_weights = False # Remap: delta stores under sd["weights"]["core"/"ar"/"sat"] sd["core"] = sd["weights"]["core"] sd["ar"] = sd["weights"]["ar"] sd["sat"] = sd["weights"]["sat"] else: cfg = sd["cfg"] tie_weights = sd.get("tie_weights", False) uk_time = get_uk_time() ckpt_name = path.name print(f"┌─────────────────────────────────────────────────┐") print(f"│ INFERENCE @ {uk_time:<35s} │") print(f"├─────────────────────────────────────────────────┤") print(f"│ Checkpoint: {ckpt_name:<35s} │") print(f"└─────────────────────────────────────────────────┘") print_expansion_info(cfg, tie_weights) core = Encoder(cfg, tie_weights=tie_weights).to(DEV) ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) sat_h = SATHead(cfg["d"]).to(DEV) core.load_state_dict(sd["core"]) ar_h.load_state_dict(sd["ar"]) sat_h.load_state_dict(sd["sat"]) core.eval() ar_h.eval() sat_h.eval() total_params = _count_enabled_params(core, ar_h, sat_h) if total_params >= 1_000_000_000: param_str = f"{total_params / 1_000_000_000:.2f}B" elif total_params >= 1_000_000: param_str = f"{total_params / 1_000_000:.2f}M" elif total_params >= 1_000: param_str = f"{total_params / 1_000:.2f}K" else: param_str = f"{total_params}" print(f"Model size: {param_str} parameters ({total_params:,})") prompt_tokens = tok.encode(args.prompt) prompt_len = len(prompt_tokens) ids = torch.tensor([prompt_tokens], device=DEV) if ids.size(1) == 0: ids = torch.tensor([[EOS]], device=DEV) prompt_len = 1 mode_str = args.mode if args.mode == "sat": mode_str = f"sat-{'var' if args.var else 'fixed'}" print(f"{Colors.INFO}Generating ({mode_str})...{Colors.RESET}") start = time.time() if args.mode == "ar": h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1)) for _ in range(args.max_new): logits = ar_h(h)[:, -1] logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) ids = torch.cat([ids, nxt], 1) h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1)) else: cached_len = ids.size(1) h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len) added = 0 while added < args.max_new: logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1) stride = min(int(stride), logits_all.size(1)) new_tokens = [] for i in range(int(stride)): logits = logits_all[:, i] logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) new_tokens.append(nxt) ids = torch.cat([ids, nxt], 1) added += 1 if added >= args.max_new: break if added >= args.max_new: break new_ids = torch.cat(new_tokens, dim=1) mask = sat_mask_cached(new_ids.size(1), cached_len) h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1)) cached_len = ids.size(1) elapsed = time.time() - start gen_tokens = len(ids[0]) - prompt_len tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0 all_tokens = ids[0].tolist() prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True) gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True) print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}") print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}") if getattr(args, "claude_friendly", False): print("[CLAUDE_FRIENDLY_START]") print(f"[mode={mode_str}]") print("[prompt_input]") print(prompt_text) print("[completion]") print(gen_text) print(f"[stats] {elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s") print("[CLAUDE_FRIENDLY_END]") # ───────────────────────── CLI ───────────────────────── def main(): ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing") sub = ap.add_subparsers(dest="cmd", required=True) tr = sub.add_parser("train") tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x") tr.add_argument("--rank", type=int) tr.add_argument("--block", type=int, default=DEFAULT_BLOCK) tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH) tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES) tr.add_argument("--target_tokens", type=int) tr.add_argument("--steps", type=int) tr.add_argument("--amp", action="store_true") tr.add_argument("--compile", action="store_true", help="Use torch.compile for speedup") tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC) tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)") tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep") tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)") tr.add_argument("--save_dir", default=str(CKDIR)) tr.add_argument("--resume", type=str) tr.add_argument("--x2", action="store_true") tr.add_argument("--warmstart_from", type=str) tr.add_argument("--fresh", action="store_true") tr.add_argument("--max_ckpts", type=int, default=None) tr.add_argument("--chilla_max_double", action="store_true") tr.add_argument("--tie_weights", action="store_true") tr.add_argument("--ar_only", action="store_true") tr.add_argument("--freeze_core", action="store_true") tr.add_argument("--unfreeze_ln", action="store_true") tr.add_argument("--train_emb", action="store_true") tr.add_argument("--lr_core", type=float, default=LR_CORE) tr.add_argument("--lr_head", type=float, default=LR_HEAD) tr.add_argument("--chat", action="store_true") tr.add_argument("--chat_messages_key", default="messages") tr.add_argument("--dataset_field_text", default="text") tr.add_argument("--sft_add_generation_prompt", action="store_true") tr.add_argument("--auto_grow", action="store_true") tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122") tr.add_argument("--grow_every_steps", type=int, default=50000) tr.add_argument("--after_sft_source", default="") tr.add_argument("--after_sft_steps", type=int, default=0) tr.add_argument("--after_sft_chat", action="store_true") tr.add_argument("--after_sft_chat_messages_key", default="messages") tr.add_argument("--after_sft_dataset_field_text", default="text") tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None) tr.add_argument("--after_sft_block", type=int, default=0) tr.add_argument("--after_sft_freeze_core", action="store_true") tr.add_argument("--after_sft_unfreeze_ln", action="store_true") tr.add_argument("--after_sft_train_emb", action="store_true") tr.add_argument("--after_sft_lr_core", type=float, default=0.0) tr.add_argument("--after_sft_lr_head", type=float, default=0.0) inf = sub.add_parser("infer") inf.add_argument("--mode", choices=["ar", "sat"], required=True) inf.add_argument("--ckpt", required=True) inf.add_argument("--prompt", required=True) inf.add_argument("--max_new", type=int, default=120) inf.add_argument("--temperature", type=float, default=None) inf.add_argument("--greedy", action="store_true") inf.add_argument("--top_k", type=int, default=None) inf.add_argument("--top_p", type=float, default=0.9) inf.add_argument("--min_p", type=float, default=0.0) inf.add_argument("--repetition_penalty", type=float, default=None) inf.add_argument("--presence_penalty", type=float, default=None) inf.add_argument("--frequency_penalty", type=float, default=None) inf.add_argument("--penalty_last_n", type=int, default=None) inf.add_argument("--var", action="store_true", default=None) inf.add_argument("--no-var", dest="var", action="store_false") inf.add_argument("--claude-friendly", action="store_true", help="Also print an artifact-free prompt/completion block for downstream JSON consumers") st = sub.add_parser("status", help="Read-only training status") st.add_argument("--json", dest="json_output", action="store_true") st.add_argument("--log", type=str, default=str(STATUS_DEFAULT_LOG)) st.add_argument("--save_dir", type=str, default=str(STATUS_DEFAULT_SAVE_DIR)) args = ap.parse_args() if args.cmd == "train": train(args) elif args.cmd == "infer": infer(args) else: raise SystemExit(_emit_status(Path(args.log), Path(args.save_dir), args.json_output)) if __name__ == "__main__": main()