Text Generation
Transformers
Safetensors
English
monoid
causal-lm
linear-attention
state-space
O(1)-inference
vector-decay
reasoning
conversational
custom_code
Instructions to use NoesisLab/Spartacus-1B-Instruct with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use NoesisLab/Spartacus-1B-Instruct with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="NoesisLab/Spartacus-1B-Instruct", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("NoesisLab/Spartacus-1B-Instruct", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use NoesisLab/Spartacus-1B-Instruct with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "NoesisLab/Spartacus-1B-Instruct" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "NoesisLab/Spartacus-1B-Instruct", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/NoesisLab/Spartacus-1B-Instruct
- SGLang
How to use NoesisLab/Spartacus-1B-Instruct with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "NoesisLab/Spartacus-1B-Instruct" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "NoesisLab/Spartacus-1B-Instruct", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "NoesisLab/Spartacus-1B-Instruct" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "NoesisLab/Spartacus-1B-Instruct", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use NoesisLab/Spartacus-1B-Instruct with Docker Model Runner:
docker model run hf.co/NoesisLab/Spartacus-1B-Instruct
| """ | |
| MonoidForCausalLM โ Causal Monoid Language Model (HuggingFace Compatible) | |
| MonoidForCausalLM โ ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ (ๅ ผๅฎน HuggingFace) | |
| Architecture / ๆถๆๆฆ่ฆ: | |
| Replace softmax attention with a monoid parallel-scan recurrence. | |
| ็จๅนบๅ็พคๅนถ่กๆซๆ้ๆจๆฟไปฃ softmax ๆณจๆๅใ | |
| Core idea / ๆ ธๅฟๆๆณ: | |
| Softmax attention computes o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i | |
| โ requires O(T) KV-cache per layer at inference. | |
| Softmax ๆณจๆๅ่ฎก็ฎ o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i | |
| โ ๆจ็ๆถๆฏๅฑ้่ฆ O(T) ็ KV ็ผๅญใ | |
| Monoid attention compresses the entire causal history into a | |
| fixed-size state matrix S_t โ โ^{dรd} per head: | |
| S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (vector decay recurrence) | |
| o_t = q_t ยท S_t (state readout) | |
| where ฮฑ_t โ โ^d is a per-dimension vector decay gate. | |
| ๅนบๅ็พคๆณจๆๅๅฐๅฎๆดๅ ๆๅๅฒๅ็ผฉๅฐๆฏไธชๅคดไธไธชๅบๅฎๅคงๅฐ็็ถๆ็ฉ้ต S_t: | |
| S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (ๅ้่กฐๅ้ๆจ) | |
| o_t = q_t ยท S_t (็ถๆ่ฏปๅบ) | |
| ๅ ถไธญ ฮฑ_t โ โ^d ๆฏ้็ปดๅบฆ็ๅ้่กฐๅ้จใ | |
| This is a monoid because the binary operator: | |
| (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) | |
| is associative โ enables parallel prefix scan for training, | |
| and O(1) sequential update for inference. | |
| ่ฟๆฏไธไธชๅนบๅ็พค๏ผๅ ไธบไบๅ ็ฎๅญ: | |
| (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) | |
| ๆปก่ถณ็ปๅๅพ โ ่ฎญ็ปๆถๅฏ็จๅนถ่กๅ็ผๆซๆ๏ผๆจ็ๆถ O(1) ้ๆญฅ้ๆจใ | |
| Key properties / ๅ ณ้ฎ็นๆง: | |
| โ Explicit causal modeling โ ฮฑ_t gate explicitly controls how fast | |
| past information decays, making causality a first-class citizen. | |
| ๆพๅผๅ ๆๅปบๆจก โ ฮฑ_t ่กฐๅ้จๆพๅผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ้็๏ผ | |
| ๅ ๆๆงๆฏไธ็ญๅ ฌๆฐ่้้ mask ๆฝๅ ็็บฆๆใ | |
| โ Monoid state compression โ the full causal prefix x_{1:t} is | |
| lossily compressed into a fixed-size (dรd) state matrix per head. | |
| No O(T) KV-cache needed; inference is O(1) per token per layer. | |
| ๅนบๅ็พค็ถๆๅ็ผฉ โ ๅฎๆดๅ ๆๅ็ผ x_{1:t} ่ขซๆๆๅ็ผฉๅฐๆฏไธชๅคด | |
| ๅบๅฎๅคงๅฐ็ (dรd) ็ถๆ็ฉ้ตไธญใๆ ้ O(T) KV ็ผๅญ๏ผ | |
| ๆจ็ๆถๆฏๅฑๆฏ token O(1)ใ | |
| โ Parallel training โ associativity of โ enables O(T) parallel | |
| prefix scan (vs O(Tยฒ) for softmax attention). | |
| ๅนถ่ก่ฎญ็ป โ โ ็็ปๅๅพไฝฟ O(T) ๅนถ่กๅ็ผๆซๆๆไธบๅฏ่ฝ | |
| (ๅฏนๆฏ softmax ๆณจๆๅ็ O(Tยฒ))ใ | |
| Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers. | |
| ๅค็จ HuggingFace Transformers ็ LlamaMLP + LlamaRMSNormใ | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm | |
| try: | |
| from monoid_scan_cuda import parallel_scan, parallel_scan_with_state | |
| except ImportError: | |
| # Pure-PyTorch fallback (sequential scan) โ works on CPU / MPS / any device. | |
| # Slower than the fused CUDA kernel but numerically identical. | |
| def parallel_scan(alpha: Tensor, kv: Tensor) -> Tensor: | |
| """Sequential prefix scan fallback: S_t[i,:] = ฮฑ_t[i]ยทS_{t-1}[i,:] + kv_t[i,:].""" | |
| B, H, T, d1, d2 = kv.shape | |
| states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) | |
| S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) | |
| for t in range(T): | |
| decay = alpha[:, :, t] # [B, H, d] | |
| while decay.dim() < S.dim(): | |
| decay = decay.unsqueeze(-1) | |
| S = S * decay + kv[:, :, t] | |
| states[:, :, t] = S | |
| return states | |
| def parallel_scan_with_state(alpha: Tensor, kv: Tensor): | |
| """Sequential prefix scan that also returns the final (decay_acc, S) state.""" | |
| B, H, T, d1, d2 = kv.shape | |
| states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) | |
| S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) | |
| decay_acc = torch.ones(B, H, d1, device=alpha.device, dtype=alpha.dtype) | |
| for t in range(T): | |
| decay = alpha[:, :, t] | |
| while decay.dim() < S.dim(): | |
| decay = decay.unsqueeze(-1) | |
| S = S * decay + kv[:, :, t] | |
| states[:, :, t] = S | |
| decay_acc = decay_acc * alpha[:, :, t] | |
| return states, (decay_acc, S) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Config / ้ ็ฝฎ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidConfig(PretrainedConfig): | |
| """ | |
| Configuration for the Monoid causal language model. | |
| ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ็้ ็ฝฎใ | |
| Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding) | |
| so that weights can be directly transferred from Llama checkpoints. | |
| ไธ LlamaConfig ็ๅ ฑไบซ็ปไปถ (MLP, RMSNorm, embedding) ไฟๆไธ่ด, | |
| ไปฅไพฟไป Llama ๆฃๆฅ็น็ดๆฅ่ฟ็งปๆ้ใ | |
| """ | |
| model_type = "monoid" | |
| def __init__( | |
| self, | |
| vocab_size: int = 32000, | |
| hidden_size: int = 576, | |
| intermediate_size: int = 1536, | |
| num_hidden_layers: int = 30, | |
| num_attention_heads: int = 9, | |
| head_dim: int = 64, | |
| max_position_embeddings: int = 2048, | |
| rms_norm_eps: float = 1e-5, | |
| hidden_act: str = "silu", | |
| mlp_bias: bool = False, | |
| attention_bias: bool = False, | |
| tie_word_embeddings: bool = True, | |
| initializer_range: float = 0.041666666666666664, | |
| pad_token_id: int = None, | |
| bos_token_id: int = 1, | |
| eos_token_id: int = 2, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| tie_word_embeddings=tie_word_embeddings, | |
| **kwargs, | |
| ) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.intermediate_size = intermediate_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.head_dim = head_dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.rms_norm_eps = rms_norm_eps | |
| self.hidden_act = hidden_act | |
| self.mlp_bias = mlp_bias | |
| self.attention_bias = attention_bias | |
| self.initializer_range = initializer_range | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Monoid Cache โ O(1) state replaces O(T) KV-Cache | |
| # ๅนบๅ็พค็ผๅญ โ O(1) ็ถๆๆฟไปฃ O(T) KV ็ผๅญ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidCache: | |
| """ | |
| Per-layer monoid state cache for autoregressive inference. | |
| ่ชๅๅฝๆจ็็้ๅฑๅนบๅ็พค็ถๆ็ผๅญใ | |
| Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory), | |
| each layer here stores exactly ONE state tuple: | |
| (decay_acc, S) where S โ โ^{B, H, d, d} | |
| This is the monoid "sum" of all past (ฮฑ_i, k_iโv_i) via โ. | |
| Memory is O(1) per layer regardless of sequence length. | |
| ไธๅไบ Transformer ็ KV-Cache (ๅญๅจๆๆ่ฟๅป็ key ๅ value, O(T) ๅ ๅญ), | |
| ่ฟ้ๆฏๅฑไป ๅญๅจไธไธช็ถๆๅ ็ป: | |
| (decay_acc, S) ๅ ถไธญ S โ โ^{B, H, d, d} | |
| ่ฟๆฏๆๆ่ฟๅป็ (ฮฑ_i, k_iโv_i) ้่ฟ โ ็ดฏ็งฏ็ๅนบๅ็พค "ๅ"ใ | |
| ๆ ่ฎบๅบๅๅค้ฟ๏ผๆฏๅฑๅ ๅญ O(1)ใ | |
| """ | |
| def __init__(self): | |
| self.states: list[tuple[Tensor, Tensor] | None] = [] | |
| self.seen_tokens: int = 0 | |
| def get_seq_length(self, layer_idx: int = 0) -> int: | |
| return self.seen_tokens | |
| def update(self, layer_idx: int, state: tuple[Tensor, Tensor]): | |
| """Store the accumulated monoid state for a given layer. | |
| ๅญๅจๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" | |
| while len(self.states) <= layer_idx: | |
| self.states.append(None) | |
| self.states[layer_idx] = state | |
| def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None: | |
| """Retrieve the accumulated monoid state for a given layer. | |
| ่ทๅๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" | |
| if layer_idx < len(self.states): | |
| return self.states[layer_idx] | |
| return None | |
| def reorder_cache(self, beam_idx: torch.LongTensor): | |
| """Reorder cache for beam search. ไธบ beam search ้ๆ็ผๅญใ""" | |
| for i, state in enumerate(self.states): | |
| if state is not None: | |
| log_d, kv = state | |
| self.states[i] = (log_d[beam_idx], kv[beam_idx]) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Monoid Operator โ the algebraic heart | |
| # ๅนบๅ็พค็ฎๅญ โ ไปฃๆฐๆ ธๅฟ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def monoid_op( | |
| a: tuple[Tensor, Tensor], | |
| b: tuple[Tensor, Tensor], | |
| ) -> tuple[Tensor, Tensor]: | |
| """ | |
| The monoid binary operator โ on (vector decay, state matrix) pairs. | |
| ๅนบๅ็พคไบๅ ็ฎๅญ โ๏ผไฝ็จไบ (ๅ้่กฐๅ, ็ถๆ็ฉ้ต) ๅฏนใ | |
| Definition / ๅฎไน: | |
| (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) | |
| where ฮฑ, ฮฒ โ (0,1)^d are per-dimension vector decay gates (sigmoid output). | |
| Why this is a monoid / ไธบไปไน่ฟๆฏๅนบๅ็พค: | |
| โข Associativity / ็ปๅๅพ: | |
| (a โ b) โ c = a โ (b โ c) โ | |
| This enables parallel prefix scan for training (reduce tree) | |
| and O(1) left-fold for inference (sequential append). | |
| ็ปๅๅพไฝฟ่ฎญ็ปๆถๅฏไปฅ็จๅนถ่กๅ็ผๆซๆ (ๅฝ็บฆๆ ), | |
| ๆจ็ๆถๅฏไปฅ O(1) ๅทฆๆๅ (้ๆญฅ่ฟฝๅ )ใ | |
| โข Identity / ๅไฝๅ : | |
| e = (1, 0) โ e โ a = a โ e = a โ | |
| Causal semantics / ๅ ๆ่ฏญไน: | |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t | |
| The decay ฮฑ_t โ (0,1) explicitly controls how much of the past | |
| the model retains. This is *explicit causal modeling* โ the model | |
| must learn to balance retention vs novelty at every timestep. | |
| ่กฐๅ ฮฑ_t โ (0,1) ๆพๅผๆงๅถๆจกๅไฟ็ๅคๅฐ่ฟๅปไฟกๆฏใ | |
| ่ฟๅฐฑๆฏ *ๆพๅผๅ ๆๅปบๆจก* โ ๆจกๅๅฟ ้กปๅจๆฏไธชๆถ้ดๆญฅๅญฆไน ๅฆไฝ | |
| ๅนณ่กกไฟ็ๆงไฟกๆฏไธๅธๆถๆฐไฟกๆฏใ | |
| """ | |
| decay_a, kv_a = a | |
| decay_b, kv_b = b | |
| new_decay = decay_a * decay_b # ฮฑยทฮฒ (element-wise product) | |
| while decay_b.dim() < kv_a.dim(): | |
| decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1] | |
| return new_decay, kv_a * decay_b + kv_b # ฮฒยทS + X | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Monoid Attention โ the core innovation | |
| # ๅนบๅ็พคๆณจๆๅ โ ๆ ธๅฟๅๆฐๅฑ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidAttention(nn.Module): | |
| """ | |
| Monoid Causal Attention โ replaces softmax attention entirely. | |
| ๅนบๅ็พคๅ ๆๆณจๆๅ โ ๅฎๅ จๆฟไปฃ softmax ๆณจๆๅใ | |
| Key differences from standard attention / ไธๆ ๅๆณจๆๅ็ๅ ณ้ฎๅบๅซ: | |
| โ No RoPE / positional encoding โ position is implicitly encoded | |
| by the causal decay gate ฮฑ_t. The model learns *when* to forget | |
| rather than encoding *where* tokens are. | |
| ไธไฝฟ็จ RoPE / ไฝ็ฝฎ็ผ็ โ ไฝ็ฝฎไฟกๆฏ็ฑๅ ๆ่กฐๅ้จ ฮฑ_t ้ๅผ็ผ็ ใ | |
| ๆจกๅๅญฆไน *ไฝๆถ้ๅฟ* ่้็ผ็ token *ๅจๅช้*ใ | |
| โ No KV-Cache โ replaced by MonoidCache with O(1) state per layer. | |
| Each state S โ โ^{Hรdรd} is a compressed summary of ALL past tokens. | |
| ไธไฝฟ็จ KV ็ผๅญ โ ็ฑ O(1) ็ MonoidCache ็ถๆๆฟไปฃใ | |
| ๆฏไธช็ถๆ S โ โ^{Hรdรd} ๆฏๆๆ่ฟๅป token ็ๅ็ผฉๆ่ฆใ | |
| โ No attention mask โ causality is built into the recurrence itself. | |
| S_t only depends on S_{t-1} and the current token by construction. | |
| ไธไฝฟ็จๆณจๆๅๆฉ็ โ ๅ ๆๆงๅ ๅปบไบ้ๆจ็ปๆๆฌ่บซใ | |
| S_t ไป ไพ่ต S_{t-1} ๅๅฝๅ token๏ผ็ปๆไธไฟ่ฏๅ ๆๆงใ | |
| Computation / ่ฎก็ฎ: | |
| Training (parallel scan, O(T)): | |
| k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state | |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # monoid recurrence via prefix scan | |
| o_t = q_t ยท S_t # linear readout from state | |
| Inference (RNN mode, O(1) per token): | |
| Same recurrence, but applied one token at a time. | |
| ่ฎญ็ป (ๅนถ่กๆซๆ, O(T)): | |
| k_t = SiLU(k_proj(x_t)) # ้่ด key ไฟ่ฏ็ถๆ็ฉ้ตๅๆญฃๅฎ | |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # ้่ฟๅ็ผๆซๆๅฎ็ฐๅนบๅ็พค้ๆจ | |
| o_t = q_t ยท S_t # ไป็ถๆไธญ็บฟๆง่ฏปๅบ | |
| ๆจ็ (RNN ๆจกๅผ, ๆฏ token O(1)): | |
| ๅไธ้ๆจๅ ฌๅผ, ไฝ้ token ้กบๅบๅบ็จใ | |
| """ | |
| def __init__(self, config: MonoidConfig, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = config.head_dim | |
| self.scaling = self.head_dim ** -0.5 # 1/โd, scale factor for qยทS readout | |
| # qยทS ่ฏปๅบ็็ผฉๆพๅ ๅญ | |
| # --- Projections (transferred from Llama) --- | |
| # --- ๆๅฝฑๅฑ (ไป Llama ่ฟ็งป) --- | |
| # q_proj, o_proj: identical dims to Llama, direct copy | |
| # k_proj, v_proj: Llama GQA has fewer KV heads; we tile to full heads | |
| # q_proj, o_proj: ็ปดๅบฆไธ Llama ไธ่ด, ็ดๆฅๅคๅถ | |
| # k_proj, v_proj: Llama GQA ็ KV ๅคดๆดๅฐ; ๆไปฌ้ๅคๅฐๅ จๅคดๆฐ | |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | |
| self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | |
| self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) | |
| # --- Output gate (novel component, randomly initialized) --- | |
| # --- ่พๅบ้จๆง (ๅ จๆฐ็ปไปถ, ้ๆบๅๅงๅ) --- | |
| # Modulates the multi-head readout before o_proj, similar to GLA/RetNet. | |
| # gate = SiLU(gate_proj(x)), output = gate โ concat_heads(o) | |
| # This lets the model suppress or amplify specific head outputs | |
| # conditioned on the current input, increasing expressiveness. | |
| # ๅจ o_proj ไนๅ่ฐๅถๅคๅคด่ฏปๅบ, ็ฑปไผผ GLA/RetNetใ | |
| # gate = SiLU(gate_proj(x)), output = gate โ concat_heads(o) | |
| # ไฝฟๆจกๅ่ฝๆ นๆฎๅฝๅ่พๅ ฅๆๅถๆๆพๅคง็นๅฎๅคด็่พๅบ, ๅขๅ ่กจ่พพๅใ | |
| self.gate_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| # --- Decay gate (novel component, randomly initialized) --- | |
| # --- ่กฐๅ้จ (ๅ จๆฐ็ปไปถ, ้ๆบๅๅงๅ) --- | |
| # Projects hidden_size โ num_heads * head_dim, yielding a VECTOR per head. | |
| # Activation: log_ฮฑ = -softplus(Wx + b), giving ฮฑ โ (0, 1]. | |
| # Vector decay: S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t | |
| # Different feature dimensions can have independent lifetimes: | |
| # - fast-decaying dims for local syntax | |
| # - slow-decaying dims for global entity/fact memory | |
| # ๅฐ hidden_size ๆๅฝฑๅฐ num_heads * head_dim, ๆฏไธชๅคดไบง็ไธไธชๅ้ใ | |
| # ๆฟๆดป: log_ฮฑ = -softplus(Wx + b), ไฝฟ ฮฑ โ (0, 1]ใ | |
| # ๅ้่กฐๅ: S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t | |
| # ไธๅ็นๅพ็ปดๅบฆๆฅๆ็ฌ็ซ็็ๅฝๅจๆ: | |
| # - ๅฟซ้่กฐๅ็็ปดๅบฆ่ด่ดฃๅฑ้จ่ฏญๆณ็ปๆ | |
| # - ๆ ข้่กฐๅ็็ปดๅบฆ่ด่ดฃๅ จๅฑๅฎไฝๅไบๅฎ่ฎฐๅฟ | |
| self.decay_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) | |
| # --- QK-Norm (novel component, randomly initialized) --- | |
| # --- QK ๅฝไธๅ (ๅ จๆฐ็ปไปถ, ้ๆบๅๅงๅ) --- | |
| # Stabilizes the scale of qยทS readout. Without this, the state | |
| # matrix S (sum of outer products) can grow unboundedly. | |
| # ็จณๅฎ qยทS ่ฏปๅบ็ๅฐบๅบฆใๆฒกๆ่ฟไธช, ็ถๆ็ฉ้ต S (ๅค็งฏไนๅ) | |
| # ๅฏ่ฝๆ ็ๅข้ฟใ | |
| self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.o_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| # --- Learnable initial state h0 (novel component, zero-initialized) --- | |
| # --- ๅฏๅญฆไน ๅๅง็ถๆ h0 (ๅ จๆฐ็ปไปถ, ้ถๅๅงๅ) --- | |
| # S_0 = h0 โ โ^{1, H, d, d}, shared across batch. | |
| # Zero-init means the model starts with "no memory" โ a clean slate. | |
| # The model can learn a non-zero h0 as a kind of "system prompt" state. | |
| # S_0 = h0 โ โ^{1, H, d, d}, ่ทจ batch ๅ ฑไบซใ | |
| # ้ถๅๅงๅๆๅณ็ๆจกๅไป"ๆ ่ฎฐๅฟ"ๅผๅง โ ไธๅผ ็ฝ็บธใ | |
| # ๆจกๅๅฏไปฅๅญฆไน ้้ถ็ h0 ไฝไธบไธ็ง"็ณป็ปๆ็คบ"็ถๆใ | |
| self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim)) | |
| def forward( | |
| self, | |
| hidden_states: Tensor, | |
| attention_mask: Tensor | None = None, | |
| monoid_cache: MonoidCache | None = None, | |
| use_cache: bool = False, | |
| ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: | |
| """ | |
| Args: | |
| hidden_states: [B, T, hidden_size] | |
| attention_mask: [B, T] with 1=real token, 0=pad. | |
| For PAD positions: ฮฑ=1 (preserve state), kv=0 (no contribution). | |
| ๆฉ็ : 1=็ๅฎtoken, 0=ๅกซๅ ใ | |
| ๅกซๅ ไฝ็ฝฎ: ฮฑ=1 (ไฟๆ็ถๆไธๅ), kv=0 (ๆ ่ดก็ฎ)ใ | |
| monoid_cache: O(1) state cache for inference | |
| ๆจ็็จ O(1) ็ถๆ็ผๅญ | |
| use_cache: whether to use/update the cache | |
| ๆฏๅฆไฝฟ็จ/ๆดๆฐ็ผๅญ | |
| Returns: | |
| output: [B, T, hidden_size] | |
| final_state: (log_decay_acc, S) or None | |
| """ | |
| B, T, _ = hidden_states.shape | |
| H, d = self.num_heads, self.head_dim | |
| # --- Project to multi-head Q, K, V --- | |
| # --- ๆๅฝฑๅฐๅคๅคด Q, K, V --- | |
| q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) # [B,H,T,d] | |
| k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2) | |
| # --- Output gate: computed from input, applied before o_proj --- | |
| # --- ่พๅบ้จๆง: ไป่พๅ ฅ่ฎก็ฎ, ๅจ o_proj ไนๅๅบ็จ --- | |
| gate = torch.nn.functional.silu(self.gate_proj(hidden_states)) # [B,T,H*d] | |
| # --- QK-Norm: stabilize qยทS readout scale --- | |
| # --- QK ๅฝไธๅ: ็จณๅฎ qยทS ่ฏปๅบๅฐบๅบฆ --- | |
| q = self.q_norm(q) * self.scaling | |
| k = self.k_norm(k) | |
| # --- Non-negative keys via SiLU --- | |
| # --- ้่ฟ SiLU ไฟ่ฏ key ้่ด --- | |
| # Why: the state S = ฮฃ ฮฑ^{t-i} k_iโv_i is a sum of outer products. | |
| # Non-negative k ensures S is positive semi-definite (PSD), | |
| # preventing "feature erasure" where one token's contribution | |
| # cancels another's. PSD guarantees monotonic information accumulation. | |
| # ๅๅ : ็ถๆ S = ฮฃ ฮฑ^{t-i} k_iโv_i ๆฏๅค็งฏไนๅใ | |
| # ้่ด็ k ไฟ่ฏ S ๅๆญฃๅฎ (PSD), ้ฒๆญขไธไธช token ็่ดก็ฎ | |
| # ๆตๆถๅฆไธไธช token ็"็นๅพๆฆ้ค"็ฐ่ฑกใ | |
| # PSD ไฟ่ฏไฟกๆฏๅ่ฐ็งฏ็ดฏใ | |
| k = torch.nn.functional.silu(k) | |
| # --- Compute per-dimension vector decay gate ฮฑ_t --- | |
| # --- ่ฎก็ฎๆฏ็ปดๅบฆๅ้่กฐๅ้จ ฮฑ_t --- | |
| # Sigmoid: ฮฑ = ฯ(Wx + b) | |
| # Value range: ฮฑ โ (0, 1). | |
| # When Wx โ -โ: ฯ โ 0 (complete forgetting) | |
| # When Wx โ +โ: ฯ โ 1 (perfect memory, no forgetting) | |
| # Each dimension of the d-vector decays independently: | |
| # S_t[i,j] = ฮฑ_t[i] ยท S_{t-1}[i,j] + k_t[i] ยท v_t[j] | |
| # | |
| # Sigmoid: ฮฑ = ฯ(Wx + b) | |
| # ๅผๅ: ฮฑ โ (0, 1)ใ | |
| # ๅฝ Wx โ -โ: ฯ โ 0 (ๅฎๅ จ้ๅฟ) | |
| # ๅฝ Wx โ +โ: ฯ โ 1 (ๅฎ็พ่ฎฐๅฟ, ไธ้ๅฟ) | |
| # d-ๅ้็ๆฏไธช็ปดๅบฆ็ฌ็ซ่กฐๅ: | |
| # S_t[i,j] = ฮฑ_t[i] ยท S_{t-1}[i,j] + k_t[i] ยท v_t[j] | |
| raw = self.decay_proj(hidden_states) # [B,T,H*d] | |
| alpha = torch.sigmoid(raw) # [B,T,H*d] | |
| alpha = alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d] | |
| # --- Apply attention_mask: PAD tokens must be invisible to the recurrence --- | |
| # --- ๅบ็จๆณจๆๅๆฉ็ : PAD token ๅฟ ้กปๅฏน้ๆจไธๅฏ่ง --- | |
| # For PAD positions (mask=0): set log_ฮฑ=0 (ฮฑ=1, preserve state) and kv=0 (no contribution). | |
| # This makes S_t = 1ยทS_{t-1} + 0 = S_{t-1}, i.e. PAD is a no-op on the state. | |
| # ๅฏนไบ PAD ไฝ็ฝฎ (mask=0): ่ฎพ log_ฮฑ=0 (ฮฑ=1, ไฟๆ็ถๆ) ไธ kv=0 (ๆ ่ดก็ฎ)ใ | |
| # ่ฟไฝฟๅพ S_t = 1ยทS_{t-1} + 0 = S_{t-1}, ๅณ PAD ๅฏน็ถๆๆฏ็ฉบๆไฝใ | |
| if attention_mask is not None: | |
| # attention_mask: [B, T] โ [B, 1, T, 1] for broadcasting with [B, H, T, d] | |
| mask = attention_mask[:, None, :, None].to(alpha.dtype) # [B,1,T,1] | |
| alpha = alpha * mask + (1 - mask) # PAD โ ฮฑ=1 (preserve state) | |
| k = k * mask # PAD โ k=0 | |
| v = v * mask # PAD โ v=0 โ kv=0 | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Inference path (RNN mode): O(1) per token per layer | |
| # ๆจ็่ทฏๅพ (RNN ๆจกๅผ): ๆฏๅฑๆฏ token O(1) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # When generating, T=1. We apply the monoid operator once | |
| # to fold the new token into the accumulated state. | |
| # This is where "O(1) inference" materializes: | |
| # S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (one monoid_op call) | |
| # o_t = q_t ยท S_t (one matmul) | |
| # Total: O(Hยทdยฒ) per layer โ independent of sequence length. | |
| # | |
| # ็ๆๆถ T=1ใๆไปฌ่ฐ็จไธๆฌกๅนบๅ็พค็ฎๅญๅฐๆฐ token ๆๅ ่ฟ็ดฏ็งฏ็ถๆใ | |
| # ่ฟๅฐฑๆฏ "O(1) ๆจ็" ็ๅ ทไฝไฝ็ฐ: | |
| # S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (ไธๆฌก monoid_op) | |
| # o_t = q_t ยท S_t (ไธๆฌก็ฉ้ตไนๆณ) | |
| # ๆป่ฎก: ๆฏๅฑ O(Hยทdยฒ) โ ไธๅบๅ้ฟๅบฆๆ ๅ ณใ | |
| if use_cache and T == 1: | |
| # Outer product: k_t โ v_t โ โ^{Hรdรd} | |
| # ๅค็งฏ: k_t โ v_t โ โ^{Hรdรd} | |
| kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0]) | |
| alpha_t = alpha[:, :, 0] # [B,H,d] | |
| prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None | |
| if prev is None: | |
| # First token: initialize from learnable h0 | |
| # ็ฌฌไธไธช token: ไปๅฏๅญฆไน ็ h0 ๅๅงๅ | |
| decay_t = alpha_t | |
| while decay_t.dim() < self.h0.dim(): | |
| decay_t = decay_t.unsqueeze(-1) | |
| new_state = (alpha_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t) | |
| else: | |
| # Subsequent tokens: fold via monoid_op โ O(1)! | |
| # ๅ็ปญ token: ้่ฟ monoid_op ๆๅ โ O(1)! | |
| new_state = monoid_op(prev, (alpha_t, kv_t)) | |
| if monoid_cache is not None: | |
| monoid_cache.update(self.layer_idx, new_state) | |
| # Readout: o_t = q_t ยท S_t | |
| # ่ฏปๅบ: o_t = q_t ยท S_t | |
| o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1]) | |
| o = self.o_norm(o) | |
| # Reshape [B,H,d] โ [B,1,H*d] (heads contiguous, matching scan path) | |
| # ้ๅก [B,H,d] โ [B,1,H*d] (ๅคด่ฟ็ปญๆๅ, ไธๆซๆ่ทฏๅพไธ่ด) | |
| o = o.contiguous().view(B, 1, -1) | |
| return self.o_proj(gate * o), new_state | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Inference prefill (use_cache=True, T>1): parallel scan + readout | |
| # ๆจ็้ขๅกซๅ (use_cache=True, T>1): ๅนถ่กๆซๆ + ่ฏปๅบ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Uses the same parallel_scan_with_state as training to leverage | |
| # Triton CUDA kernel acceleration instead of O(T) Python loop. | |
| # Memory: O(BยทHยทTยทdยฒ) โ same as training path. | |
| # ไฝฟ็จไธ่ฎญ็ป็ธๅ็ parallel_scan_with_state ๆฅๅฉ็จ | |
| # Triton CUDA ๆ ธๅฝๆฐๅ ้, ่้ O(T) ็ Python ๅพช็ฏใ | |
| # ๅ ๅญ: O(BยทHยทTยทdยฒ) โ ไธ่ฎญ็ป่ทฏๅพ็ธๅใ | |
| if use_cache: | |
| kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d] | |
| states, (decay_acc, S_T) = parallel_scan_with_state(alpha, kv) | |
| # Add h0 contribution: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0 | |
| # ๅ ๅ h0 ่ดก็ฎ: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0 | |
| cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d] | |
| h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1] | |
| states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d] | |
| # Final state includes h0 contribution | |
| # ๆ็ป็ถๆๅ ๅซ h0 ่ดก็ฎ | |
| total_h0_decay = decay_acc.unsqueeze(-1) # [B,H,d,1] | |
| S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d] | |
| # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works | |
| final_state = (decay_acc, S_final) | |
| if monoid_cache is not None: | |
| monoid_cache.update(self.layer_idx, final_state) | |
| # Vectorized readout: o_t = q_t ยท S_t for all t | |
| # ๅ้ๅ่ฏปๅบ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ o_t = q_t ยท S_t | |
| o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d] | |
| o = self.o_norm(o) | |
| o = o.transpose(1, 2).contiguous().view(B, T, -1) | |
| return self.o_proj(gate * o), final_state | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Training path: parallel scan + vectorized readout | |
| # ่ฎญ็ป่ทฏๅพ: ๅนถ่กๆซๆ + ๅ้ๅ่ฏปๅบ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Materialize full kv tensor [B,H,T,d,d] and scan in one pass. | |
| # Memory: O(BยทHยทTยทdยฒ) โ trades memory for speed. | |
| # Eliminates Tร30 Python-loop kernel launches for outer product | |
| # and readout; scan itself is parallel when CUDA kernel available. | |
| # | |
| # ็ฉๅๅฎๆด kv ๅผ ้ [B,H,T,d,d] ๅนถไธๆฌกๆงๆซๆใ | |
| # ๅ ๅญ: O(BยทHยทTยทdยฒ) โ ไปฅๅ ๅญๆข้ๅบฆใ | |
| # ๆถ้คๅค็งฏๅ่ฏปๅบ็ Tร30 ๆฌก Python ๅพช็ฏ kernel launch; | |
| # ๅฝ CUDA kernel ๅฏ็จๆถๆซๆๆฌ่บซไนๆฏๅนถ่ก็ใ | |
| # Vectorized outer product: kv_t = k_t โ v_t for all t at once | |
| # ๅ้ๅๅค็งฏ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ k_t โ v_t | |
| kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d] | |
| # Parallel prefix scan: S_t = diag(ฮฑ_t)ยทS_{t-1} + kv_t (from S=0) | |
| # ๅนถ่กๅ็ผๆซๆ: S_t = diag(ฮฑ_t)ยทS_{t-1} + kv_t (ไป S=0 ๅผๅง) | |
| # alpha is [B,H,T,d] โ vector decay per dimension. | |
| # alpha ไธบ [B,H,T,d] โ ๆฏ็ปดๅบฆๅ้่กฐๅใ | |
| states = parallel_scan(alpha, kv) # [B,H,T,d,d] | |
| # Add h0 contribution: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0 | |
| # ๅ ๅ h0 ่ดก็ฎ: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0 | |
| cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) # [B,H,T,d] | |
| h0_decay = cum_alpha.unsqueeze(-1) # [B,H,T,d,1] | |
| states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d] | |
| # Vectorized readout: o_t = q_t ยท S_t for all t at once | |
| # ๅ้ๅ่ฏปๅบ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ q_t ยท S_t | |
| o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d] | |
| o = self.o_norm(o) | |
| o = o.transpose(1, 2).contiguous().view(B, T, -1) | |
| return self.o_proj(gate * o), None | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Decoder Layer: MonoidAttn + LlamaMLP + LlamaRMSNorm | |
| # ่งฃ็ ๅฑ: ๅนบๅ็พคๆณจๆๅ + LlamaMLP + LlamaRMSNorm | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidDecoderLayer(nn.Module): | |
| """ | |
| Pre-Norm Transformer block with Monoid attention. | |
| ไฝฟ็จๅนบๅ็พคๆณจๆๅ็ Pre-Norm Transformer ๅใ | |
| Data flow / ๆฐๆฎๆต: | |
| x โ RMSNorm โ MonoidAttn โ +residual โ RMSNorm โ LlamaMLP โ +residual โ out | |
| The MLP and RMSNorm are identical to Llama (weights transferred directly). | |
| Only MonoidAttention is the novel component. | |
| MLP ๅ RMSNorm ไธ Llama ๅฎๅ จ็ธๅ (ๆ้็ดๆฅ่ฟ็งป)ใ | |
| ไป MonoidAttention ๆฏๅ จๆฐ็ปไปถใ | |
| """ | |
| gradient_checkpointing = False | |
| def __init__(self, config: MonoidConfig, layer_idx: int): | |
| super().__init__() | |
| self.self_attn = MonoidAttention(config, layer_idx) | |
| self.mlp = LlamaMLP(config) | |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: Tensor, | |
| attention_mask: Tensor | None = None, | |
| monoid_cache: MonoidCache | None = None, | |
| use_cache: bool = False, | |
| ) -> Tensor: | |
| # --- Attention block with residual --- | |
| # --- ๆณจๆๅๅ + ๆฎๅทฎ่ฟๆฅ --- | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache) | |
| hidden_states = residual + hidden_states | |
| # --- FFN block with residual --- | |
| # --- ๅ้ฆ็ฝ็ปๅ + ๆฎๅทฎ่ฟๆฅ --- | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # MonoidModel (backbone) | |
| # MonoidModel (้ชจๅนฒ็ฝ็ป) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidPreTrainedModel(PreTrainedModel): | |
| config_class = MonoidConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["MonoidDecoderLayer"] | |
| def _init_weights(self, module: nn.Module): | |
| std = self.config.initializer_range | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| if isinstance(module, MonoidAttention): | |
| # decay_proj: bias init so sigmoid(bias) โ 0.95 โ mostly remembering at start | |
| # decay_proj: ๅ็ฝฎๅๅงๅไฝฟ sigmoid(bias) โ 0.95 โ ๅๅงๆถไปฅ่ฎฐๅฟไธบไธป | |
| nn.init.constant_(module.decay_proj.bias, 3.0) | |
| # gate_proj: small init so gate starts near identity (SiLU(0)=0, | |
| # but normal weights give moderate gate values) | |
| # gate_proj: ๅฐๅๅงๅ, ไฝฟ้จๆงไปๆฅ่ฟๆ็ญๅผๅง | |
| nn.init.normal_(module.gate_proj.weight, mean=0.0, std=0.01) | |
| # o_norm: RMSNorm weight defaults to 1.0 (identity), explicit for clarity | |
| # o_norm: RMSNorm ๆ้้ป่ฎคไธบ 1.0 (ๆ็ญ), ๆพๅผ่ฎพ็ฝฎ็กฎไฟๆญฃ็กฎ | |
| nn.init.ones_(module.o_norm.weight) | |
| class MonoidModel(MonoidPreTrainedModel): | |
| """ | |
| Stack of MonoidDecoderLayers with token embedding and final norm. | |
| ๅนบๅ็พค่งฃ็ ๅฑๅ ๅ , ๅธฆ token ๅตๅ ฅๅๆ็ปๅฝไธๅใ | |
| Forward: embed_tokens โ N ร MonoidDecoderLayer โ final_norm | |
| ๅๅ: embed_tokens โ N ร MonoidDecoderLayer โ final_norm | |
| """ | |
| def __init__(self, config: MonoidConfig): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| self.layers = nn.ModuleList( | |
| [MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Tensor | None = None, | |
| attention_mask: Tensor | None = None, | |
| inputs_embeds: Tensor | None = None, | |
| monoid_cache: MonoidCache | None = None, | |
| use_cache: bool = False, | |
| ) -> BaseModelOutputWithPast: | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| hidden_states = inputs_embeds | |
| for layer in self.layers: | |
| if self.gradient_checkpointing and self.training and not use_cache: | |
| hidden_states = self._gradient_checkpointing_func( | |
| layer.__call__, | |
| hidden_states, | |
| attention_mask, | |
| monoid_cache, | |
| use_cache, | |
| ) | |
| else: | |
| hidden_states = layer(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache) | |
| hidden_states = self.norm(hidden_states) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=monoid_cache, | |
| ) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # MonoidForCausalLM โ the full causal LM | |
| # MonoidForCausalLM โ ๅฎๆดๅ ๆ่ฏญ่จๆจกๅ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin): | |
| """ | |
| Monoid-based causal language model with LM head. | |
| ๅบไบๅนบๅ็พค็ๅ ๆ่ฏญ่จๆจกๅ, ๅธฆ่ฏญ่จๆจกๅๅคดใ | |
| The architecture in one sentence: | |
| "Llama body + Monoid mind" โ reuse Llama's proven MLP/embeddings, | |
| replace attention with monoid state compression for O(1) inference. | |
| ไธๅฅ่ฏๆฆๆฌๆถๆ: | |
| "Llama ็่บซไฝ + ๅนบๅ็พค็ๆ็ปด" โ ๅค็จ Llama ๆ็็ MLP/ๅตๅ ฅๅฑ, | |
| ็จๅนบๅ็พค็ถๆๅ็ผฉๆฟๆขๆณจๆๅ, ๅฎ็ฐ O(1) ๆจ็ใ | |
| """ | |
| _tied_weights_keys = ["lm_head.weight"] | |
| # Tell HuggingFace GenerationMixin NOT to create DynamicCache. | |
| # Monoid uses its own O(1) MonoidCache, not KV-Cache. | |
| # ๅ่ฏ HuggingFace ไธ่ฆๅๅปบ DynamicCacheใ | |
| # Monoid ไฝฟ็จ่ชๅทฑ็ O(1) MonoidCache, ไธๆฏ KV ็ผๅญใ | |
| _is_stateful = True | |
| def __init__(self, config: MonoidConfig): | |
| super().__init__(config) | |
| self.model = MonoidModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: Tensor, | |
| past_key_values=None, | |
| attention_mask: Tensor | None = None, | |
| inputs_embeds: Tensor | None = None, | |
| **kwargs, | |
| ) -> dict: | |
| """ | |
| Called by GenerationMixin at each decoding step. | |
| GenerationMixin ๅจๆฏไธช่งฃ็ ๆญฅ่ฐ็จๆญคๆนๆณใ | |
| HuggingFace may pass a DynamicCache; we intercept and replace | |
| it with MonoidCache since we don't use standard KV-cache. | |
| HuggingFace ๅฏ่ฝไผ ๅ ฅ DynamicCache; ๆไปฌๆฆๆชๅนถๆฟๆขไธบ | |
| MonoidCache, ๅ ไธบๆไปฌไธไฝฟ็จๆ ๅ KV ็ผๅญใ | |
| """ | |
| # Intercept non-MonoidCache objects (e.g. DynamicCache from GenerationMixin) | |
| # ๆฆๆช้ MonoidCache ๅฏน่ฑก (ๅฆ GenerationMixin ๅๅปบ็ DynamicCache) | |
| if past_key_values is not None and not isinstance(past_key_values, MonoidCache): | |
| past_key_values = None | |
| if past_key_values is not None and past_key_values.seen_tokens > 0: | |
| # Cache exists โ only feed the latest token (O(1) inference) | |
| # ็ผๅญๅทฒๅญๅจ โ ๅช้่พๅ ฅๆๆฐ็ token (O(1) ๆจ็) | |
| input_ids = input_ids[:, -1:] | |
| # Decode step: single real token, no PAD โ mask not needed | |
| # ่งฃ็ ๆญฅ: ๅไธช็ๅฎtoken, ๆ PAD โ ไธ้่ฆๆฉ็ | |
| attention_mask = None | |
| model_inputs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "monoid_cache": past_key_values, | |
| "use_cache": True, | |
| } | |
| return model_inputs | |
| def forward( | |
| self, | |
| input_ids: Tensor | None = None, | |
| attention_mask: Tensor | None = None, # [B,T] 1=real, 0=pad โ used to mask PAD from recurrence | |
| # [B,T] 1=็ๅฎtoken, 0=ๅกซๅ โ ็จไบๅฑ่ฝPADๅฏน้ๆจ็ๅฝฑๅ | |
| position_ids: Tensor | None = None, # kept for API compat; monoid ignores this | |
| # ไฟ็ API ๅ ผๅฎนๆง; ๅนบๅ็พคไธไฝฟ็จ | |
| past_key_values: MonoidCache | None = None, | |
| inputs_embeds: Tensor | None = None, | |
| labels: Tensor | None = None, | |
| use_cache: bool | None = None, | |
| monoid_cache: MonoidCache | None = None, | |
| output_attentions: bool | None = None, # kept for API compat | |
| output_hidden_states: bool | None = None, # kept for API compat | |
| logits_to_keep: int | Tensor = 0, | |
| **kwargs, | |
| ) -> CausalLMOutputWithPast: | |
| # monoid_cache takes priority; fall back to past_key_values for GenerationMixin compat | |
| # monoid_cache ไผๅ ; ๅ ผๅฎน GenerationMixin ไผ ๅ ฅ็ past_key_values | |
| cache = monoid_cache or past_key_values | |
| # Discard any non-MonoidCache (e.g. DynamicCache injected by GenerationMixin) | |
| # ไธขๅผไปปไฝ้ MonoidCache ๅฏน่ฑก (ๅฆ GenerationMixin ๆณจๅ ฅ็ DynamicCache) | |
| if cache is not None and not isinstance(cache, MonoidCache): | |
| cache = None | |
| if use_cache and cache is None: | |
| cache = MonoidCache() | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| monoid_cache=cache, | |
| use_cache=bool(use_cache), | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| # Optionally only compute logits for the last K tokens (memory saving) | |
| # ๅฏ้ไป ่ฎก็ฎๆๅ K ไธช token ็ logits (่็ๅ ๅญ) | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| # Standard causal LM loss: cross-entropy with shift | |
| # ๆ ๅๅ ๆ่ฏญ่จๆจกๅๆๅคฑ: ๅธฆๅ็งป็ไบคๅ็ต | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = nn.functional.cross_entropy( | |
| shift_logits.view(-1, self.vocab_size), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| if cache is not None: | |
| cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]) | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=cache, | |
| ) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # AutoModel Registration / ่ชๅจๆณจๅ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| AutoConfig.register("monoid", MonoidConfig) | |
| AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Smoke Tests / ้ช่ฏ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if __name__ == '__main__': | |
| device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') | |
| print(f'Device: {device}') | |
| config = MonoidConfig( | |
| vocab_size=49152, | |
| hidden_size=576, | |
| intermediate_size=1536, | |
| num_hidden_layers=30, | |
| num_attention_heads=9, | |
| head_dim=64, | |
| rms_norm_eps=1e-5, | |
| hidden_act="silu", | |
| tie_word_embeddings=True, | |
| ) | |
| model = MonoidForCausalLM(config).to(device) | |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f'Parameters: {n_params:,}') | |
| # -- Training smoke test / ่ฎญ็ปๅ็ๆต่ฏ -- | |
| B, T = 2, 64 | |
| ids = torch.randint(0, config.vocab_size, (B, T), device=device) | |
| out = model(ids, labels=ids) | |
| print(f'Train โ logits: {out.logits.shape}, loss: {out.loss:.4f}') | |
| # -- Inference smoke test (manual RNN loop) / ๆจ็ๅ็ๆต่ฏ (ๆๅจ RNN ๅพช็ฏ) -- | |
| prompt = torch.randint(0, config.vocab_size, (1, 8), device=device) | |
| cache = MonoidCache() | |
| # Prefill / ้ขๅกซๅ | |
| prefill_out = model(prompt, use_cache=True, monoid_cache=cache) | |
| print(f'Prefill โ logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}') | |
| # Decode 1 token / ่งฃ็ 1 ไธช token | |
| next_tok = prefill_out.logits[:, -1:].argmax(dim=-1) | |
| step_out = model(next_tok, use_cache=True, monoid_cache=cache) | |
| print(f'Decode โ logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}') | |
| # -- Monoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ -- | |
| print('\nMonoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ:') | |
| a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) | |
| b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) | |
| c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) | |
| ab_c = monoid_op(monoid_op(a, b), c) | |
| a_bc = monoid_op(a, monoid_op(b, c)) | |
| err = (ab_c[1] - a_bc[1]).abs().max().item() | |
| print(f' |(aโb)โc - aโ(bโc)| = {err:.2e}') | |
| print('\nDone.') | |