""" pytest suite for OpenMythos (second.py). Tests every major class and feature: shapes, correctness invariants, and architecture-specific properties (LTI stability, ACT halting, depth extrapolation, KV cache consistency, GQA vs MLA swap). """ import torch import pytest from open_mythos.main import ( ACTHalting, Expert, GQAttention, LTIInjection, LoRAAdapter, MLAttention, MoEFFN, MythosConfig, OpenMythos, RecurrentBlock, RMSNorm, TransformerBlock, apply_rope, loop_index_embedding, precompute_rope_freqs, ) # --------------------------------------------------------------------------- # Shared small configs (kept tiny so tests run fast on CPU) # --------------------------------------------------------------------------- B, T = 2, 8 # batch, sequence length def gqa_cfg(**overrides) -> MythosConfig: defaults = dict( vocab_size=200, dim=64, n_heads=4, n_kv_heads=2, max_seq_len=32, max_loop_iters=3, prelude_layers=1, coda_layers=1, attn_type="gqa", n_experts=4, n_shared_experts=1, n_experts_per_tok=2, expert_dim=16, act_threshold=0.99, lora_rank=4, # MLA fields must be valid even when not used kv_lora_rank=16, q_lora_rank=32, qk_rope_head_dim=8, qk_nope_head_dim=8, v_head_dim=8, ) defaults.update(overrides) return MythosConfig(**defaults) def mla_cfg(**overrides) -> MythosConfig: return gqa_cfg(attn_type="mla", **overrides) # --------------------------------------------------------------------------- # RMSNorm # --------------------------------------------------------------------------- class TestRMSNorm: def test_output_shape(self): norm = RMSNorm(64) x = torch.randn(2, 8, 64) assert norm(x).shape == x.shape def test_unit_rms(self): # after norm the RMS of each vector should be ≈ 1 when weight=1 norm = RMSNorm(64) torch.nn.init.ones_(norm.weight) x = torch.randn(4, 64) out = norm(x) rms = out.pow(2).mean(-1).sqrt() assert torch.allclose(rms, torch.ones_like(rms), atol=1e-4) def test_learnable_weight(self): norm = RMSNorm(8) assert norm.weight.requires_grad # --------------------------------------------------------------------------- # RoPE utilities # --------------------------------------------------------------------------- class TestRoPE: def test_precompute_shape(self): freqs = precompute_rope_freqs(dim=16, max_len=32) assert freqs.shape == (32, 8) # (max_len, dim//2) assert freqs.is_complex() def test_apply_rope_shape(self): freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.randn(B, T, 4, 16) out = apply_rope(x, freqs) assert out.shape == x.shape def test_apply_rope_preserves_norm(self): # rotation is an isometry — norms must be unchanged freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.randn(B, T, 4, 16) out = apply_rope(x, freqs) assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5) def test_different_positions_differ(self): freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.ones(1, 2, 1, 16) out = apply_rope(x, freqs) # position 0 and position 1 should produce different rotations assert not torch.allclose(out[0, 0], out[0, 1]) # --------------------------------------------------------------------------- # GQAttention # --------------------------------------------------------------------------- class TestGQAttention: def setup_method(self): self.cfg = gqa_cfg() self.freqs = precompute_rope_freqs( self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len ) self.attn = GQAttention(self.cfg) def test_output_shape(self): x = torch.randn(B, T, self.cfg.dim) out = self.attn(x, self.freqs) assert out.shape == (B, T, self.cfg.dim) def test_kv_cache_accumulates(self): cache = {} x = torch.randn(B, T, self.cfg.dim) self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") assert "layer0" in cache k_len = cache["layer0"]["k"].shape[1] # second call adds T more tokens self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0") assert cache["layer0"]["k"].shape[1] == k_len + T def test_with_causal_mask(self): x = torch.randn(B, T, self.cfg.dim) mask = torch.full((1, 1, T, T), float("-inf")) mask = torch.triu(mask, diagonal=1) out = self.attn(x, self.freqs, mask=mask) assert out.shape == (B, T, self.cfg.dim) # --------------------------------------------------------------------------- # MLAttention # --------------------------------------------------------------------------- class TestMLAttention: def setup_method(self): self.cfg = mla_cfg() self.freqs = precompute_rope_freqs( self.cfg.qk_rope_head_dim, self.cfg.max_seq_len ) self.attn = MLAttention(self.cfg) def test_output_shape(self): x = torch.randn(B, T, self.cfg.dim) out = self.attn(x, self.freqs) assert out.shape == (B, T, self.cfg.dim) def test_cache_stores_compressed_kv(self): cache = {} x = torch.randn(B, T, self.cfg.dim) self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") assert "c_kv" in cache["mla0"] assert "k_rope" in cache["mla0"] # c_kv should have kv_lora_rank as last dim, not full K/V assert cache["mla0"]["c_kv"].shape[-1] == self.cfg.kv_lora_rank def test_cache_accumulates_across_steps(self): cache = {} x = torch.randn(B, T, self.cfg.dim) self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") first_len = cache["mla0"]["c_kv"].shape[1] self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0") assert cache["mla0"]["c_kv"].shape[1] == first_len + T def test_with_causal_mask(self): x = torch.randn(B, T, self.cfg.dim) mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1) out = self.attn(x, self.freqs, mask=mask) assert out.shape == (B, T, self.cfg.dim) # --------------------------------------------------------------------------- # Expert (dense SwiGLU FFN) # --------------------------------------------------------------------------- class TestExpert: def test_output_shape(self): expert = Expert(dim=64, expert_dim=32) x = torch.randn(B, T, 64) assert expert(x).shape == (B, T, 64) def test_flat_input(self): expert = Expert(dim=32, expert_dim=16) x = torch.randn(5, 32) assert expert(x).shape == (5, 32) # --------------------------------------------------------------------------- # MoEFFN # --------------------------------------------------------------------------- class TestMoEFFN: def setup_method(self): self.cfg = gqa_cfg() self.moe = MoEFFN(self.cfg) def test_output_shape(self): x = torch.randn(B, T, self.cfg.dim) assert self.moe(x).shape == (B, T, self.cfg.dim) def test_router_bias_not_grad(self): # router_bias is a buffer, not a parameter param_names = {n for n, _ in self.moe.named_parameters()} assert "router_bias" not in param_names def test_shared_experts_always_fire(self): # Zero out all routed experts; output should still be nonzero from shared for exp in self.moe.routed_experts: for p in exp.parameters(): p.data.zero_() x = torch.randn(B, T, self.cfg.dim) out = self.moe(x) assert out.abs().sum() > 0 # --------------------------------------------------------------------------- # loop_index_embedding # --------------------------------------------------------------------------- class TestLoopIndexEmbedding: def test_output_shape(self): h = torch.randn(B, T, 64) out = loop_index_embedding(h, loop_t=0, loop_dim=8) assert out.shape == h.shape def test_different_iterations_differ(self): h = torch.zeros(1, 1, 64) out0 = loop_index_embedding(h, loop_t=0, loop_dim=8) out1 = loop_index_embedding(h, loop_t=1, loop_dim=8) assert not torch.allclose(out0, out1) def test_only_first_dims_modified(self): h = torch.zeros(1, 1, 64) loop_dim = 8 out = loop_index_embedding(h, loop_t=3, loop_dim=loop_dim) # channels beyond loop_dim should be unchanged (still 0) assert torch.all(out[..., loop_dim:] == 0) # --------------------------------------------------------------------------- # LoRAAdapter # --------------------------------------------------------------------------- class TestLoRAAdapter: def setup_method(self): self.lora = LoRAAdapter(dim=64, rank=8, max_loops=10) def test_output_shape(self): x = torch.randn(B, T, 64) out = self.lora(x, loop_t=0) assert out.shape == (B, T, 64) def test_different_loops_differ(self): x = torch.randn(B, T, 64) out0 = self.lora(x, loop_t=0) out1 = self.lora(x, loop_t=1) assert not torch.allclose(out0, out1) # --------------------------------------------------------------------------- # TransformerBlock # --------------------------------------------------------------------------- class TestTransformerBlock: def test_gqa_output_shape(self): cfg = gqa_cfg() block = TransformerBlock(cfg, use_moe=False) freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) assert block(x, freqs).shape == (B, T, cfg.dim) def test_mla_output_shape(self): cfg = mla_cfg() block = TransformerBlock(cfg, use_moe=False) freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) assert block(x, freqs).shape == (B, T, cfg.dim) def test_moe_block_output_shape(self): cfg = gqa_cfg() block = TransformerBlock(cfg, use_moe=True) freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len) x = torch.randn(B, T, cfg.dim) assert block(x, freqs).shape == (B, T, cfg.dim) def test_attn_type_selection(self): assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention) assert isinstance(TransformerBlock(mla_cfg()).attn, MLAttention) # --------------------------------------------------------------------------- # LTIInjection # --------------------------------------------------------------------------- class TestLTIInjection: def setup_method(self): self.inj = LTIInjection(dim=64) def test_output_shape(self): h = torch.randn(B, T, 64) e = torch.randn(B, T, 64) t = torch.randn(B, T, 64) assert self.inj(h, e, t).shape == (B, T, 64) def test_spectral_radius_lt_1(self): A = self.inj.get_A() assert A.max().item() < 1.0 def test_spectral_radius_gt_0(self): A = self.inj.get_A() assert A.min().item() > 0.0 def test_spectral_radius_stable_after_large_grad_step(self): # Simulate an aggressive gradient update and verify stability holds opt = torch.optim.SGD(self.inj.parameters(), lr=1e3) h = torch.randn(B, T, 64) e = torch.randn(B, T, 64) t = torch.randn(B, T, 64) loss = self.inj(h, e, t).sum() loss.backward() opt.step() A = self.inj.get_A() assert A.max().item() < 1.0 # --------------------------------------------------------------------------- # ACTHalting # --------------------------------------------------------------------------- class TestACTHalting: def setup_method(self): self.act = ACTHalting(dim=64) def test_output_shape(self): h = torch.randn(B, T, 64) p = self.act(h) assert p.shape == (B, T) def test_values_in_01(self): h = torch.randn(B, T, 64) p = self.act(h) assert p.min().item() >= 0.0 assert p.max().item() <= 1.0 # --------------------------------------------------------------------------- # RecurrentBlock # --------------------------------------------------------------------------- class TestRecurrentBlock: def setup_method(self): self.cfg = gqa_cfg() self.block = RecurrentBlock(self.cfg) self.freqs = precompute_rope_freqs( self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len ) def test_output_shape(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) out = self.block(h, e, self.freqs) assert out.shape == (B, T, self.cfg.dim) def test_more_loops_changes_output(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1) out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3) assert not torch.allclose(out1, out3) def test_single_loop_runs(self): h = torch.randn(B, T, self.cfg.dim) e = torch.randn(B, T, self.cfg.dim) out = self.block(h, e, self.freqs, n_loops=1) assert out.shape == (B, T, self.cfg.dim) # --------------------------------------------------------------------------- # OpenMythos — GQA mode # --------------------------------------------------------------------------- class TestOpenMythosGQA: def setup_method(self): self.cfg = gqa_cfg() self.model = OpenMythos(self.cfg) self.ids = torch.randint(0, self.cfg.vocab_size, (B, T)) def test_forward_shape(self): logits = self.model(self.ids) assert logits.shape == (B, T, self.cfg.vocab_size) def test_forward_no_nan(self): logits = self.model(self.ids) assert not torch.isnan(logits).any() def test_generate_shape(self): out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2) assert out.shape == (B, T + 4) def test_weight_tying(self): assert self.model.head.weight is self.model.embed.weight def test_lti_spectral_radius(self): A = self.model.recurrent.injection.get_A() assert A.max().item() < 1.0 def test_depth_extrapolation_changes_output(self): # More loops at inference should produce different (ideally better) output logits_shallow = self.model(self.ids, n_loops=1) logits_deep = self.model(self.ids, n_loops=3) assert not torch.allclose(logits_shallow, logits_deep) def test_kv_cache_generate_matches_no_cache(self): # Single-token generation with and without cache should agree torch.manual_seed(0) prompt = torch.randint(0, self.cfg.vocab_size, (1, T)) with torch.no_grad(): logits_no_cache = self.model(prompt, n_loops=2)[:, -1, :] cache = {} logits_cached = self.model(prompt, n_loops=2, kv_cache=cache)[:, -1, :] assert torch.allclose(logits_no_cache, logits_cached, atol=1e-4) def test_single_token_forward(self): # Mask is None when T=1; should not crash single = torch.randint(0, self.cfg.vocab_size, (B, 1)) logits = self.model(single) assert logits.shape == (B, 1, self.cfg.vocab_size) # --------------------------------------------------------------------------- # OpenMythos — MLA mode # --------------------------------------------------------------------------- class TestOpenMythosMLА: def setup_method(self): self.cfg = mla_cfg() self.model = OpenMythos(self.cfg) self.ids = torch.randint(0, self.cfg.vocab_size, (B, T)) def test_forward_shape(self): logits = self.model(self.ids) assert logits.shape == (B, T, self.cfg.vocab_size) def test_forward_no_nan(self): assert not torch.isnan(self.model(self.ids)).any() def test_generate_shape(self): out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2) assert out.shape == (B, T + 4) def test_lti_spectral_radius(self): A = self.model.recurrent.injection.get_A() assert A.max().item() < 1.0 def test_mla_cache_is_compressed(self): # MLA cache should store c_kv (lora_rank), not full K/V (n_heads * head_dim) cache = {} with torch.no_grad(): self.model(self.ids, kv_cache=cache) # find any MLA cache entry and check dimensions mla_entries = {k: v for k, v in cache.items() if "c_kv" in v} assert len(mla_entries) > 0 for entry in mla_entries.values(): assert entry["c_kv"].shape[-1] == self.cfg.kv_lora_rank # --------------------------------------------------------------------------- # GQA vs MLA: same config, different attn_type # --------------------------------------------------------------------------- class TestAttnTypeSwap: def test_gqa_and_mla_produce_different_outputs(self): cfg_gqa = gqa_cfg() cfg_mla = mla_cfg() ids = torch.randint(0, cfg_gqa.vocab_size, (B, T)) logits_gqa = OpenMythos(cfg_gqa)(ids) logits_mla = OpenMythos(cfg_mla)(ids) # different architectures, different params → outputs must differ assert not torch.allclose(logits_gqa, logits_mla) def test_both_modes_produce_valid_shapes(self): ids = torch.randint(0, 200, (B, T)) for attn_type in ("gqa", "mla"): cfg = gqa_cfg(attn_type=attn_type) logits = OpenMythos(cfg)(ids) assert logits.shape == (B, T, cfg.vocab_size) def test_mla_fewer_kv_cache_bytes(self): # MLA cache should be smaller than GQA cache for the same sequence ids = torch.randint(0, 200, (1, T)) cache_gqa, cache_mla = {}, {} with torch.no_grad(): OpenMythos(gqa_cfg())(ids, kv_cache=cache_gqa) OpenMythos(mla_cfg())(ids, kv_cache=cache_mla) def cache_bytes(cache): return sum( t.numel() * t.element_size() for entry in cache.values() for t in entry.values() ) assert cache_bytes(cache_mla) < cache_bytes(cache_gqa) if __name__ == "__main__": pytest.main([__file__, "--verbose"])