mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
tests
This commit is contained in:
parent
c258cdc8da
commit
0699c00c94
@ -45,4 +45,6 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8)
|
|||||||
print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
|
print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
|
||||||
|
|
||||||
A = model.recurrent.injection.get_A()
|
A = model.recurrent.injection.get_A()
|
||||||
print(f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)")
|
print(
|
||||||
|
f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)"
|
||||||
|
)
|
||||||
|
|||||||
@ -1012,4 +1012,3 @@ class OpenMythos(nn.Module):
|
|||||||
next_tok = torch.multinomial(probs, num_samples=1)
|
next_tok = torch.multinomial(probs, num_samples=1)
|
||||||
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|||||||
548
test_main.py
Normal file
548
test_main.py
Normal file
@ -0,0 +1,548 @@
|
|||||||
|
"""
|
||||||
|
pytest suite for OpenMythos (second.py).
|
||||||
|
Tests every major class and feature: shapes, correctness invariants, and
|
||||||
|
architecture-specific properties (LTI stability, ACT halting, depth extrapolation,
|
||||||
|
KV cache consistency, GQA vs MLA swap).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from open_mythos.main import (
|
||||||
|
ACTHalting,
|
||||||
|
Expert,
|
||||||
|
GQAttention,
|
||||||
|
LTIInjection,
|
||||||
|
LoRAAdapter,
|
||||||
|
MLAttention,
|
||||||
|
MoEFFN,
|
||||||
|
MythosConfig,
|
||||||
|
OpenMythos,
|
||||||
|
RecurrentBlock,
|
||||||
|
RMSNorm,
|
||||||
|
TransformerBlock,
|
||||||
|
apply_rope,
|
||||||
|
loop_index_embedding,
|
||||||
|
precompute_rope_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared small configs (kept tiny so tests run fast on CPU)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
B, T = 2, 8 # batch, sequence length
|
||||||
|
|
||||||
|
|
||||||
|
def gqa_cfg(**overrides) -> MythosConfig:
|
||||||
|
defaults = dict(
|
||||||
|
vocab_size=200,
|
||||||
|
dim=64,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
max_seq_len=32,
|
||||||
|
max_loop_iters=3,
|
||||||
|
prelude_layers=1,
|
||||||
|
coda_layers=1,
|
||||||
|
attn_type="gqa",
|
||||||
|
n_experts=4,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_experts_per_tok=2,
|
||||||
|
expert_dim=16,
|
||||||
|
act_threshold=0.99,
|
||||||
|
lora_rank=4,
|
||||||
|
# MLA fields must be valid even when not used
|
||||||
|
kv_lora_rank=16,
|
||||||
|
q_lora_rank=32,
|
||||||
|
qk_rope_head_dim=8,
|
||||||
|
qk_nope_head_dim=8,
|
||||||
|
v_head_dim=8,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return MythosConfig(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def mla_cfg(**overrides) -> MythosConfig:
|
||||||
|
return gqa_cfg(attn_type="mla", **overrides)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RMSNorm
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRMSNorm:
|
||||||
|
def test_output_shape(self):
|
||||||
|
norm = RMSNorm(64)
|
||||||
|
x = torch.randn(2, 8, 64)
|
||||||
|
assert norm(x).shape == x.shape
|
||||||
|
|
||||||
|
def test_unit_rms(self):
|
||||||
|
# after norm the RMS of each vector should be ≈ 1 when weight=1
|
||||||
|
norm = RMSNorm(64)
|
||||||
|
torch.nn.init.ones_(norm.weight)
|
||||||
|
x = torch.randn(4, 64)
|
||||||
|
out = norm(x)
|
||||||
|
rms = out.pow(2).mean(-1).sqrt()
|
||||||
|
assert torch.allclose(rms, torch.ones_like(rms), atol=1e-4)
|
||||||
|
|
||||||
|
def test_learnable_weight(self):
|
||||||
|
norm = RMSNorm(8)
|
||||||
|
assert norm.weight.requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RoPE utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoPE:
|
||||||
|
def test_precompute_shape(self):
|
||||||
|
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||||
|
assert freqs.shape == (32, 8) # (max_len, dim//2)
|
||||||
|
assert freqs.is_complex()
|
||||||
|
|
||||||
|
def test_apply_rope_shape(self):
|
||||||
|
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||||
|
x = torch.randn(B, T, 4, 16)
|
||||||
|
out = apply_rope(x, freqs)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
def test_apply_rope_preserves_norm(self):
|
||||||
|
# rotation is an isometry — norms must be unchanged
|
||||||
|
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||||
|
x = torch.randn(B, T, 4, 16)
|
||||||
|
out = apply_rope(x, freqs)
|
||||||
|
assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5)
|
||||||
|
|
||||||
|
def test_different_positions_differ(self):
|
||||||
|
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||||
|
x = torch.ones(1, 2, 1, 16)
|
||||||
|
out = apply_rope(x, freqs)
|
||||||
|
# position 0 and position 1 should produce different rotations
|
||||||
|
assert not torch.allclose(out[0, 0], out[0, 1])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GQAttention
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGQAttention:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = gqa_cfg()
|
||||||
|
self.freqs = precompute_rope_freqs(
|
||||||
|
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||||||
|
)
|
||||||
|
self.attn = GQAttention(self.cfg)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out = self.attn(x, self.freqs)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
def test_kv_cache_accumulates(self):
|
||||||
|
cache = {}
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||||||
|
assert "layer0" in cache
|
||||||
|
k_len = cache["layer0"]["k"].shape[1]
|
||||||
|
# second call adds T more tokens
|
||||||
|
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||||||
|
assert cache["layer0"]["k"].shape[1] == k_len + T
|
||||||
|
|
||||||
|
def test_with_causal_mask(self):
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
mask = torch.full((1, 1, T, T), float("-inf"))
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
out = self.attn(x, self.freqs, mask=mask)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MLAttention
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMLAttention:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = mla_cfg()
|
||||||
|
self.freqs = precompute_rope_freqs(
|
||||||
|
self.cfg.qk_rope_head_dim, self.cfg.max_seq_len
|
||||||
|
)
|
||||||
|
self.attn = MLAttention(self.cfg)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out = self.attn(x, self.freqs)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
def test_cache_stores_compressed_kv(self):
|
||||||
|
cache = {}
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||||
|
assert "c_kv" in cache["mla0"]
|
||||||
|
assert "k_rope" in cache["mla0"]
|
||||||
|
# c_kv should have kv_lora_rank as last dim, not full K/V
|
||||||
|
assert cache["mla0"]["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||||||
|
|
||||||
|
def test_cache_accumulates_across_steps(self):
|
||||||
|
cache = {}
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||||
|
first_len = cache["mla0"]["c_kv"].shape[1]
|
||||||
|
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||||
|
assert cache["mla0"]["c_kv"].shape[1] == first_len + T
|
||||||
|
|
||||||
|
def test_with_causal_mask(self):
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1)
|
||||||
|
out = self.attn(x, self.freqs, mask=mask)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Expert (dense SwiGLU FFN)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpert:
|
||||||
|
def test_output_shape(self):
|
||||||
|
expert = Expert(dim=64, expert_dim=32)
|
||||||
|
x = torch.randn(B, T, 64)
|
||||||
|
assert expert(x).shape == (B, T, 64)
|
||||||
|
|
||||||
|
def test_flat_input(self):
|
||||||
|
expert = Expert(dim=32, expert_dim=16)
|
||||||
|
x = torch.randn(5, 32)
|
||||||
|
assert expert(x).shape == (5, 32)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MoEFFN
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoEFFN:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = gqa_cfg()
|
||||||
|
self.moe = MoEFFN(self.cfg)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
assert self.moe(x).shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
def test_router_bias_not_grad(self):
|
||||||
|
# router_bias is a buffer, not a parameter
|
||||||
|
param_names = {n for n, _ in self.moe.named_parameters()}
|
||||||
|
assert "router_bias" not in param_names
|
||||||
|
|
||||||
|
def test_shared_experts_always_fire(self):
|
||||||
|
# Zero out all routed experts; output should still be nonzero from shared
|
||||||
|
for exp in self.moe.routed_experts:
|
||||||
|
for p in exp.parameters():
|
||||||
|
p.data.zero_()
|
||||||
|
x = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out = self.moe(x)
|
||||||
|
assert out.abs().sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# loop_index_embedding
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoopIndexEmbedding:
|
||||||
|
def test_output_shape(self):
|
||||||
|
h = torch.randn(B, T, 64)
|
||||||
|
out = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||||||
|
assert out.shape == h.shape
|
||||||
|
|
||||||
|
def test_different_iterations_differ(self):
|
||||||
|
h = torch.zeros(1, 1, 64)
|
||||||
|
out0 = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||||||
|
out1 = loop_index_embedding(h, loop_t=1, loop_dim=8)
|
||||||
|
assert not torch.allclose(out0, out1)
|
||||||
|
|
||||||
|
def test_only_first_dims_modified(self):
|
||||||
|
h = torch.zeros(1, 1, 64)
|
||||||
|
loop_dim = 8
|
||||||
|
out = loop_index_embedding(h, loop_t=3, loop_dim=loop_dim)
|
||||||
|
# channels beyond loop_dim should be unchanged (still 0)
|
||||||
|
assert torch.all(out[..., loop_dim:] == 0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoRAAdapter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoRAAdapter:
|
||||||
|
def setup_method(self):
|
||||||
|
self.lora = LoRAAdapter(dim=64, rank=8, max_loops=10)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
x = torch.randn(B, T, 64)
|
||||||
|
out = self.lora(x, loop_t=0)
|
||||||
|
assert out.shape == (B, T, 64)
|
||||||
|
|
||||||
|
def test_different_loops_differ(self):
|
||||||
|
x = torch.randn(B, T, 64)
|
||||||
|
out0 = self.lora(x, loop_t=0)
|
||||||
|
out1 = self.lora(x, loop_t=1)
|
||||||
|
assert not torch.allclose(out0, out1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TransformerBlock
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformerBlock:
|
||||||
|
def test_gqa_output_shape(self):
|
||||||
|
cfg = gqa_cfg()
|
||||||
|
block = TransformerBlock(cfg, use_moe=False)
|
||||||
|
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||||||
|
x = torch.randn(B, T, cfg.dim)
|
||||||
|
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||||
|
|
||||||
|
def test_mla_output_shape(self):
|
||||||
|
cfg = mla_cfg()
|
||||||
|
block = TransformerBlock(cfg, use_moe=False)
|
||||||
|
freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len)
|
||||||
|
x = torch.randn(B, T, cfg.dim)
|
||||||
|
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||||
|
|
||||||
|
def test_moe_block_output_shape(self):
|
||||||
|
cfg = gqa_cfg()
|
||||||
|
block = TransformerBlock(cfg, use_moe=True)
|
||||||
|
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||||||
|
x = torch.randn(B, T, cfg.dim)
|
||||||
|
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||||
|
|
||||||
|
def test_attn_type_selection(self):
|
||||||
|
assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention)
|
||||||
|
assert isinstance(TransformerBlock(mla_cfg()).attn, MLAttention)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LTIInjection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLTIInjection:
|
||||||
|
def setup_method(self):
|
||||||
|
self.inj = LTIInjection(dim=64)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
h = torch.randn(B, T, 64)
|
||||||
|
e = torch.randn(B, T, 64)
|
||||||
|
t = torch.randn(B, T, 64)
|
||||||
|
assert self.inj(h, e, t).shape == (B, T, 64)
|
||||||
|
|
||||||
|
def test_spectral_radius_lt_1(self):
|
||||||
|
A = self.inj.get_A()
|
||||||
|
assert A.max().item() < 1.0
|
||||||
|
|
||||||
|
def test_spectral_radius_gt_0(self):
|
||||||
|
A = self.inj.get_A()
|
||||||
|
assert A.min().item() > 0.0
|
||||||
|
|
||||||
|
def test_spectral_radius_stable_after_large_grad_step(self):
|
||||||
|
# Simulate an aggressive gradient update and verify stability holds
|
||||||
|
opt = torch.optim.SGD(self.inj.parameters(), lr=1e3)
|
||||||
|
h = torch.randn(B, T, 64)
|
||||||
|
e = torch.randn(B, T, 64)
|
||||||
|
t = torch.randn(B, T, 64)
|
||||||
|
loss = self.inj(h, e, t).sum()
|
||||||
|
loss.backward()
|
||||||
|
opt.step()
|
||||||
|
A = self.inj.get_A()
|
||||||
|
assert A.max().item() < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ACTHalting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestACTHalting:
|
||||||
|
def setup_method(self):
|
||||||
|
self.act = ACTHalting(dim=64)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
h = torch.randn(B, T, 64)
|
||||||
|
p = self.act(h)
|
||||||
|
assert p.shape == (B, T)
|
||||||
|
|
||||||
|
def test_values_in_01(self):
|
||||||
|
h = torch.randn(B, T, 64)
|
||||||
|
p = self.act(h)
|
||||||
|
assert p.min().item() >= 0.0
|
||||||
|
assert p.max().item() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RecurrentBlock
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecurrentBlock:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = gqa_cfg()
|
||||||
|
self.block = RecurrentBlock(self.cfg)
|
||||||
|
self.freqs = precompute_rope_freqs(
|
||||||
|
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
h = torch.randn(B, T, self.cfg.dim)
|
||||||
|
e = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out = self.block(h, e, self.freqs)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
def test_more_loops_changes_output(self):
|
||||||
|
h = torch.randn(B, T, self.cfg.dim)
|
||||||
|
e = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1)
|
||||||
|
out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3)
|
||||||
|
assert not torch.allclose(out1, out3)
|
||||||
|
|
||||||
|
def test_single_loop_runs(self):
|
||||||
|
h = torch.randn(B, T, self.cfg.dim)
|
||||||
|
e = torch.randn(B, T, self.cfg.dim)
|
||||||
|
out = self.block(h, e, self.freqs, n_loops=1)
|
||||||
|
assert out.shape == (B, T, self.cfg.dim)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenMythos — GQA mode
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenMythosGQA:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = gqa_cfg()
|
||||||
|
self.model = OpenMythos(self.cfg)
|
||||||
|
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||||||
|
|
||||||
|
def test_forward_shape(self):
|
||||||
|
logits = self.model(self.ids)
|
||||||
|
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||||||
|
|
||||||
|
def test_forward_no_nan(self):
|
||||||
|
logits = self.model(self.ids)
|
||||||
|
assert not torch.isnan(logits).any()
|
||||||
|
|
||||||
|
def test_generate_shape(self):
|
||||||
|
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||||||
|
assert out.shape == (B, T + 4)
|
||||||
|
|
||||||
|
def test_weight_tying(self):
|
||||||
|
assert self.model.head.weight is self.model.embed.weight
|
||||||
|
|
||||||
|
def test_lti_spectral_radius(self):
|
||||||
|
A = self.model.recurrent.injection.get_A()
|
||||||
|
assert A.max().item() < 1.0
|
||||||
|
|
||||||
|
def test_depth_extrapolation_changes_output(self):
|
||||||
|
# More loops at inference should produce different (ideally better) output
|
||||||
|
logits_shallow = self.model(self.ids, n_loops=1)
|
||||||
|
logits_deep = self.model(self.ids, n_loops=3)
|
||||||
|
assert not torch.allclose(logits_shallow, logits_deep)
|
||||||
|
|
||||||
|
def test_kv_cache_generate_matches_no_cache(self):
|
||||||
|
# Single-token generation with and without cache should agree
|
||||||
|
torch.manual_seed(0)
|
||||||
|
prompt = torch.randint(0, self.cfg.vocab_size, (1, T))
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_no_cache = self.model(prompt, n_loops=2)[:, -1, :]
|
||||||
|
cache = {}
|
||||||
|
logits_cached = self.model(prompt, n_loops=2, kv_cache=cache)[:, -1, :]
|
||||||
|
assert torch.allclose(logits_no_cache, logits_cached, atol=1e-4)
|
||||||
|
|
||||||
|
def test_single_token_forward(self):
|
||||||
|
# Mask is None when T=1; should not crash
|
||||||
|
single = torch.randint(0, self.cfg.vocab_size, (B, 1))
|
||||||
|
logits = self.model(single)
|
||||||
|
assert logits.shape == (B, 1, self.cfg.vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenMythos — MLA mode
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenMythosMLА:
|
||||||
|
def setup_method(self):
|
||||||
|
self.cfg = mla_cfg()
|
||||||
|
self.model = OpenMythos(self.cfg)
|
||||||
|
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||||||
|
|
||||||
|
def test_forward_shape(self):
|
||||||
|
logits = self.model(self.ids)
|
||||||
|
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||||||
|
|
||||||
|
def test_forward_no_nan(self):
|
||||||
|
assert not torch.isnan(self.model(self.ids)).any()
|
||||||
|
|
||||||
|
def test_generate_shape(self):
|
||||||
|
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||||||
|
assert out.shape == (B, T + 4)
|
||||||
|
|
||||||
|
def test_lti_spectral_radius(self):
|
||||||
|
A = self.model.recurrent.injection.get_A()
|
||||||
|
assert A.max().item() < 1.0
|
||||||
|
|
||||||
|
def test_mla_cache_is_compressed(self):
|
||||||
|
# MLA cache should store c_kv (lora_rank), not full K/V (n_heads * head_dim)
|
||||||
|
cache = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model(self.ids, kv_cache=cache)
|
||||||
|
# find any MLA cache entry and check dimensions
|
||||||
|
mla_entries = {k: v for k, v in cache.items() if "c_kv" in v}
|
||||||
|
assert len(mla_entries) > 0
|
||||||
|
for entry in mla_entries.values():
|
||||||
|
assert entry["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GQA vs MLA: same config, different attn_type
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAttnTypeSwap:
|
||||||
|
def test_gqa_and_mla_produce_different_outputs(self):
|
||||||
|
cfg_gqa = gqa_cfg()
|
||||||
|
cfg_mla = mla_cfg()
|
||||||
|
ids = torch.randint(0, cfg_gqa.vocab_size, (B, T))
|
||||||
|
logits_gqa = OpenMythos(cfg_gqa)(ids)
|
||||||
|
logits_mla = OpenMythos(cfg_mla)(ids)
|
||||||
|
# different architectures, different params → outputs must differ
|
||||||
|
assert not torch.allclose(logits_gqa, logits_mla)
|
||||||
|
|
||||||
|
def test_both_modes_produce_valid_shapes(self):
|
||||||
|
ids = torch.randint(0, 200, (B, T))
|
||||||
|
for attn_type in ("gqa", "mla"):
|
||||||
|
cfg = gqa_cfg(attn_type=attn_type)
|
||||||
|
logits = OpenMythos(cfg)(ids)
|
||||||
|
assert logits.shape == (B, T, cfg.vocab_size)
|
||||||
|
|
||||||
|
def test_mla_fewer_kv_cache_bytes(self):
|
||||||
|
# MLA cache should be smaller than GQA cache for the same sequence
|
||||||
|
ids = torch.randint(0, 200, (1, T))
|
||||||
|
cache_gqa, cache_mla = {}, {}
|
||||||
|
with torch.no_grad():
|
||||||
|
OpenMythos(gqa_cfg())(ids, kv_cache=cache_gqa)
|
||||||
|
OpenMythos(mla_cfg())(ids, kv_cache=cache_mla)
|
||||||
|
|
||||||
|
def cache_bytes(cache):
|
||||||
|
return sum(
|
||||||
|
t.numel() * t.element_size()
|
||||||
|
for entry in cache.values()
|
||||||
|
for t in entry.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cache_bytes(cache_mla) < cache_bytes(cache_gqa)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "--verbose"])
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user