diff --git a/example.py b/example.py index 0a2eb66..15e2c56 100644 --- a/example.py +++ b/example.py @@ -45,4 +45,6 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8) print(f"[{attn_type.upper()}] Generated shape: {out.shape}") A = model.recurrent.injection.get_A() -print(f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)") +print( + f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" +) diff --git a/open_mythos/main.py b/open_mythos/main.py index aa2f878..77555c1 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -1012,4 +1012,3 @@ class OpenMythos(nn.Module): next_tok = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_tok], dim=1) return input_ids - diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..6c9d4dc --- /dev/null +++ b/test_main.py @@ -0,0 +1,548 @@ +""" +pytest suite for OpenMythos (second.py). +Tests every major class and feature: shapes, correctness invariants, and +architecture-specific properties (LTI stability, ACT halting, depth extrapolation, +KV cache consistency, GQA vs MLA swap). +""" + +import torch +import pytest +from open_mythos.main import ( + ACTHalting, + Expert, + GQAttention, + LTIInjection, + LoRAAdapter, + MLAttention, + MoEFFN, + MythosConfig, + OpenMythos, + RecurrentBlock, + RMSNorm, + TransformerBlock, + apply_rope, + loop_index_embedding, + precompute_rope_freqs, +) + +# --------------------------------------------------------------------------- +# Shared small configs (kept tiny so tests run fast on CPU) +# --------------------------------------------------------------------------- + +B, T = 2, 8 # batch, sequence length + + +def gqa_cfg(**overrides) -> MythosConfig: + defaults = dict( + vocab_size=200, + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=32, + max_loop_iters=3, + prelude_layers=1, + coda_layers=1, + attn_type="gqa", + n_experts=4, + n_shared_experts=1, + n_experts_per_tok=2, + expert_dim=16, + act_threshold=0.99, + lora_rank=4, + # MLA fields must be valid even when not used + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=8, + ) + defaults.update(overrides) + return MythosConfig(**defaults) + + +def mla_cfg(**overrides) -> MythosConfig: + return gqa_cfg(attn_type="mla", **overrides) + + +# --------------------------------------------------------------------------- +# RMSNorm +# --------------------------------------------------------------------------- + + +class TestRMSNorm: + def test_output_shape(self): + norm = RMSNorm(64) + x = torch.randn(2, 8, 64) + assert norm(x).shape == x.shape + + def test_unit_rms(self): + # after norm the RMS of each vector should be ≈ 1 when weight=1 + norm = RMSNorm(64) + torch.nn.init.ones_(norm.weight) + x = torch.randn(4, 64) + out = norm(x) + rms = out.pow(2).mean(-1).sqrt() + assert torch.allclose(rms, torch.ones_like(rms), atol=1e-4) + + def test_learnable_weight(self): + norm = RMSNorm(8) + assert norm.weight.requires_grad + + +# --------------------------------------------------------------------------- +# RoPE utilities +# --------------------------------------------------------------------------- + + +class TestRoPE: + def test_precompute_shape(self): + freqs = precompute_rope_freqs(dim=16, max_len=32) + assert freqs.shape == (32, 8) # (max_len, dim//2) + assert freqs.is_complex() + + def test_apply_rope_shape(self): + freqs = precompute_rope_freqs(dim=16, max_len=32) + x = torch.randn(B, T, 4, 16) + out = apply_rope(x, freqs) + assert out.shape == x.shape + + def test_apply_rope_preserves_norm(self): + # rotation is an isometry — norms must be unchanged + freqs = precompute_rope_freqs(dim=16, max_len=32) + x = torch.randn(B, T, 4, 16) + out = apply_rope(x, freqs) + assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5) + + def test_different_positions_differ(self): + freqs = precompute_rope_freqs(dim=16, max_len=32) + x = torch.ones(1, 2, 1, 16) + out = apply_rope(x, freqs) + # position 0 and position 1 should produce different rotations + assert not torch.allclose(out[0, 0], out[0, 1]) + + +# --------------------------------------------------------------------------- +# GQAttention +# --------------------------------------------------------------------------- + + +class TestGQAttention: + def setup_method(self): + self.cfg = gqa_cfg() + self.freqs = precompute_rope_freqs( + self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len + ) + self.attn = GQAttention(self.cfg) + + def test_output_shape(self): + x = torch.randn(B, T, self.cfg.dim) + out = self.attn(x, self.freqs) + assert out.shape == (B, T, self.cfg.dim) + + def test_kv_cache_accumulates(self): + cache = {} + x = torch.randn(B, T, self.cfg.dim) + self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") + assert "layer0" in cache + k_len = cache["layer0"]["k"].shape[1] + # second call adds T more tokens + self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") + assert cache["layer0"]["k"].shape[1] == k_len + T + + def test_with_causal_mask(self): + x = torch.randn(B, T, self.cfg.dim) + mask = torch.full((1, 1, T, T), float("-inf")) + mask = torch.triu(mask, diagonal=1) + out = self.attn(x, self.freqs, mask=mask) + assert out.shape == (B, T, self.cfg.dim) + + +# --------------------------------------------------------------------------- +# MLAttention +# --------------------------------------------------------------------------- + + +class TestMLAttention: + def setup_method(self): + self.cfg = mla_cfg() + self.freqs = precompute_rope_freqs( + self.cfg.qk_rope_head_dim, self.cfg.max_seq_len + ) + self.attn = MLAttention(self.cfg) + + def test_output_shape(self): + x = torch.randn(B, T, self.cfg.dim) + out = self.attn(x, self.freqs) + assert out.shape == (B, T, self.cfg.dim) + + def test_cache_stores_compressed_kv(self): + cache = {} + x = torch.randn(B, T, self.cfg.dim) + self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + assert "c_kv" in cache["mla0"] + assert "k_rope" in cache["mla0"] + # c_kv should have kv_lora_rank as last dim, not full K/V + assert cache["mla0"]["c_kv"].shape[-1] == self.cfg.kv_lora_rank + + def test_cache_accumulates_across_steps(self): + cache = {} + x = torch.randn(B, T, self.cfg.dim) + self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + first_len = cache["mla0"]["c_kv"].shape[1] + self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") + assert cache["mla0"]["c_kv"].shape[1] == first_len + T + + def test_with_causal_mask(self): + x = torch.randn(B, T, self.cfg.dim) + mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1) + out = self.attn(x, self.freqs, mask=mask) + assert out.shape == (B, T, self.cfg.dim) + + +# --------------------------------------------------------------------------- +# Expert (dense SwiGLU FFN) +# --------------------------------------------------------------------------- + + +class TestExpert: + def test_output_shape(self): + expert = Expert(dim=64, expert_dim=32) + x = torch.randn(B, T, 64) + assert expert(x).shape == (B, T, 64) + + def test_flat_input(self): + expert = Expert(dim=32, expert_dim=16) + x = torch.randn(5, 32) + assert expert(x).shape == (5, 32) + + +# --------------------------------------------------------------------------- +# MoEFFN +# --------------------------------------------------------------------------- + + +class TestMoEFFN: + def setup_method(self): + self.cfg = gqa_cfg() + self.moe = MoEFFN(self.cfg) + + def test_output_shape(self): + x = torch.randn(B, T, self.cfg.dim) + assert self.moe(x).shape == (B, T, self.cfg.dim) + + def test_router_bias_not_grad(self): + # router_bias is a buffer, not a parameter + param_names = {n for n, _ in self.moe.named_parameters()} + assert "router_bias" not in param_names + + def test_shared_experts_always_fire(self): + # Zero out all routed experts; output should still be nonzero from shared + for exp in self.moe.routed_experts: + for p in exp.parameters(): + p.data.zero_() + x = torch.randn(B, T, self.cfg.dim) + out = self.moe(x) + assert out.abs().sum() > 0 + + +# --------------------------------------------------------------------------- +# loop_index_embedding +# --------------------------------------------------------------------------- + + +class TestLoopIndexEmbedding: + def test_output_shape(self): + h = torch.randn(B, T, 64) + out = loop_index_embedding(h, loop_t=0, loop_dim=8) + assert out.shape == h.shape + + def test_different_iterations_differ(self): + h = torch.zeros(1, 1, 64) + out0 = loop_index_embedding(h, loop_t=0, loop_dim=8) + out1 = loop_index_embedding(h, loop_t=1, loop_dim=8) + assert not torch.allclose(out0, out1) + + def test_only_first_dims_modified(self): + h = torch.zeros(1, 1, 64) + loop_dim = 8 + out = loop_index_embedding(h, loop_t=3, loop_dim=loop_dim) + # channels beyond loop_dim should be unchanged (still 0) + assert torch.all(out[..., loop_dim:] == 0) + + +# --------------------------------------------------------------------------- +# LoRAAdapter +# --------------------------------------------------------------------------- + + +class TestLoRAAdapter: + def setup_method(self): + self.lora = LoRAAdapter(dim=64, rank=8, max_loops=10) + + def test_output_shape(self): + x = torch.randn(B, T, 64) + out = self.lora(x, loop_t=0) + assert out.shape == (B, T, 64) + + def test_different_loops_differ(self): + x = torch.randn(B, T, 64) + out0 = self.lora(x, loop_t=0) + out1 = self.lora(x, loop_t=1) + assert not torch.allclose(out0, out1) + + +# --------------------------------------------------------------------------- +# TransformerBlock +# --------------------------------------------------------------------------- + + +class TestTransformerBlock: + def test_gqa_output_shape(self): + cfg = gqa_cfg() + block = TransformerBlock(cfg, use_moe=False) + freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + assert block(x, freqs).shape == (B, T, cfg.dim) + + def test_mla_output_shape(self): + cfg = mla_cfg() + block = TransformerBlock(cfg, use_moe=False) + freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + assert block(x, freqs).shape == (B, T, cfg.dim) + + def test_moe_block_output_shape(self): + cfg = gqa_cfg() + block = TransformerBlock(cfg, use_moe=True) + freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) + x = torch.randn(B, T, cfg.dim) + assert block(x, freqs).shape == (B, T, cfg.dim) + + def test_attn_type_selection(self): + assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention) + assert isinstance(TransformerBlock(mla_cfg()).attn, MLAttention) + + +# --------------------------------------------------------------------------- +# LTIInjection +# --------------------------------------------------------------------------- + + +class TestLTIInjection: + def setup_method(self): + self.inj = LTIInjection(dim=64) + + def test_output_shape(self): + h = torch.randn(B, T, 64) + e = torch.randn(B, T, 64) + t = torch.randn(B, T, 64) + assert self.inj(h, e, t).shape == (B, T, 64) + + def test_spectral_radius_lt_1(self): + A = self.inj.get_A() + assert A.max().item() < 1.0 + + def test_spectral_radius_gt_0(self): + A = self.inj.get_A() + assert A.min().item() > 0.0 + + def test_spectral_radius_stable_after_large_grad_step(self): + # Simulate an aggressive gradient update and verify stability holds + opt = torch.optim.SGD(self.inj.parameters(), lr=1e3) + h = torch.randn(B, T, 64) + e = torch.randn(B, T, 64) + t = torch.randn(B, T, 64) + loss = self.inj(h, e, t).sum() + loss.backward() + opt.step() + A = self.inj.get_A() + assert A.max().item() < 1.0 + + +# --------------------------------------------------------------------------- +# ACTHalting +# --------------------------------------------------------------------------- + + +class TestACTHalting: + def setup_method(self): + self.act = ACTHalting(dim=64) + + def test_output_shape(self): + h = torch.randn(B, T, 64) + p = self.act(h) + assert p.shape == (B, T) + + def test_values_in_01(self): + h = torch.randn(B, T, 64) + p = self.act(h) + assert p.min().item() >= 0.0 + assert p.max().item() <= 1.0 + + +# --------------------------------------------------------------------------- +# RecurrentBlock +# --------------------------------------------------------------------------- + + +class TestRecurrentBlock: + def setup_method(self): + self.cfg = gqa_cfg() + self.block = RecurrentBlock(self.cfg) + self.freqs = precompute_rope_freqs( + self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len + ) + + def test_output_shape(self): + h = torch.randn(B, T, self.cfg.dim) + e = torch.randn(B, T, self.cfg.dim) + out = self.block(h, e, self.freqs) + assert out.shape == (B, T, self.cfg.dim) + + def test_more_loops_changes_output(self): + h = torch.randn(B, T, self.cfg.dim) + e = torch.randn(B, T, self.cfg.dim) + out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1) + out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3) + assert not torch.allclose(out1, out3) + + def test_single_loop_runs(self): + h = torch.randn(B, T, self.cfg.dim) + e = torch.randn(B, T, self.cfg.dim) + out = self.block(h, e, self.freqs, n_loops=1) + assert out.shape == (B, T, self.cfg.dim) + + +# --------------------------------------------------------------------------- +# OpenMythos — GQA mode +# --------------------------------------------------------------------------- + + +class TestOpenMythosGQA: + def setup_method(self): + self.cfg = gqa_cfg() + self.model = OpenMythos(self.cfg) + self.ids = torch.randint(0, self.cfg.vocab_size, (B, T)) + + def test_forward_shape(self): + logits = self.model(self.ids) + assert logits.shape == (B, T, self.cfg.vocab_size) + + def test_forward_no_nan(self): + logits = self.model(self.ids) + assert not torch.isnan(logits).any() + + def test_generate_shape(self): + out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2) + assert out.shape == (B, T + 4) + + def test_weight_tying(self): + assert self.model.head.weight is self.model.embed.weight + + def test_lti_spectral_radius(self): + A = self.model.recurrent.injection.get_A() + assert A.max().item() < 1.0 + + def test_depth_extrapolation_changes_output(self): + # More loops at inference should produce different (ideally better) output + logits_shallow = self.model(self.ids, n_loops=1) + logits_deep = self.model(self.ids, n_loops=3) + assert not torch.allclose(logits_shallow, logits_deep) + + def test_kv_cache_generate_matches_no_cache(self): + # Single-token generation with and without cache should agree + torch.manual_seed(0) + prompt = torch.randint(0, self.cfg.vocab_size, (1, T)) + with torch.no_grad(): + logits_no_cache = self.model(prompt, n_loops=2)[:, -1, :] + cache = {} + logits_cached = self.model(prompt, n_loops=2, kv_cache=cache)[:, -1, :] + assert torch.allclose(logits_no_cache, logits_cached, atol=1e-4) + + def test_single_token_forward(self): + # Mask is None when T=1; should not crash + single = torch.randint(0, self.cfg.vocab_size, (B, 1)) + logits = self.model(single) + assert logits.shape == (B, 1, self.cfg.vocab_size) + + +# --------------------------------------------------------------------------- +# OpenMythos — MLA mode +# --------------------------------------------------------------------------- + + +class TestOpenMythosMLА: + def setup_method(self): + self.cfg = mla_cfg() + self.model = OpenMythos(self.cfg) + self.ids = torch.randint(0, self.cfg.vocab_size, (B, T)) + + def test_forward_shape(self): + logits = self.model(self.ids) + assert logits.shape == (B, T, self.cfg.vocab_size) + + def test_forward_no_nan(self): + assert not torch.isnan(self.model(self.ids)).any() + + def test_generate_shape(self): + out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2) + assert out.shape == (B, T + 4) + + def test_lti_spectral_radius(self): + A = self.model.recurrent.injection.get_A() + assert A.max().item() < 1.0 + + def test_mla_cache_is_compressed(self): + # MLA cache should store c_kv (lora_rank), not full K/V (n_heads * head_dim) + cache = {} + with torch.no_grad(): + self.model(self.ids, kv_cache=cache) + # find any MLA cache entry and check dimensions + mla_entries = {k: v for k, v in cache.items() if "c_kv" in v} + assert len(mla_entries) > 0 + for entry in mla_entries.values(): + assert entry["c_kv"].shape[-1] == self.cfg.kv_lora_rank + + +# --------------------------------------------------------------------------- +# GQA vs MLA: same config, different attn_type +# --------------------------------------------------------------------------- + + +class TestAttnTypeSwap: + def test_gqa_and_mla_produce_different_outputs(self): + cfg_gqa = gqa_cfg() + cfg_mla = mla_cfg() + ids = torch.randint(0, cfg_gqa.vocab_size, (B, T)) + logits_gqa = OpenMythos(cfg_gqa)(ids) + logits_mla = OpenMythos(cfg_mla)(ids) + # different architectures, different params → outputs must differ + assert not torch.allclose(logits_gqa, logits_mla) + + def test_both_modes_produce_valid_shapes(self): + ids = torch.randint(0, 200, (B, T)) + for attn_type in ("gqa", "mla"): + cfg = gqa_cfg(attn_type=attn_type) + logits = OpenMythos(cfg)(ids) + assert logits.shape == (B, T, cfg.vocab_size) + + def test_mla_fewer_kv_cache_bytes(self): + # MLA cache should be smaller than GQA cache for the same sequence + ids = torch.randint(0, 200, (1, T)) + cache_gqa, cache_mla = {}, {} + with torch.no_grad(): + OpenMythos(gqa_cfg())(ids, kv_cache=cache_gqa) + OpenMythos(mla_cfg())(ids, kv_cache=cache_mla) + + def cache_bytes(cache): + return sum( + t.numel() * t.element_size() + for entry in cache.values() + for t in entry.values() + ) + + assert cache_bytes(cache_mla) < cache_bytes(cache_gqa) + + +if __name__ == "__main__": + pytest.main([__file__, "--verbose"]) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29