tiny tests

This commit is contained in:
Kye Gomez 2026-04-22 12:48:33 -04:00
parent 7d78ebec79
commit 963e11277d
6 changed files with 1046 additions and 6 deletions

View File

@ -97,8 +97,9 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8)
print(f"[{attn_type.upper()}] Generated shape: {out.shape}") print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
A = model.recurrent.injection.get_A() A = model.recurrent.injection.get_A()
rho = torch.linalg.eigvals(A).abs().max().item()
print( print(
f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" f"[{attn_type.upper()}] Spectral radius ρ(A) = {rho:.4f} (must be < 1)"
) )
``` ```
@ -217,6 +218,17 @@ The injection of `e` at every step is what prevents the model from drifting —
The full implementation is in [`open_mythos/main.py`](open_mythos/main.py). See the [`OpenMythos` class reference](docs/open_mythos.md) for a detailed API walkthrough, configuration options, and usage examples. The full implementation is in [`open_mythos/main.py`](open_mythos/main.py). See the [`OpenMythos` class reference](docs/open_mythos.md) for a detailed API walkthrough, configuration options, and usage examples.
### Attention Implementations
The attention layer is switchable via `cfg.attn_type`:
| Option | Class | Description |
|---|---|---|
| `"gqa"` | `GQAttention` | Grouped Query Attention (Ainslie et al., 2023) — fewer KV heads than Q heads (`n_kv_heads < n_heads`), reducing KV-cache memory by `n_heads / n_kv_heads`. Uses **Flash Attention 2** (Dao et al., 2023) when `flash-attn>=2.8.3` is installed: GQA is handled natively (no KV head expansion), I/O-bound-optimal, with a transparent fallback to manual scaled dot-product attention when the package is absent. |
| `"mla"` | `MLAttention` | Multi-Latent Attention (DeepSeek-V2) — caches a compressed KV latent (`kv_lora_rank`) rather than full K/V, with split RoPE / no-RoPE head dims for position-aware compression. |
RoPE is applied to Q and K before caching, so cached values do not need to be re-rotated on retrieval.
--- ---
## Why This Explains Mythos ## Why This Explains Mythos
@ -371,6 +383,7 @@ Theoretical analysis suggests 2-3x improvements in inference throughput. For a d
| Training stability | LTI-constrained injection parameters with spectral radius < 1 | | Training stability | LTI-constrained injection parameters with spectral radius < 1 |
| Loop differentiation | Likely uses loop-index positional embedding (à la RoPE) per iteration | | Loop differentiation | Likely uses loop-index positional embedding (à la RoPE) per iteration |
| Halting | Adaptive Computation Time or learned convergence criterion | | Halting | Adaptive Computation Time or learned convergence criterion |
| Attention | GQA (with optional Flash Attention 2) or MLA with compressed KV latent cache |
| Scaling law | Optimal training scales looping and data together, not parameters alone | | Scaling law | Optimal training scales looping and data together, not parameters alone |
| Reasoning vs. memory | Structurally biased toward composition; memorization requires separate treatment | | Reasoning vs. memory | Structurally biased toward composition; memorization requires separate treatment |
| Deployment | Continuous Depth-wise Batching enables variable compute per request | | Deployment | Continuous Depth-wise Batching enables variable compute per request |

View File

@ -252,7 +252,9 @@ class GQAttention(nn.Module):
k = k.to(torch.bfloat16) k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16) v = v.to(torch.bfloat16)
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None)) out = flash_attn_func(
q, k, v, dropout_p=dropout_p, causal=(mask is not None)
)
out = out.to(orig_dtype).contiguous().view(B, T, -1) out = out.to(orig_dtype).contiguous().view(B, T, -1)
else: else:
# Fallback: manual scaled dot-product with explicit KV head expansion. # Fallback: manual scaled dot-product with explicit KV head expansion.
@ -964,18 +966,27 @@ class OpenMythos(nn.Module):
nn.init.normal_(m.weight, std=0.02) nn.init.normal_(m.weight, std=0.02)
@staticmethod @staticmethod
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: def _causal_mask(
seq_len: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
""" """
Build an additive causal mask: 0 on and below the diagonal, -inf above. Build an additive causal mask: 0 on and below the diagonal, -inf above.
Args: Args:
seq_len -- sequence length seq_len -- sequence length
device -- target device device -- target device
dtype -- tensor dtype (must match activation dtype so the additive
mask doesn't upcast the attention logits in the fallback
attention path e.g. bf16 weights with an fp32 mask
promotes attn to fp32 and then breaks the fp32-vs-bf16
matmul against V)
Returns: Returns:
Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S) Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S)
""" """
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device) mask = torch.full(
(1, 1, seq_len, seq_len), float("-inf"), device=device, dtype=dtype
)
return torch.triu(mask, diagonal=1) return torch.triu(mask, diagonal=1)
def forward( def forward(
@ -1009,7 +1020,7 @@ class OpenMythos(nn.Module):
freqs_cis = ( freqs_cis = (
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
)[start_pos : start_pos + T] )[start_pos : start_pos + T]
mask = self._causal_mask(T, device) if T > 1 else None mask = self._causal_mask(T, device, x.dtype) if T > 1 else None
for i, layer in enumerate(self.prelude): for i, layer in enumerate(self.prelude):
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}") x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")

