mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
load balance bias gradient leak][bugf][act-cache-consistency][keep all loops with kv cache][bugf][lora-depth-extrapolation][clamp scale index beyond max loops][improvement][pyproject-version][bump version to 0 4 0]
1049 lines
41 KiB
Python
1049 lines
41 KiB
Python
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
@dataclass
|
||
class MythosConfig:
|
||
"""
|
||
Hyperparameter configuration for OpenMythos.
|
||
|
||
Core:
|
||
vocab_size -- token vocabulary size
|
||
dim -- model hidden dimension
|
||
n_heads -- number of query attention heads
|
||
n_kv_heads -- number of key/value heads (GQA; ignored by MLA)
|
||
max_seq_len -- maximum sequence length for RoPE precomputation
|
||
max_loop_iters -- default recurrent loop depth T at inference
|
||
prelude_layers -- number of standard transformer layers before the loop
|
||
coda_layers -- number of standard transformer layers after the loop
|
||
|
||
Attention (attn_type selects between the two):
|
||
attn_type -- "gqa" for Grouped Query Attention, "mla" for Multi-Latent Attention
|
||
kv_lora_rank -- [MLA] compressed KV latent dimension stored in the cache
|
||
q_lora_rank -- [MLA] compressed Q latent dimension
|
||
qk_rope_head_dim-- [MLA] per-head dims that receive RoPE
|
||
qk_nope_head_dim-- [MLA] per-head dims without positional encoding
|
||
v_head_dim -- [MLA] per-head value dimension
|
||
|
||
MoE FFN (used inside the recurrent block):
|
||
n_experts -- total number of routed expert FFNs
|
||
n_shared_experts-- number of always-active shared experts
|
||
n_experts_per_tok-- top-K experts selected per token by the router
|
||
expert_dim -- hidden dimension inside each fine-grained expert
|
||
|
||
Other:
|
||
act_threshold -- ACT halting threshold (cumulative probability to stop looping)
|
||
rope_theta -- RoPE base frequency
|
||
lora_rank -- rank of the per-loop depth-wise LoRA adapter
|
||
"""
|
||
|
||
vocab_size: int = 32000
|
||
dim: int = 2048
|
||
n_heads: int = 16
|
||
n_kv_heads: int = 4 # GQA: fewer KV heads than Q heads
|
||
max_seq_len: int = 4096
|
||
max_loop_iters: int = 16 # T — recurrent depth at inference
|
||
prelude_layers: int = 2
|
||
coda_layers: int = 2
|
||
# Attention type: "gqa" | "mla"
|
||
attn_type: str = "mla"
|
||
# MLA params (only used when attn_type="mla")
|
||
kv_lora_rank: int = 512 # compressed KV latent cached instead of full K/V
|
||
q_lora_rank: int = 1536 # compressed Q latent dim
|
||
qk_rope_head_dim: int = 64 # per-head dims that receive RoPE
|
||
qk_nope_head_dim: int = 128 # per-head dims without RoPE
|
||
v_head_dim: int = 128 # per-head value dim
|
||
# MoE
|
||
n_experts: int = 64
|
||
n_shared_experts: int = 2
|
||
n_experts_per_tok: int = 4 # top-K routed
|
||
expert_dim: int = 512 # fine-grained: dim // (n_experts // n_experts_per_tok)
|
||
# ACT halting
|
||
act_threshold: float = 0.99
|
||
# RoPE
|
||
rope_theta: float = 500000.0
|
||
# LoRA depth adaptation
|
||
lora_rank: int = 16
|
||
# Maximum tokens to generate per forward pass
|
||
max_output_tokens: int = 4096
|
||
# Dropout (set 0.0 to disable; 0.1 is standard for pretraining)
|
||
dropout: float = 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RMSNorm
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RMSNorm(nn.Module):
|
||
"""
|
||
Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).
|
||
|
||
Normalizes by the RMS of the input rather than mean+variance, with a
|
||
learned per-channel rescaling weight. No bias term. Used in place of
|
||
LayerNorm throughout the model for stability and efficiency.
|
||
"""
|
||
|
||
def __init__(self, dim: int, eps: float = 1e-6):
|
||
"""
|
||
Args:
|
||
dim -- feature dimension to normalize over
|
||
eps -- small constant added before sqrt for numerical stability
|
||
"""
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input tensor of shape (..., dim)
|
||
Returns:
|
||
RMS-normalized tensor of the same shape, rescaled by self.weight
|
||
"""
|
||
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
||
return x * rms * self.weight
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RoPE
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def precompute_rope_freqs(
|
||
dim: int, max_len: int, theta: float = 500000.0
|
||
) -> torch.Tensor:
|
||
"""
|
||
Precompute complex-valued RoPE rotation matrices for positions 0..max_len-1.
|
||
|
||
Each position gets a complex phasor e^{i·m·θ_k} for each frequency pair k.
|
||
Stored as a complex tensor so that rotation is a single pointwise multiply.
|
||
|
||
Args:
|
||
dim -- head dimension (must be even); frequencies are computed for dim//2 pairs
|
||
max_len -- maximum sequence length to precompute
|
||
theta -- RoPE base (higher = slower frequency decay; 500k is the LLaMA-3 default)
|
||
|
||
Returns:
|
||
complex64 tensor of shape (max_len, dim//2)
|
||
"""
|
||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||
t = torch.arange(max_len, dtype=torch.float32)
|
||
freqs = torch.outer(t, freqs)
|
||
return torch.polar(torch.ones_like(freqs), freqs)
|
||
|
||
|
||
def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Apply rotary positional embeddings to query or key tensors.
|
||
|
||
Interprets each pair of adjacent features as a 2D complex number and
|
||
multiplies by the precomputed phasor for that position, rotating the
|
||
representation in the complex plane without changing its norm.
|
||
|
||
Args:
|
||
x -- tensor of shape (B, T, H, head_dim); head_dim must be even
|
||
freqs_cis -- precomputed complex frequencies of shape (T, head_dim//2),
|
||
already sliced to exactly the positions being processed
|
||
(caller is responsible for correct start_pos offset)
|
||
|
||
Returns:
|
||
Rotated tensor of the same shape and dtype as x
|
||
"""
|
||
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||
return (
|
||
torch.view_as_real(xc * freqs_cis.unsqueeze(0).unsqueeze(2))
|
||
.flatten(-2)
|
||
.to(x.dtype)
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Grouped Query Attention with KV cache
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class GQAttention(nn.Module):
|
||
"""
|
||
Grouped Query Attention (Ainslie et al., 2023).
|
||
|
||
Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is
|
||
shared across n_heads // n_kv_heads query heads, reducing the KV cache size
|
||
by that factor while keeping full query expressiveness.
|
||
|
||
RoPE is applied to both Q and K. K and V are stored in kv_cache after
|
||
RoPE application so that cached values are already positionally encoded and
|
||
do not need to be re-rotated on retrieval.
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig; uses dim, n_heads, n_kv_heads
|
||
"""
|
||
super().__init__()
|
||
self.n_heads = cfg.n_heads
|
||
self.n_kv_heads = cfg.n_kv_heads
|
||
self.head_dim = cfg.dim // cfg.n_heads
|
||
self.groups = cfg.n_heads // cfg.n_kv_heads
|
||
|
||
self.wq = nn.Linear(cfg.dim, cfg.n_heads * self.head_dim, bias=False)
|
||
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
|
||
self.attn_drop = nn.Dropout(cfg.dropout)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
freqs_cis: torch.Tensor,
|
||
mask: Optional[torch.Tensor] = None,
|
||
kv_cache: Optional[dict] = None,
|
||
cache_key: str = "default",
|
||
) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input of shape (B, T, dim)
|
||
freqs_cis -- RoPE frequencies for head_dim, shape (T, head_dim//2)
|
||
mask -- additive causal mask of shape (1, 1, T, S) or None
|
||
kv_cache -- dict mutated in-place; stores {"k": ..., "v": ...} per cache_key
|
||
cache_key -- unique key identifying this layer in the cache dict
|
||
|
||
Returns:
|
||
Output tensor of shape (B, T, dim)
|
||
"""
|
||
B, T, _ = x.shape
|
||
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
|
||
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
|
||
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
|
||
|
||
q = apply_rope(q, freqs_cis)
|
||
k = apply_rope(k, freqs_cis)
|
||
|
||
if kv_cache is not None:
|
||
if cache_key in kv_cache:
|
||
k = torch.cat([kv_cache[cache_key]["k"], k], dim=1)
|
||
v = torch.cat([kv_cache[cache_key]["v"], v], dim=1)
|
||
kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()}
|
||
|
||
# expand KV to match Q heads
|
||
k = k.repeat_interleave(self.groups, dim=2)
|
||
v = v.repeat_interleave(self.groups, dim=2)
|
||
|
||
q = q.transpose(1, 2) # (B, H, T, head_dim)
|
||
k = k.transpose(1, 2)
|
||
v = v.transpose(1, 2)
|
||
|
||
scale = self.head_dim**-0.5
|
||
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||
if mask is not None:
|
||
attn = attn + mask
|
||
attn = self.attn_drop(F.softmax(attn, dim=-1))
|
||
out = torch.matmul(attn, v)
|
||
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
||
return self.wo(out)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Multi-Latent Attention (DeepSeek-V2 style)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class MLAttention(nn.Module):
|
||
"""
|
||
Multi-Latent Attention (DeepSeek-V2, 2024).
|
||
|
||
The key insight: instead of caching full K and V tensors (each of size
|
||
n_heads × head_dim per token), MLA compresses the KV path through a
|
||
low-rank latent c_kv and only caches that plus the RoPE keys. K_nope and
|
||
V are reconstructed from c_kv at each decoding step, trading a cheap
|
||
linear projection for dramatically smaller cache memory.
|
||
|
||
Q path:
|
||
x → q_down (dim→q_lora_rank) → q_norm
|
||
→ q_up_nope (q_lora_rank → n_heads×qk_nope_head_dim) [no RoPE]
|
||
→ q_up_rope (q_lora_rank → n_heads×qk_rope_head_dim) [RoPE applied]
|
||
q = cat(q_nope, q_rope) per head
|
||
|
||
KV path:
|
||
x → kv_down (dim → kv_lora_rank + qk_rope_head_dim)
|
||
splits into c_kv (latent, cached) and k_rope_raw (shared across heads)
|
||
k_rope = RoPE(expand(k_rope_raw)) — applied before caching
|
||
c_kv → kv_norm → kv_up → [k_nope | v] — reconstructed each step
|
||
k = cat(k_nope, k_rope) per head
|
||
|
||
Cache stores: c_kv (kv_lora_rank) + k_rope (n_heads × qk_rope_head_dim),
|
||
versus full GQA cache: n_kv_heads × head_dim × 2. At production scale this
|
||
is roughly a 10–20× memory reduction.
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig; uses dim, n_heads, kv_lora_rank, q_lora_rank,
|
||
qk_rope_head_dim, qk_nope_head_dim, v_head_dim
|
||
"""
|
||
super().__init__()
|
||
self.n_heads = cfg.n_heads
|
||
self.kv_lora_rank = cfg.kv_lora_rank
|
||
self.qk_rope_dim = cfg.qk_rope_head_dim
|
||
self.qk_nope_dim = cfg.qk_nope_head_dim
|
||
self.v_dim = cfg.v_head_dim
|
||
self.q_head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
|
||
|
||
# Q compression
|
||
self.q_down = nn.Linear(cfg.dim, cfg.q_lora_rank, bias=False)
|
||
self.q_norm = RMSNorm(cfg.q_lora_rank)
|
||
self.q_up_nope = nn.Linear(
|
||
cfg.q_lora_rank, cfg.n_heads * cfg.qk_nope_head_dim, bias=False
|
||
)
|
||
self.q_up_rope = nn.Linear(
|
||
cfg.q_lora_rank, cfg.n_heads * cfg.qk_rope_head_dim, bias=False
|
||
)
|
||
|
||
# KV compression: output is [c_kv | k_rope_raw] concatenated
|
||
self.kv_down = nn.Linear(
|
||
cfg.dim, cfg.kv_lora_rank + cfg.qk_rope_head_dim, bias=False
|
||
)
|
||
self.kv_norm = RMSNorm(cfg.kv_lora_rank)
|
||
self.kv_up = nn.Linear(
|
||
cfg.kv_lora_rank,
|
||
cfg.n_heads * (cfg.qk_nope_head_dim + cfg.v_head_dim),
|
||
bias=False,
|
||
)
|
||
|
||
self.wo = nn.Linear(cfg.n_heads * cfg.v_head_dim, cfg.dim, bias=False)
|
||
self.attn_drop = nn.Dropout(cfg.dropout)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
freqs_cis: torch.Tensor,
|
||
mask: Optional[torch.Tensor] = None,
|
||
kv_cache: Optional[dict] = None,
|
||
cache_key: str = "default",
|
||
) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input of shape (B, T, dim)
|
||
freqs_cis -- RoPE frequencies sized for qk_rope_head_dim, shape (T, rope_dim//2)
|
||
mask -- additive causal mask of shape (1, 1, T, S) or None
|
||
kv_cache -- dict mutated in-place; stores {"c_kv": ..., "k_rope": ...}
|
||
cache_key -- unique key identifying this layer in the cache dict
|
||
|
||
Returns:
|
||
Output tensor of shape (B, T, dim)
|
||
"""
|
||
B, T, _ = x.shape
|
||
|
||
# Q
|
||
c_q = self.q_norm(self.q_down(x))
|
||
q_nope = self.q_up_nope(c_q).view(B, T, self.n_heads, self.qk_nope_dim)
|
||
q_rope = self.q_up_rope(c_q).view(B, T, self.n_heads, self.qk_rope_dim)
|
||
q_rope = apply_rope(q_rope, freqs_cis)
|
||
q = torch.cat([q_nope, q_rope], dim=-1) # (B, T, H, nope+rope)
|
||
|
||
# KV compress
|
||
kv_raw = self.kv_down(x)
|
||
c_kv = kv_raw[..., : self.kv_lora_rank] # (B, T, lora_rank) ← cached
|
||
k_rope = kv_raw[..., self.kv_lora_rank :] # (B, T, rope_dim)
|
||
# expand rope keys across heads and apply RoPE before caching so
|
||
# retrieved keys are already positionally encoded
|
||
k_rope = (
|
||
k_rope.unsqueeze(2)
|
||
.expand(B, T, self.n_heads, self.qk_rope_dim)
|
||
.contiguous()
|
||
)
|
||
k_rope = apply_rope(k_rope, freqs_cis) # (B, T, H, rope_dim) ← cached
|
||
|
||
if kv_cache is not None:
|
||
if cache_key in kv_cache:
|
||
c_kv = torch.cat([kv_cache[cache_key]["c_kv"], c_kv], dim=1)
|
||
k_rope = torch.cat([kv_cache[cache_key]["k_rope"], k_rope], dim=1)
|
||
kv_cache[cache_key] = {"c_kv": c_kv.detach(), "k_rope": k_rope.detach()}
|
||
|
||
S = c_kv.shape[1] # full sequence length including cache
|
||
|
||
# reconstruct K_nope and V from latent (not cached, recomputed each step)
|
||
kv = self.kv_up(self.kv_norm(c_kv)) # (B, S, H*(nope+v))
|
||
kv = kv.view(B, S, self.n_heads, self.qk_nope_dim + self.v_dim)
|
||
k_nope = kv[..., : self.qk_nope_dim] # (B, S, H, nope)
|
||
v = kv[..., self.qk_nope_dim :] # (B, S, H, v_dim)
|
||
k = torch.cat([k_nope, k_rope], dim=-1) # (B, S, H, nope+rope)
|
||
|
||
# attention
|
||
q = q.transpose(1, 2) # (B, H, T, q_head_dim)
|
||
k = k.transpose(1, 2) # (B, H, S, q_head_dim)
|
||
v = v.transpose(1, 2) # (B, H, S, v_dim)
|
||
|
||
scale = self.q_head_dim**-0.5
|
||
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||
if mask is not None:
|
||
attn = attn + mask
|
||
attn = self.attn_drop(F.softmax(attn, dim=-1))
|
||
out = torch.matmul(attn, v) # (B, H, T, v_dim)
|
||
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
||
return self.wo(out)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DeepSeek-style MoE FFN
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class Expert(nn.Module):
|
||
"""
|
||
Single SwiGLU feed-forward expert.
|
||
|
||
Implements the gated linear unit variant: output = down(silu(gate(x)) * up(x)).
|
||
Used both as individual routed experts inside MoEFFN and as the standard dense
|
||
FFN in prelude/coda blocks (where expert_dim = dim * 4 // 3).
|
||
"""
|
||
|
||
def __init__(self, dim: int, expert_dim: int):
|
||
"""
|
||
Args:
|
||
dim -- input and output feature dimension
|
||
expert_dim -- inner (hidden) dimension of the expert
|
||
"""
|
||
super().__init__()
|
||
self.gate = nn.Linear(dim, expert_dim, bias=False)
|
||
self.up = nn.Linear(dim, expert_dim, bias=False)
|
||
self.down = nn.Linear(expert_dim, dim, bias=False)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input of shape (..., dim)
|
||
Returns:
|
||
Tensor of shape (..., dim)
|
||
"""
|
||
return self.down(F.silu(self.gate(x)) * self.up(x))
|
||
|
||
|
||
class MoEFFN(nn.Module):
|
||
"""
|
||
Fine-grained Mixture-of-Experts FFN (DeepSeekMoE, Dai et al., 2024).
|
||
|
||
Two classes of experts:
|
||
- Routed experts: n_experts small FFNs; each token activates top-K of them
|
||
via a learned router. A per-expert bias on router logits is updated during
|
||
training to keep load balanced across experts without distorting the loss.
|
||
- Shared experts: n_shared_experts larger FFNs always activated for every token,
|
||
absorbing common cross-domain patterns (syntax, basic reasoning) that would
|
||
otherwise be redundantly learned by many routed experts.
|
||
|
||
Total activated parameters per token ≈ topk/n_experts of routed + all shared,
|
||
keeping compute sparse while the total parameter count stays large.
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig; uses n_experts, n_shared_experts, n_experts_per_tok,
|
||
dim, expert_dim
|
||
"""
|
||
super().__init__()
|
||
self.n_experts = cfg.n_experts
|
||
self.n_shared = cfg.n_shared_experts
|
||
self.topk = cfg.n_experts_per_tok
|
||
|
||
self.router = nn.Linear(cfg.dim, cfg.n_experts, bias=False)
|
||
# load-balancing bias adjusted externally during training; not a gradient param
|
||
self.register_buffer("router_bias", torch.zeros(cfg.n_experts))
|
||
|
||
self.routed_experts = nn.ModuleList(
|
||
[Expert(cfg.dim, cfg.expert_dim) for _ in range(cfg.n_experts)]
|
||
)
|
||
self.shared_experts = nn.ModuleList(
|
||
[
|
||
Expert(cfg.dim, cfg.expert_dim * cfg.n_experts_per_tok)
|
||
for _ in range(self.n_shared)
|
||
]
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input of shape (B, T, dim)
|
||
Returns:
|
||
Tensor of shape (B, T, dim); shared expert outputs are summed on top
|
||
of the weighted routed expert outputs
|
||
"""
|
||
B, T, D = x.shape
|
||
flat = x.view(B * T, D)
|
||
|
||
# Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the
|
||
# selection of which experts fire so underused experts are picked more,
|
||
# but the gating weights come from unbiased softmax scores so the bias
|
||
# never shows up in the gradient.
|
||
logits = self.router(flat) # (B*T, n_experts), unbiased
|
||
scores = F.softmax(logits, dim=-1)
|
||
_, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1)
|
||
topk_scores = scores.gather(-1, topk_idx)
|
||
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm
|
||
|
||
# routed expert dispatch (token-level scatter)
|
||
out = torch.zeros_like(flat)
|
||
for i in range(self.topk):
|
||
expert_ids = topk_idx[:, i]
|
||
token_scores = topk_scores[:, i].unsqueeze(-1)
|
||
for eid in range(self.n_experts):
|
||
mask = expert_ids == eid
|
||
if not mask.any():
|
||
continue
|
||
out[mask] += token_scores[mask] * self.routed_experts[eid](flat[mask])
|
||
|
||
# shared experts always fire for every token
|
||
for shared in self.shared_experts:
|
||
out = out + shared(flat)
|
||
|
||
return out.view(B, T, D)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Loop-index RoPE (differentiates recurrent block across iterations)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def loop_index_embedding(
|
||
h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0
|
||
) -> torch.Tensor:
|
||
"""
|
||
Inject a sinusoidal loop-index signal into the first loop_dim channels of h.
|
||
|
||
Analogous to RoPE for sequence position, but applied over recurrence depth
|
||
instead of token position. Without this, the shared recurrent block weights
|
||
must handle both early-stage pattern-matching and late-stage refinement with
|
||
no signal distinguishing which loop they are on. Adding the loop index lets
|
||
the same parameters implement functionally distinct operations per iteration.
|
||
|
||
Args:
|
||
h -- hidden state tensor of shape (B, T, dim)
|
||
loop_t -- current loop iteration index (0-based)
|
||
loop_dim -- number of leading channels to receive the embedding (must be even)
|
||
theta -- sinusoidal base frequency
|
||
|
||
Returns:
|
||
h with a sinusoidal bias added to its first loop_dim channels; same shape
|
||
"""
|
||
freqs = 1.0 / (
|
||
theta
|
||
** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)
|
||
)
|
||
angles = loop_t * freqs # (loop_dim//2,)
|
||
emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
|
||
emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
|
||
emb_full[:loop_dim] = emb
|
||
return h + emb_full.unsqueeze(0).unsqueeze(0)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Depth-wise LoRA adapter (per loop iteration)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class LoRAAdapter(nn.Module):
|
||
"""
|
||
Depth-wise LoRA adaptation for the recurrent block (Bae et al., 2024).
|
||
|
||
Pure weight-tying (identical weights every loop) limits expressiveness;
|
||
fully distinct weights per loop eliminate parameter savings. This adapter
|
||
sits in between: a shared low-rank down-projection and up-projection matrix B
|
||
are shared across all loops, while a small per-loop scale vector shifts the
|
||
effective transformation at each depth without adding significant parameters.
|
||
|
||
delta(x, t) = (down(x) * scale[t]) @ B
|
||
"""
|
||
|
||
def __init__(self, dim: int, rank: int, max_loops: int):
|
||
"""
|
||
Args:
|
||
dim -- model hidden dimension (input and output size)
|
||
rank -- low-rank bottleneck dimension
|
||
max_loops -- maximum number of loop iterations (determines embedding table size)
|
||
"""
|
||
super().__init__()
|
||
self.down = nn.Linear(dim, rank, bias=False) # shared A: dim → rank
|
||
self.B = nn.Parameter(torch.randn(rank, dim) * 0.02) # shared B: rank → dim
|
||
self.scale = nn.Embedding(max_loops, rank) # per-loop element-wise scale
|
||
|
||
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input tensor of shape (B, T, dim)
|
||
loop_t -- current loop index used to look up the per-loop scale
|
||
|
||
Returns:
|
||
Delta tensor of shape (B, T, dim) to be added to the block output
|
||
"""
|
||
# Clamp for depth extrapolation: at inference n_loops can exceed the
|
||
# training max_loop_iters. Iterations beyond the trained range reuse
|
||
# the last learned per-loop scale rather than indexing out of range.
|
||
max_t = self.scale.num_embeddings - 1
|
||
t_idx = loop_t if loop_t <= max_t else max_t
|
||
s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,)
|
||
down = self.down(x) * s # (B, T, rank)
|
||
return down @ self.B # (B, T, dim)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Single Transformer Block (shared across recurrent loops)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TransformerBlock(nn.Module):
|
||
"""
|
||
Standard pre-norm transformer block with swappable attention and optional MoE FFN.
|
||
|
||
Attention is selected by cfg.attn_type:
|
||
"gqa" → GQAttention (Grouped Query Attention, fewer KV heads)
|
||
"mla" → MLAttention (Multi-Latent Attention, compressed KV cache)
|
||
|
||
FFN is selected by use_moe:
|
||
True → MoEFFN (fine-grained routed + shared experts; used in RecurrentBlock)
|
||
False → Expert (dense SwiGLU FFN; used in Prelude and Coda)
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig, use_moe: bool = False):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig; attn_type selects the attention class
|
||
use_moe -- if True, use MoEFFN; otherwise use a dense Expert FFN
|
||
"""
|
||
super().__init__()
|
||
self.attn_norm = RMSNorm(cfg.dim)
|
||
self.ffn_norm = RMSNorm(cfg.dim)
|
||
self.attn = MLAttention(cfg) if cfg.attn_type == "mla" else GQAttention(cfg)
|
||
self.ffn = MoEFFN(cfg) if use_moe else Expert(cfg.dim, cfg.dim * 4 // 3)
|
||
self.resid_drop = nn.Dropout(cfg.dropout)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
freqs_cis: torch.Tensor,
|
||
mask: Optional[torch.Tensor] = None,
|
||
kv_cache: Optional[dict] = None,
|
||
cache_key: str = "default",
|
||
) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x -- input of shape (B, T, dim)
|
||
freqs_cis -- precomputed RoPE frequencies
|
||
mask -- additive causal mask or None
|
||
kv_cache -- cache dict mutated in-place by the attention layer
|
||
cache_key -- key identifying this layer in the cache
|
||
|
||
Returns:
|
||
Output tensor of shape (B, T, dim)
|
||
"""
|
||
x = x + self.resid_drop(
|
||
self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key)
|
||
)
|
||
x = x + self.resid_drop(self.ffn(self.ffn_norm(x)))
|
||
return x
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LTI-stable injection parameters (spectral radius < 1 by construction)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class LTIInjection(nn.Module):
|
||
"""
|
||
Stable input injection for the recurrent update rule (Parcae, Prairie et al., 2026).
|
||
|
||
The recurrent hidden state evolves as:
|
||
h_{t+1} = A · h_t + B · e + Transformer(h_t, e)
|
||
|
||
where e is the encoded input injected at every loop step to prevent drift.
|
||
Without constraints, A can develop spectral radius ≥ 1, causing the hidden
|
||
state to explode across loop iterations and destabilize training.
|
||
|
||
This class guarantees ρ(A) < 1 by construction via a ZOH discretization:
|
||
A_continuous = Diag(-exp(log_A)) always negative diagonal
|
||
A_discrete = exp(Δt · A_continuous) element-wise, values in (0, 1)
|
||
|
||
where log_A and log_dt are learned parameters and exp ensures positivity.
|
||
This makes looped model training robust to hyperparameter choices and stable
|
||
even at high learning rates.
|
||
"""
|
||
|
||
def __init__(self, dim: int):
|
||
"""
|
||
Args:
|
||
dim -- hidden state dimension; one scalar per channel for A and B
|
||
"""
|
||
super().__init__()
|
||
self.log_A = nn.Parameter(torch.zeros(dim)) # log of A_continuous magnitude
|
||
self.log_dt = nn.Parameter(torch.zeros(1)) # log of discretization step Δt
|
||
self.B = nn.Parameter(torch.ones(dim) * 0.1)
|
||
|
||
def get_A(self) -> torch.Tensor:
|
||
"""
|
||
Compute the discretized diagonal state matrix A_discrete.
|
||
|
||
Returns:
|
||
1-D tensor of shape (dim,) with all values strictly in (0, 1),
|
||
guaranteeing ρ(A) < 1 regardless of learned parameter values.
|
||
"""
|
||
# Compute in log space to avoid 0 * inf = NaN when log_dt → -∞, log_A → +∞.
|
||
# dt * A_c = -exp(log_dt) * exp(log_A) = -exp(log_dt + log_A)
|
||
# Clamp keeps the product finite in float32 for any gradient step size.
|
||
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
|
||
|
||
def forward(
|
||
self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute h_{t+1} = A·h_t + B·e + transformer_out.
|
||
|
||
Args:
|
||
h -- current hidden state (B, T, dim)
|
||
e -- encoded input from Prelude, frozen across loops (B, T, dim)
|
||
transformer_out -- output of the recurrent TransformerBlock at this step (B, T, dim)
|
||
|
||
Returns:
|
||
Updated hidden state of shape (B, T, dim)
|
||
"""
|
||
A = self.get_A()
|
||
return A * h + self.B * e + transformer_out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ACT halting (Adaptive Computation Time)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class ACTHalting(nn.Module):
|
||
"""
|
||
Adaptive Computation Time halting mechanism (Graves, 2016).
|
||
|
||
Learns a per-position halting probability at each loop iteration. Positions
|
||
where the hidden state has converged (high cumulative halting probability)
|
||
stop accumulating updates, while positions still being refined continue.
|
||
This lets easy tokens halt early and hard tokens receive more computation,
|
||
all within the same batch. Also makes the model Turing-complete under
|
||
certain assumptions about the expressiveness of the transformer block.
|
||
"""
|
||
|
||
def __init__(self, dim: int):
|
||
"""
|
||
Args:
|
||
dim -- hidden state dimension; input to the halting scalar predictor
|
||
"""
|
||
super().__init__()
|
||
self.halt = nn.Linear(dim, 1)
|
||
|
||
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Predict per-position halting probability from the current hidden state.
|
||
|
||
Args:
|
||
h -- hidden state of shape (B, T, dim)
|
||
|
||
Returns:
|
||
Halting probability tensor of shape (B, T), values in (0, 1)
|
||
"""
|
||
return torch.sigmoid(self.halt(h)).squeeze(-1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Recurrent Block (one set of weights, looped T times)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RecurrentBlock(nn.Module):
|
||
"""
|
||
The core recurrent block of OpenMythos — a single TransformerBlock looped T times.
|
||
|
||
At each loop iteration t, the hidden state h is updated via:
|
||
1. loop_index_embedding: inject sinusoidal loop-index signal into h
|
||
2. TransformerBlock: compute attention + MoE FFN on normalized (h + e)
|
||
3. LoRAAdapter: apply depth-wise LoRA delta to transformer output
|
||
4. LTIInjection: stable update h = A·h + B·e + transformer_out
|
||
5. ACTHalting: accumulate per-position halting probabilities;
|
||
positions that have converged stop contributing
|
||
|
||
The encoded input e (output of the Prelude) is injected at every step to keep
|
||
the original input signal alive across arbitrary loop depth, preventing drift.
|
||
The ACT mechanism produces a weighted sum of hidden states across iterations,
|
||
where the weights reflect when each position converged.
|
||
|
||
More loop iterations at inference = deeper reasoning chains, following the
|
||
depth-extrapolation property of looped transformers (Saunshi et al., 2025).
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig; uses dim, lora_rank, max_loop_iters, act_threshold
|
||
"""
|
||
super().__init__()
|
||
self.cfg = cfg
|
||
self.block = TransformerBlock(cfg, use_moe=True)
|
||
self.injection = LTIInjection(cfg.dim)
|
||
self.act = ACTHalting(cfg.dim)
|
||
self.lora = LoRAAdapter(cfg.dim, cfg.lora_rank, cfg.max_loop_iters)
|
||
self.norm = RMSNorm(cfg.dim)
|
||
self.loop_dim = (
|
||
cfg.dim // 8
|
||
) # fraction of channels receiving loop-index embedding
|
||
|
||
def forward(
|
||
self,
|
||
h: torch.Tensor,
|
||
e: torch.Tensor,
|
||
freqs_cis: torch.Tensor,
|
||
mask: Optional[torch.Tensor] = None,
|
||
n_loops: Optional[int] = None,
|
||
kv_cache: Optional[dict] = None,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Run the recurrent loop for up to n_loops iterations with ACT early exit.
|
||
|
||
Args:
|
||
h -- initial hidden state from the Prelude, shape (B, T, dim)
|
||
e -- encoded input frozen for injection each step, shape (B, T, dim)
|
||
freqs_cis-- precomputed RoPE frequencies
|
||
mask -- additive causal mask or None
|
||
n_loops -- number of loop iterations; defaults to cfg.max_loop_iters.
|
||
Can be increased at inference for deeper reasoning (depth extrapolation).
|
||
kv_cache -- cache dict passed through to the inner TransformerBlock;
|
||
each loop iteration uses a separate cache key
|
||
|
||
Returns:
|
||
ACT-weighted sum of hidden states across iterations, shape (B, T, dim)
|
||
"""
|
||
n_loops = n_loops or self.cfg.max_loop_iters
|
||
B, T, D = h.shape
|
||
|
||
halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
|
||
cumulative_p = torch.zeros(B, T, device=h.device)
|
||
h_out = torch.zeros_like(h)
|
||
|
||
for t in range(n_loops):
|
||
h_loop = loop_index_embedding(h, t, self.loop_dim)
|
||
combined = self.norm(h_loop + e)
|
||
cache_key = f"recurrent_loop_{t}"
|
||
trans_out = self.block(combined, freqs_cis, mask, kv_cache, cache_key)
|
||
trans_out = trans_out + self.lora(trans_out, t)
|
||
h = self.injection(h, e, trans_out)
|
||
|
||
p = self.act(h) # (B, T)
|
||
still_running = ~halted
|
||
|
||
# ACT remainder trick: once cumulative_p + p crosses threshold,
|
||
# assign the remaining probability mass as the final weight.
|
||
# Gate by still_running so halted positions contribute exactly
|
||
# once (on the halting step) and zero thereafter — otherwise
|
||
# threshold<1 leaves a non-zero remainder that leaks every step.
|
||
remainder = (1.0 - cumulative_p).clamp(min=0)
|
||
weight = torch.where(
|
||
cumulative_p + p >= self.cfg.act_threshold,
|
||
remainder,
|
||
p,
|
||
)
|
||
weight = weight * still_running.float()
|
||
h_out = h_out + weight.unsqueeze(-1) * h
|
||
|
||
cumulative_p = cumulative_p + p * still_running.float()
|
||
halted = halted | (cumulative_p >= self.cfg.act_threshold)
|
||
|
||
# Only short-circuit when there is no KV cache to keep consistent.
|
||
# With a cache, every loop depth must run on every forward pass so
|
||
# later decode steps find populated keys at every cache_key.
|
||
if halted.all() and kv_cache is None:
|
||
break
|
||
|
||
return h_out
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Full Model
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class OpenMythos(nn.Module):
|
||
"""
|
||
OpenMythos — Recurrent-Depth Transformer language model.
|
||
|
||
Implements the hypothesized Claude Mythos architecture as a Recurrent-Depth
|
||
Transformer (RDT). The model divides computation into three functional blocks:
|
||
|
||
Input tokens
|
||
↓
|
||
[Prelude] — prelude_layers standard transformer blocks, run once
|
||
↓
|
||
[Recurrent Block] — one transformer block looped T times with input injection
|
||
↑_______↓ h_{t+1} = A·h_t + B·e + Transformer(h_t, e)
|
||
↓
|
||
[Coda] — coda_layers standard transformer blocks, run once
|
||
↓
|
||
Output logits
|
||
|
||
Key properties:
|
||
- Same weights, more loops → deeper reasoning, no parameter growth
|
||
- Depth extrapolation: train on N loops, test on N+k loops (emergent)
|
||
- ACT halting: variable compute per position within a batch
|
||
- MoE FFN in the recurrent block: breadth across domains
|
||
- LTI-stable injection: spectral radius < 1 guaranteed by construction
|
||
- Supports both GQA and MLA attention (set via cfg.attn_type)
|
||
"""
|
||
|
||
def __init__(self, cfg: MythosConfig):
|
||
"""
|
||
Args:
|
||
cfg -- MythosConfig specifying all architecture hyperparameters
|
||
"""
|
||
super().__init__()
|
||
self.cfg = cfg
|
||
|
||
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
|
||
|
||
# GQA uses full head_dim for RoPE; MLA uses only qk_rope_head_dim (decoupled)
|
||
freqs = precompute_rope_freqs(
|
||
cfg.dim // cfg.n_heads, cfg.max_seq_len, cfg.rope_theta
|
||
)
|
||
self.register_buffer("freqs_cis", freqs)
|
||
freqs_mla = precompute_rope_freqs(
|
||
cfg.qk_rope_head_dim, cfg.max_seq_len, cfg.rope_theta
|
||
)
|
||
self.register_buffer("freqs_cis_mla", freqs_mla)
|
||
|
||
self.prelude = nn.ModuleList(
|
||
[TransformerBlock(cfg, use_moe=False) for _ in range(cfg.prelude_layers)]
|
||
)
|
||
self.recurrent = RecurrentBlock(cfg)
|
||
self.coda = nn.ModuleList(
|
||
[TransformerBlock(cfg, use_moe=False) for _ in range(cfg.coda_layers)]
|
||
)
|
||
|
||
self.norm = RMSNorm(cfg.dim)
|
||
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
|
||
self.head.weight = self.embed.weight # weight tying
|
||
|
||
self._init_weights()
|
||
|
||
def _init_weights(self) -> None:
|
||
"""Initialize all linear and embedding weights with N(0, 0.02)."""
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Linear):
|
||
nn.init.normal_(m.weight, std=0.02)
|
||
elif isinstance(m, nn.Embedding):
|
||
nn.init.normal_(m.weight, std=0.02)
|
||
|
||
@staticmethod
|
||
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
||
"""
|
||
Build an additive causal mask: 0 on and below the diagonal, -inf above.
|
||
|
||
Args:
|
||
seq_len -- sequence length
|
||
device -- target device
|
||
|
||
Returns:
|
||
Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S)
|
||
"""
|
||
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device)
|
||
return torch.triu(mask, diagonal=1)
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
n_loops: Optional[int] = None,
|
||
kv_cache: Optional[dict] = None,
|
||
start_pos: int = 0,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Forward pass through Prelude → Recurrent Block → Coda.
|
||
|
||
Args:
|
||
input_ids -- token indices of shape (B, T)
|
||
n_loops -- recurrent loop depth; defaults to cfg.max_loop_iters.
|
||
Increase at inference to extrapolate to harder problems.
|
||
kv_cache -- dict mutated in-place for autoregressive KV caching;
|
||
pass an empty dict {} and reuse across decode steps
|
||
start_pos -- index of the first token in input_ids within the full
|
||
sequence; used to select the correct RoPE frequencies
|
||
during incremental decoding (0 for prefill, prompt_len
|
||
for each subsequent decode step)
|
||
|
||
Returns:
|
||
Logits of shape (B, T, vocab_size)
|
||
"""
|
||
T = input_ids.shape[1]
|
||
device = input_ids.device
|
||
|
||
x = self.embed(input_ids)
|
||
freqs_cis = (
|
||
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
|
||
)[start_pos : start_pos + T]
|
||
mask = self._causal_mask(T, device) if T > 1 else None
|
||
|
||
for i, layer in enumerate(self.prelude):
|
||
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")
|
||
|
||
e = x # encoded input frozen for injection every loop
|
||
x = self.recurrent(x, e, freqs_cis, mask, n_loops, kv_cache)
|
||
|
||
for i, layer in enumerate(self.coda):
|
||
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"coda_{i}")
|
||
|
||
return self.head(self.norm(x))
|
||
|
||
@torch.no_grad()
|
||
def generate(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
max_new_tokens: int = 64,
|
||
n_loops: int = 8,
|
||
temperature: float = 1.0,
|
||
top_k: int = 50,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Autoregressive token generation with KV caching.
|
||
|
||
On step 0 the full prompt is processed. On subsequent steps only the
|
||
last generated token is passed, with all previous keys and values
|
||
retrieved from kv_cache. This keeps decode cost proportional to one
|
||
token per step rather than the full growing sequence.
|
||
|
||
n_loops can be set higher than the training value to extrapolate to
|
||
harder problems at inference time (depth extrapolation property).
|
||
|
||
Args:
|
||
input_ids -- prompt token indices of shape (B, T)
|
||
max_new_tokens -- number of tokens to generate
|
||
n_loops -- recurrent loop depth for each decode step
|
||
temperature -- softmax temperature; lower = more greedy
|
||
top_k -- restrict sampling to top-K logits (0 = disabled)
|
||
|
||
Returns:
|
||
Token indices of shape (B, T + max_new_tokens)
|
||
"""
|
||
kv_cache: dict = {}
|
||
prompt_len = input_ids.shape[1]
|
||
for step in range(max_new_tokens):
|
||
if step == 0:
|
||
cur_ids = input_ids
|
||
start_pos = 0
|
||
else:
|
||
cur_ids = input_ids[:, -1:]
|
||
start_pos = prompt_len + step - 1
|
||
logits = self.forward(
|
||
cur_ids, n_loops=n_loops, kv_cache=kv_cache, start_pos=start_pos
|
||
)
|
||
logits = logits[:, -1, :] / temperature
|
||
if top_k > 0:
|
||
v, _ = logits.topk(top_k)
|
||
logits[logits < v[:, -1:]] = float("-inf")
|
||
probs = F.softmax(logits, dim=-1)
|
||
next_tok = torch.multinomial(probs, num_samples=1)
|
||
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
||
return input_ids
|