mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
549 lines
18 KiB
Python
549 lines
18 KiB
Python
"""
|
||
pytest suite for OpenMythos (second.py).
|
||
Tests every major class and feature: shapes, correctness invariants, and
|
||
architecture-specific properties (LTI stability, ACT halting, depth extrapolation,
|
||
KV cache consistency, GQA vs MLA swap).
|
||
"""
|
||
|
||
import torch
|
||
import pytest
|
||
from open_mythos.main import (
|
||
ACTHalting,
|
||
Expert,
|
||
GQAttention,
|
||
LTIInjection,
|
||
LoRAAdapter,
|
||
MLAttention,
|
||
MoEFFN,
|
||
MythosConfig,
|
||
OpenMythos,
|
||
RecurrentBlock,
|
||
RMSNorm,
|
||
TransformerBlock,
|
||
apply_rope,
|
||
loop_index_embedding,
|
||
precompute_rope_freqs,
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Shared small configs (kept tiny so tests run fast on CPU)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
B, T = 2, 8 # batch, sequence length
|
||
|
||
|
||
def gqa_cfg(**overrides) -> MythosConfig:
|
||
defaults = dict(
|
||
vocab_size=200,
|
||
dim=64,
|
||
n_heads=4,
|
||
n_kv_heads=2,
|
||
max_seq_len=32,
|
||
max_loop_iters=3,
|
||
prelude_layers=1,
|
||
coda_layers=1,
|
||
attn_type="gqa",
|
||
n_experts=4,
|
||
n_shared_experts=1,
|
||
n_experts_per_tok=2,
|
||
expert_dim=16,
|
||
act_threshold=0.99,
|
||
lora_rank=4,
|
||
# MLA fields must be valid even when not used
|
||
kv_lora_rank=16,
|
||
q_lora_rank=32,
|
||
qk_rope_head_dim=8,
|
||
qk_nope_head_dim=8,
|
||
v_head_dim=8,
|
||
)
|
||
defaults.update(overrides)
|
||
return MythosConfig(**defaults)
|
||
|
||
|
||
def mla_cfg(**overrides) -> MythosConfig:
|
||
return gqa_cfg(attn_type="mla", **overrides)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RMSNorm
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRMSNorm:
|
||
def test_output_shape(self):
|
||
norm = RMSNorm(64)
|
||
x = torch.randn(2, 8, 64)
|
||
assert norm(x).shape == x.shape
|
||
|
||
def test_unit_rms(self):
|
||
# after norm the RMS of each vector should be ≈ 1 when weight=1
|
||
norm = RMSNorm(64)
|
||
torch.nn.init.ones_(norm.weight)
|
||
x = torch.randn(4, 64)
|
||
out = norm(x)
|
||
rms = out.pow(2).mean(-1).sqrt()
|
||
assert torch.allclose(rms, torch.ones_like(rms), atol=1e-4)
|
||
|
||
def test_learnable_weight(self):
|
||
norm = RMSNorm(8)
|
||
assert norm.weight.requires_grad
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RoPE utilities
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRoPE:
|
||
def test_precompute_shape(self):
|
||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||
assert freqs.shape == (32, 8) # (max_len, dim//2)
|
||
assert freqs.is_complex()
|
||
|
||
def test_apply_rope_shape(self):
|
||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||
x = torch.randn(B, T, 4, 16)
|
||
out = apply_rope(x, freqs)
|
||
assert out.shape == x.shape
|
||
|
||
def test_apply_rope_preserves_norm(self):
|
||
# rotation is an isometry — norms must be unchanged
|
||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||
x = torch.randn(B, T, 4, 16)
|
||
out = apply_rope(x, freqs)
|
||
assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5)
|
||
|
||
def test_different_positions_differ(self):
|
||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||
x = torch.ones(1, 2, 1, 16)
|
||
out = apply_rope(x, freqs)
|
||
# position 0 and position 1 should produce different rotations
|
||
assert not torch.allclose(out[0, 0], out[0, 1])
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GQAttention
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGQAttention:
|
||
def setup_method(self):
|
||
self.cfg = gqa_cfg()
|
||
self.freqs = precompute_rope_freqs(
|
||
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||
)
|
||
self.attn = GQAttention(self.cfg)
|
||
|
||
def test_output_shape(self):
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
out = self.attn(x, self.freqs)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
def test_kv_cache_accumulates(self):
|
||
cache = {}
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||
assert "layer0" in cache
|
||
k_len = cache["layer0"]["k"].shape[1]
|
||
# second call adds T more tokens
|
||
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||
assert cache["layer0"]["k"].shape[1] == k_len + T
|
||
|
||
def test_with_causal_mask(self):
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
mask = torch.full((1, 1, T, T), float("-inf"))
|
||
mask = torch.triu(mask, diagonal=1)
|
||
out = self.attn(x, self.freqs, mask=mask)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MLAttention
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMLAttention:
|
||
def setup_method(self):
|
||
self.cfg = mla_cfg()
|
||
self.freqs = precompute_rope_freqs(
|
||
self.cfg.qk_rope_head_dim, self.cfg.max_seq_len
|
||
)
|
||
self.attn = MLAttention(self.cfg)
|
||
|
||
def test_output_shape(self):
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
out = self.attn(x, self.freqs)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
def test_cache_stores_compressed_kv(self):
|
||
cache = {}
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||
assert "c_kv" in cache["mla0"]
|
||
assert "k_rope" in cache["mla0"]
|
||
# c_kv should have kv_lora_rank as last dim, not full K/V
|
||
assert cache["mla0"]["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||
|
||
def test_cache_accumulates_across_steps(self):
|
||
cache = {}
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||
first_len = cache["mla0"]["c_kv"].shape[1]
|
||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||
assert cache["mla0"]["c_kv"].shape[1] == first_len + T
|
||
|
||
def test_with_causal_mask(self):
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1)
|
||
out = self.attn(x, self.freqs, mask=mask)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Expert (dense SwiGLU FFN)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestExpert:
|
||
def test_output_shape(self):
|
||
expert = Expert(dim=64, expert_dim=32)
|
||
x = torch.randn(B, T, 64)
|
||
assert expert(x).shape == (B, T, 64)
|
||
|
||
def test_flat_input(self):
|
||
expert = Expert(dim=32, expert_dim=16)
|
||
x = torch.randn(5, 32)
|
||
assert expert(x).shape == (5, 32)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MoEFFN
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMoEFFN:
|
||
def setup_method(self):
|
||
self.cfg = gqa_cfg()
|
||
self.moe = MoEFFN(self.cfg)
|
||
|
||
def test_output_shape(self):
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
assert self.moe(x).shape == (B, T, self.cfg.dim)
|
||
|
||
def test_router_bias_not_grad(self):
|
||
# router_bias is a buffer, not a parameter
|
||
param_names = {n for n, _ in self.moe.named_parameters()}
|
||
assert "router_bias" not in param_names
|
||
|
||
def test_shared_experts_always_fire(self):
|
||
# Zero out all routed experts; output should still be nonzero from shared
|
||
for exp in self.moe.routed_experts:
|
||
for p in exp.parameters():
|
||
p.data.zero_()
|
||
x = torch.randn(B, T, self.cfg.dim)
|
||
out = self.moe(x)
|
||
assert out.abs().sum() > 0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# loop_index_embedding
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLoopIndexEmbedding:
|
||
def test_output_shape(self):
|
||
h = torch.randn(B, T, 64)
|
||
out = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||
assert out.shape == h.shape
|
||
|
||
def test_different_iterations_differ(self):
|
||
h = torch.zeros(1, 1, 64)
|
||
out0 = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||
out1 = loop_index_embedding(h, loop_t=1, loop_dim=8)
|
||
assert not torch.allclose(out0, out1)
|
||
|
||
def test_only_first_dims_modified(self):
|
||
h = torch.zeros(1, 1, 64)
|
||
loop_dim = 8
|
||
out = loop_index_embedding(h, loop_t=3, loop_dim=loop_dim)
|
||
# channels beyond loop_dim should be unchanged (still 0)
|
||
assert torch.all(out[..., loop_dim:] == 0)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LoRAAdapter
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLoRAAdapter:
|
||
def setup_method(self):
|
||
self.lora = LoRAAdapter(dim=64, rank=8, max_loops=10)
|
||
|
||
def test_output_shape(self):
|
||
x = torch.randn(B, T, 64)
|
||
out = self.lora(x, loop_t=0)
|
||
assert out.shape == (B, T, 64)
|
||
|
||
def test_different_loops_differ(self):
|
||
x = torch.randn(B, T, 64)
|
||
out0 = self.lora(x, loop_t=0)
|
||
out1 = self.lora(x, loop_t=1)
|
||
assert not torch.allclose(out0, out1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TransformerBlock
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTransformerBlock:
|
||
def test_gqa_output_shape(self):
|
||
cfg = gqa_cfg()
|
||
block = TransformerBlock(cfg, use_moe=False)
|
||
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||
x = torch.randn(B, T, cfg.dim)
|
||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||
|
||
def test_mla_output_shape(self):
|
||
cfg = mla_cfg()
|
||
block = TransformerBlock(cfg, use_moe=False)
|
||
freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len)
|
||
x = torch.randn(B, T, cfg.dim)
|
||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||
|
||
def test_moe_block_output_shape(self):
|
||
cfg = gqa_cfg()
|
||
block = TransformerBlock(cfg, use_moe=True)
|
||
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||
x = torch.randn(B, T, cfg.dim)
|
||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||
|
||
def test_attn_type_selection(self):
|
||
assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention)
|
||
assert isinstance(TransformerBlock(mla_cfg()).attn, MLAttention)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LTIInjection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLTIInjection:
|
||
def setup_method(self):
|
||
self.inj = LTIInjection(dim=64)
|
||
|
||
def test_output_shape(self):
|
||
h = torch.randn(B, T, 64)
|
||
e = torch.randn(B, T, 64)
|
||
t = torch.randn(B, T, 64)
|
||
assert self.inj(h, e, t).shape == (B, T, 64)
|
||
|
||
def test_spectral_radius_lt_1(self):
|
||
A = self.inj.get_A()
|
||
assert A.max().item() < 1.0
|
||
|
||
def test_spectral_radius_gt_0(self):
|
||
A = self.inj.get_A()
|
||
assert A.min().item() > 0.0
|
||
|
||
def test_spectral_radius_stable_after_large_grad_step(self):
|
||
# Simulate an aggressive gradient update and verify stability holds
|
||
opt = torch.optim.SGD(self.inj.parameters(), lr=1e3)
|
||
h = torch.randn(B, T, 64)
|
||
e = torch.randn(B, T, 64)
|
||
t = torch.randn(B, T, 64)
|
||
loss = self.inj(h, e, t).sum()
|
||
loss.backward()
|
||
opt.step()
|
||
A = self.inj.get_A()
|
||
assert A.max().item() < 1.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ACTHalting
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestACTHalting:
|
||
def setup_method(self):
|
||
self.act = ACTHalting(dim=64)
|
||
|
||
def test_output_shape(self):
|
||
h = torch.randn(B, T, 64)
|
||
p = self.act(h)
|
||
assert p.shape == (B, T)
|
||
|
||
def test_values_in_01(self):
|
||
h = torch.randn(B, T, 64)
|
||
p = self.act(h)
|
||
assert p.min().item() >= 0.0
|
||
assert p.max().item() <= 1.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RecurrentBlock
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRecurrentBlock:
|
||
def setup_method(self):
|
||
self.cfg = gqa_cfg()
|
||
self.block = RecurrentBlock(self.cfg)
|
||
self.freqs = precompute_rope_freqs(
|
||
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||
)
|
||
|
||
def test_output_shape(self):
|
||
h = torch.randn(B, T, self.cfg.dim)
|
||
e = torch.randn(B, T, self.cfg.dim)
|
||
out = self.block(h, e, self.freqs)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
def test_more_loops_changes_output(self):
|
||
h = torch.randn(B, T, self.cfg.dim)
|
||
e = torch.randn(B, T, self.cfg.dim)
|
||
out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1)
|
||
out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3)
|
||
assert not torch.allclose(out1, out3)
|
||
|
||
def test_single_loop_runs(self):
|
||
h = torch.randn(B, T, self.cfg.dim)
|
||
e = torch.randn(B, T, self.cfg.dim)
|
||
out = self.block(h, e, self.freqs, n_loops=1)
|
||
assert out.shape == (B, T, self.cfg.dim)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# OpenMythos — GQA mode
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestOpenMythosGQA:
|
||
def setup_method(self):
|
||
self.cfg = gqa_cfg()
|
||
self.model = OpenMythos(self.cfg)
|
||
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||
|
||
def test_forward_shape(self):
|
||
logits = self.model(self.ids)
|
||
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||
|
||
def test_forward_no_nan(self):
|
||
logits = self.model(self.ids)
|
||
assert not torch.isnan(logits).any()
|
||
|
||
def test_generate_shape(self):
|
||
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||
assert out.shape == (B, T + 4)
|
||
|
||
def test_weight_tying(self):
|
||
assert self.model.head.weight is self.model.embed.weight
|
||
|
||
def test_lti_spectral_radius(self):
|
||
A = self.model.recurrent.injection.get_A()
|
||
assert A.max().item() < 1.0
|
||
|
||
def test_depth_extrapolation_changes_output(self):
|
||
# More loops at inference should produce different (ideally better) output
|
||
logits_shallow = self.model(self.ids, n_loops=1)
|
||
logits_deep = self.model(self.ids, n_loops=3)
|
||
assert not torch.allclose(logits_shallow, logits_deep)
|
||
|
||
def test_kv_cache_generate_matches_no_cache(self):
|
||
# Single-token generation with and without cache should agree
|
||
torch.manual_seed(0)
|
||
prompt = torch.randint(0, self.cfg.vocab_size, (1, T))
|
||
with torch.no_grad():
|
||
logits_no_cache = self.model(prompt, n_loops=2)[:, -1, :]
|
||
cache = {}
|
||
logits_cached = self.model(prompt, n_loops=2, kv_cache=cache)[:, -1, :]
|
||
assert torch.allclose(logits_no_cache, logits_cached, atol=1e-4)
|
||
|
||
def test_single_token_forward(self):
|
||
# Mask is None when T=1; should not crash
|
||
single = torch.randint(0, self.cfg.vocab_size, (B, 1))
|
||
logits = self.model(single)
|
||
assert logits.shape == (B, 1, self.cfg.vocab_size)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# OpenMythos — MLA mode
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestOpenMythosMLА:
|
||
def setup_method(self):
|
||
self.cfg = mla_cfg()
|
||
self.model = OpenMythos(self.cfg)
|
||
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||
|
||
def test_forward_shape(self):
|
||
logits = self.model(self.ids)
|
||
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||
|
||
def test_forward_no_nan(self):
|
||
assert not torch.isnan(self.model(self.ids)).any()
|
||
|
||
def test_generate_shape(self):
|
||
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||
assert out.shape == (B, T + 4)
|
||
|
||
def test_lti_spectral_radius(self):
|
||
A = self.model.recurrent.injection.get_A()
|
||
assert A.max().item() < 1.0
|
||
|
||
def test_mla_cache_is_compressed(self):
|
||
# MLA cache should store c_kv (lora_rank), not full K/V (n_heads * head_dim)
|
||
cache = {}
|
||
with torch.no_grad():
|
||
self.model(self.ids, kv_cache=cache)
|
||
# find any MLA cache entry and check dimensions
|
||
mla_entries = {k: v for k, v in cache.items() if "c_kv" in v}
|
||
assert len(mla_entries) > 0
|
||
for entry in mla_entries.values():
|
||
assert entry["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GQA vs MLA: same config, different attn_type
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestAttnTypeSwap:
|
||
def test_gqa_and_mla_produce_different_outputs(self):
|
||
cfg_gqa = gqa_cfg()
|
||
cfg_mla = mla_cfg()
|
||
ids = torch.randint(0, cfg_gqa.vocab_size, (B, T))
|
||
logits_gqa = OpenMythos(cfg_gqa)(ids)
|
||
logits_mla = OpenMythos(cfg_mla)(ids)
|
||
# different architectures, different params → outputs must differ
|
||
assert not torch.allclose(logits_gqa, logits_mla)
|
||
|
||
def test_both_modes_produce_valid_shapes(self):
|
||
ids = torch.randint(0, 200, (B, T))
|
||
for attn_type in ("gqa", "mla"):
|
||
cfg = gqa_cfg(attn_type=attn_type)
|
||
logits = OpenMythos(cfg)(ids)
|
||
assert logits.shape == (B, T, cfg.vocab_size)
|
||
|
||
def test_mla_fewer_kv_cache_bytes(self):
|
||
# MLA cache should be smaller than GQA cache for the same sequence
|
||
ids = torch.randint(0, 200, (1, T))
|
||
cache_gqa, cache_mla = {}, {}
|
||
with torch.no_grad():
|
||
OpenMythos(gqa_cfg())(ids, kv_cache=cache_gqa)
|
||
OpenMythos(mla_cfg())(ids, kv_cache=cache_mla)
|
||
|
||
def cache_bytes(cache):
|
||
return sum(
|
||
t.numel() * t.element_size()
|
||
for entry in cache.values()
|
||
for t in entry.values()
|
||
)
|
||
|
||
assert cache_bytes(cache_mla) < cache_bytes(cache_gqa)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "--verbose"])
|