JiRackTernary_70b / contrast_optimizer_70b.py
kgrabko's picture
Update contrast_optimizer_70b.py
902b48b verified
# ==============================================================================
# COPYRIGHT (C) 2025-2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
# ==============================================================================
import torch
import numpy as np
from transformers import AutoTokenizer
from JiRackTernaryPyTorch_70b import JiRackTernaryConfig
from load_packed_70b import load_jirack_70b_packed
def test_contrast_levels():
PATH = "JiRack_BitNet_70B_Packed"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B")
config = JiRackTernaryConfig.from_pretrained(PATH)
model = load_jirack_70b_packed(PATH, config)
model.eval().to("cuda:0")
# Уровни контрастности для теста
contrast_scales = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 15.0]
test_text = "The solar system consists of the Sun and the objects that orbit it."
inputs = tokenizer(test_text, return_tensors="pt").to("cuda:0")
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"].bool()
target_ids = inputs["input_ids"].clone()
print(f"\n🔍 Поиск оптимальной резкости логитов...")
print(f"{'Contrast':<10} | {'Average Loss':<15} | {'Perplexity':<12}")
print("-" * 45)
results = []
with torch.no_grad():
# Базовый прогон для получения "чистых" логитов
outputs = model(inputs["input_ids"])
base_logits = outputs.logits
for scale in contrast_scales:
# Применяем масштаб контрастности вручную поверх выхода
scaled_logits = (base_logits - base_logits.mean(dim=-1, keepdim=True)) * scale
shift_logits = scaled_logits[..., :-1, :].contiguous()
shift_labels = target_ids[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
ppl = torch.exp(loss).item()
print(f"{scale:<10.1f} | {loss.item():>15.4f} | {ppl:>12.2f}")
results.append((scale, ppl))
best_scale = min(results, key=lambda x: x[1])[0]
print("-" * 45)
print(f"🎯 Рекомендованная контрастность: {best_scale}")
if __name__ == "__main__":
test_contrast_levels()