mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
481 lines
15 KiB
Python
481 lines
15 KiB
Python
"""
|
|
OpenMythos vs. vanilla GQA+MoE transformer benchmark.
|
|
|
|
Compares OpenMythos (Prelude + looped Recurrent Block + Coda with ACT halting,
|
|
LTI-stable injection, LoRA depth adapter) against a parameter-matched vanilla
|
|
transformer built from the same GQAttention + MoEFFN building blocks stacked
|
|
non-recurrently. The baseline reuses OpenMythos primitives so the comparison
|
|
isolates the recurrent-depth architecture, not the kernels.
|
|
|
|
Metrics reported:
|
|
- Parameter counts (total, MoE-active approximation)
|
|
- Prefill latency + throughput at several sequence lengths
|
|
- Decode (autoregressive step) latency with KV cache
|
|
- Peak memory (CUDA only)
|
|
- OpenMythos depth-scaling sweep: latency vs. n_loops
|
|
|
|
Run:
|
|
python benchmarks/bench_vs_transformer.py # small CPU/GPU smoke test
|
|
python benchmarks/bench_vs_transformer.py --size 1b --device cuda
|
|
python benchmarks/bench_vs_transformer.py --seq-lens 128,512,2048 --n-loops 1,4,8,16
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import gc
|
|
import statistics
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from open_mythos import MythosConfig, OpenMythos, mythos_1b
|
|
from open_mythos.main import (
|
|
RMSNorm,
|
|
TransformerBlock,
|
|
precompute_rope_freqs,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Baseline: non-looped GQA + MoE transformer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class BaselineTransformer(nn.Module):
|
|
"""
|
|
Vanilla decoder-only transformer with GQA attention and MoE FFNs, stacked
|
|
non-recurrently. Shares TransformerBlock / GQAttention / MoEFFN kernels
|
|
with OpenMythos so any speed delta is attributable to the recurrent-depth
|
|
architecture rather than the underlying attention/FFN implementation.
|
|
"""
|
|
|
|
def __init__(self, cfg: MythosConfig, n_layers: int):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.n_layers = n_layers
|
|
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
|
|
self.layers = nn.ModuleList(
|
|
[TransformerBlock(cfg, use_moe=True) for _ in range(n_layers)]
|
|
)
|
|
self.norm = RMSNorm(cfg.dim)
|
|
self.head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
|
|
|
|
head_dim = cfg.dim // cfg.n_heads
|
|
self.register_buffer(
|
|
"freqs_cis",
|
|
precompute_rope_freqs(head_dim, cfg.max_seq_len, cfg.rope_theta),
|
|
persistent=False,
|
|
)
|
|
|
|
@staticmethod
|
|
def _causal_mask(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
|
mask = torch.full((1, 1, T, T), float("-inf"), device=device, dtype=dtype)
|
|
return torch.triu(mask, diagonal=1)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
kv_cache: Optional[dict] = None,
|
|
start_pos: int = 0,
|
|
) -> torch.Tensor:
|
|
T = input_ids.shape[1]
|
|
x = self.embed(input_ids)
|
|
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
|
|
mask = self._causal_mask(T, x.device, x.dtype) if T > 1 else None
|
|
for i, layer in enumerate(self.layers):
|
|
x = layer(x, freqs_cis, mask, kv_cache, cache_key=f"layer_{i}")
|
|
return self.head(self.norm(x))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Timing utilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _sync(device: torch.device) -> None:
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def time_fn(fn, device: torch.device, warmup: int = 2, trials: int = 5) -> float:
|
|
"""Returns median wall-clock seconds over `trials` after `warmup` runs."""
|
|
for _ in range(warmup):
|
|
fn()
|
|
_sync(device)
|
|
times = []
|
|
for _ in range(trials):
|
|
_sync(device)
|
|
t0 = time.perf_counter()
|
|
fn()
|
|
_sync(device)
|
|
times.append(time.perf_counter() - t0)
|
|
return statistics.median(times)
|
|
|
|
|
|
def peak_mem_mb(device: torch.device) -> float:
|
|
if device.type != "cuda":
|
|
return 0.0
|
|
return torch.cuda.max_memory_allocated() / (1024 * 1024)
|
|
|
|
|
|
def reset_mem(device: torch.device) -> None:
|
|
gc.collect()
|
|
if device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parameter counting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class ParamCounts:
|
|
total: int
|
|
moe_active_est: int # active per token (shared + top-k routed)
|
|
|
|
|
|
def count_params(model: nn.Module, cfg: MythosConfig) -> ParamCounts:
|
|
total = sum(p.numel() for p in model.parameters())
|
|
# Rough active-per-token count for MoE layers: shared + top-k routed fraction.
|
|
# For simplicity we report total and an estimated activation ratio separately.
|
|
active_ratio = (cfg.n_shared_experts + cfg.n_experts_per_tok) / (
|
|
cfg.n_shared_experts + cfg.n_experts
|
|
)
|
|
# Only FFN parameters shrink under activation; attention + embed/head are always on.
|
|
# This is a coarse lower bound on active params.
|
|
ffn_params = 0
|
|
other_params = 0
|
|
for name, p in model.named_parameters():
|
|
if ".ffn." in name or name.startswith("ffn.") or ".experts." in name:
|
|
ffn_params += p.numel()
|
|
else:
|
|
other_params += p.numel()
|
|
active_est = other_params + int(ffn_params * active_ratio)
|
|
return ParamCounts(total=total, moe_active_est=active_est)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Benchmarks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def bench_prefill(
|
|
model: nn.Module,
|
|
vocab_size: int,
|
|
batch: int,
|
|
seq_len: int,
|
|
device: torch.device,
|
|
n_loops: Optional[int] = None,
|
|
) -> tuple[float, float]:
|
|
"""Returns (median_seconds, tokens_per_sec)."""
|
|
ids = torch.randint(0, vocab_size, (batch, seq_len), device=device)
|
|
|
|
if isinstance(model, OpenMythos):
|
|
|
|
def run() -> None:
|
|
with torch.no_grad():
|
|
model(ids, n_loops=n_loops)
|
|
|
|
else:
|
|
|
|
def run() -> None:
|
|
with torch.no_grad():
|
|
model(ids)
|
|
|
|
secs = time_fn(run, device)
|
|
tps = (batch * seq_len) / secs
|
|
return secs, tps
|
|
|
|
|
|
def bench_decode(
|
|
model: nn.Module,
|
|
vocab_size: int,
|
|
batch: int,
|
|
prompt_len: int,
|
|
decode_steps: int,
|
|
device: torch.device,
|
|
n_loops: Optional[int] = None,
|
|
) -> tuple[float, float]:
|
|
"""
|
|
Prefill a `prompt_len` prompt, then time `decode_steps` single-token decode
|
|
steps with KV cache. Returns (avg_seconds_per_step, decode_tokens_per_sec).
|
|
"""
|
|
prompt = torch.randint(0, vocab_size, (batch, prompt_len), device=device)
|
|
|
|
def one_run() -> None:
|
|
kv_cache: dict = {}
|
|
with torch.no_grad():
|
|
if isinstance(model, OpenMythos):
|
|
model(prompt, n_loops=n_loops, kv_cache=kv_cache, start_pos=0)
|
|
else:
|
|
model(prompt, kv_cache=kv_cache, start_pos=0)
|
|
for i in range(decode_steps):
|
|
next_tok = torch.randint(0, vocab_size, (batch, 1), device=device)
|
|
if isinstance(model, OpenMythos):
|
|
model(
|
|
next_tok,
|
|
n_loops=n_loops,
|
|
kv_cache=kv_cache,
|
|
start_pos=prompt_len + i,
|
|
)
|
|
else:
|
|
model(next_tok, kv_cache=kv_cache, start_pos=prompt_len + i)
|
|
|
|
secs = time_fn(one_run, device, warmup=1, trials=3)
|
|
per_step = secs / decode_steps
|
|
tps = batch * decode_steps / secs
|
|
return per_step, tps
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def small_cfg() -> MythosConfig:
|
|
"""Tiny config for smoke tests — runs on CPU in seconds."""
|
|
return MythosConfig(
|
|
vocab_size=1024,
|
|
dim=256,
|
|
n_heads=8,
|
|
n_kv_heads=2,
|
|
max_seq_len=1024,
|
|
max_loop_iters=4,
|
|
prelude_layers=1,
|
|
coda_layers=1,
|
|
attn_type="gqa",
|
|
n_experts=8,
|
|
n_shared_experts=1,
|
|
n_experts_per_tok=2,
|
|
expert_dim=128,
|
|
lora_rank=4,
|
|
dropout=0.0,
|
|
)
|
|
|
|
|
|
def get_cfg(size: str) -> MythosConfig:
|
|
size = size.lower()
|
|
if size == "small":
|
|
return small_cfg()
|
|
if size == "1b":
|
|
cfg = mythos_1b()
|
|
# GQA for apples-to-apples; MLA changes KV shape semantics.
|
|
cfg.attn_type = "gqa"
|
|
return cfg
|
|
raise ValueError(f"unknown size: {size!r} (use 'small' or '1b')")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reporting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def fmt_count(n: int) -> str:
|
|
for unit in ("", "K", "M", "B", "T"):
|
|
if abs(n) < 1000:
|
|
return f"{n:.2f}{unit}"
|
|
n /= 1000
|
|
return f"{n:.2f}P"
|
|
|
|
|
|
def print_header(title: str) -> None:
|
|
bar = "=" * 72
|
|
print(f"\n{bar}\n{title}\n{bar}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
|
p.add_argument("--size", default="small", choices=["small", "1b"])
|
|
p.add_argument(
|
|
"--device",
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
choices=["cuda", "cpu"],
|
|
)
|
|
p.add_argument(
|
|
"--dtype",
|
|
default="auto",
|
|
choices=["auto", "fp32", "bf16", "fp16"],
|
|
help="'auto' picks fp32 on CPU and bf16 on CUDA",
|
|
)
|
|
p.add_argument("--batch", type=int, default=1)
|
|
p.add_argument(
|
|
"--seq-lens",
|
|
default="128,512",
|
|
help="comma-separated prefill sequence lengths",
|
|
)
|
|
p.add_argument(
|
|
"--n-loops",
|
|
default="1,4,8",
|
|
help="comma-separated loop counts to sweep (OpenMythos only)",
|
|
)
|
|
p.add_argument(
|
|
"--decode-steps",
|
|
type=int,
|
|
default=32,
|
|
help="number of autoregressive decode steps after prefill",
|
|
)
|
|
p.add_argument(
|
|
"--decode-prompt-len",
|
|
type=int,
|
|
default=128,
|
|
help="prefill length before decode",
|
|
)
|
|
return p.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
device = torch.device(args.device)
|
|
dtype_arg = args.dtype
|
|
if dtype_arg == "auto":
|
|
dtype_arg = "bf16" if device.type == "cuda" else "fp32"
|
|
dtype = {
|
|
"fp32": torch.float32,
|
|
"bf16": torch.bfloat16,
|
|
"fp16": torch.float16,
|
|
}[dtype_arg]
|
|
|
|
seq_lens = [int(s) for s in args.seq_lens.split(",") if s.strip()]
|
|
n_loops_sweep = [int(s) for s in args.n_loops.split(",") if s.strip()]
|
|
|
|
cfg = get_cfg(args.size)
|
|
print_header(f"Config: size={args.size} device={device} dtype={dtype_arg}")
|
|
print(
|
|
f" dim={cfg.dim} n_heads={cfg.n_heads} n_kv_heads={cfg.n_kv_heads} "
|
|
f"prelude={cfg.prelude_layers} coda={cfg.coda_layers} "
|
|
f"max_loop_iters={cfg.max_loop_iters}\n"
|
|
f" experts={cfg.n_experts} shared={cfg.n_shared_experts} "
|
|
f"top_k={cfg.n_experts_per_tok} expert_dim={cfg.expert_dim}"
|
|
)
|
|
|
|
# Build models. Baseline depth = prelude + 1 (one unique recurrent block) + coda
|
|
# to match the unique-parameter depth of OpenMythos (parameter-matched baseline).
|
|
baseline_n_layers = cfg.prelude_layers + 1 + cfg.coda_layers
|
|
|
|
torch.manual_seed(0)
|
|
mythos = OpenMythos(cfg).to(device=device, dtype=dtype).eval()
|
|
torch.manual_seed(0)
|
|
baseline = (
|
|
BaselineTransformer(cfg, n_layers=baseline_n_layers)
|
|
.to(device=device, dtype=dtype)
|
|
.eval()
|
|
)
|
|
|
|
m_params = count_params(mythos, cfg)
|
|
b_params = count_params(baseline, cfg)
|
|
print_header(
|
|
"Parameters (block-matched: baseline depth = prelude + 1 recurrent + coda)"
|
|
)
|
|
print(
|
|
f" OpenMythos : total={fmt_count(m_params.total):>10} "
|
|
f"active/tok≈{fmt_count(m_params.moe_active_est):>10}"
|
|
)
|
|
print(
|
|
f" Baseline : total={fmt_count(b_params.total):>10} "
|
|
f"active/tok≈{fmt_count(b_params.moe_active_est):>10}"
|
|
)
|
|
print(
|
|
f" Baseline unique layers = {baseline_n_layers} "
|
|
f"(Mythos total runtime depth at max_loops = "
|
|
f"{cfg.prelude_layers + cfg.max_loop_iters + cfg.coda_layers})"
|
|
)
|
|
|
|
# ---- Prefill ----
|
|
print_header("Prefill latency (batch={batch})".format(batch=args.batch))
|
|
header = f" {'model':<26} {'seq':>6} {'sec':>10} {'tok/s':>12} {'peak MB':>10}"
|
|
print(header)
|
|
for seq_len in seq_lens:
|
|
if seq_len > cfg.max_seq_len:
|
|
print(f" skip seq_len={seq_len} (> max_seq_len={cfg.max_seq_len})")
|
|
continue
|
|
|
|
reset_mem(device)
|
|
secs, tps = bench_prefill(baseline, cfg.vocab_size, args.batch, seq_len, device)
|
|
mem = peak_mem_mb(device)
|
|
print(
|
|
f" {'Baseline (stacked)':<26} {seq_len:>6} "
|
|
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
|
|
)
|
|
|
|
for nl in n_loops_sweep:
|
|
reset_mem(device)
|
|
secs, tps = bench_prefill(
|
|
mythos, cfg.vocab_size, args.batch, seq_len, device, n_loops=nl
|
|
)
|
|
mem = peak_mem_mb(device)
|
|
print(
|
|
f" {'OpenMythos (loops=' + str(nl) + ')':<26} {seq_len:>6} "
|
|
f"{secs*1000:>9.2f}ms {tps:>12,.0f} {mem:>10.1f}"
|
|
)
|
|
|
|
# ---- Decode ----
|
|
print_header(
|
|
f"Decode latency (prefill {args.decode_prompt_len} tokens + "
|
|
f"{args.decode_steps} decode steps, batch={args.batch})"
|
|
)
|
|
print(f" {'model':<26} {'sec/step':>12} {'decode tok/s':>14}")
|
|
|
|
reset_mem(device)
|
|
per_step, tps = bench_decode(
|
|
baseline,
|
|
cfg.vocab_size,
|
|
args.batch,
|
|
args.decode_prompt_len,
|
|
args.decode_steps,
|
|
device,
|
|
)
|
|
print(f" {'Baseline (stacked)':<26} {per_step*1000:>10.2f}ms {tps:>14,.1f}")
|
|
|
|
for nl in n_loops_sweep:
|
|
reset_mem(device)
|
|
per_step, tps = bench_decode(
|
|
mythos,
|
|
cfg.vocab_size,
|
|
args.batch,
|
|
args.decode_prompt_len,
|
|
args.decode_steps,
|
|
device,
|
|
n_loops=nl,
|
|
)
|
|
print(
|
|
f" {'OpenMythos (loops=' + str(nl) + ')':<26} "
|
|
f"{per_step*1000:>10.2f}ms {tps:>14,.1f}"
|
|
)
|
|
|
|
# ---- Depth scaling ----
|
|
print_header(
|
|
"OpenMythos depth scaling (fixed seq={}, batch={})".format(
|
|
seq_lens[0], args.batch
|
|
)
|
|
)
|
|
print(f" {'n_loops':>8} {'sec':>10} {'tok/s':>12} {'Δ vs loops=1':>14}")
|
|
base_secs = None
|
|
for nl in n_loops_sweep:
|
|
reset_mem(device)
|
|
secs, tps = bench_prefill(
|
|
mythos, cfg.vocab_size, args.batch, seq_lens[0], device, n_loops=nl
|
|
)
|
|
if base_secs is None:
|
|
base_secs = secs
|
|
delta = "1.00x"
|
|
else:
|
|
delta = f"{secs / base_secs:.2f}x"
|
|
print(f" {nl:>8} {secs*1000:>9.2f}ms {tps:>12,.0f} {delta:>14}")
|
|
|
|
print("\nDone.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|