mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
tiny tests
This commit is contained in:
parent
7d78ebec79
commit
963e11277d
15
README.md
15
README.md
@ -97,8 +97,9 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8)
|
|||||||
print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
|
print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
|
||||||
|
|
||||||
A = model.recurrent.injection.get_A()
|
A = model.recurrent.injection.get_A()
|
||||||
|
rho = torch.linalg.eigvals(A).abs().max().item()
|
||||||
print(
|
print(
|
||||||
f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)"
|
f"[{attn_type.upper()}] Spectral radius ρ(A) = {rho:.4f} (must be < 1)"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -217,6 +218,17 @@ The injection of `e` at every step is what prevents the model from drifting —
|
|||||||
|
|
||||||
The full implementation is in [`open_mythos/main.py`](open_mythos/main.py). See the [`OpenMythos` class reference](docs/open_mythos.md) for a detailed API walkthrough, configuration options, and usage examples.
|
The full implementation is in [`open_mythos/main.py`](open_mythos/main.py). See the [`OpenMythos` class reference](docs/open_mythos.md) for a detailed API walkthrough, configuration options, and usage examples.
|
||||||
|
|
||||||
|
### Attention Implementations
|
||||||
|
|
||||||
|
The attention layer is switchable via `cfg.attn_type`:
|
||||||
|
|
||||||
|
| Option | Class | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `"gqa"` | `GQAttention` | Grouped Query Attention (Ainslie et al., 2023) — fewer KV heads than Q heads (`n_kv_heads < n_heads`), reducing KV-cache memory by `n_heads / n_kv_heads`. Uses **Flash Attention 2** (Dao et al., 2023) when `flash-attn>=2.8.3` is installed: GQA is handled natively (no KV head expansion), I/O-bound-optimal, with a transparent fallback to manual scaled dot-product attention when the package is absent. |
|
||||||
|
| `"mla"` | `MLAttention` | Multi-Latent Attention (DeepSeek-V2) — caches a compressed KV latent (`kv_lora_rank`) rather than full K/V, with split RoPE / no-RoPE head dims for position-aware compression. |
|
||||||
|
|
||||||
|
RoPE is applied to Q and K before caching, so cached values do not need to be re-rotated on retrieval.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Why This Explains Mythos
|
## Why This Explains Mythos
|
||||||
@ -371,6 +383,7 @@ Theoretical analysis suggests 2-3x improvements in inference throughput. For a d
|
|||||||
| Training stability | LTI-constrained injection parameters with spectral radius < 1 |
|
| Training stability | LTI-constrained injection parameters with spectral radius < 1 |
|
||||||
| Loop differentiation | Likely uses loop-index positional embedding (à la RoPE) per iteration |
|
| Loop differentiation | Likely uses loop-index positional embedding (à la RoPE) per iteration |
|
||||||
| Halting | Adaptive Computation Time or learned convergence criterion |
|
| Halting | Adaptive Computation Time or learned convergence criterion |
|
||||||
|
| Attention | GQA (with optional Flash Attention 2) or MLA with compressed KV latent cache |
|
||||||
| Scaling law | Optimal training scales looping and data together, not parameters alone |
|
| Scaling law | Optimal training scales looping and data together, not parameters alone |
|
||||||
| Reasoning vs. memory | Structurally biased toward composition; memorization requires separate treatment |
|
| Reasoning vs. memory | Structurally biased toward composition; memorization requires separate treatment |
|
||||||
| Deployment | Continuous Depth-wise Batching enables variable compute per request |
|
| Deployment | Continuous Depth-wise Batching enables variable compute per request |
|
||||||
|
|||||||
@ -252,7 +252,9 @@ class GQAttention(nn.Module):
|
|||||||
k = k.to(torch.bfloat16)
|
k = k.to(torch.bfloat16)
|
||||||
v = v.to(torch.bfloat16)
|
v = v.to(torch.bfloat16)
|
||||||
dropout_p = self.dropout_p if self.training else 0.0
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None))
|
out = flash_attn_func(
|
||||||
|
q, k, v, dropout_p=dropout_p, causal=(mask is not None)
|
||||||
|
)
|
||||||
out = out.to(orig_dtype).contiguous().view(B, T, -1)
|
out = out.to(orig_dtype).contiguous().view(B, T, -1)
|
||||||
else:
|
else:
|
||||||
# Fallback: manual scaled dot-product with explicit KV head expansion.
|
# Fallback: manual scaled dot-product with explicit KV head expansion.
|
||||||
@ -964,18 +966,27 @@ class OpenMythos(nn.Module):
|
|||||||
nn.init.normal_(m.weight, std=0.02)
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
def _causal_mask(
|
||||||
|
seq_len: int, device: torch.device, dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Build an additive causal mask: 0 on and below the diagonal, -inf above.
|
Build an additive causal mask: 0 on and below the diagonal, -inf above.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_len -- sequence length
|
seq_len -- sequence length
|
||||||
device -- target device
|
device -- target device
|
||||||
|
dtype -- tensor dtype (must match activation dtype so the additive
|
||||||
|
mask doesn't upcast the attention logits in the fallback
|
||||||
|
attention path — e.g. bf16 weights with an fp32 mask
|
||||||
|
promotes attn to fp32 and then breaks the fp32-vs-bf16
|
||||||
|
matmul against V)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S)
|
Tensor of shape (1, 1, seq_len, seq_len) broadcastable over (B, H, T, S)
|
||||||
"""
|
"""
|
||||||
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=device)
|
mask = torch.full(
|
||||||
|
(1, 1, seq_len, seq_len), float("-inf"), device=device, dtype=dtype
|
||||||
|
)
|
||||||
return torch.triu(mask, diagonal=1)
|
return torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1009,7 +1020,7 @@ class OpenMythos(nn.Module):
|
|||||||
freqs_cis = (
|
freqs_cis = (
|
||||||
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
|
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
|
||||||
)[start_pos : start_pos + T]
|
)[start_pos : start_pos + T]
|
||||||
mask = self._causal_mask(T, device) if T > 1 else None
|
mask = self._causal_mask(T, device, x.dtype) if T > 1 else None
|
||||||
|
|
||||||
for i, layer in enumerate(self.prelude):
|
for i, layer in enumerate(self.prelude):
|
||||||
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")
|
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"prelude_{i}")
|
||||||
|
|||||||
@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "open-mythos"
|
name = "open-mythos"
|
||||||
version = "0.4.0"
|
version = "0.5.0"
|
||||||
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
|
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
authors = ["Kye Gomez <kye@swarms.world>"]
|
authors = ["Kye Gomez <kye@swarms.world>"]
|
||||||
|
|||||||
480
tests/bench_vs_transformer.py
Normal file
480
tests/bench_vs_transformer.py
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
"""
|
||||||
|
OpenMythos vs. vanilla GQA+MoE transformer benchmark.
|
||||||
|
|
||||||
|
Compares OpenMythos (Prelude + looped Recurrent Block + Coda with ACT halting,
|
||||||
|
LTI-stable injection, LoRA depth adapter) against a parameter-matched vanilla
|
||||||
|
transformer built from the same GQAttention + MoEFFN building blocks stacked
|
||||||
|
non-recurrently. The baseline reuses OpenMythos primitives so the comparison
|
||||||
|
isolates the recurrent-depth architecture, not the kernels.
|
||||||
|
|
||||||
|
Metrics reported:
|
||||||
|
- Parameter counts (total, MoE-active approximation)
|
||||||
|
- Prefill latency + throughput at several sequence lengths
|
||||||
|
- Decode (autoregressive step) latency with KV cache
|
||||||
|
- Peak memory (CUDA only)
|
||||||
|
- OpenMythos depth-scaling sweep: latency vs. n_loops
|
||||||
|
|
||||||
|
Run:
|
||||||
|
python benchmarks/bench_vs_transformer.py # small CPU/GPU smoke test
|
||||||
|
python benchmarks/bench_vs_transformer.py --size 1b --device cuda
|
||||||
|
python benchmarks/bench_vs_transformer.py --seq-lens 128,512,2048 --n-loops 1,4,8,16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from open_mythos import MythosConfig, OpenMythos, mythos_1b
|
||||||
|
from open_mythos.main import (
|
||||||
|
RMSNorm,
|
||||||
|
TransformerBlock,
|
||||||
|
precompute_rope_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Baseline: non-looped GQA + MoE transformer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BaselineTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Vanilla decoder-only transformer with GQA attention and MoE FFNs, stacked
|
||||||
|
non-recurrently. Shares TransformerBlock / GQAttention / MoEFFN kernels
|
||||||
|
with OpenMythos so any speed delta is attributable to the recurrent-depth
|
||||||
|
architecture rather than the underlying attention/FFN implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg: MythosConfig, n_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[TransformerBlock(cfg, use_moe=True) for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
self.norm = RMSNorm(cfg.dim)
|
||||||
|
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
|
||||||
|
|
||||||
|
head_dim = cfg.dim // cfg.n_heads
|
||||||
|
self.register_buffer(
|
||||||
|
"freqs_cis",
|
||||||
|
precompute_rope_freqs(head_dim, cfg.max_seq_len, cfg.rope_theta),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _causal_mask(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
mask = torch.full((1, 1, T, T), float("-inf"), device=device, dtype=dtype)
|
||||||
|
return torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
start_pos: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
T = input_ids.shape[1]
|
||||||
|
x = self.embed(input_ids)
|
||||||
|
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
|
||||||
|
mask = self._causal_mask(T, x.device, x.dtype) if T > 1 else None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"layer_{i}")
|
||||||
|
return self.head(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Timing utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sync(device: torch.device) -> None:
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, device: torch.device, warmup: int = 2, trials: int = 5) -> float:
|
||||||
|
"""Returns median wall-clock seconds over `trials` after `warmup` runs."""
|
||||||
|
for _ in range(warmup):
|
||||||
|
fn()
|
||||||
|
_sync(device)
|
||||||
|
times = []
|
||||||
|
for _ in range(trials):
|
||||||
|
_sync(device)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
fn()
|
||||||
|
_sync(device)
|
||||||
|
times.append(time.perf_counter() - t0)
|
||||||
|
return statistics.median(times)
|
||||||
|
|
||||||
|
|
||||||
|
def peak_mem_mb(device: torch.device) -> float:
|
||||||
|
if device.type != "cuda":
|
||||||
|
return 0.0
|
||||||
|
return torch.cuda.max_memory_allocated() / (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_mem(device: torch.device) -> None:
|
||||||
|
gc.collect()
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parameter counting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamCounts:
|
||||||
|
total: int
|
||||||
|
moe_active_est: int # active per token (shared + top-k routed)
|
||||||
|
|
||||||
|
|
||||||
|
def count_params(model: nn.Module, cfg: MythosConfig) -> ParamCounts:
|
||||||
|
total = sum(p.numel() for p in model.parameters())
|
||||||
|
# Rough active-per-token count for MoE layers: shared + top-k routed fraction.
|
||||||
|
# For simplicity we report total and an estimated activation ratio separately.
|
||||||
|
active_ratio = (cfg.n_shared_experts + cfg.n_experts_per_tok) / (
|
||||||
|
cfg.n_shared_experts + cfg.n_experts
|
||||||
|
)
|
||||||
|
# Only FFN parameters shrink under activation; attention + embed/head are always on.
|
||||||
|
# This is a coarse lower bound on active params.
|
||||||
|
ffn_params = 0
|
||||||
|
other_params = 0
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if ".ffn." in name or name.startswith("ffn.") or ".experts." in name:
|
||||||
|
ffn_params += p.numel()
|
||||||
|
else:
|
||||||
|
other_params += p.numel()
|
||||||
|
active_est = other_params + int(ffn_params * active_ratio)
|
||||||
|
return ParamCounts(total=total, moe_active_est=active_est)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Benchmarks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def bench_prefill(
|
||||||
|
model: nn.Module,
|
||||||
|
vocab_size: int,
|
||||||
|
batch: int,
|
||||||
|
seq_len: int,
|
||||||
|
device: torch.device,
|
||||||
|
n_loops: Optional[int] = None,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""Returns (median_seconds, tokens_per_sec)."""
|
||||||
|
ids = torch.randint(0, vocab_size, (batch, seq_len), device=device)
|
||||||
|
|
||||||
|
if isinstance(model, OpenMythos):
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
with torch.no_grad():
|
||||||
|
model(ids, n_loops=n_loops)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
with torch.no_grad():
|
||||||
|
model(ids)
|
||||||
|
|
||||||
|
secs = time_fn(run, device)
|
||||||
|
tps = (batch * seq_len) / secs
|
||||||
|
return secs, tps
|
||||||
|
|
||||||
|
|
||||||
|
def bench_decode(
|
||||||
|
model: nn.Module,
|
||||||
|
vocab_size: int,
|
||||||
|
batch: int,
|
||||||
|
prompt_len: int,
|
||||||
|
decode_steps: int,
|
||||||
|
device: torch.device,
|
||||||
|
n_loops: Optional[int] = None,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Prefill a `prompt_len` prompt, then time `decode_steps` single-token decode
|
||||||
|
steps with KV cache. Returns (avg_seconds_per_step, decode_tokens_per_sec).
|
||||||
|
"""
|
||||||
|
prompt = torch.randint(0, vocab_size, (batch, prompt_len), device=device)
|
||||||
|
|
||||||
|
def one_run() -> None:
|
||||||
|
kv_cache: dict = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
if isinstance(model, OpenMythos):
|
||||||
|
model(prompt, n_loops=n_loops, kv_cache=kv_cache, start_pos=0)
|
||||||
|
else:
|
||||||
|
model(prompt, kv_cache=kv_cache, start_pos=0)
|
||||||
|
for i in range(decode_steps):
|
||||||
|
next_tok = torch.randint(0, vocab_size, (batch, 1), device=device)
|
||||||
|
if isinstance(model, OpenMythos):
|
||||||
|
model(
|
||||||
|
next_tok,
|
||||||
|
n_loops=n_loops,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
start_pos=prompt_len + i,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model(next_tok, kv_cache=kv_cache, start_pos=prompt_len + i)
|
||||||
|
|
||||||
|
secs = time_fn(one_run, device, warmup=1, trials=3)
|
||||||
|
per_step = secs / decode_steps
|
||||||
|
tps = batch * decode_steps / secs
|
||||||
|
return per_step, tps
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def small_cfg() -> MythosConfig:
|
||||||
|
"""Tiny config for smoke tests — runs on CPU in seconds."""
|
||||||
|
return MythosConfig(
|
||||||
|
vocab_size=1024,
|
||||||
|
dim=256,
|
||||||
|
n_heads=8,
|
||||||
|
n_kv_heads=2,
|
||||||
|
max_seq_len=1024,
|
||||||
|
max_loop_iters=4,
|
||||||
|
prelude_layers=1,
|
||||||
|
coda_layers=1,
|
||||||
|
attn_type="gqa",
|
||||||
|
n_experts=8,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_experts_per_tok=2,
|
||||||
|
expert_dim=128,
|
||||||
|
lora_rank=4,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg(size: str) -> MythosConfig:
|
||||||
|
size = size.lower()
|
||||||
|
if size == "small":
|
||||||
|
return small_cfg()
|
||||||
|
if size == "1b":
|
||||||
|
cfg = mythos_1b()
|
||||||
|
# GQA for apples-to-apples; MLA changes KV shape semantics.
|
||||||
|
cfg.attn_type = "gqa"
|
||||||
|
return cfg
|
||||||
|
raise ValueError(f"unknown size: {size!r} (use 'small' or '1b')")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reporting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def fmt_count(n: int) -> str:
|
||||||
|
for unit in ("", "K", "M", "B", "T"):
|
||||||
|
if abs(n) < 1000:
|
||||||
|
return f"{n:.2f}{unit}"
|
||||||
|
n /= 1000
|
||||||
|
return f"{n:.2f}P"
|
||||||
|
|
||||||
|
|
||||||
|
def print_header(title: str) -> None:
|
||||||
|
bar = "=" * 72
|
||||||
|
print(f"\n{bar}\n{title}\n{bar}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||||
|
p.add_argument("--size", default="small", choices=["small", "1b"])
|
||||||
|
p.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
choices=["cuda", "cpu"],
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "fp32", "bf16", "fp16"],
|
||||||
|
help="'auto' picks fp32 on CPU and bf16 on CUDA",
|
||||||
|
)
|
||||||
|
p.add_argument("--batch", type=int, default=1)
|
||||||
|
p.add_argument(
|
||||||
|
"--seq-lens",
|
||||||
|
default="128,512",
|
||||||
|
help="comma-separated prefill sequence lengths",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--n-loops",
|
||||||
|
default="1,4,8",
|
||||||
|
help="comma-separated loop counts to sweep (OpenMythos only)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--decode-steps",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="number of autoregressive decode steps after prefill",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--decode-prompt-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="prefill length before decode",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
device = torch.device(args.device)
|
||||||
|
dtype_arg = args.dtype
|
||||||
|
if dtype_arg == "auto":
|
||||||
|
dtype_arg = "bf16" if device.type == "cuda" else "fp32"
|
||||||
|
dtype = {
|
||||||
|
"fp32": torch.float32,
|
||||||
|
"bf16": torch.bfloat16,
|
||||||
|
"fp16": torch.float16,
|
||||||
|
}[dtype_arg]
|
||||||
|
|
||||||
|
seq_lens = [int(s) for s in args.seq_lens.split(",") if s.strip()]
|
||||||
|
n_loops_sweep = [int(s) for s in args.n_loops.split(",") if s.strip()]
|
||||||
|
|
||||||
|
cfg = get_cfg(args.size)
|
||||||
|
print_header(f"Config: size={args.size} device={device} dtype={dtype_arg}")
|
||||||
|
print(
|
||||||
|
f" dim={cfg.dim} n_heads={cfg.n_heads} n_kv_heads={cfg.n_kv_heads} "
|
||||||
|
f"prelude={cfg.prelude_layers} coda={cfg.coda_layers} "
|
||||||
|
f"max_loop_iters={cfg.max_loop_iters}\n"
|
||||||
|
f" experts={cfg.n_experts} shared={cfg.n_shared_experts} "
|
||||||
|
f"top_k={cfg.n_experts_per_tok} expert_dim={cfg.expert_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build models. Baseline depth = prelude + 1 (one unique recurrent block) + coda
|
||||||
|
# to match the unique-parameter depth of OpenMythos (parameter-matched baseline).
|
||||||
|
baseline_n_layers = cfg.prelude_layers + 1 + cfg.coda_layers
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
mythos = OpenMythos(cfg).to(device=device, dtype=dtype).eval()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
baseline = (
|
||||||
|
BaselineTransformer(cfg, n_layers=baseline_n_layers)
|
||||||
|
.to(device=device, dtype=dtype)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
m_params = count_params(mythos, cfg)
|
||||||
|
b_params = count_params(baseline, cfg)
|
||||||
|
print_header(
|
||||||
|
"Parameters (block-matched: baseline depth = prelude + 1 recurrent + coda)"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" OpenMythos : total={fmt_count(m_params.total):>10} "
|
||||||
|
f"active/tok≈{fmt_count(m_params.moe_active_est):>10}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Baseline : total={fmt_count(b_params.total):>10} "
|
||||||
|
f"active/tok≈{fmt_count(b_params.moe_active_est):>10}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Baseline unique layers = {baseline_n_layers} "
|
||||||
|
f"(Mythos total runtime depth at max_loops = "
|
||||||
|
f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Prefill ----
|
||||||
|
print_header("Prefill latency (batch={batch})".format(batch=args.batch))
|
||||||
|
header = f" {'model':<26} {'seq':>6} {'sec':>10} {'tok/s':>12} {'peak MB':>10}"
|
||||||
|
print(header)
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
if seq_len > cfg.max_seq_len:
|
||||||
|
print(f" skip seq_len={seq_len} (> max_seq_len={cfg.max_seq_len})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
reset_mem(device)
|
||||||
|
secs, tps = bench_prefill(baseline, cfg.vocab_size, args.batch, seq_len, device)
|
||||||
|
mem = peak_mem_mb(device)
|
||||||
|
print(
|
||||||
|
f" {'Baseline (stacked)':<26} {seq_len:>6} "
|
||||||
|
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for nl in n_loops_sweep:
|
||||||
|
reset_mem(device)
|
||||||
|
secs, tps = bench_prefill(
|
||||||
|
mythos, cfg.vocab_size, args.batch, seq_len, device, n_loops=nl
|
||||||
|
)
|
||||||
|
mem = peak_mem_mb(device)
|
||||||
|
print(
|
||||||
|
f" {'OpenMythos (loops=' + str(nl) + ')':<26} {seq_len:>6} "
|
||||||
|
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Decode ----
|
||||||
|
print_header(
|
||||||
|
f"Decode latency (prefill {args.decode_prompt_len} tokens + "
|
||||||
|
f"{args.decode_steps} decode steps, batch={args.batch})"
|
||||||
|
)
|
||||||
|
print(f" {'model':<26} {'sec/step':>12} {'decode tok/s':>14}")
|
||||||
|
|
||||||
|
reset_mem(device)
|
||||||
|
per_step, tps = bench_decode(
|
||||||
|
baseline,
|
||||||
|
cfg.vocab_size,
|
||||||
|
args.batch,
|
||||||
|
args.decode_prompt_len,
|
||||||
|
args.decode_steps,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
print(f" {'Baseline (stacked)':<26} {per_step*1000:>10.2f}ms {tps:>14,.1f}")
|
||||||
|
|
||||||
|
for nl in n_loops_sweep:
|
||||||
|
reset_mem(device)
|
||||||
|
per_step, tps = bench_decode(
|
||||||
|
mythos,
|
||||||
|
cfg.vocab_size,
|
||||||
|
args.batch,
|
||||||
|
args.decode_prompt_len,
|
||||||
|
args.decode_steps,
|
||||||
|
device,
|
||||||
|
n_loops=nl,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'OpenMythos (loops=' + str(nl) + ')':<26} "
|
||||||
|
f"{per_step*1000:>10.2f}ms {tps:>14,.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Depth scaling ----
|
||||||
|
print_header(
|
||||||
|
"OpenMythos depth scaling (fixed seq={}, batch={})".format(
|
||||||
|
seq_lens[0], args.batch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f" {'n_loops':>8} {'sec':>10} {'tok/s':>12} {'Δ vs loops=1':>14}")
|
||||||
|
base_secs = None
|
||||||
|
for nl in n_loops_sweep:
|
||||||
|
reset_mem(device)
|
||||||
|
secs, tps = bench_prefill(
|
||||||
|
mythos, cfg.vocab_size, args.batch, seq_lens[0], device, n_loops=nl
|
||||||
|
)
|
||||||
|
if base_secs is None:
|
||||||
|
base_secs = secs
|
||||||
|
delta = "1.00x"
|
||||||
|
else:
|
||||||
|
delta = f"{secs / base_secs:.2f}x"
|
||||||
|
print(f" {nl:>8} {secs*1000:>9.2f}ms {tps:>12,.0f} {delta:>14}")
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
536
tests/small_benchmark.py
Normal file
536
tests/small_benchmark.py
Normal file
@ -0,0 +1,536 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Side-by-side training of OpenMythos vs. a vanilla GQA transformer on a small
|
||||||
|
HuggingFace dataset (wikitext-2 by default).
|
||||||
|
|
||||||
|
Both models share the same tiny config and see the exact same batches in the
|
||||||
|
same order, so per-step loss + throughput are directly comparable. The baseline
|
||||||
|
is a dense GQA + SwiGLU stack whose unique-layer depth matches the recurrent
|
||||||
|
block's unique-parameter depth (prelude + 1 + coda), so parameter counts land
|
||||||
|
in the same ballpark.
|
||||||
|
|
||||||
|
python training/small_benchmark.py
|
||||||
|
python training/small_benchmark.py --steps 500 --device cuda
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from open_mythos import MythosConfig, OpenMythos
|
||||||
|
from open_mythos.main import (
|
||||||
|
RMSNorm,
|
||||||
|
TransformerBlock,
|
||||||
|
precompute_rope_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Baseline: dense GQA + SwiGLU transformer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BaselineTransformer(nn.Module):
|
||||||
|
"""Vanilla decoder-only transformer with dense SwiGLU FFNs.
|
||||||
|
|
||||||
|
Reuses OpenMythos's TransformerBlock (attention + FFN kernels are identical)
|
||||||
|
so any measured delta reflects the looped recurrent-depth architecture, not
|
||||||
|
kernel differences. Supports both attn_type="gqa" and "mla".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg: MythosConfig, n_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[TransformerBlock(cfg, use_moe=False) for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
self.norm = RMSNorm(cfg.dim)
|
||||||
|
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
|
||||||
|
self.head.weight = self.embed.weight # weight tying
|
||||||
|
|
||||||
|
# MLA applies RoPE to qk_rope_head_dim only; GQA rotates the full head_dim.
|
||||||
|
rope_dim = (
|
||||||
|
cfg.qk_rope_head_dim if cfg.attn_type == "mla" else cfg.dim // cfg.n_heads
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"freqs_cis",
|
||||||
|
precompute_rope_freqs(rope_dim, cfg.max_seq_len, cfg.rope_theta),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self) -> None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
elif isinstance(m, nn.Embedding):
|
||||||
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _causal_mask(T: int, device: torch.device) -> torch.Tensor:
|
||||||
|
mask = torch.full((1, 1, T, T), float("-inf"), device=device)
|
||||||
|
return torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
T = input_ids.shape[1]
|
||||||
|
x = self.embed(input_ids)
|
||||||
|
freqs_cis = self.freqs_cis[:T]
|
||||||
|
mask = self._causal_mask(T, x.device) if T > 1 else None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = layer(x, freqs_cis, mask, cache_key=f"layer_{i}")
|
||||||
|
return self.head(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset: tokenize once, pack into fixed-length next-token pairs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PackedLMDataset(Dataset):
|
||||||
|
"""Flatten an HF text dataset into one token buffer, slice fixed-length pairs.
|
||||||
|
|
||||||
|
Accepts either map-style or streaming (`IterableDataset`) HF datasets —
|
||||||
|
iteration stops once `max_tokens` are collected, so large corpora like
|
||||||
|
TinyStories can be streamed without downloading the whole thing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hf_ds,
|
||||||
|
tokenizer,
|
||||||
|
seq_len: int,
|
||||||
|
max_tokens: int,
|
||||||
|
text_field: str = "text",
|
||||||
|
):
|
||||||
|
buf: list[int] = []
|
||||||
|
for sample in hf_ds:
|
||||||
|
text = sample[text_field]
|
||||||
|
if not text or not text.strip():
|
||||||
|
continue
|
||||||
|
buf.extend(tokenizer.encode(text, add_special_tokens=False))
|
||||||
|
if len(buf) >= max_tokens:
|
||||||
|
break
|
||||||
|
self.seq_len = seq_len
|
||||||
|
n_pairs = max(1, (len(buf) - 1) // seq_len)
|
||||||
|
buf = buf[: n_pairs * seq_len + 1]
|
||||||
|
self.data = torch.tensor(buf, dtype=torch.long)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return (len(self.data) - 1) // self.seq_len
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
s = idx * self.seq_len
|
||||||
|
chunk = self.data[s : s + self.seq_len + 1]
|
||||||
|
return chunk[:-1], chunk[1:]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Metrics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metrics:
|
||||||
|
total_loss: float = 0.0
|
||||||
|
total_tokens: int = 0
|
||||||
|
total_time: float = 0.0
|
||||||
|
steps: int = 0
|
||||||
|
first_losses: list[float] = field(default_factory=list)
|
||||||
|
last_losses: Deque[float] = field(default_factory=lambda: deque(maxlen=10))
|
||||||
|
|
||||||
|
def update(self, loss: float, tokens: int, seconds: float) -> None:
|
||||||
|
self.total_loss += loss
|
||||||
|
self.total_tokens += tokens
|
||||||
|
self.total_time += seconds
|
||||||
|
self.steps += 1
|
||||||
|
if len(self.first_losses) < 10:
|
||||||
|
self.first_losses.append(loss)
|
||||||
|
self.last_losses.append(loss)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_loss(self) -> float:
|
||||||
|
return self.total_loss / max(1, self.steps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tok_per_sec(self) -> float:
|
||||||
|
return self.total_tokens / max(1e-9, self.total_time)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initial_loss(self) -> float:
|
||||||
|
return sum(self.first_losses) / max(1, len(self.first_losses))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_loss(self) -> float:
|
||||||
|
return sum(self.last_losses) / max(1, len(self.last_losses))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Training step
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def train_step(
|
||||||
|
model: nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: torch.device,
|
||||||
|
vocab_size: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""Run one optimizer step; return (loss, wall-clock seconds)."""
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(x)
|
||||||
|
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return loss.item(), time.perf_counter() - t0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
vocab_size: int,
|
||||||
|
max_batches: int | None = None,
|
||||||
|
n_loops: int | None = None,
|
||||||
|
) -> float:
|
||||||
|
"""Mean cross-entropy over (up to `max_batches`) of the loader.
|
||||||
|
|
||||||
|
`n_loops` is only forwarded to OpenMythos; for any other module the kwarg
|
||||||
|
is dropped, so the same function benchmarks baseline and mythos uniformly.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
total_tokens = 0
|
||||||
|
for i, (x, y) in enumerate(loader):
|
||||||
|
if max_batches is not None and i >= max_batches:
|
||||||
|
break
|
||||||
|
x = x.to(device, non_blocking=True)
|
||||||
|
y = y.to(device, non_blocking=True)
|
||||||
|
if isinstance(model, OpenMythos):
|
||||||
|
logits = model(x, n_loops=n_loops)
|
||||||
|
else:
|
||||||
|
logits = model(x)
|
||||||
|
# sum-reduction so we weight by token count, not batch count
|
||||||
|
loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1), reduction="sum")
|
||||||
|
total_loss += loss.item()
|
||||||
|
total_tokens += y.numel()
|
||||||
|
return total_loss / max(1, total_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config + utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def build_tiny_cfg(vocab_size: int, seq_len: int) -> MythosConfig:
|
||||||
|
"""Tiny shared config with MLA attention — runs in reasonable time on CPU.
|
||||||
|
|
||||||
|
MLA LoRA ranks and head dims scale with `dim=128` instead of the
|
||||||
|
2048-dim-sized defaults (q_lora_rank=1536, qk_nope_head_dim=128, ...),
|
||||||
|
which would otherwise dominate the parameter count at this scale.
|
||||||
|
"""
|
||||||
|
return MythosConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
dim=128,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
max_seq_len=seq_len,
|
||||||
|
max_loop_iters=4,
|
||||||
|
prelude_layers=1,
|
||||||
|
coda_layers=1,
|
||||||
|
attn_type="mla",
|
||||||
|
kv_lora_rank=64,
|
||||||
|
q_lora_rank=128,
|
||||||
|
qk_rope_head_dim=16,
|
||||||
|
qk_nope_head_dim=32,
|
||||||
|
v_head_dim=32,
|
||||||
|
n_experts=4,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_experts_per_tok=2,
|
||||||
|
expert_dim=128,
|
||||||
|
lora_rank=4,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_params(model: nn.Module) -> int:
|
||||||
|
return sum(p.numel() for p in model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
def fmt_count(n: float) -> str:
|
||||||
|
for unit in ("", "K", "M", "B"):
|
||||||
|
if abs(n) < 1000:
|
||||||
|
return f"{n:.2f}{unit}"
|
||||||
|
n /= 1000
|
||||||
|
return f"{n:.2f}T"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||||
|
p.add_argument("--steps", type=int, default=1000)
|
||||||
|
p.add_argument("--batch-size", type=int, default=32)
|
||||||
|
p.add_argument("--seq-len", type=int, default=256)
|
||||||
|
p.add_argument("--lr", type=float, default=3e-4)
|
||||||
|
# Defaults point at TinyStories — simpler vocabulary + shorter documents
|
||||||
|
# lets a dim=128 model actually reach a meaningful loss in modest time.
|
||||||
|
p.add_argument("--dataset", default="roneneldan/TinyStories")
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset-config",
|
||||||
|
default="",
|
||||||
|
help="pass '' for datasets with no config (e.g. TinyStories)",
|
||||||
|
)
|
||||||
|
p.add_argument("--train-split", default="train")
|
||||||
|
p.add_argument("--eval-split", default="validation")
|
||||||
|
p.add_argument(
|
||||||
|
"--train-tokens",
|
||||||
|
type=int,
|
||||||
|
default=5_000_000,
|
||||||
|
help="max tokens to materialize for the training buffer",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--eval-tokens",
|
||||||
|
type=int,
|
||||||
|
default=200_000,
|
||||||
|
help="max tokens to materialize for the held-out eval buffer",
|
||||||
|
)
|
||||||
|
p.add_argument("--text-field", default="text")
|
||||||
|
p.add_argument("--tokenizer", default="gpt2")
|
||||||
|
p.add_argument("--log-every", type=int, default=25)
|
||||||
|
p.add_argument(
|
||||||
|
"--eval-every",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="run held-out eval every N steps (0 disables)",
|
||||||
|
)
|
||||||
|
p.add_argument("--eval-batches", type=int, default=20)
|
||||||
|
p.add_argument(
|
||||||
|
"--depth-sweep",
|
||||||
|
default="1,2,4,8,16",
|
||||||
|
help="comma-separated n_loops values for OpenMythos depth-extrapolation eval",
|
||||||
|
)
|
||||||
|
p.add_argument("--seed", type=int, default=0)
|
||||||
|
p.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_ds(name: str, config: str, split: str):
|
||||||
|
"""Streaming `load_dataset` with optional config (empty string == no config)."""
|
||||||
|
if config:
|
||||||
|
return load_dataset(name, config, split=split, streaming=True)
|
||||||
|
return load_dataset(name, split=split, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[setup] device={device} batch={args.batch_size} "
|
||||||
|
f"seq_len={args.seq_len} steps={args.steps}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
||||||
|
# AutoTokenizer.vocab_size can be smaller than the head size for BPE
|
||||||
|
# tokenizers with added tokens; use len(tokenizer) to be safe.
|
||||||
|
vocab_size = len(tokenizer)
|
||||||
|
print(f"[setup] tokenizer={args.tokenizer} vocab_size={vocab_size:,}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Data: streamed train + held-out eval splits
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
print(f"[setup] dataset={args.dataset} config={args.dataset_config or '∅'}")
|
||||||
|
raw_train = load_text_ds(args.dataset, args.dataset_config, args.train_split)
|
||||||
|
train_ds = PackedLMDataset(
|
||||||
|
raw_train, tokenizer, args.seq_len, args.train_tokens, args.text_field
|
||||||
|
)
|
||||||
|
raw_eval = load_text_ds(args.dataset, args.dataset_config, args.eval_split)
|
||||||
|
eval_ds = PackedLMDataset(
|
||||||
|
raw_eval, tokenizer, args.seq_len, args.eval_tokens, args.text_field
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[setup] train tokens={train_ds.data.numel():,} pairs={len(train_ds)} | "
|
||||||
|
f"eval tokens={eval_ds.data.numel():,} pairs={len(eval_ds)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True
|
||||||
|
)
|
||||||
|
eval_loader = DataLoader(
|
||||||
|
eval_ds, batch_size=args.batch_size, shuffle=False, drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Models — same init seed so both start from the same embedding
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
cfg = build_tiny_cfg(vocab_size, args.seq_len)
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
mythos = OpenMythos(cfg).to(device)
|
||||||
|
|
||||||
|
# Parameter-matched depth: prelude + one unique recurrent block + coda.
|
||||||
|
baseline_layers = cfg.prelude_layers + 1 + cfg.coda_layers
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
baseline = BaselineTransformer(cfg, n_layers=baseline_layers).to(device)
|
||||||
|
|
||||||
|
n_m, n_b = count_params(mythos), count_params(baseline)
|
||||||
|
print(
|
||||||
|
f"[setup] OpenMythos params = {fmt_count(n_m)} ({n_m:,})\n"
|
||||||
|
f"[setup] Baseline params = {fmt_count(n_b)} ({n_b:,}) "
|
||||||
|
f"[{baseline_layers} layers]"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[setup] Mythos runtime depth = prelude({cfg.prelude_layers}) + "
|
||||||
|
f"loops({cfg.max_loop_iters}) + coda({cfg.coda_layers}) = "
|
||||||
|
f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers}"
|
||||||
|
)
|
||||||
|
|
||||||
|
opt_m = torch.optim.AdamW(
|
||||||
|
mythos.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1
|
||||||
|
)
|
||||||
|
opt_b = torch.optim.AdamW(
|
||||||
|
baseline.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
mm, bm = Metrics(), Metrics()
|
||||||
|
eval_history: list[tuple[int, float, float]] = [] # (step, mythos_eval, base_eval)
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"\n{'step':>6} | {'mythos loss':>12} | {'base loss':>10} | "
|
||||||
|
f"{'mythos tok/s':>13} | {'base tok/s':>11}"
|
||||||
|
)
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Training loop with periodic held-out eval
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
data_iter = iter(train_loader)
|
||||||
|
t_total = time.perf_counter()
|
||||||
|
for step in range(1, args.steps + 1):
|
||||||
|
try:
|
||||||
|
x, y = next(data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
data_iter = iter(train_loader)
|
||||||
|
x, y = next(data_iter)
|
||||||
|
x = x.to(device, non_blocking=True)
|
||||||
|
y = y.to(device, non_blocking=True)
|
||||||
|
tokens = x.numel()
|
||||||
|
|
||||||
|
loss_m, dt_m = train_step(mythos, x, y, opt_m, device, vocab_size)
|
||||||
|
loss_b, dt_b = train_step(baseline, x, y, opt_b, device, vocab_size)
|
||||||
|
|
||||||
|
mm.update(loss_m, tokens, dt_m)
|
||||||
|
bm.update(loss_b, tokens, dt_b)
|
||||||
|
|
||||||
|
if step == 1 or step % args.log_every == 0:
|
||||||
|
print(
|
||||||
|
f"{step:>6} | {loss_m:>12.4f} | {loss_b:>10.4f} | "
|
||||||
|
f"{tokens / dt_m:>13,.0f} | {tokens / dt_b:>11,.0f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.eval_every and step % args.eval_every == 0:
|
||||||
|
eval_m = evaluate(
|
||||||
|
mythos, eval_loader, device, vocab_size, args.eval_batches
|
||||||
|
)
|
||||||
|
eval_b = evaluate(
|
||||||
|
baseline, eval_loader, device, vocab_size, args.eval_batches
|
||||||
|
)
|
||||||
|
eval_history.append((step, eval_m, eval_b))
|
||||||
|
print(
|
||||||
|
f" [eval @ step {step}] mythos {eval_m:.4f} baseline {eval_b:.4f} "
|
||||||
|
f"(Δ = {eval_m - eval_b:+.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_wall = time.perf_counter() - t_total
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Summary
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
bar = "=" * 70
|
||||||
|
print(f"\n{bar}\nSummary ({args.steps} steps, wall clock {total_wall:.1f}s)\n{bar}")
|
||||||
|
print(f" {'':<24} {'OpenMythos':>16} {'Baseline':>16}")
|
||||||
|
print(f" {'params':<24} {fmt_count(n_m):>16} {fmt_count(n_b):>16}")
|
||||||
|
print(
|
||||||
|
f" {'initial train (first 10)':<24} "
|
||||||
|
f"{mm.initial_loss:>16.4f} {bm.initial_loss:>16.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'final train (last 10)':<24} "
|
||||||
|
f"{mm.final_loss:>16.4f} {bm.final_loss:>16.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'avg train (all steps)':<24} "
|
||||||
|
f"{mm.avg_loss:>16.4f} {bm.avg_loss:>16.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'train time (sec)':<24} "
|
||||||
|
f"{mm.total_time:>16.2f} {bm.total_time:>16.2f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'avg tok/s':<24} " f"{mm.tok_per_sec:>16,.0f} {bm.tok_per_sec:>16,.0f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" {'sec/step':<24} "
|
||||||
|
f"{mm.total_time / max(1, mm.steps):>16.4f} "
|
||||||
|
f"{bm.total_time / max(1, bm.steps):>16.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Depth extrapolation: OpenMythos eval loss as a function of n_loops.
|
||||||
|
# Trained at cfg.max_loop_iters; we run inference with a sweep to see
|
||||||
|
# whether additional loops keep improving (depth extrapolation) or the
|
||||||
|
# model collapses outside the trained regime.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
loops_sweep = sorted({int(s) for s in args.depth_sweep.split(",") if s.strip()})
|
||||||
|
print(f"\n{bar}\nDepth extrapolation (held-out eval, full eval set)\n{bar}")
|
||||||
|
baseline_eval = evaluate(baseline, eval_loader, device, vocab_size)
|
||||||
|
print(f" Baseline (fixed depth) : eval loss = {baseline_eval:.4f}")
|
||||||
|
# First collect all sweep losses, then print with deltas vs. the trained depth.
|
||||||
|
sweep: list[tuple[int, float]] = []
|
||||||
|
for nl in loops_sweep:
|
||||||
|
sweep.append(
|
||||||
|
(nl, evaluate(mythos, eval_loader, device, vocab_size, n_loops=nl))
|
||||||
|
)
|
||||||
|
trained_loss = next((loss for nl, loss in sweep if nl == cfg.max_loop_iters), None)
|
||||||
|
print(f" OpenMythos (trained at n_loops={cfg.max_loop_iters}):")
|
||||||
|
print(f" {'n_loops':>8} {'eval loss':>10} {'Δ vs trained':>14}")
|
||||||
|
for nl, loss in sweep:
|
||||||
|
if trained_loss is None or nl == cfg.max_loop_iters:
|
||||||
|
delta_str = ""
|
||||||
|
else:
|
||||||
|
delta_str = f"{loss - trained_loss:+.4f}"
|
||||||
|
marker = " ←trained" if nl == cfg.max_loop_iters else ""
|
||||||
|
print(f" {nl:>8} {loss:>10.4f} {delta_str:>14}{marker}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user