View File

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "open-mythos" name = "open-mythos"
version = "0.4.0" version = "0.5.0"
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture" description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@swarms.world>"] authors = ["Kye Gomez <kye@swarms.world>"]

View File

@ -0,0 +1,480 @@
"""
OpenMythos vs. vanilla GQA+MoE transformer benchmark.
Compares OpenMythos (Prelude + looped Recurrent Block + Coda with ACT halting,
LTI-stable injection, LoRA depth adapter) against a parameter-matched vanilla
transformer built from the same GQAttention + MoEFFN building blocks stacked
non-recurrently. The baseline reuses OpenMythos primitives so the comparison
isolates the recurrent-depth architecture, not the kernels.
Metrics reported:
- Parameter counts (total, MoE-active approximation)
- Prefill latency + throughput at several sequence lengths
- Decode (autoregressive step) latency with KV cache
- Peak memory (CUDA only)
- OpenMythos depth-scaling sweep: latency vs. n_loops
Run:
python benchmarks/bench_vs_transformer.py # small CPU/GPU smoke test
python benchmarks/bench_vs_transformer.py --size 1b --device cuda
python benchmarks/bench_vs_transformer.py --seq-lens 128,512,2048 --n-loops 1,4,8,16
"""
from __future__ import annotations
import argparse
import gc
import statistics
import time
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from open_mythos import MythosConfig, OpenMythos, mythos_1b
from open_mythos.main import (
RMSNorm,
TransformerBlock,
precompute_rope_freqs,
)
# ---------------------------------------------------------------------------
# Baseline: non-looped GQA + MoE transformer
# ---------------------------------------------------------------------------
class BaselineTransformer(nn.Module):
"""
Vanilla decoder-only transformer with GQA attention and MoE FFNs, stacked
non-recurrently. Shares TransformerBlock / GQAttention / MoEFFN kernels
with OpenMythos so any speed delta is attributable to the recurrent-depth
architecture rather than the underlying attention/FFN implementation.
"""
def __init__(self, cfg: MythosConfig, n_layers: int):
super().__init__()
self.cfg = cfg
self.n_layers = n_layers
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
self.layers = nn.ModuleList(
[TransformerBlock(cfg, use_moe=True) for _ in range(n_layers)]
)
self.norm = RMSNorm(cfg.dim)
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
head_dim = cfg.dim // cfg.n_heads
self.register_buffer(
"freqs_cis",
precompute_rope_freqs(head_dim, cfg.max_seq_len, cfg.rope_theta),
persistent=False,
)
@staticmethod
def _causal_mask(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
mask = torch.full((1, 1, T, T), float("-inf"), device=device, dtype=dtype)
return torch.triu(mask, diagonal=1)
def forward(
self,
input_ids: torch.Tensor,
kv_cache: Optional[dict] = None,
start_pos: int = 0,
) -> torch.Tensor:
T = input_ids.shape[1]
x = self.embed(input_ids)
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
mask = self._causal_mask(T, x.device, x.dtype) if T > 1 else None
for i, layer in enumerate(self.layers):
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"layer_{i}")
return self.head(self.norm(x))
# ---------------------------------------------------------------------------
# Timing utilities
# ---------------------------------------------------------------------------
def _sync(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize()
def time_fn(fn, device: torch.device, warmup: int = 2, trials: int = 5) -> float:
"""Returns median wall-clock seconds over `trials` after `warmup` runs."""
for _ in range(warmup):
fn()
_sync(device)
times = []
for _ in range(trials):
_sync(device)
t0 = time.perf_counter()
fn()
_sync(device)
times.append(time.perf_counter() - t0)
return statistics.median(times)
def peak_mem_mb(device: torch.device) -> float:
if device.type != "cuda":
return 0.0
return torch.cuda.max_memory_allocated() / (1024 * 1024)
def reset_mem(device: torch.device) -> None:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# ---------------------------------------------------------------------------
# Parameter counting
# ---------------------------------------------------------------------------
@dataclass
class ParamCounts:
total: int
moe_active_est: int # active per token (shared + top-k routed)
def count_params(model: nn.Module, cfg: MythosConfig) -> ParamCounts:
total = sum(p.numel() for p in model.parameters())
# Rough active-per-token count for MoE layers: shared + top-k routed fraction.
# For simplicity we report total and an estimated activation ratio separately.
active_ratio = (cfg.n_shared_experts + cfg.n_experts_per_tok) / (
cfg.n_shared_experts + cfg.n_experts
)
# Only FFN parameters shrink under activation; attention + embed/head are always on.
# This is a coarse lower bound on active params.
ffn_params = 0
other_params = 0
for name, p in model.named_parameters():
if ".ffn." in name or name.startswith("ffn.") or ".experts." in name:
ffn_params += p.numel()
else:
other_params += p.numel()
active_est = other_params + int(ffn_params * active_ratio)
return ParamCounts(total=total, moe_active_est=active_est)
# ---------------------------------------------------------------------------
# Benchmarks
# ---------------------------------------------------------------------------
def bench_prefill(
model: nn.Module,
vocab_size: int,
batch: int,
seq_len: int,
device: torch.device,
n_loops: Optional[int] = None,
) -> tuple[float, float]:
"""Returns (median_seconds, tokens_per_sec)."""
ids = torch.randint(0, vocab_size, (batch, seq_len), device=device)
if isinstance(model, OpenMythos):
def run() -> None:
with torch.no_grad():
model(ids, n_loops=n_loops)
else:
def run() -> None:
with torch.no_grad():
model(ids)
secs = time_fn(run, device)
tps = (batch * seq_len) / secs
return secs, tps
def bench_decode(
model: nn.Module,
vocab_size: int,
batch: int,
prompt_len: int,
decode_steps: int,
device: torch.device,
n_loops: Optional[int] = None,
) -> tuple[float, float]:
"""
Prefill a `prompt_len` prompt, then time `decode_steps` single-token decode
steps with KV cache. Returns (avg_seconds_per_step, decode_tokens_per_sec).
"""
prompt = torch.randint(0, vocab_size, (batch, prompt_len), device=device)
def one_run() -> None:
kv_cache: dict = {}
with torch.no_grad():
if isinstance(model, OpenMythos):
model(prompt, n_loops=n_loops, kv_cache=kv_cache, start_pos=0)
else:
model(prompt, kv_cache=kv_cache, start_pos=0)
for i in range(decode_steps):
next_tok = torch.randint(0, vocab_size, (batch, 1), device=device)
if isinstance(model, OpenMythos):
model(
next_tok,
n_loops=n_loops,
kv_cache=kv_cache,
start_pos=prompt_len + i,
)
else:
model(next_tok, kv_cache=kv_cache, start_pos=prompt_len + i)
secs = time_fn(one_run, device, warmup=1, trials=3)
per_step = secs / decode_steps
tps = batch * decode_steps / secs
return per_step, tps
# ---------------------------------------------------------------------------
# Config helpers
# ---------------------------------------------------------------------------
def small_cfg() -> MythosConfig:
"""Tiny config for smoke tests — runs on CPU in seconds."""
return MythosConfig(
vocab_size=1024,
dim=256,
n_heads=8,
n_kv_heads=2,
max_seq_len=1024,
max_loop_iters=4,
prelude_layers=1,
coda_layers=1,
attn_type="gqa",
n_experts=8,
n_shared_experts=1,
n_experts_per_tok=2,
expert_dim=128,
lora_rank=4,
dropout=0.0,
)
def get_cfg(size: str) -> MythosConfig:
size = size.lower()
if size == "small":
return small_cfg()
if size == "1b":
cfg = mythos_1b()
# GQA for apples-to-apples; MLA changes KV shape semantics.
cfg.attn_type = "gqa"
return cfg
raise ValueError(f"unknown size: {size!r} (use 'small' or '1b')")
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def fmt_count(n: int) -> str:
for unit in ("", "K", "M", "B", "T"):
if abs(n) < 1000:
return f"{n:.2f}{unit}"
n /= 1000
return f"{n:.2f}P"
def print_header(title: str) -> None:
bar = "=" * 72
print(f"\n{bar}\n{title}\n{bar}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
p.add_argument("--size", default="small", choices=["small", "1b"])
p.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cuda", "cpu"],
)
p.add_argument(
"--dtype",
default="auto",
choices=["auto", "fp32", "bf16", "fp16"],
help="'auto' picks fp32 on CPU and bf16 on CUDA",
)
p.add_argument("--batch", type=int, default=1)
p.add_argument(
"--seq-lens",
default="128,512",
help="comma-separated prefill sequence lengths",
)
p.add_argument(
"--n-loops",
default="1,4,8",
help="comma-separated loop counts to sweep (OpenMythos only)",
)
p.add_argument(
"--decode-steps",
type=int,
default=32,
help="number of autoregressive decode steps after prefill",
)
p.add_argument(
"--decode-prompt-len",
type=int,
default=128,
help="prefill length before decode",
)
return p.parse_args()
def main() -> None:
args = parse_args()
device = torch.device(args.device)
dtype_arg = args.dtype
if dtype_arg == "auto":
dtype_arg = "bf16" if device.type == "cuda" else "fp32"
dtype = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[dtype_arg]
seq_lens = [int(s) for s in args.seq_lens.split(",") if s.strip()]
n_loops_sweep = [int(s) for s in args.n_loops.split(",") if s.strip()]
cfg = get_cfg(args.size)
print_header(f"Config: size={args.size} device={device} dtype={dtype_arg}")
print(
f" dim={cfg.dim} n_heads={cfg.n_heads} n_kv_heads={cfg.n_kv_heads} "
f"prelude={cfg.prelude_layers} coda={cfg.coda_layers} "
f"max_loop_iters={cfg.max_loop_iters}\n"
f" experts={cfg.n_experts} shared={cfg.n_shared_experts} "
f"top_k={cfg.n_experts_per_tok} expert_dim={cfg.expert_dim}"
)
# Build models. Baseline depth = prelude + 1 (one unique recurrent block) + coda
# to match the unique-parameter depth of OpenMythos (parameter-matched baseline).
baseline_n_layers = cfg.prelude_layers + 1 + cfg.coda_layers
torch.manual_seed(0)
mythos = OpenMythos(cfg).to(device=device, dtype=dtype).eval()
torch.manual_seed(0)
baseline = (
BaselineTransformer(cfg, n_layers=baseline_n_layers)
.to(device=device, dtype=dtype)
.eval()
)
m_params = count_params(mythos, cfg)
b_params = count_params(baseline, cfg)
print_header(
"Parameters (block-matched: baseline depth = prelude + 1 recurrent + coda)"
)
print(
f" OpenMythos : total={fmt_count(m_params.total):>10} "
f"active/tok≈{fmt_count(m_params.moe_active_est):>10}"
)
print(
f" Baseline : total={fmt_count(b_params.total):>10} "
f"active/tok≈{fmt_count(b_params.moe_active_est):>10}"
)
print(
f" Baseline unique layers = {baseline_n_layers} "
f"(Mythos total runtime depth at max_loops = "
f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers})"
)
# ---- Prefill ----
print_header("Prefill latency (batch={batch})".format(batch=args.batch))
header = f" {'model':<26} {'seq':>6} {'sec':>10} {'tok/s':>12} {'peak MB':>10}"
print(header)
for seq_len in seq_lens:
if seq_len > cfg.max_seq_len:
print(f" skip seq_len={seq_len} (> max_seq_len={cfg.max_seq_len})")
continue
reset_mem(device)
secs, tps = bench_prefill(baseline, cfg.vocab_size, args.batch, seq_len, device)
mem = peak_mem_mb(device)
print(
f" {'Baseline (stacked)':<26} {seq_len:>6} "
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
)
for nl in n_loops_sweep:
reset_mem(device)
secs, tps = bench_prefill(
mythos, cfg.vocab_size, args.batch, seq_len, device, n_loops=nl
)
mem = peak_mem_mb(device)
print(
f" {'OpenMythos (loops=' + str(nl) + ')':<26} {seq_len:>6} "
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
)
# ---- Decode ----
print_header(
f"Decode latency (prefill {args.decode_prompt_len} tokens + "
f"{args.decode_steps} decode steps, batch={args.batch})"
)
print(f" {'model':<26} {'sec/step':>12} {'decode tok/s':>14}")
reset_mem(device)
per_step, tps = bench_decode(
baseline,
cfg.vocab_size,
args.batch,
args.decode_prompt_len,
args.decode_steps,
device,
)
print(f" {'Baseline (stacked)':<26} {per_step*1000:>10.2f}ms {tps:>14,.1f}")
for nl in n_loops_sweep:
reset_mem(device)
per_step, tps = bench_decode(
mythos,
cfg.vocab_size,
args.batch,
args.decode_prompt_len,
args.decode_steps,
device,
n_loops=nl,
)
print(
f" {'OpenMythos (loops=' + str(nl) + ')':<26} "
f"{per_step*1000:>10.2f}ms {tps:>14,.1f}"
)
# ---- Depth scaling ----
print_header(
"OpenMythos depth scaling (fixed seq={}, batch={})".format(
seq_lens[0], args.batch
)
)
print(f" {'n_loops':>8} {'sec':>10} {'tok/s':>12} {'Δ vs loops=1':>14}")
base_secs = None
for nl in n_loops_sweep:
reset_mem(device)
secs, tps = bench_prefill(
mythos, cfg.vocab_size, args.batch, seq_lens[0], device, n_loops=nl
)
if base_secs is None:
base_secs = secs
delta = "1.00x"
else:
delta = f"{secs / base_secs:.2f}x"
print(f" {nl:>8} {secs*1000:>9.2f}ms {tps:>12,.0f} {delta:>14}")
print("\nDone.")
if __name__ == "__main__":
main()

