File size: 12,116 Bytes
04dc2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
Test script to verify 250K context length support
Tests RoPE scaling and long context handling
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import logging
from typing import Optional
import time

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class LongContextTester:
    """Test long context capabilities of Helion-OSC"""
    
    def __init__(self, model_path: str = "./inference"):
        """
        Initialize tester
        
        Args:
            model_path: Path to model inference directory
        """
        self.model_path = model_path
        logger.info("Loading model configuration...")
        
        # Load config
        self.config = AutoConfig.from_pretrained(model_path)
        
        # Verify context length
        max_pos = self.config.max_position_embeddings
        logger.info(f"Model max position embeddings: {max_pos:,}")
        
        if max_pos < 250000:
            logger.warning(f"Context length ({max_pos:,}) is less than 250K!")
        else:
            logger.info(f"βœ“ Context length supports 250K+ tokens ({max_pos:,})")
        
        # Check RoPE scaling
        rope_scaling = getattr(self.config, 'rope_scaling', None)
        rope_theta = getattr(self.config, 'rope_theta', None)
        
        if rope_scaling:
            logger.info(f"RoPE Scaling: {rope_scaling}")
        if rope_theta:
            logger.info(f"RoPE Theta: {rope_theta:,}")
    
    def test_tokenization_capacity(self, tokenizer_path: str = "DeepXR/Helion-OSC"):
        """Test that tokenizer supports long sequences"""
        logger.info("\n" + "="*80)
        logger.info("TEST 1: Tokenizer Capacity")
        logger.info("="*80)
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            
            max_length = tokenizer.model_max_length
            logger.info(f"Tokenizer max length: {max_length:,}")
            
            if max_length >= 250000:
                logger.info("βœ“ Tokenizer supports 250K+ tokens")
            else:
                logger.warning(f"βœ— Tokenizer max length only {max_length:,}")
            
            # Test with a long sequence
            test_tokens = 10000
            test_text = "Hello world! " * (test_tokens // 2)
            
            logger.info(f"Testing tokenization of ~{test_tokens:,} tokens...")
            encoded = tokenizer(test_text, return_tensors="pt", truncation=False)
            actual_tokens = encoded['input_ids'].shape[1]
            
            logger.info(f"Successfully tokenized {actual_tokens:,} tokens")
            logger.info("βœ“ Tokenization test passed")
            
            return True
            
        except Exception as e:
            logger.error(f"βœ— Tokenization test failed: {e}")
            return False
    
    def test_position_embeddings(self):
        """Test position embedding capacity"""
        logger.info("\n" + "="*80)
        logger.info("TEST 2: Position Embeddings")
        logger.info("="*80)
        
        max_pos = self.config.max_position_embeddings
        hidden_size = self.config.hidden_size
        
        logger.info(f"Max positions: {max_pos:,}")
        logger.info(f"Hidden size: {hidden_size:,}")
        
        # Calculate memory requirement for position embeddings
        if hasattr(self.config, 'rope_theta'):
            logger.info("Using RoPE (Rotary Position Embeddings)")
            logger.info("βœ“ RoPE scales efficiently to long contexts")
            
            # RoPE doesn't store position embeddings, it computes them
            logger.info(f"RoPE Theta: {self.config.rope_theta:,}")
            
            if hasattr(self.config, 'rope_scaling'):
                scaling = self.config.rope_scaling
                logger.info(f"RoPE Scaling Configuration:")
                logger.info(f"  Type: {scaling.get('type', 'N/A')}")
                logger.info(f"  Factor: {scaling.get('factor', 'N/A')}")
                
                if scaling.get('factor', 0) >= 32:
                    logger.info("βœ“ RoPE scaling factor supports 250K+ context (32x from 8K base)")
                else:
                    logger.warning("βœ— RoPE scaling factor may be insufficient")
            
            return True
        else:
            # Learned position embeddings
            pos_emb_size = max_pos * hidden_size * 2  # bfloat16
            pos_emb_gb = pos_emb_size / (1024**3)
            logger.info(f"Position embedding size: {pos_emb_gb:.2f} GB")
            
            if max_pos >= 250000:
                logger.info("βœ“ Sufficient position embeddings for 250K context")
                return True
            else:
                logger.warning("βœ— Insufficient position embeddings")
                return False
    
    def test_attention_computation(self, sequence_lengths: list = [1024, 8192, 32768, 131072]):
        """Test attention computation at various lengths"""
        logger.info("\n" + "="*80)
        logger.info("TEST 3: Attention Computation Scaling")
        logger.info("="*80)
        
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_attention_heads
        head_dim = hidden_size // num_heads
        
        logger.info(f"Attention heads: {num_heads}")
        logger.info(f"Head dimension: {head_dim}")
        
        for seq_len in sequence_lengths:
            # Calculate attention matrix size
            # For self-attention: (batch, heads, seq_len, seq_len)
            attn_size = 1 * num_heads * seq_len * seq_len * 2  # bfloat16
            attn_gb = attn_size / (1024**3)
            
            logger.info(f"\nSequence length: {seq_len:,} tokens")
            logger.info(f"  Attention matrix: {attn_gb:.2f} GB")
            
            if seq_len <= 32768:
                logger.info(f"  βœ“ Manageable size")
            elif seq_len <= 131072:
                logger.info(f"  ⚠ Large - may need Flash Attention")
            else:
                logger.info(f"  ⚠ Very large - requires optimizations")
        
        # Check for Flash Attention support
        use_flash = getattr(self.config, 'use_flash_attention_2', False)
        if use_flash:
            logger.info("\nβœ“ Flash Attention 2 enabled - efficient for long contexts")
        else:
            logger.warning("\n⚠ Flash Attention not configured - may be slow for long contexts")
        
        return True
    
    def test_memory_requirements(self):
        """Calculate memory requirements for 250K context"""
        logger.info("\n" + "="*80)
        logger.info("TEST 4: Memory Requirements")
        logger.info("="*80)
        
        context_length = 250000
        batch_size = 1
        hidden_size = self.config.hidden_size
        num_layers = self.config.num_hidden_layers
        
        logger.info(f"Configuration:")
        logger.info(f"  Context: {context_length:,} tokens")
        logger.info(f"  Batch size: {batch_size}")
        logger.info(f"  Hidden size: {hidden_size:,}")
        logger.info(f"  Layers: {num_layers}")
        
        # Calculate activation memory (rough estimate)
        # Main components: hidden states, attention outputs
        hidden_states_size = batch_size * context_length * hidden_size * 2  # bfloat16
        hidden_states_gb = hidden_states_size / (1024**3)
        
        # Per layer
        layer_memory_gb = hidden_states_gb * 2  # rough estimate with attention
        total_activation_gb = layer_memory_gb * num_layers
        
        logger.info(f"\nMemory estimates:")
        logger.info(f"  Hidden states per layer: {hidden_states_gb:.2f} GB")
        logger.info(f"  Total activation memory: {total_activation_gb:.2f} GB")
        logger.info(f"  Model weights: ~349 GB")
        logger.info(f"  Total (weights + activations): ~{349 + total_activation_gb:.2f} GB")
        
        logger.info(f"\nRecommendations:")
        if total_activation_gb < 50:
            logger.info("  βœ“ Should fit on 8x A100 (80GB) GPUs")
        elif total_activation_gb < 100:
            logger.info("  ⚠ May need gradient checkpointing")
        else:
            logger.info("  ⚠ Will need aggressive optimizations (checkpointing, CPU offload)")
        
        return True
    
    def test_rope_frequencies(self):
        """Test RoPE frequency calculations for long context"""
        logger.info("\n" + "="*80)
        logger.info("TEST 5: RoPE Frequency Analysis")
        logger.info("="*80)
        
        rope_theta = getattr(self.config, 'rope_theta', 10000)
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_attention_heads
        head_dim = hidden_size // num_heads
        
        logger.info(f"RoPE theta: {rope_theta:,}")
        logger.info(f"Head dimension: {head_dim}")
        
        # Calculate frequency range
        # freqs = theta^(-2i/d) for i in [0, d/2]
        min_freq = rope_theta ** (-2 * (head_dim-1) / head_dim)
        max_freq = rope_theta ** 0
        
        logger.info(f"Frequency range: [{min_freq:.6f}, {max_freq:.6f}]")
        
        # Calculate wavelengths at different frequencies
        wavelengths = [2 * 3.14159 / (rope_theta ** (-2 * i / head_dim)) 
                      for i in range(0, head_dim // 2, head_dim // 8)]
        
        logger.info(f"\nWavelengths (in tokens):")
        for i, wl in enumerate(wavelengths):
            logger.info(f"  Frequency {i}: {wl:,.0f} tokens")
        
        max_wavelength = max(wavelengths)
        if max_wavelength >= 250000:
            logger.info(f"\nβœ“ Maximum wavelength ({max_wavelength:,.0f}) supports 250K context")
        else:
            logger.warning(f"\n⚠ Maximum wavelength ({max_wavelength:,.0f}) may be insufficient")
        
        return True
    
    def run_all_tests(self):
        """Run all context length tests"""
        logger.info("\n" + "="*80)
        logger.info("HELION-OSC 250K CONTEXT LENGTH TEST SUITE")
        logger.info("="*80)
        
        results = {
            "tokenization": self.test_tokenization_capacity(),
            "position_embeddings": self.test_position_embeddings(),
            "attention_scaling": self.test_attention_computation(),
            "memory_requirements": self.test_memory_requirements(),
            "rope_frequencies": self.test_rope_frequencies()
        }
        
        # Summary
        logger.info("\n" + "="*80)
        logger.info("TEST SUMMARY")
        logger.info("="*80)
        
        for test_name, passed in results.items():
            status = "βœ“ PASS" if passed else "βœ— FAIL"
            logger.info(f"{test_name}: {status}")
        
        all_passed = all(results.values())
        
        if all_passed:
            logger.info("\nβœ“ All tests passed - Model supports 250K context length")
        else:
            logger.warning("\n⚠ Some tests failed - Check configuration")
        
        return all_passed


def main():
    """Main test script"""
    import argparse
    
    parser = argparse.ArgumentParser(description="Test Helion-OSC 250K context support")
    parser.add_argument(
        "--model-path",
        type=str,
        default="./inference",
        help="Path to model inference directory"
    )
    parser.add_argument(
        "--test",
        choices=["all", "tokenization", "position", "attention", "memory", "rope"],
        default="all",
        help="Which test to run"
    )
    
    args = parser.parse_args()
    
    tester = LongContextTester(args.model_path)
    
    if args.test == "all":
        tester.run_all_tests()
    elif args.test == "tokenization":
        tester.test_tokenization_capacity()
    elif args.test == "position":
        tester.test_position_embeddings()
    elif args.test == "attention":
        tester.test_attention_computation()
    elif args.test == "memory":
        tester.test_memory_requirements()
    elif args.test == "rope":
        tester.test_rope_frequencies()


if __name__ == "__main__":
    main()