|
|
""" |
|
|
Test script to verify 250K context length support |
|
|
Tests RoPE scaling and long context handling |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
|
import logging |
|
|
from typing import Optional |
|
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LongContextTester: |
|
|
"""Test long context capabilities of Helion-OSC""" |
|
|
|
|
|
def __init__(self, model_path: str = "./inference"): |
|
|
""" |
|
|
Initialize tester |
|
|
|
|
|
Args: |
|
|
model_path: Path to model inference directory |
|
|
""" |
|
|
self.model_path = model_path |
|
|
logger.info("Loading model configuration...") |
|
|
|
|
|
|
|
|
self.config = AutoConfig.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
max_pos = self.config.max_position_embeddings |
|
|
logger.info(f"Model max position embeddings: {max_pos:,}") |
|
|
|
|
|
if max_pos < 250000: |
|
|
logger.warning(f"Context length ({max_pos:,}) is less than 250K!") |
|
|
else: |
|
|
logger.info(f"✓ Context length supports 250K+ tokens ({max_pos:,})") |
|
|
|
|
|
|
|
|
rope_scaling = getattr(self.config, 'rope_scaling', None) |
|
|
rope_theta = getattr(self.config, 'rope_theta', None) |
|
|
|
|
|
if rope_scaling: |
|
|
logger.info(f"RoPE Scaling: {rope_scaling}") |
|
|
if rope_theta: |
|
|
logger.info(f"RoPE Theta: {rope_theta:,}") |
|
|
|
|
|
def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"): |
|
|
"""Test that tokenizer supports long sequences""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST 1: Tokenizer Capacity") |
|
|
logger.info("="*80) |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
|
|
|
max_length = tokenizer.model_max_length |
|
|
logger.info(f"Tokenizer max length: {max_length:,}") |
|
|
|
|
|
if max_length >= 250000: |
|
|
logger.info("✓ Tokenizer supports 250K+ tokens") |
|
|
else: |
|
|
logger.warning(f"✗ Tokenizer max length only {max_length:,}") |
|
|
|
|
|
|
|
|
test_tokens = 10000 |
|
|
test_text = "Hello world! " * (test_tokens // 2) |
|
|
|
|
|
logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...") |
|
|
encoded = tokenizer(test_text, return_tensors="pt", truncation=False) |
|
|
actual_tokens = encoded['input_ids'].shape[1] |
|
|
|
|
|
logger.info(f"Successfully tokenized {actual_tokens:,} tokens") |
|
|
logger.info("✓ Tokenization test passed") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"✗ Tokenization test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_position_embeddings(self): |
|
|
"""Test position embedding capacity""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST 2: Position Embeddings") |
|
|
logger.info("="*80) |
|
|
|
|
|
max_pos = self.config.max_position_embeddings |
|
|
hidden_size = self.config.hidden_size |
|
|
|
|
|
logger.info(f"Max positions: {max_pos:,}") |
|
|
logger.info(f"Hidden size: {hidden_size:,}") |
|
|
|
|
|
|
|
|
if hasattr(self.config, 'rope_theta'): |
|
|
logger.info("Using RoPE (Rotary Position Embeddings)") |
|
|
logger.info("✓ RoPE scales efficiently to long contexts") |
|
|
|
|
|
|
|
|
logger.info(f"RoPE Theta: {self.config.rope_theta:,}") |
|
|
|
|
|
if hasattr(self.config, 'rope_scaling'): |
|
|
scaling = self.config.rope_scaling |
|
|
logger.info(f"RoPE Scaling Configuration:") |
|
|
logger.info(f" Type: {scaling.get('type', 'N/A')}") |
|
|
logger.info(f" Factor: {scaling.get('factor', 'N/A')}") |
|
|
|
|
|
if scaling.get('factor', 0) >= 32: |
|
|
logger.info("✓ RoPE scaling factor supports 250K+ context (32x from 8K base)") |
|
|
else: |
|
|
logger.warning("✗ RoPE scaling factor may be insufficient") |
|
|
|
|
|
return True |
|
|
else: |
|
|
|
|
|
pos_emb_size = max_pos * hidden_size * 2 |
|
|
pos_emb_gb = pos_emb_size / (1024**3) |
|
|
logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB") |
|
|
|
|
|
if max_pos >= 250000: |
|
|
logger.info("✓ Sufficient position embeddings for 250K context") |
|
|
return True |
|
|
else: |
|
|
logger.warning("✗ Insufficient position embeddings") |
|
|
return False |
|
|
|
|
|
def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]): |
|
|
"""Test attention computation at various lengths""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST 3: Attention Computation Scaling") |
|
|
logger.info("="*80) |
|
|
|
|
|
hidden_size = self.config.hidden_size |
|
|
num_heads = self.config.num_attention_heads |
|
|
head_dim = hidden_size // num_heads |
|
|
|
|
|
logger.info(f"Attention heads: {num_heads}") |
|
|
logger.info(f"Head dimension: {head_dim}") |
|
|
|
|
|
for seq_len in sequence_lengths: |
|
|
|
|
|
|
|
|
attn_size = 1 * num_heads * seq_len * seq_len * 2 |
|
|
attn_gb = attn_size / (1024**3) |
|
|
|
|
|
logger.info(f"\nSequence length: {seq_len:,} tokens") |
|
|
logger.info(f" Attention matrix: {attn_gb:.2f} GB") |
|
|
|
|
|
if seq_len <= 32768: |
|
|
logger.info(f" ✓ Manageable size") |
|
|
elif seq_len <= 131072: |
|
|
logger.info(f" ⚠ Large - may need Flash Attention") |
|
|
else: |
|
|
logger.info(f" ⚠ Very large - requires optimizations") |
|
|
|
|
|
|
|
|
use_flash = getattr(self.config, 'use_flash_attention_2', False) |
|
|
if use_flash: |
|
|
logger.info("\n✓ Flash Attention 2 enabled - efficient for long contexts") |
|
|
else: |
|
|
logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts") |
|
|
|
|
|
return True |
|
|
|
|
|
def test_memory_requirements(self): |
|
|
"""Calculate memory requirements for 250K context""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST 4: Memory Requirements") |
|
|
logger.info("="*80) |
|
|
|
|
|
context_length = 250000 |
|
|
batch_size = 1 |
|
|
hidden_size = self.config.hidden_size |
|
|
num_layers = self.config.num_hidden_layers |
|
|
|
|
|
logger.info(f"Configuration:") |
|
|
logger.info(f" Context: {context_length:,} tokens") |
|
|
logger.info(f" Batch size: {batch_size}") |
|
|
logger.info(f" Hidden size: {hidden_size:,}") |
|
|
logger.info(f" Layers: {num_layers}") |
|
|
|
|
|
|
|
|
|
|
|
hidden_states_size = batch_size * context_length * hidden_size * 2 |
|
|
hidden_states_gb = hidden_states_size / (1024**3) |
|
|
|
|
|
|
|
|
layer_memory_gb = hidden_states_gb * 2 |
|
|
total_activation_gb = layer_memory_gb * num_layers |
|
|
|
|
|
logger.info(f"\nMemory estimates:") |
|
|
logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB") |
|
|
logger.info(f" Total activation memory: {total_activation_gb:.2f} GB") |
|
|
logger.info(f" Model weights: ~349 GB") |
|
|
logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB") |
|
|
|
|
|
logger.info(f"\nRecommendations:") |
|
|
if total_activation_gb < 50: |
|
|
logger.info(" ✓ Should fit on 8x A100 (80GB) GPUs") |
|
|
elif total_activation_gb < 100: |
|
|
logger.info(" ⚠ May need gradient checkpointing") |
|
|
else: |
|
|
logger.info(" ⚠ Will need aggressive optimizations (checkpointing, CPU offload)") |
|
|
|
|
|
return True |
|
|
|
|
|
def test_rope_frequencies(self): |
|
|
"""Test RoPE frequency calculations for long context""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST 5: RoPE Frequency Analysis") |
|
|
logger.info("="*80) |
|
|
|
|
|
rope_theta = getattr(self.config, 'rope_theta', 10000) |
|
|
hidden_size = self.config.hidden_size |
|
|
num_heads = self.config.num_attention_heads |
|
|
head_dim = hidden_size // num_heads |
|
|
|
|
|
logger.info(f"RoPE theta: {rope_theta:,}") |
|
|
logger.info(f"Head dimension: {head_dim}") |
|
|
|
|
|
|
|
|
|
|
|
min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim) |
|
|
max_freq = rope_theta ** 0 |
|
|
|
|
|
logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]") |
|
|
|
|
|
|
|
|
wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim)) |
|
|
for i in range(0, head_dim // 2, head_dim // 8)] |
|
|
|
|
|
logger.info(f"\nWavelengths (in tokens):") |
|
|
for i, wl in enumerate(wavelengths): |
|
|
logger.info(f" Frequency {i}: {wl:,.0f} tokens") |
|
|
|
|
|
max_wavelength = max(wavelengths) |
|
|
if max_wavelength >= 250000: |
|
|
logger.info(f"\n✓ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context") |
|
|
else: |
|
|
logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient") |
|
|
|
|
|
return True |
|
|
|
|
|
def run_all_tests(self): |
|
|
"""Run all context length tests""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE") |
|
|
logger.info("="*80) |
|
|
|
|
|
results = { |
|
|
"tokenization": self.test_tokenization_capacity(), |
|
|
"position_embeddings": self.test_position_embeddings(), |
|
|
"attention_scaling": self.test_attention_computation(), |
|
|
"memory_requirements": self.test_memory_requirements(), |
|
|
"rope_frequencies": self.test_rope_frequencies() |
|
|
} |
|
|
|
|
|
|
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("TEST SUMMARY") |
|
|
logger.info("="*80) |
|
|
|
|
|
for test_name, passed in results.items(): |
|
|
status = "✓ PASS" if passed else "✗ FAIL" |
|
|
logger.info(f"{test_name}: {status}") |
|
|
|
|
|
all_passed = all(results.values()) |
|
|
|
|
|
if all_passed: |
|
|
logger.info("\n✓ All tests passed - Model supports 250K context length") |
|
|
else: |
|
|
logger.warning("\n⚠ Some tests failed - Check configuration") |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main test script""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support") |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default="./inference", |
|
|
help="Path to model inference directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test", |
|
|
choices=["all", "tokenization", "position", "attention", "memory", "rope"], |
|
|
default="all", |
|
|
help="Which test to run" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
tester = LongContextTester(args.model_path) |
|
|
|
|
|
if args.test == "all": |
|
|
tester.run_all_tests() |
|
|
elif args.test == "tokenization": |
|
|
tester.test_tokenization_capacity() |
|
|
elif args.test == "position": |
|
|
tester.test_position_embeddings() |
|
|
elif args.test == "attention": |
|
|
tester.test_attention_computation() |
|
|
elif args.test == "memory": |
|
|
tester.test_memory_requirements() |
|
|
elif args.test == "rope": |
|
|
tester.test_rope_frequencies() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |