From 963e11277d516cb48a488fb3f63bdceef779e0cc Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Wed, 22 Apr 2026 12:48:33 -0400 Subject: [PATCH] tiny tests --- README.md | 15 +- open_mythos/main.py | 19 +- pyproject.toml | 2 +- tests/bench_vs_transformer.py | 480 ++++++++++++++++ tests/small_benchmark.py | 536 ++++++++++++++++++ .../variants_example.py | 0 6 files changed, 1046 insertions(+), 6 deletions(-) create mode 100644 tests/bench_vs_transformer.py create mode 100644 tests/small_benchmark.py rename variants_example.py => tests/variants_example.py (100%) diff --git a/README.md b/README.md index 4e8e2df..af2d774 100644 --- a/README.md +++ b/README.md @@ -97,8 +97,9 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8) print(f"[{attn_type.upper()}] Generated shape: {out.shape}") A = model.recurrent.injection.get_A() +rho = torch.linalg.eigvals(A).abs().max().item() print( - f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" + f"[{attn_type.upper()}] Spectral radius ρ(A) = {rho:.4f} (must be < 1)" ) ``` @@ -217,6 +218,17 @@ The injection of `e` at every step is what prevents the model from drifting — The full implementation is in [`open_mythos/main.py`](open_mythos/main.py). See the [`OpenMythos` class reference](docs/open_mythos.md) for a detailed API walkthrough, configuration options, and usage examples. +### Attention Implementations + +The attention layer is switchable via `cfg.attn_type`: + +| Option | Class | Description | +|---|---|---| +| `"gqa"` | `GQAttention` | Grouped Query Attention (Ainslie et al., 2023) — fewer KV heads than Q heads (`n_kv_heads < n_heads`), reducing KV-cache memory by `n_heads / n_kv_heads`. Uses **Flash Attention 2** (Dao et al., 2023) when `flash-attn>=2.8.3` is installed: GQA is handled natively (no KV head expansion), I/O-bound-optimal, with a transparent fallback to manual scaled dot-product attention when the package is absent. | +| `"mla"` | `MLAttention` | Multi-Latent Attention (DeepSeek-V2) — caches a compressed KV latent (`kv_lora_rank`) rather than full K/V, with split RoPE / no-RoPE head dims for position-aware compression. | + +RoPE is applied to Q and K before caching, so cached values do not need to be re-rotated on retrieval. + --- ## Why This Explains Mythos @@ -371,6 +383,7 @@ Theoretical analysis suggests 2-3x improvements in inference throughput. For a d | Training stability | LTI-constrained injection parameters with spectral radius < 1 | | Loop differentiation | Likely uses loop-index positional embedding (à la RoPE) per iteration | | Halting | Adaptive Computation Time or learned convergence criterion | +| Attention | GQA (with optional Flash Attention 2) or MLA with compressed KV latent cache | | Scaling law | Optimal training scales looping and data together, not parameters alone | | Reasoning vs. memory | Structurally biased toward composition; memorization requires separate treatment | | Deployment | Continuous Depth-wise Batching enables variable compute per request | diff --git a/open_mythos/main.py b/open_mythos/main.py index 98ca58c..65b0fa8 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -252,7 +252,9 @@ class GQAttention(nn.Module): k = k.to(torch.bfloat16) v = v.to(torch.bfloat16) dropout_p = self.dropout_p if self.training else 0.0 - out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None)) + out = flash_attn_func( + q, k, v, dropout_p=dropout_p, causal=(mask is not None) + ) out = out.to(orig_dtype).contiguous().view(B, T, -1) else: # Fallback: manual scaled dot-product with explicit KV head expansion. @@ -964,18 +966,27 @@ class OpenMythos(nn.Module): nn.init.normal_(m.weight, std=0.02) @staticmethod - def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: + def _causal_mask( + seq_len: int, device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: """ Build an additive causal mask: 0 on and below the diagonal, -inf above. Args: seq_len -- sequence length device -- target device + dtype -- tensor dtype (must match activation dtype so the additive + mask doesn't upcast the attention logits in the fallback + attention path — e.g. bf16 weights with an fp32 mask + promotes attn to fp32 and then breaks the fp32-vs-bf16 + matmul against V) Returns: Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S) """ - mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device) + mask = torch.full( + (1, 1, seq_len, seq_len), float("-inf"), device=device, dtype=dtype + ) return torch.triu(mask, diagonal=1) def forward( @@ -1009,7 +1020,7 @@ class OpenMythos(nn.Module): freqs_cis = ( self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis )[start_pos : start_pos + T] - mask = self._causal_mask(T, device) if T > 1 else None + mask = self._causal_mask(T, device, x.dtype) if T > 1 else None for i, layer in enumerate(self.prelude): x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}") diff --git a/pyproject.toml b/pyproject.toml index ef6a3b6..8129e90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "open-mythos" -version = "0.4.0" +version = "0.5.0" description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture" license = "MIT" authors = ["Kye Gomez "] diff --git a/tests/bench_vs_transformer.py b/tests/bench_vs_transformer.py new file mode 100644 index 0000000..e929aa6 --- /dev/null +++ b/tests/bench_vs_transformer.py @@ -0,0 +1,480 @@ +""" +OpenMythos vs. vanilla GQA+MoE transformer benchmark. + +Compares OpenMythos (Prelude + looped Recurrent Block + Coda with ACT halting, +LTI-stable injection, LoRA depth adapter) against a parameter-matched vanilla +transformer built from the same GQAttention + MoEFFN building blocks stacked +non-recurrently. The baseline reuses OpenMythos primitives so the comparison +isolates the recurrent-depth architecture, not the kernels. + +Metrics reported: + - Parameter counts (total, MoE-active approximation) + - Prefill latency + throughput at several sequence lengths + - Decode (autoregressive step) latency with KV cache + - Peak memory (CUDA only) + - OpenMythos depth-scaling sweep: latency vs. n_loops + +Run: + python benchmarks/bench_vs_transformer.py # small CPU/GPU smoke test + python benchmarks/bench_vs_transformer.py --size 1b --device cuda + python benchmarks/bench_vs_transformer.py --seq-lens 128,512,2048 --n-loops 1,4,8,16 +""" + +from __future__ import annotations + +import argparse +import gc +import statistics +import time +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from open_mythos import MythosConfig, OpenMythos, mythos_1b +from open_mythos.main import ( + RMSNorm, + TransformerBlock, + precompute_rope_freqs, +) + + +# --------------------------------------------------------------------------- +# Baseline: non-looped GQA + MoE transformer +# --------------------------------------------------------------------------- + + +class BaselineTransformer(nn.Module): + """ + Vanilla decoder-only transformer with GQA attention and MoE FFNs, stacked + non-recurrently. Shares TransformerBlock / GQAttention / MoEFFN kernels + with OpenMythos so any speed delta is attributable to the recurrent-depth + architecture rather than the underlying attention/FFN implementation. + """ + + def __init__(self, cfg: MythosConfig, n_layers: int): + super().__init__() + self.cfg = cfg + self.n_layers = n_layers + self.embed = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [TransformerBlock(cfg, use_moe=True) for _ in range(n_layers)] + ) + self.norm = RMSNorm(cfg.dim) + self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + + head_dim = cfg.dim // cfg.n_heads + self.register_buffer( + "freqs_cis", + precompute_rope_freqs(head_dim, cfg.max_seq_len, cfg.rope_theta), + persistent=False, + ) + + @staticmethod + def _causal_mask(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + mask = torch.full((1, 1, T, T), float("-inf"), device=device, dtype=dtype) + return torch.triu(mask, diagonal=1) + + def forward( + self, + input_ids: torch.Tensor, + kv_cache: Optional[dict] = None, + start_pos: int = 0, + ) -> torch.Tensor: + T = input_ids.shape[1] + x = self.embed(input_ids) + freqs_cis = self.freqs_cis[start_pos : start_pos + T] + mask = self._causal_mask(T, x.device, x.dtype) if T > 1 else None + for i, layer in enumerate(self.layers): + x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"layer_{i}") + return self.head(self.norm(x)) + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + + +def _sync(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize() + + +def time_fn(fn, device: torch.device, warmup: int = 2, trials: int = 5) -> float: + """Returns median wall-clock seconds over `trials` after `warmup` runs.""" + for _ in range(warmup): + fn() + _sync(device) + times = [] + for _ in range(trials): + _sync(device) + t0 = time.perf_counter() + fn() + _sync(device) + times.append(time.perf_counter() - t0) + return statistics.median(times) + + +def peak_mem_mb(device: torch.device) -> float: + if device.type != "cuda": + return 0.0 + return torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def reset_mem(device: torch.device) -> None: + gc.collect() + if device.type == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + +# --------------------------------------------------------------------------- +# Parameter counting +# --------------------------------------------------------------------------- + + +@dataclass +class ParamCounts: + total: int + moe_active_est: int # active per token (shared + top-k routed) + + +def count_params(model: nn.Module, cfg: MythosConfig) -> ParamCounts: + total = sum(p.numel() for p in model.parameters()) + # Rough active-per-token count for MoE layers: shared + top-k routed fraction. + # For simplicity we report total and an estimated activation ratio separately. + active_ratio = (cfg.n_shared_experts + cfg.n_experts_per_tok) / ( + cfg.n_shared_experts + cfg.n_experts + ) + # Only FFN parameters shrink under activation; attention + embed/head are always on. + # This is a coarse lower bound on active params. + ffn_params = 0 + other_params = 0 + for name, p in model.named_parameters(): + if ".ffn." in name or name.startswith("ffn.") or ".experts." in name: + ffn_params += p.numel() + else: + other_params += p.numel() + active_est = other_params + int(ffn_params * active_ratio) + return ParamCounts(total=total, moe_active_est=active_est) + + +# --------------------------------------------------------------------------- +# Benchmarks +# --------------------------------------------------------------------------- + + +def bench_prefill( + model: nn.Module, + vocab_size: int, + batch: int, + seq_len: int, + device: torch.device, + n_loops: Optional[int] = None, +) -> tuple[float, float]: + """Returns (median_seconds, tokens_per_sec).""" + ids = torch.randint(0, vocab_size, (batch, seq_len), device=device) + + if isinstance(model, OpenMythos): + + def run() -> None: + with torch.no_grad(): + model(ids, n_loops=n_loops) + + else: + + def run() -> None: + with torch.no_grad(): + model(ids) + + secs = time_fn(run, device) + tps = (batch * seq_len) / secs + return secs, tps + + +def bench_decode( + model: nn.Module, + vocab_size: int, + batch: int, + prompt_len: int, + decode_steps: int, + device: torch.device, + n_loops: Optional[int] = None, +) -> tuple[float, float]: + """ + Prefill a `prompt_len` prompt, then time `decode_steps` single-token decode + steps with KV cache. Returns (avg_seconds_per_step, decode_tokens_per_sec). + """ + prompt = torch.randint(0, vocab_size, (batch, prompt_len), device=device) + + def one_run() -> None: + kv_cache: dict = {} + with torch.no_grad(): + if isinstance(model, OpenMythos): + model(prompt, n_loops=n_loops, kv_cache=kv_cache, start_pos=0) + else: + model(prompt, kv_cache=kv_cache, start_pos=0) + for i in range(decode_steps): + next_tok = torch.randint(0, vocab_size, (batch, 1), device=device) + if isinstance(model, OpenMythos): + model( + next_tok, + n_loops=n_loops, + kv_cache=kv_cache, + start_pos=prompt_len + i, + ) + else: + model(next_tok, kv_cache=kv_cache, start_pos=prompt_len + i) + + secs = time_fn(one_run, device, warmup=1, trials=3) + per_step = secs / decode_steps + tps = batch * decode_steps / secs + return per_step, tps + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + + +def small_cfg() -> MythosConfig: + """Tiny config for smoke tests — runs on CPU in seconds.""" + return MythosConfig( + vocab_size=1024, + dim=256, + n_heads=8, + n_kv_heads=2, + max_seq_len=1024, + max_loop_iters=4, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=8, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=128, + lora_rank=4, + dropout=0.0, + ) + + +def get_cfg(size: str) -> MythosConfig: + size = size.lower() + if size == "small": + return small_cfg() + if size == "1b": + cfg = mythos_1b() + # GQA for apples-to-apples; MLA changes KV shape semantics. + cfg.attn_type = "gqa" + return cfg + raise ValueError(f"unknown size: {size!r} (use 'small' or '1b')") + + +# --------------------------------------------------------------------------- +# Reporting +# --------------------------------------------------------------------------- + + +def fmt_count(n: int) -> str: + for unit in ("", "K", "M", "B", "T"): + if abs(n) < 1000: + return f"{n:.2f}{unit}" + n /= 1000 + return f"{n:.2f}P" + + +def print_header(title: str) -> None: + bar = "=" * 72 + print(f"\n{bar}\n{title}\n{bar}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + p.add_argument("--size", default="small", choices=["small", "1b"]) + p.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cuda", "cpu"], + ) + p.add_argument( + "--dtype", + default="auto", + choices=["auto", "fp32", "bf16", "fp16"], + help="'auto' picks fp32 on CPU and bf16 on CUDA", + ) + p.add_argument("--batch", type=int, default=1) + p.add_argument( + "--seq-lens", + default="128,512", + help="comma-separated prefill sequence lengths", + ) + p.add_argument( + "--n-loops", + default="1,4,8", + help="comma-separated loop counts to sweep (OpenMythos only)", + ) + p.add_argument( + "--decode-steps", + type=int, + default=32, + help="number of autoregressive decode steps after prefill", + ) + p.add_argument( + "--decode-prompt-len", + type=int, + default=128, + help="prefill length before decode", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + device = torch.device(args.device) + dtype_arg = args.dtype + if dtype_arg == "auto": + dtype_arg = "bf16" if device.type == "cuda" else "fp32" + dtype = { + "fp32": torch.float32, + "bf16": torch.bfloat16, + "fp16": torch.float16, + }[dtype_arg] + + seq_lens = [int(s) for s in args.seq_lens.split(",") if s.strip()] + n_loops_sweep = [int(s) for s in args.n_loops.split(",") if s.strip()] + + cfg = get_cfg(args.size) + print_header(f"Config: size={args.size} device={device} dtype={dtype_arg}") + print( + f" dim={cfg.dim} n_heads={cfg.n_heads} n_kv_heads={cfg.n_kv_heads} " + f"prelude={cfg.prelude_layers} coda={cfg.coda_layers} " + f"max_loop_iters={cfg.max_loop_iters}\n" + f" experts={cfg.n_experts} shared={cfg.n_shared_experts} " + f"top_k={cfg.n_experts_per_tok} expert_dim={cfg.expert_dim}" + ) + + # Build models. Baseline depth = prelude + 1 (one unique recurrent block) + coda + # to match the unique-parameter depth of OpenMythos (parameter-matched baseline). + baseline_n_layers = cfg.prelude_layers + 1 + cfg.coda_layers + + torch.manual_seed(0) + mythos = OpenMythos(cfg).to(device=device, dtype=dtype).eval() + torch.manual_seed(0) + baseline = ( + BaselineTransformer(cfg, n_layers=baseline_n_layers) + .to(device=device, dtype=dtype) + .eval() + ) + + m_params = count_params(mythos, cfg) + b_params = count_params(baseline, cfg) + print_header( + "Parameters (block-matched: baseline depth = prelude + 1 recurrent + coda)" + ) + print( + f" OpenMythos : total={fmt_count(m_params.total):>10} " + f"active/tok≈{fmt_count(m_params.moe_active_est):>10}" + ) + print( + f" Baseline : total={fmt_count(b_params.total):>10} " + f"active/tok≈{fmt_count(b_params.moe_active_est):>10}" + ) + print( + f" Baseline unique layers = {baseline_n_layers} " + f"(Mythos total runtime depth at max_loops = " + f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers})" + ) + + # ---- Prefill ---- + print_header("Prefill latency (batch={batch})".format(batch=args.batch)) + header = f" {'model':<26} {'seq':>6} {'sec':>10} {'tok/s':>12} {'peak MB':>10}" + print(header) + for seq_len in seq_lens: + if seq_len > cfg.max_seq_len: + print(f" skip seq_len={seq_len} (> max_seq_len={cfg.max_seq_len})") + continue + + reset_mem(device) + secs, tps = bench_prefill(baseline, cfg.vocab_size, args.batch, seq_len, device) + mem = peak_mem_mb(device) + print( + f" {'Baseline (stacked)':<26} {seq_len:>6} " + f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}" + ) + + for nl in n_loops_sweep: + reset_mem(device) + secs, tps = bench_prefill( + mythos, cfg.vocab_size, args.batch, seq_len, device, n_loops=nl + ) + mem = peak_mem_mb(device) + print( + f" {'OpenMythos (loops=' + str(nl) + ')':<26} {seq_len:>6} " + f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}" + ) + + # ---- Decode ---- + print_header( + f"Decode latency (prefill {args.decode_prompt_len} tokens + " + f"{args.decode_steps} decode steps, batch={args.batch})" + ) + print(f" {'model':<26} {'sec/step':>12} {'decode tok/s':>14}") + + reset_mem(device) + per_step, tps = bench_decode( + baseline, + cfg.vocab_size, + args.batch, + args.decode_prompt_len, + args.decode_steps, + device, + ) + print(f" {'Baseline (stacked)':<26} {per_step*1000:>10.2f}ms {tps:>14,.1f}") + + for nl in n_loops_sweep: + reset_mem(device) + per_step, tps = bench_decode( + mythos, + cfg.vocab_size, + args.batch, + args.decode_prompt_len, + args.decode_steps, + device, + n_loops=nl, + ) + print( + f" {'OpenMythos (loops=' + str(nl) + ')':<26} " + f"{per_step*1000:>10.2f}ms {tps:>14,.1f}" + ) + + # ---- Depth scaling ---- + print_header( + "OpenMythos depth scaling (fixed seq={}, batch={})".format( + seq_lens[0], args.batch + ) + ) + print(f" {'n_loops':>8} {'sec':>10} {'tok/s':>12} {'Δ vs loops=1':>14}") + base_secs = None + for nl in n_loops_sweep: + reset_mem(device) + secs, tps = bench_prefill( + mythos, cfg.vocab_size, args.batch, seq_lens[0], device, n_loops=nl + ) + if base_secs is None: + base_secs = secs + delta = "1.00x" + else: + delta = f"{secs / base_secs:.2f}x" + print(f" {nl:>8} {secs*1000:>9.2f}ms {tps:>12,.0f} {delta:>14}") + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/tests/small_benchmark.py b/tests/small_benchmark.py new file mode 100644 index 0000000..c4e7df7 --- /dev/null +++ b/tests/small_benchmark.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Side-by-side training of OpenMythos vs. a vanilla GQA transformer on a small +HuggingFace dataset (wikitext-2 by default). + +Both models share the same tiny config and see the exact same batches in the +same order, so per-step loss + throughput are directly comparable. The baseline +is a dense GQA + SwiGLU stack whose unique-layer depth matches the recurrent +block's unique-parameter depth (prelude + 1 + coda), so parameter counts land +in the same ballpark. + + python training/small_benchmark.py + python training/small_benchmark.py --steps 500 --device cuda +""" + +from __future__ import annotations + +import argparse +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Deque + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from open_mythos import MythosConfig, OpenMythos +from open_mythos.main import ( + RMSNorm, + TransformerBlock, + precompute_rope_freqs, +) + + +# --------------------------------------------------------------------------- +# Baseline: dense GQA + SwiGLU transformer +# --------------------------------------------------------------------------- + + +class BaselineTransformer(nn.Module): + """Vanilla decoder-only transformer with dense SwiGLU FFNs. + + Reuses OpenMythos's TransformerBlock (attention + FFN kernels are identical) + so any measured delta reflects the looped recurrent-depth architecture, not + kernel differences. Supports both attn_type="gqa" and "mla". + """ + + def __init__(self, cfg: MythosConfig, n_layers: int): + super().__init__() + self.cfg = cfg + self.embed = nn.Embedding(cfg.vocab_size, cfg.dim) + self.layers = nn.ModuleList( + [TransformerBlock(cfg, use_moe=False) for _ in range(n_layers)] + ) + self.norm = RMSNorm(cfg.dim) + self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) + self.head.weight = self.embed.weight # weight tying + + # MLA applies RoPE to qk_rope_head_dim only; GQA rotates the full head_dim. + rope_dim = ( + cfg.qk_rope_head_dim if cfg.attn_type == "mla" else cfg.dim // cfg.n_heads + ) + self.register_buffer( + "freqs_cis", + precompute_rope_freqs(rope_dim, cfg.max_seq_len, cfg.rope_theta), + persistent=False, + ) + self._init_weights() + + def _init_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + @staticmethod + def _causal_mask(T: int, device: torch.device) -> torch.Tensor: + mask = torch.full((1, 1, T, T), float("-inf"), device=device) + return torch.triu(mask, diagonal=1) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + T = input_ids.shape[1] + x = self.embed(input_ids) + freqs_cis = self.freqs_cis[:T] + mask = self._causal_mask(T, x.device) if T > 1 else None + for i, layer in enumerate(self.layers): + x = layer(x, freqs_cis, mask, cache_key=f"layer_{i}") + return self.head(self.norm(x)) + + +# --------------------------------------------------------------------------- +# Dataset: tokenize once, pack into fixed-length next-token pairs +# --------------------------------------------------------------------------- + + +class PackedLMDataset(Dataset): + """Flatten an HF text dataset into one token buffer, slice fixed-length pairs. + + Accepts either map-style or streaming (`IterableDataset`) HF datasets — + iteration stops once `max_tokens` are collected, so large corpora like + TinyStories can be streamed without downloading the whole thing. + """ + + def __init__( + self, + hf_ds, + tokenizer, + seq_len: int, + max_tokens: int, + text_field: str = "text", + ): + buf: list[int] = [] + for sample in hf_ds: + text = sample[text_field] + if not text or not text.strip(): + continue + buf.extend(tokenizer.encode(text, add_special_tokens=False)) + if len(buf) >= max_tokens: + break + self.seq_len = seq_len + n_pairs = max(1, (len(buf) - 1) // seq_len) + buf = buf[: n_pairs * seq_len + 1] + self.data = torch.tensor(buf, dtype=torch.long) + + def __len__(self) -> int: + return (len(self.data) - 1) // self.seq_len + + def __getitem__(self, idx: int): + s = idx * self.seq_len + chunk = self.data[s : s + self.seq_len + 1] + return chunk[:-1], chunk[1:] + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +@dataclass +class Metrics: + total_loss: float = 0.0 + total_tokens: int = 0 + total_time: float = 0.0 + steps: int = 0 + first_losses: list[float] = field(default_factory=list) + last_losses: Deque[float] = field(default_factory=lambda: deque(maxlen=10)) + + def update(self, loss: float, tokens: int, seconds: float) -> None: + self.total_loss += loss + self.total_tokens += tokens + self.total_time += seconds + self.steps += 1 + if len(self.first_losses) < 10: + self.first_losses.append(loss) + self.last_losses.append(loss) + + @property + def avg_loss(self) -> float: + return self.total_loss / max(1, self.steps) + + @property + def tok_per_sec(self) -> float: + return self.total_tokens / max(1e-9, self.total_time) + + @property + def initial_loss(self) -> float: + return sum(self.first_losses) / max(1, len(self.first_losses)) + + @property + def final_loss(self) -> float: + return sum(self.last_losses) / max(1, len(self.last_losses)) + + +# --------------------------------------------------------------------------- +# Training step +# --------------------------------------------------------------------------- + + +def train_step( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + optimizer: torch.optim.Optimizer, + device: torch.device, + vocab_size: int, +) -> tuple[float, float]: + """Run one optimizer step; return (loss, wall-clock seconds).""" + t0 = time.perf_counter() + model.train() + optimizer.zero_grad() + logits = model(x) + loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + if device.type == "cuda": + torch.cuda.synchronize() + return loss.item(), time.perf_counter() - t0 + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + device: torch.device, + vocab_size: int, + max_batches: int | None = None, + n_loops: int | None = None, +) -> float: + """Mean cross-entropy over (up to `max_batches`) of the loader. + + `n_loops` is only forwarded to OpenMythos; for any other module the kwarg + is dropped, so the same function benchmarks baseline and mythos uniformly. + """ + model.eval() + total_loss = 0.0 + total_tokens = 0 + for i, (x, y) in enumerate(loader): + if max_batches is not None and i >= max_batches: + break + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + if isinstance(model, OpenMythos): + logits = model(x, n_loops=n_loops) + else: + logits = model(x) + # sum-reduction so we weight by token count, not batch count + loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1), reduction="sum") + total_loss += loss.item() + total_tokens += y.numel() + return total_loss / max(1, total_tokens) + + +# --------------------------------------------------------------------------- +# Config + utilities +# --------------------------------------------------------------------------- + + +def build_tiny_cfg(vocab_size: int, seq_len: int) -> MythosConfig: + """Tiny shared config with MLA attention — runs in reasonable time on CPU. + + MLA LoRA ranks and head dims scale with `dim=128` instead of the + 2048-dim-sized defaults (q_lora_rank=1536, qk_nope_head_dim=128, ...), + which would otherwise dominate the parameter count at this scale. + """ + return MythosConfig( + vocab_size=vocab_size, + dim=128, + n_heads=4, + n_kv_heads=2, + max_seq_len=seq_len, + max_loop_iters=4, + prelude_layers=1, + coda_layers=1, + attn_type="mla", + kv_lora_rank=64, + q_lora_rank=128, + qk_rope_head_dim=16, + qk_nope_head_dim=32, + v_head_dim=32, + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=128, + lora_rank=4, + rope_theta=10000.0, + dropout=0.0, + ) + + +def count_params(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def fmt_count(n: float) -> str: + for unit in ("", "K", "M", "B"): + if abs(n) < 1000: + return f"{n:.2f}{unit}" + n /= 1000 + return f"{n:.2f}T" + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + p.add_argument("--steps", type=int, default=1000) + p.add_argument("--batch-size", type=int, default=32) + p.add_argument("--seq-len", type=int, default=256) + p.add_argument("--lr", type=float, default=3e-4) + # Defaults point at TinyStories — simpler vocabulary + shorter documents + # lets a dim=128 model actually reach a meaningful loss in modest time. + p.add_argument("--dataset", default="roneneldan/TinyStories") + p.add_argument( + "--dataset-config", + default="", + help="pass '' for datasets with no config (e.g. TinyStories)", + ) + p.add_argument("--train-split", default="train") + p.add_argument("--eval-split", default="validation") + p.add_argument( + "--train-tokens", + type=int, + default=5_000_000, + help="max tokens to materialize for the training buffer", + ) + p.add_argument( + "--eval-tokens", + type=int, + default=200_000, + help="max tokens to materialize for the held-out eval buffer", + ) + p.add_argument("--text-field", default="text") + p.add_argument("--tokenizer", default="gpt2") + p.add_argument("--log-every", type=int, default=25) + p.add_argument( + "--eval-every", + type=int, + default=200, + help="run held-out eval every N steps (0 disables)", + ) + p.add_argument("--eval-batches", type=int, default=20) + p.add_argument( + "--depth-sweep", + default="1,2,4,8,16", + help="comma-separated n_loops values for OpenMythos depth-extrapolation eval", + ) + p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + ) + return p.parse_args() + + +def load_text_ds(name: str, config: str, split: str): + """Streaming `load_dataset` with optional config (empty string == no config).""" + if config: + return load_dataset(name, config, split=split, streaming=True) + return load_dataset(name, split=split, streaming=True) + + +def main() -> None: + args = parse_args() + device = torch.device(args.device) + + print( + f"[setup] device={device} batch={args.batch_size} " + f"seq_len={args.seq_len} steps={args.steps}" + ) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + # AutoTokenizer.vocab_size can be smaller than the head size for BPE + # tokenizers with added tokens; use len(tokenizer) to be safe. + vocab_size = len(tokenizer) + print(f"[setup] tokenizer={args.tokenizer} vocab_size={vocab_size:,}") + + # ------------------------------------------------------------------ + # Data: streamed train + held-out eval splits + # ------------------------------------------------------------------ + print(f"[setup] dataset={args.dataset} config={args.dataset_config or '∅'}") + raw_train = load_text_ds(args.dataset, args.dataset_config, args.train_split) + train_ds = PackedLMDataset( + raw_train, tokenizer, args.seq_len, args.train_tokens, args.text_field + ) + raw_eval = load_text_ds(args.dataset, args.dataset_config, args.eval_split) + eval_ds = PackedLMDataset( + raw_eval, tokenizer, args.seq_len, args.eval_tokens, args.text_field + ) + print( + f"[setup] train tokens={train_ds.data.numel():,} pairs={len(train_ds)} | " + f"eval tokens={eval_ds.data.numel():,} pairs={len(eval_ds)}" + ) + + torch.manual_seed(args.seed) + train_loader = DataLoader( + train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True + ) + eval_loader = DataLoader( + eval_ds, batch_size=args.batch_size, shuffle=False, drop_last=False + ) + + # ------------------------------------------------------------------ + # Models — same init seed so both start from the same embedding + # ------------------------------------------------------------------ + cfg = build_tiny_cfg(vocab_size, args.seq_len) + + torch.manual_seed(args.seed) + mythos = OpenMythos(cfg).to(device) + + # Parameter-matched depth: prelude + one unique recurrent block + coda. + baseline_layers = cfg.prelude_layers + 1 + cfg.coda_layers + torch.manual_seed(args.seed) + baseline = BaselineTransformer(cfg, n_layers=baseline_layers).to(device) + + n_m, n_b = count_params(mythos), count_params(baseline) + print( + f"[setup] OpenMythos params = {fmt_count(n_m)} ({n_m:,})\n" + f"[setup] Baseline params = {fmt_count(n_b)} ({n_b:,}) " + f"[{baseline_layers} layers]" + ) + print( + f"[setup] Mythos runtime depth = prelude({cfg.prelude_layers}) + " + f"loops({cfg.max_loop_iters}) + coda({cfg.coda_layers}) = " + f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers}" + ) + + opt_m = torch.optim.AdamW( + mythos.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1 + ) + opt_b = torch.optim.AdamW( + baseline.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1 + ) + + mm, bm = Metrics(), Metrics() + eval_history: list[tuple[int, float, float]] = [] # (step, mythos_eval, base_eval) + + header = ( + f"\n{'step':>6} | {'mythos loss':>12} | {'base loss':>10} | " + f"{'mythos tok/s':>13} | {'base tok/s':>11}" + ) + print(header) + print("-" * len(header)) + + # ------------------------------------------------------------------ + # Training loop with periodic held-out eval + # ------------------------------------------------------------------ + data_iter = iter(train_loader) + t_total = time.perf_counter() + for step in range(1, args.steps + 1): + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(train_loader) + x, y = next(data_iter) + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + tokens = x.numel() + + loss_m, dt_m = train_step(mythos, x, y, opt_m, device, vocab_size) + loss_b, dt_b = train_step(baseline, x, y, opt_b, device, vocab_size) + + mm.update(loss_m, tokens, dt_m) + bm.update(loss_b, tokens, dt_b) + + if step == 1 or step % args.log_every == 0: + print( + f"{step:>6} | {loss_m:>12.4f} | {loss_b:>10.4f} | " + f"{tokens / dt_m:>13,.0f} | {tokens / dt_b:>11,.0f}" + ) + + if args.eval_every and step % args.eval_every == 0: + eval_m = evaluate( + mythos, eval_loader, device, vocab_size, args.eval_batches + ) + eval_b = evaluate( + baseline, eval_loader, device, vocab_size, args.eval_batches + ) + eval_history.append((step, eval_m, eval_b)) + print( + f" [eval @ step {step}] mythos {eval_m:.4f} baseline {eval_b:.4f} " + f"(Δ = {eval_m - eval_b:+.4f})" + ) + + total_wall = time.perf_counter() - t_total + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + bar = "=" * 70 + print(f"\n{bar}\nSummary ({args.steps} steps, wall clock {total_wall:.1f}s)\n{bar}") + print(f" {'':<24} {'OpenMythos':>16} {'Baseline':>16}") + print(f" {'params':<24} {fmt_count(n_m):>16} {fmt_count(n_b):>16}") + print( + f" {'initial train (first 10)':<24} " + f"{mm.initial_loss:>16.4f} {bm.initial_loss:>16.4f}" + ) + print( + f" {'final train (last 10)':<24} " + f"{mm.final_loss:>16.4f} {bm.final_loss:>16.4f}" + ) + print( + f" {'avg train (all steps)':<24} " + f"{mm.avg_loss:>16.4f} {bm.avg_loss:>16.4f}" + ) + print( + f" {'train time (sec)':<24} " + f"{mm.total_time:>16.2f} {bm.total_time:>16.2f}" + ) + print( + f" {'avg tok/s':<24} " f"{mm.tok_per_sec:>16,.0f} {bm.tok_per_sec:>16,.0f}" + ) + print( + f" {'sec/step':<24} " + f"{mm.total_time / max(1, mm.steps):>16.4f} " + f"{bm.total_time / max(1, bm.steps):>16.4f}" + ) + + # ------------------------------------------------------------------ + # Depth extrapolation: OpenMythos eval loss as a function of n_loops. + # Trained at cfg.max_loop_iters; we run inference with a sweep to see + # whether additional loops keep improving (depth extrapolation) or the + # model collapses outside the trained regime. + # ------------------------------------------------------------------ + loops_sweep = sorted({int(s) for s in args.depth_sweep.split(",") if s.strip()}) + print(f"\n{bar}\nDepth extrapolation (held-out eval, full eval set)\n{bar}") + baseline_eval = evaluate(baseline, eval_loader, device, vocab_size) + print(f" Baseline (fixed depth) : eval loss = {baseline_eval:.4f}") + # First collect all sweep losses, then print with deltas vs. the trained depth. + sweep: list[tuple[int, float]] = [] + for nl in loops_sweep: + sweep.append( + (nl, evaluate(mythos, eval_loader, device, vocab_size, n_loops=nl)) + ) + trained_loss = next((loss for nl, loss in sweep if nl == cfg.max_loop_iters), None) + print(f" OpenMythos (trained at n_loops={cfg.max_loop_iters}):") + print(f" {'n_loops':>8} {'eval loss':>10} {'Δ vs trained':>14}") + for nl, loss in sweep: + if trained_loss is None or nl == cfg.max_loop_iters: + delta_str = "" + else: + delta_str = f"{loss - trained_loss:+.4f}" + marker = " ←trained" if nl == cfg.max_loop_iters else "" + print(f" {nl:>8} {loss:>10.4f} {delta_str:>14}{marker}") + + +if __name__ == "__main__": + main() diff --git a/variants_example.py b/tests/variants_example.py similarity index 100% rename from variants_example.py rename to tests/variants_example.py