File size: 12,116 Bytes
04dc2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
"""
Test script to verify 250K context length support
Tests RoPE scaling and long context handling
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import logging
from typing import Optional
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LongContextTester:
"""Test long context capabilities of Helion-OSC"""
def __init__(self, model_path: str = "./inference"):
"""
Initialize tester
Args:
model_path: Path to model inference directory
"""
self.model_path = model_path
logger.info("Loading model configuration...")
# Load config
self.config = AutoConfig.from_pretrained(model_path)
# Verify context length
max_pos = self.config.max_position_embeddings
logger.info(f"Model max position embeddings: {max_pos:,}")
if max_pos < 250000:
logger.warning(f"Context length ({max_pos:,}) is less than 250K!")
else:
logger.info(f"β Context length supports 250K+ tokens ({max_pos:,})")
# Check RoPE scaling
rope_scaling = getattr(self.config, 'rope_scaling', None)
rope_theta = getattr(self.config, 'rope_theta', None)
if rope_scaling:
logger.info(f"RoPE Scaling: {rope_scaling}")
if rope_theta:
logger.info(f"RoPE Theta: {rope_theta:,}")
def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"):
"""Test that tokenizer supports long sequences"""
logger.info("\n" + "="*80)
logger.info("TEST 1: Tokenizer Capacity")
logger.info("="*80)
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
max_length = tokenizer.model_max_length
logger.info(f"Tokenizer max length: {max_length:,}")
if max_length >= 250000:
logger.info("β Tokenizer supports 250K+ tokens")
else:
logger.warning(f"β Tokenizer max length only {max_length:,}")
# Test with a long sequence
test_tokens = 10000
test_text = "Hello world! " * (test_tokens // 2)
logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...")
encoded = tokenizer(test_text, return_tensors="pt", truncation=False)
actual_tokens = encoded['input_ids'].shape[1]
logger.info(f"Successfully tokenized {actual_tokens:,} tokens")
logger.info("β Tokenization test passed")
return True
except Exception as e:
logger.error(f"β Tokenization test failed: {e}")
return False
def test_position_embeddings(self):
"""Test position embedding capacity"""
logger.info("\n" + "="*80)
logger.info("TEST 2: Position Embeddings")
logger.info("="*80)
max_pos = self.config.max_position_embeddings
hidden_size = self.config.hidden_size
logger.info(f"Max positions: {max_pos:,}")
logger.info(f"Hidden size: {hidden_size:,}")
# Calculate memory requirement for position embeddings
if hasattr(self.config, 'rope_theta'):
logger.info("Using RoPE (Rotary Position Embeddings)")
logger.info("β RoPE scales efficiently to long contexts")
# RoPE doesn't store position embeddings, it computes them
logger.info(f"RoPE Theta: {self.config.rope_theta:,}")
if hasattr(self.config, 'rope_scaling'):
scaling = self.config.rope_scaling
logger.info(f"RoPE Scaling Configuration:")
logger.info(f" Type: {scaling.get('type', 'N/A')}")
logger.info(f" Factor: {scaling.get('factor', 'N/A')}")
if scaling.get('factor', 0) >= 32:
logger.info("β RoPE scaling factor supports 250K+ context (32x from 8K base)")
else:
logger.warning("β RoPE scaling factor may be insufficient")
return True
else:
# Learned position embeddings
pos_emb_size = max_pos * hidden_size * 2 # bfloat16
pos_emb_gb = pos_emb_size / (1024**3)
logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB")
if max_pos >= 250000:
logger.info("β Sufficient position embeddings for 250K context")
return True
else:
logger.warning("β Insufficient position embeddings")
return False
def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]):
"""Test attention computation at various lengths"""
logger.info("\n" + "="*80)
logger.info("TEST 3: Attention Computation Scaling")
logger.info("="*80)
hidden_size = self.config.hidden_size
num_heads = self.config.num_attention_heads
head_dim = hidden_size // num_heads
logger.info(f"Attention heads: {num_heads}")
logger.info(f"Head dimension: {head_dim}")
for seq_len in sequence_lengths:
# Calculate attention matrix size
# For self-attention: (batch, heads, seq_len, seq_len)
attn_size = 1 * num_heads * seq_len * seq_len * 2 # bfloat16
attn_gb = attn_size / (1024**3)
logger.info(f"\nSequence length: {seq_len:,} tokens")
logger.info(f" Attention matrix: {attn_gb:.2f} GB")
if seq_len <= 32768:
logger.info(f" β Manageable size")
elif seq_len <= 131072:
logger.info(f" β Large - may need Flash Attention")
else:
logger.info(f" β Very large - requires optimizations")
# Check for Flash Attention support
use_flash = getattr(self.config, 'use_flash_attention_2', False)
if use_flash:
logger.info("\nβ Flash Attention 2 enabled - efficient for long contexts")
else:
logger.warning("\nβ Flash Attention not configured - may be slow for long contexts")
return True
def test_memory_requirements(self):
"""Calculate memory requirements for 250K context"""
logger.info("\n" + "="*80)
logger.info("TEST 4: Memory Requirements")
logger.info("="*80)
context_length = 250000
batch_size = 1
hidden_size = self.config.hidden_size
num_layers = self.config.num_hidden_layers
logger.info(f"Configuration:")
logger.info(f" Context: {context_length:,} tokens")
logger.info(f" Batch size: {batch_size}")
logger.info(f" Hidden size: {hidden_size:,}")
logger.info(f" Layers: {num_layers}")
# Calculate activation memory (rough estimate)
# Main components: hidden states, attention outputs
hidden_states_size = batch_size * context_length * hidden_size * 2 # bfloat16
hidden_states_gb = hidden_states_size / (1024**3)
# Per layer
layer_memory_gb = hidden_states_gb * 2 # rough estimate with attention
total_activation_gb = layer_memory_gb * num_layers
logger.info(f"\nMemory estimates:")
logger.info(f" Hidden states per layer: {hidden_states_gb:.2f} GB")
logger.info(f" Total activation memory: {total_activation_gb:.2f} GB")
logger.info(f" Model weights: ~349 GB")
logger.info(f" Total (weights + activations): ~{349 + total_activation_gb:.2f} GB")
logger.info(f"\nRecommendations:")
if total_activation_gb < 50:
logger.info(" β Should fit on 8x A100 (80GB) GPUs")
elif total_activation_gb < 100:
logger.info(" β May need gradient checkpointing")
else:
logger.info(" β Will need aggressive optimizations (checkpointing, CPU offload)")
return True
def test_rope_frequencies(self):
"""Test RoPE frequency calculations for long context"""
logger.info("\n" + "="*80)
logger.info("TEST 5: RoPE Frequency Analysis")
logger.info("="*80)
rope_theta = getattr(self.config, 'rope_theta', 10000)
hidden_size = self.config.hidden_size
num_heads = self.config.num_attention_heads
head_dim = hidden_size // num_heads
logger.info(f"RoPE theta: {rope_theta:,}")
logger.info(f"Head dimension: {head_dim}")
# Calculate frequency range
# freqs = theta^(-2i/d) for i in [0, d/2]
min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim)
max_freq = rope_theta ** 0
logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]")
# Calculate wavelengths at different frequencies
wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim))
for i in range(0, head_dim // 2, head_dim // 8)]
logger.info(f"\nWavelengths (in tokens):")
for i, wl in enumerate(wavelengths):
logger.info(f" Frequency {i}: {wl:,.0f} tokens")
max_wavelength = max(wavelengths)
if max_wavelength >= 250000:
logger.info(f"\nβ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context")
else:
logger.warning(f"\nβ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient")
return True
def run_all_tests(self):
"""Run all context length tests"""
logger.info("\n" + "="*80)
logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE")
logger.info("="*80)
results = {
"tokenization": self.test_tokenization_capacity(),
"position_embeddings": self.test_position_embeddings(),
"attention_scaling": self.test_attention_computation(),
"memory_requirements": self.test_memory_requirements(),
"rope_frequencies": self.test_rope_frequencies()
}
# Summary
logger.info("\n" + "="*80)
logger.info("TEST SUMMARY")
logger.info("="*80)
for test_name, passed in results.items():
status = "β PASS" if passed else "β FAIL"
logger.info(f"{test_name}: {status}")
all_passed = all(results.values())
if all_passed:
logger.info("\nβ All tests passed - Model supports 250K context length")
else:
logger.warning("\nβ Some tests failed - Check configuration")
return all_passed
def main():
"""Main test script"""
import argparse
parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support")
parser.add_argument(
"--model-path",
type=str,
default="./inference",
help="Path to model inference directory"
)
parser.add_argument(
"--test",
choices=["all", "tokenization", "position", "attention", "memory", "rope"],
default="all",
help="Which test to run"
)
args = parser.parse_args()
tester = LongContextTester(args.model_path)
if args.test == "all":
tester.run_all_tests()
elif args.test == "tokenization":
tester.test_tokenization_capacity()
elif args.test == "position":
tester.test_position_embeddings()
elif args.test == "attention":
tester.test_attention_computation()
elif args.test == "memory":
tester.test_memory_requirements()
elif args.test == "rope":
tester.test_rope_frequencies()
if __name__ == "__main__":
main() |