536
tests/small_benchmark.py Normal file
View File

@ -0,0 +1,536 @@
#!/usr/bin/env python3
"""
Side-by-side training of OpenMythos vs. a vanilla GQA transformer on a small
HuggingFace dataset (wikitext-2 by default).
Both models share the same tiny config and see the exact same batches in the
same order, so per-step loss + throughput are directly comparable. The baseline
is a dense GQA + SwiGLU stack whose unique-layer depth matches the recurrent
block's unique-parameter depth (prelude + 1 + coda), so parameter counts land
in the same ballpark.
python training/small_benchmark.py
python training/small_benchmark.py --steps 500 --device cuda
"""
from __future__ import annotations
import argparse
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Deque
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from open_mythos import MythosConfig, OpenMythos
from open_mythos.main import (
RMSNorm,
TransformerBlock,
precompute_rope_freqs,
)
# ---------------------------------------------------------------------------
# Baseline: dense GQA + SwiGLU transformer
# ---------------------------------------------------------------------------
class BaselineTransformer(nn.Module):
"""Vanilla decoder-only transformer with dense SwiGLU FFNs.
Reuses OpenMythos's TransformerBlock (attention + FFN kernels are identical)
so any measured delta reflects the looped recurrent-depth architecture, not
kernel differences. Supports both attn_type="gqa" and "mla".
"""
def __init__(self, cfg: MythosConfig, n_layers: int):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
self.layers = nn.ModuleList(
[TransformerBlock(cfg, use_moe=False) for _ in range(n_layers)]
)
self.norm = RMSNorm(cfg.dim)
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
self.head.weight = self.embed.weight # weight tying
# MLA applies RoPE to qk_rope_head_dim only; GQA rotates the full head_dim.
rope_dim = (
cfg.qk_rope_head_dim if cfg.attn_type == "mla" else cfg.dim // cfg.n_heads
)
self.register_buffer(
"freqs_cis",
precompute_rope_freqs(rope_dim, cfg.max_seq_len, cfg.rope_theta),
persistent=False,
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
@staticmethod
def _causal_mask(T: int, device: torch.device) -> torch.Tensor:
mask = torch.full((1, 1, T, T), float("-inf"), device=device)
return torch.triu(mask, diagonal=1)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
T = input_ids.shape[1]
x = self.embed(input_ids)
freqs_cis = self.freqs_cis[:T]
mask = self._causal_mask(T, x.device) if T > 1 else None
for i, layer in enumerate(self.layers):
x = layer(x, freqs_cis, mask, cache_key=f"layer_{i}")
return self.head(self.norm(x))
# ---------------------------------------------------------------------------
# Dataset: tokenize once, pack into fixed-length next-token pairs
# ---------------------------------------------------------------------------
class PackedLMDataset(Dataset):
"""Flatten an HF text dataset into one token buffer, slice fixed-length pairs.
Accepts either map-style or streaming (`IterableDataset`) HF datasets
iteration stops once `max_tokens` are collected, so large corpora like
TinyStories can be streamed without downloading the whole thing.
"""
def __init__(
self,
hf_ds,
tokenizer,
seq_len: int,
max_tokens: int,
text_field: str = "text",
):
buf: list[int] = []
for sample in hf_ds:
text = sample[text_field]
if not text or not text.strip():
continue
buf.extend(tokenizer.encode(text, add_special_tokens=False))
if len(buf) >= max_tokens:
break
self.seq_len = seq_len
n_pairs = max(1, (len(buf) - 1) // seq_len)
buf = buf[: n_pairs * seq_len + 1]
self.data = torch.tensor(buf, dtype=torch.long)
def __len__(self) -> int:
return (len(self.data) - 1) // self.seq_len
def __getitem__(self, idx: int):
s = idx * self.seq_len
chunk = self.data[s : s + self.seq_len + 1]
return chunk[:-1], chunk[1:]
# ---------------------------------------------------------------------------
# Metrics
# ---------------------------------------------------------------------------
@dataclass
class Metrics:
total_loss: float = 0.0
total_tokens: int = 0
total_time: float = 0.0
steps: int = 0
first_losses: list[float] = field(default_factory=list)
last_losses: Deque[float] = field(default_factory=lambda: deque(maxlen=10))
def update(self, loss: float, tokens: int, seconds: float) -> None:
self.total_loss += loss
self.total_tokens += tokens
self.total_time += seconds
self.steps += 1
if len(self.first_losses) < 10:
self.first_losses.append(loss)
self.last_losses.append(loss)
@property
def avg_loss(self) -> float:
return self.total_loss / max(1, self.steps)
@property
def tok_per_sec(self) -> float:
return self.total_tokens / max(1e-9, self.total_time)
@property
def initial_loss(self) -> float:
return sum(self.first_losses) / max(1, len(self.first_losses))
@property
def final_loss(self) -> float:
return sum(self.last_losses) / max(1, len(self.last_losses))
# ---------------------------------------------------------------------------
# Training step
# ---------------------------------------------------------------------------
def train_step(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
optimizer: torch.optim.Optimizer,
device: torch.device,
vocab_size: int,
) -> tuple[float, float]:
"""Run one optimizer step; return (loss, wall-clock seconds)."""
t0 = time.perf_counter()
model.train()
optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if device.type == "cuda":
torch.cuda.synchronize()
return loss.item(), time.perf_counter() - t0
@torch.no_grad()
def evaluate(
model: nn.Module,
loader: DataLoader,
device: torch.device,
vocab_size: int,
max_batches: int | None = None,
n_loops: int | None = None,
) -> float:
"""Mean cross-entropy over (up to `max_batches`) of the loader.
`n_loops` is only forwarded to OpenMythos; for any other module the kwarg
is dropped, so the same function benchmarks baseline and mythos uniformly.
"""
model.eval()
total_loss = 0.0
total_tokens = 0
for i, (x, y) in enumerate(loader):
if max_batches is not None and i >= max_batches:
break
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
if isinstance(model, OpenMythos):
logits = model(x, n_loops=n_loops)
else:
logits = model(x)
# sum-reduction so we weight by token count, not batch count
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1), reduction="sum")
total_loss += loss.item()
total_tokens += y.numel()
return total_loss / max(1, total_tokens)
# ---------------------------------------------------------------------------
# Config + utilities
# ---------------------------------------------------------------------------
def build_tiny_cfg(vocab_size: int, seq_len: int) -> MythosConfig:
"""Tiny shared config with MLA attention — runs in reasonable time on CPU.
MLA LoRA ranks and head dims scale with `dim=128` instead of the
2048-dim-sized defaults (q_lora_rank=1536, qk_nope_head_dim=128, ...),
which would otherwise dominate the parameter count at this scale.
"""
return MythosConfig(
vocab_size=vocab_size,
dim=128,
n_heads=4,
n_kv_heads=2,
max_seq_len=seq_len,
max_loop_iters=4,
prelude_layers=1,
coda_layers=1,
attn_type="mla",
kv_lora_rank=64,
q_lora_rank=128,
qk_rope_head_dim=16,
qk_nope_head_dim=32,
v_head_dim=32,
n_experts=4,
n_shared_experts=1,
n_experts_per_tok=2,
expert_dim=128,
lora_rank=4,
rope_theta=10000.0,
dropout=0.0,
)
def count_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def fmt_count(n: float) -> str:
for unit in ("", "K", "M", "B"):
if abs(n) < 1000:
return f"{n:.2f}{unit}"
n /= 1000
return f"{n:.2f}T"
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
p.add_argument("--steps", type=int, default=1000)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--seq-len", type=int, default=256)
p.add_argument("--lr", type=float, default=3e-4)
# Defaults point at TinyStories — simpler vocabulary + shorter documents
# lets a dim=128 model actually reach a meaningful loss in modest time.
p.add_argument("--dataset", default="roneneldan/TinyStories")
p.add_argument(
"--dataset-config",
default="",
help="pass '' for datasets with no config (e.g. TinyStories)",
)
p.add_argument("--train-split", default="train")
p.add_argument("--eval-split", default="validation")
p.add_argument(
"--train-tokens",
type=int,
default=5_000_000,
help="max tokens to materialize for the training buffer",
)
p.add_argument(
"--eval-tokens",
type=int,
default=200_000,
help="max tokens to materialize for the held-out eval buffer",
)
p.add_argument("--text-field", default="text")
p.add_argument("--tokenizer", default="gpt2")
p.add_argument("--log-every", type=int, default=25)
p.add_argument(
"--eval-every",
type=int,
default=200,
help="run held-out eval every N steps (0 disables)",
)
p.add_argument("--eval-batches", type=int, default=20)
p.add_argument(
"--depth-sweep",
default="1,2,4,8,16",
help="comma-separated n_loops values for OpenMythos depth-extrapolation eval",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
)
return p.parse_args()
def load_text_ds(name: str, config: str, split: str):
"""Streaming `load_dataset` with optional config (empty string == no config)."""
if config:
return load_dataset(name, config, split=split, streaming=True)
return load_dataset(name, split=split, streaming=True)
def main() -> None:
args = parse_args()
device = torch.device(args.device)
print(
f"[setup] device={device} batch={args.batch_size} "
f"seq_len={args.seq_len} steps={args.steps}"
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# AutoTokenizer.vocab_size can be smaller than the head size for BPE
# tokenizers with added tokens; use len(tokenizer) to be safe.
vocab_size = len(tokenizer)
print(f"[setup] tokenizer={args.tokenizer} vocab_size={vocab_size:,}")
# ------------------------------------------------------------------
# Data: streamed train + held-out eval splits
# ------------------------------------------------------------------
print(f"[setup] dataset={args.dataset} config={args.dataset_config or ''}")
raw_train = load_text_ds(args.dataset, args.dataset_config, args.train_split)
train_ds = PackedLMDataset(
raw_train, tokenizer, args.seq_len, args.train_tokens, args.text_field
)
raw_eval = load_text_ds(args.dataset, args.dataset_config, args.eval_split)
eval_ds = PackedLMDataset(
raw_eval, tokenizer, args.seq_len, args.eval_tokens, args.text_field
)
print(
f"[setup] train tokens={train_ds.data.numel():,} pairs={len(train_ds)} | "
f"eval tokens={eval_ds.data.numel():,} pairs={len(eval_ds)}"
)
torch.manual_seed(args.seed)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True
)
eval_loader = DataLoader(
eval_ds, batch_size=args.batch_size, shuffle=False, drop_last=False
)
# ------------------------------------------------------------------
# Models — same init seed so both start from the same embedding
# ------------------------------------------------------------------
cfg = build_tiny_cfg(vocab_size, args.seq_len)
torch.manual_seed(args.seed)
mythos = OpenMythos(cfg).to(device)
# Parameter-matched depth: prelude + one unique recurrent block + coda.
baseline_layers = cfg.prelude_layers + 1 + cfg.coda_layers
torch.manual_seed(args.seed)
baseline = BaselineTransformer(cfg, n_layers=baseline_layers).to(device)
n_m, n_b = count_params(mythos), count_params(baseline)
print(
f"[setup] OpenMythos params = {fmt_count(n_m)} ({n_m:,})\n"
f"[setup] Baseline params = {fmt_count(n_b)} ({n_b:,}) "
f"[{baseline_layers} layers]"
)
print(
f"[setup] Mythos runtime depth = prelude({cfg.prelude_layers}) + "
f"loops({cfg.max_loop_iters}) + coda({cfg.coda_layers}) = "
f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers}"
)
opt_m = torch.optim.AdamW(
mythos.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1
)
opt_b = torch.optim.AdamW(
baseline.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1
)
mm, bm = Metrics(), Metrics()
eval_history: list[tuple[int, float, float]] = [] # (step, mythos_eval, base_eval)
header = (
f"\n{'step':>6} | {'mythos loss':>12} | {'base loss':>10} | "
f"{'mythos tok/s':>13} | {'base tok/s':>11}"
)
print(header)
print("-" * len(header))
# ------------------------------------------------------------------
# Training loop with periodic held-out eval
# ------------------------------------------------------------------
data_iter = iter(train_loader)
t_total = time.perf_counter()
for step in range(1, args.steps + 1):
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
x, y = next(data_iter)
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
tokens = x.numel()
loss_m, dt_m = train_step(mythos, x, y, opt_m, device, vocab_size)
loss_b, dt_b = train_step(baseline, x, y, opt_b, device, vocab_size)
mm.update(loss_m, tokens, dt_m)
bm.update(loss_b, tokens, dt_b)
if step == 1 or step % args.log_every == 0:
print(
f"{step:>6} | {loss_m:>12.4f} | {loss_b:>10.4f} | "
f"{tokens / dt_m:>13,.0f} | {tokens / dt_b:>11,.0f}"
)
if args.eval_every and step % args.eval_every == 0:
eval_m = evaluate(
mythos, eval_loader, device, vocab_size, args.eval_batches
)
eval_b = evaluate(
baseline, eval_loader, device, vocab_size, args.eval_batches
)
eval_history.append((step, eval_m, eval_b))
print(
f" [eval @ step {step}] mythos {eval_m:.4f} baseline {eval_b:.4f} "
f"(Δ = {eval_m - eval_b:+.4f})"
)
total_wall = time.perf_counter() - t_total
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
bar = "=" * 70
print(f"\n{bar}\nSummary ({args.steps} steps, wall clock {total_wall:.1f}s)\n{bar}")
print(f" {'':<24} {'OpenMythos':>16} {'Baseline':>16}")
print(f" {'params':<24} {fmt_count(n_m):>16} {fmt_count(n_b):>16}")
print(
f" {'initial train (first 10)':<24} "
f"{mm.initial_loss:>16.4f} {bm.initial_loss:>16.4f}"
)
print(
f" {'final train (last 10)':<24} "
f"{mm.final_loss:>16.4f} {bm.final_loss:>16.4f}"
)
print(
f" {'avg train (all steps)':<24} "
f"{mm.avg_loss:>16.4f} {bm.avg_loss:>16.4f}"
)
print(
f" {'train time (sec)':<24} "
f"{mm.total_time:>16.2f} {bm.total_time:>16.2f}"
)
print(
f" {'avg tok/s':<24} " f"{mm.tok_per_sec:>16,.0f} {bm.tok_per_sec:>16,.0f}"
)
print(
f" {'sec/step':<24} "
f"{mm.total_time / max(1, mm.steps):>16.4f} "
f"{bm.total_time / max(1, bm.steps):>16.4f}"
)
# ------------------------------------------------------------------
# Depth extrapolation: OpenMythos eval loss as a function of n_loops.
# Trained at cfg.max_loop_iters; we run inference with a sweep to see
# whether additional loops keep improving (depth extrapolation) or the
# model collapses outside the trained regime.
# ------------------------------------------------------------------
loops_sweep = sorted({int(s) for s in args.depth_sweep.split(",") if s.strip()})
print(f"\n{bar}\nDepth extrapolation (held-out eval, full eval set)\n{bar}")
baseline_eval = evaluate(baseline, eval_loader, device, vocab_size)
print(f" Baseline (fixed depth) : eval loss = {baseline_eval:.4f}")
# First collect all sweep losses, then print with deltas vs. the trained depth.
sweep: list[tuple[int, float]] = []
for nl in loops_sweep:
sweep.append(
(nl, evaluate(mythos, eval_loader, device, vocab_size, n_loops=nl))
)
trained_loss = next((loss for nl, loss in sweep if nl == cfg.max_loop_iters), None)
print(f" OpenMythos (trained at n_loops={cfg.max_loop_iters}):")
print(f" {'n_loops':>8} {'eval loss':>10} {'Δ vs trained':>14}")
for nl, loss in sweep:
if trained_loss is None or nl == cfg.max_loop_iters:
delta_str = ""
else:
delta_str = f"{loss - trained_loss:+.4f}"
marker = " ←trained" if nl == cfg.max_loop_iters else ""
print(f" {nl:>8} {loss:>10.4f} {delta_str:>14}{marker}")
if __name__ == "__main__":
main()