""" Test script to verify 250K context length support Tests RoPE scaling and long context handling """ import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import logging from typing import Optional import time logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class LongContextTester: """Test long context capabilities of Helion-OSC""" def __init__(self, model_path: str = "./inference"): """ Initialize tester Args: model_path: Path to model inference directory """ self.model_path = model_path logger.info("Loading model configuration...") # Load config self.config = AutoConfig.from_pretrained(model_path) # Verify context length max_pos = self.config.max_position_embeddings logger.info(f"Model max position embeddings: {max_pos:,}") if max_pos < 250000: logger.warning(f"Context length ({max_pos:,}) is less than 250K!") else: logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})") # Check RoPE scaling rope_scaling = getattr(self.config, 'rope_scaling', None) rope_theta = getattr(self.config, 'rope_theta', None) if rope_scaling: logger.info(f"RoPE Scaling: {rope_scaling}") if rope_theta: logger.info(f"RoPE Theta: {rope_theta:,}") def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"): """Test that tokenizer supports long sequences""" logger.info("\n" + "="*80) logger.info("TEST 1: Tokenizer Capacity") logger.info("="*80) try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) max_length = tokenizer.model_max_length logger.info(f"Tokenizer max length: {max_length:,}") if max_length >= 250000: logger.info("✓ Tokenizer supports 250K+ tokens") else: logger.warning(f"✗ Tokenizer max length only {max_length:,}") # Test with a long sequence test_tokens = 10000 test_text = "Hello world! " * (test_tokens // 2) logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...") encoded = tokenizer(test_text, return_tensors="pt", truncation=False) actual_tokens = encoded['input_ids'].shape[1] logger.info(f"Successfully tokenized {actual_tokens:,} tokens") logger.info("✓ Tokenization test passed") return True except Exception as e: logger.error(f"✗ Tokenization test failed: {e}") return False def test_position_embeddings(self): """Test position embedding capacity""" logger.info("\n" + "="*80) logger.info("TEST 2: Position Embeddings") logger.info("="*80) max_pos = self.config.max_position_embeddings hidden_size = self.config.hidden_size logger.info(f"Max positions: {max_pos:,}") logger.info(f"Hidden size: {hidden_size:,}") # Calculate memory requirement for position embeddings if hasattr(self.config, 'rope_theta'): logger.info("Using RoPE (Rotary Position Embeddings)") logger.info("✓ RoPE scales efficiently to long contexts") # RoPE doesn't store position embeddings, it computes them logger.info(f"RoPE Theta: {self.config.rope_theta:,}") if hasattr(self.config, 'rope_scaling'): scaling = self.config.rope_scaling logger.info(f"RoPE Scaling Configuration:") logger.info(f" Type: {scaling.get('type', 'N/A')}") logger.info(f" Factor: {scaling.get('factor', 'N/A')}") if scaling.get('factor', 0) >= 32: logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)") else: logger.warning("✗ RoPE scaling factor may be insufficient") return True else: # Learned position embeddings pos_emb_size = max_pos * hidden_size * 2 # bfloat16 pos_emb_gb = pos_emb_size / (1024**3) logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB") if max_pos >= 250000: logger.info("✓ Sufficient position embeddings for 250K context") return True else: logger.warning("✗ Insufficient position embeddings") return False def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]): """Test attention computation at various lengths""" logger.info("\n" + "="*80) logger.info("TEST 3: Attention Computation Scaling") logger.info("="*80) hidden_size = self.config.hidden_size num_heads = self.config.num_attention_heads head_dim = hidden_size // num_heads logger.info(f"Attention heads: {num_heads}") logger.info(f"Head dimension: {head_dim}") for seq_len in sequence_lengths: # Calculate attention matrix size # For self-attention: (batch, heads, seq_len, seq_len) attn_size = 1 * num_heads * seq_len * seq_len * 2 # bfloat16 attn_gb = attn_size / (1024**3) logger.info(f"\nSequence length: {seq_len:,} tokens") logger.info(f" Attention matrix: {attn_gb:.2f} GB") if seq_len <= 32768: logger.info(f" ✓ Manageable size") elif seq_len <= 131072: logger.info(f" ⚠ Large - may need Flash Attention") else: logger.info(f" ⚠ Very large - requires optimizations") # Check for Flash Attention support use_flash = getattr(self.config, 'use_flash_attention_2', False) if use_flash: logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts") else: logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts") return True def test_memory_requirements(self): """Calculate memory requirements for 250K context""" logger.info("\n" + "="*80) logger.info("TEST 4: Memory Requirements") logger.info("="*80) context_length = 250000 batch_size = 1 hidden_size = self.config.hidden_size num_layers = self.config.num_hidden_layers logger.info(f"Configuration:") logger.info(f" Context: {context_length:,} tokens") logger.info(f" Batch size: {batch_size}") logger.info(f" Hidden size: {hidden_size:,}") logger.info(f" Layers: {num_layers}") # Calculate activation memory (rough estimate) # Main components: hidden states, attention outputs hidden_states_size = batch_size * context_length * hidden_size * 2 # bfloat16 hidden_states_gb = hidden_states_size / (1024**3) # Per layer layer_memory_gb = hidden_states_gb * 2 # rough estimate with attention total_activation_gb = layer_memory_gb * num_layers logger.info(f"\nMemory estimates:") logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB") logger.info(f" Total activation memory: {total_activation_gb:.2f} GB") logger.info(f" Model weights: ~349 GB") logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB") logger.info(f"\nRecommendations:") if total_activation_gb < 50: logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs") elif total_activation_gb < 100: logger.info(" ⚠ May need gradient checkpointing") else: logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)") return True def test_rope_frequencies(self): """Test RoPE frequency calculations for long context""" logger.info("\n" + "="*80) logger.info("TEST 5: RoPE Frequency Analysis") logger.info("="*80) rope_theta = getattr(self.config, 'rope_theta', 10000) hidden_size = self.config.hidden_size num_heads = self.config.num_attention_heads head_dim = hidden_size // num_heads logger.info(f"RoPE theta: {rope_theta:,}") logger.info(f"Head dimension: {head_dim}") # Calculate frequency range # freqs = theta^(-2i/d) for i in [0, d/2] min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim) max_freq = rope_theta ** 0 logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]") # Calculate wavelengths at different frequencies wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim)) for i in range(0, head_dim // 2, head_dim // 8)] logger.info(f"\nWavelengths (in tokens):") for i, wl in enumerate(wavelengths): logger.info(f" Frequency {i}: {wl:,.0f} tokens") max_wavelength = max(wavelengths) if max_wavelength >= 250000: logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context") else: logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient") return True def run_all_tests(self): """Run all context length tests""" logger.info("\n" + "="*80) logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE") logger.info("="*80) results = { "tokenization": self.test_tokenization_capacity(), "position_embeddings": self.test_position_embeddings(), "attention_scaling": self.test_attention_computation(), "memory_requirements": self.test_memory_requirements(), "rope_frequencies": self.test_rope_frequencies() } # Summary logger.info("\n" + "="*80) logger.info("TEST SUMMARY") logger.info("="*80) for test_name, passed in results.items(): status = "✓ PASS" if passed else "✗ FAIL" logger.info(f"{test_name}: {status}") all_passed = all(results.values()) if all_passed: logger.info("\n✓ All tests passed - Model supports 250K context length") else: logger.warning("\n⚠ Some tests failed - Check configuration") return all_passed def main(): """Main test script""" import argparse parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support") parser.add_argument( "--model-path", type=str, default="./inference", help="Path to model inference directory" ) parser.add_argument( "--test", choices=["all", "tokenization", "position", "attention", "memory", "rope"], default="all", help="Which test to run" ) args = parser.parse_args() tester = LongContextTester(args.model_path) if args.test == "all": tester.run_all_tests() elif args.test == "tokenization": tester.test_tokenization_capacity() elif args.test == "position": tester.test_position_embeddings() elif args.test == "attention": tester.test_attention_computation() elif args.test == "memory": tester.test_memory_requirements() elif args.test == "rope": tester.test_rope_frequencies() if __name__ == "__main__": main()