Helion-OSC / inference /test_long_context.py
Trouter-Library's picture
Create inference/test_long_context.py
04dc2a9 verified
"""
Test script to verify 250K context length support
Tests RoPE scaling and long context handling
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import logging
from typing import Optional
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LongContextTester:
"""Test long context capabilities of Helion-OSC"""
def __init__(self, model_path: str = "./inference"):
"""
Initialize tester
Args:
model_path: Path to model inference directory
"""
self.model_path = model_path
logger.info("Loading model configuration...")
# Load config
self.config = AutoConfig.from_pretrained(model_path)
# Verify context length
max_pos = self.config.max_position_embeddings
logger.info(f"Model max position embeddings: {max_pos:,}")
if max_pos < 250000:
logger.warning(f"Context length ({max_pos:,}) is less than 250K!")
else:
logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})")
# Check RoPE scaling
rope_scaling = getattr(self.config, 'rope_scaling', None)
rope_theta = getattr(self.config, 'rope_theta', None)
if rope_scaling:
logger.info(f"RoPE Scaling: {rope_scaling}")
if rope_theta:
logger.info(f"RoPE Theta: {rope_theta:,}")
def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"):
"""Test that tokenizer supports long sequences"""
logger.info("\n" + "="*80)
logger.info("TEST 1: Tokenizer Capacity")
logger.info("="*80)
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
max_length = tokenizer.model_max_length
logger.info(f"Tokenizer max length: {max_length:,}")
if max_length >= 250000:
logger.info("✓ Tokenizer supports 250K+ tokens")
else:
logger.warning(f"✗ Tokenizer max length only {max_length:,}")
# Test with a long sequence
test_tokens = 10000
test_text = "Hello world! " * (test_tokens // 2)
logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...")
encoded = tokenizer(test_text, return_tensors="pt", truncation=False)
actual_tokens = encoded['input_ids'].shape[1]
logger.info(f"Successfully tokenized {actual_tokens:,} tokens")
logger.info("✓ Tokenization test passed")
return True
except Exception as e:
logger.error(f"✗ Tokenization test failed: {e}")
return False
def test_position_embeddings(self):
"""Test position embedding capacity"""
logger.info("\n" + "="*80)
logger.info("TEST 2: Position Embeddings")
logger.info("="*80)
max_pos = self.config.max_position_embeddings
hidden_size = self.config.hidden_size
logger.info(f"Max positions: {max_pos:,}")
logger.info(f"Hidden size: {hidden_size:,}")
# Calculate memory requirement for position embeddings
if hasattr(self.config, 'rope_theta'):
logger.info("Using RoPE (Rotary Position Embeddings)")
logger.info("✓ RoPE scales efficiently to long contexts")
# RoPE doesn't store position embeddings, it computes them
logger.info(f"RoPE Theta: {self.config.rope_theta:,}")
if hasattr(self.config, 'rope_scaling'):
scaling = self.config.rope_scaling
logger.info(f"RoPE Scaling Configuration:")
logger.info(f" Type: {scaling.get('type', 'N/A')}")
logger.info(f" Factor: {scaling.get('factor', 'N/A')}")
if scaling.get('factor', 0) >= 32:
logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)")
else:
logger.warning("✗ RoPE scaling factor may be insufficient")
return True
else:
# Learned position embeddings
pos_emb_size = max_pos * hidden_size * 2 # bfloat16
pos_emb_gb = pos_emb_size / (1024**3)
logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB")
if max_pos >= 250000:
logger.info("✓ Sufficient position embeddings for 250K context")
return True
else:
logger.warning("✗ Insufficient position embeddings")
return False
def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]):
"""Test attention computation at various lengths"""
logger.info("\n" + "="*80)
logger.info("TEST 3: Attention Computation Scaling")
logger.info("="*80)
hidden_size = self.config.hidden_size
num_heads = self.config.num_attention_heads
head_dim = hidden_size // num_heads
logger.info(f"Attention heads: {num_heads}")
logger.info(f"Head dimension: {head_dim}")
for seq_len in sequence_lengths:
# Calculate attention matrix size
# For self-attention: (batch, heads, seq_len, seq_len)
attn_size = 1 * num_heads * seq_len * seq_len * 2 # bfloat16
attn_gb = attn_size / (1024**3)
logger.info(f"\nSequence length: {seq_len:,} tokens")
logger.info(f" Attention matrix: {attn_gb:.2f} GB")
if seq_len <= 32768:
logger.info(f" ✓ Manageable size")
elif seq_len <= 131072:
logger.info(f" ⚠ Large - may need Flash Attention")
else:
logger.info(f" ⚠ Very large - requires optimizations")
# Check for Flash Attention support
use_flash = getattr(self.config, 'use_flash_attention_2', False)
if use_flash:
logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts")
else:
logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts")
return True
def test_memory_requirements(self):
"""Calculate memory requirements for 250K context"""
logger.info("\n" + "="*80)
logger.info("TEST 4: Memory Requirements")
logger.info("="*80)
context_length = 250000
batch_size = 1
hidden_size = self.config.hidden_size
num_layers = self.config.num_hidden_layers
logger.info(f"Configuration:")
logger.info(f" Context: {context_length:,} tokens")
logger.info(f" Batch size: {batch_size}")
logger.info(f" Hidden size: {hidden_size:,}")
logger.info(f" Layers: {num_layers}")
# Calculate activation memory (rough estimate)
# Main components: hidden states, attention outputs
hidden_states_size = batch_size * context_length * hidden_size * 2 # bfloat16
hidden_states_gb = hidden_states_size / (1024**3)
# Per layer
layer_memory_gb = hidden_states_gb * 2 # rough estimate with attention
total_activation_gb = layer_memory_gb * num_layers
logger.info(f"\nMemory estimates:")
logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB")
logger.info(f" Total activation memory: {total_activation_gb:.2f} GB")
logger.info(f" Model weights: ~349 GB")
logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB")
logger.info(f"\nRecommendations:")
if total_activation_gb < 50:
logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs")
elif total_activation_gb < 100:
logger.info(" ⚠ May need gradient checkpointing")
else:
logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)")
return True
def test_rope_frequencies(self):
"""Test RoPE frequency calculations for long context"""
logger.info("\n" + "="*80)
logger.info("TEST 5: RoPE Frequency Analysis")
logger.info("="*80)
rope_theta = getattr(self.config, 'rope_theta', 10000)
hidden_size = self.config.hidden_size
num_heads = self.config.num_attention_heads
head_dim = hidden_size // num_heads
logger.info(f"RoPE theta: {rope_theta:,}")
logger.info(f"Head dimension: {head_dim}")
# Calculate frequency range
# freqs = theta^(-2i/d) for i in [0, d/2]
min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim)
max_freq = rope_theta ** 0
logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]")
# Calculate wavelengths at different frequencies
wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim))
for i in range(0, head_dim // 2, head_dim // 8)]
logger.info(f"\nWavelengths (in tokens):")
for i, wl in enumerate(wavelengths):
logger.info(f" Frequency {i}: {wl:,.0f} tokens")
max_wavelength = max(wavelengths)
if max_wavelength >= 250000:
logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context")
else:
logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient")
return True
def run_all_tests(self):
"""Run all context length tests"""
logger.info("\n" + "="*80)
logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE")
logger.info("="*80)
results = {
"tokenization": self.test_tokenization_capacity(),
"position_embeddings": self.test_position_embeddings(),
"attention_scaling": self.test_attention_computation(),
"memory_requirements": self.test_memory_requirements(),
"rope_frequencies": self.test_rope_frequencies()
}
# Summary
logger.info("\n" + "="*80)
logger.info("TEST SUMMARY")
logger.info("="*80)
for test_name, passed in results.items():
status = "✓ PASS" if passed else "✗ FAIL"
logger.info(f"{test_name}: {status}")
all_passed = all(results.values())
if all_passed:
logger.info("\n✓ All tests passed - Model supports 250K context length")
else:
logger.warning("\n⚠ Some tests failed - Check configuration")
return all_passed
def main():
"""Main test script"""
import argparse
parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support")
parser.add_argument(
"--model-path",
type=str,
default="./inference",
help="Path to model inference directory"
)
parser.add_argument(
"--test",
choices=["all", "tokenization", "position", "attention", "memory", "rope"],
default="all",
help="Which test to run"
)
args = parser.parse_args()
tester = LongContextTester(args.model_path)
if args.test == "all":
tester.run_all_tests()
elif args.test == "tokenization":
tester.test_tokenization_capacity()
elif args.test == "position":
tester.test_position_embeddings()
elif args.test == "attention":
tester.test_attention_computation()
elif args.test == "memory":
tester.test_memory_requirements()
elif args.test == "rope":
tester.test_rope_frequencies()
if __name__ == "__main__":
main()