mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
tests
This commit is contained in:
parent
c258cdc8da
commit
0699c00c94
@ -45,4 +45,6 @@ out = model.generate(ids, max_new_tokens=8, n_loops=8)
|
||||
print(f"[{attn_type.upper()}] Generated shape: {out.shape}")
|
||||
|
||||
A = model.recurrent.injection.get_A()
|
||||
print(f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)")
|
||||
print(
|
||||
f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)"
|
||||
)
|
||||
|
||||
@ -1012,4 +1012,3 @@ class OpenMythos(nn.Module):
|
||||
next_tok = torch.multinomial(probs, num_samples=1)
|
||||
input_ids = torch.cat([input_ids, next_tok], dim=1)
|
||||
return input_ids
|
||||
|
||||
|
||||
548
test_main.py
Normal file
548
test_main.py
Normal file
@ -0,0 +1,548 @@
|
||||
"""
|
||||
pytest suite for OpenMythos (second.py).
|
||||
Tests every major class and feature: shapes, correctness invariants, and
|
||||
architecture-specific properties (LTI stability, ACT halting, depth extrapolation,
|
||||
KV cache consistency, GQA vs MLA swap).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from open_mythos.main import (
|
||||
ACTHalting,
|
||||
Expert,
|
||||
GQAttention,
|
||||
LTIInjection,
|
||||
LoRAAdapter,
|
||||
MLAttention,
|
||||
MoEFFN,
|
||||
MythosConfig,
|
||||
OpenMythos,
|
||||
RecurrentBlock,
|
||||
RMSNorm,
|
||||
TransformerBlock,
|
||||
apply_rope,
|
||||
loop_index_embedding,
|
||||
precompute_rope_freqs,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared small configs (kept tiny so tests run fast on CPU)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
B, T = 2, 8 # batch, sequence length
|
||||
|
||||
|
||||
def gqa_cfg(**overrides) -> MythosConfig:
|
||||
defaults = dict(
|
||||
vocab_size=200,
|
||||
dim=64,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
max_seq_len=32,
|
||||
max_loop_iters=3,
|
||||
prelude_layers=1,
|
||||
coda_layers=1,
|
||||
attn_type="gqa",
|
||||
n_experts=4,
|
||||
n_shared_experts=1,
|
||||
n_experts_per_tok=2,
|
||||
expert_dim=16,
|
||||
act_threshold=0.99,
|
||||
lora_rank=4,
|
||||
# MLA fields must be valid even when not used
|
||||
kv_lora_rank=16,
|
||||
q_lora_rank=32,
|
||||
qk_rope_head_dim=8,
|
||||
qk_nope_head_dim=8,
|
||||
v_head_dim=8,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return MythosConfig(**defaults)
|
||||
|
||||
|
||||
def mla_cfg(**overrides) -> MythosConfig:
|
||||
return gqa_cfg(attn_type="mla", **overrides)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RMSNorm
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRMSNorm:
|
||||
def test_output_shape(self):
|
||||
norm = RMSNorm(64)
|
||||
x = torch.randn(2, 8, 64)
|
||||
assert norm(x).shape == x.shape
|
||||
|
||||
def test_unit_rms(self):
|
||||
# after norm the RMS of each vector should be ≈ 1 when weight=1
|
||||
norm = RMSNorm(64)
|
||||
torch.nn.init.ones_(norm.weight)
|
||||
x = torch.randn(4, 64)
|
||||
out = norm(x)
|
||||
rms = out.pow(2).mean(-1).sqrt()
|
||||
assert torch.allclose(rms, torch.ones_like(rms), atol=1e-4)
|
||||
|
||||
def test_learnable_weight(self):
|
||||
norm = RMSNorm(8)
|
||||
assert norm.weight.requires_grad
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoPE:
|
||||
def test_precompute_shape(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
assert freqs.shape == (32, 8) # (max_len, dim//2)
|
||||
assert freqs.is_complex()
|
||||
|
||||
def test_apply_rope_shape(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.randn(B, T, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_apply_rope_preserves_norm(self):
|
||||
# rotation is an isometry — norms must be unchanged
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.randn(B, T, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5)
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
x = torch.ones(1, 2, 1, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
# position 0 and position 1 should produce different rotations
|
||||
assert not torch.allclose(out[0, 0], out[0, 1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GQAttention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGQAttention:
|
||||
def setup_method(self):
|
||||
self.cfg = gqa_cfg()
|
||||
self.freqs = precompute_rope_freqs(
|
||||
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||||
)
|
||||
self.attn = GQAttention(self.cfg)
|
||||
|
||||
def test_output_shape(self):
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
out = self.attn(x, self.freqs)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
def test_kv_cache_accumulates(self):
|
||||
cache = {}
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||||
assert "layer0" in cache
|
||||
k_len = cache["layer0"]["k"].shape[1]
|
||||
# second call adds T more tokens
|
||||
self.attn(x, self.freqs, kv_cache=cache, cache_key="layer0")
|
||||
assert cache["layer0"]["k"].shape[1] == k_len + T
|
||||
|
||||
def test_with_causal_mask(self):
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
mask = torch.full((1, 1, T, T), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
out = self.attn(x, self.freqs, mask=mask)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MLAttention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMLAttention:
|
||||
def setup_method(self):
|
||||
self.cfg = mla_cfg()
|
||||
self.freqs = precompute_rope_freqs(
|
||||
self.cfg.qk_rope_head_dim, self.cfg.max_seq_len
|
||||
)
|
||||
self.attn = MLAttention(self.cfg)
|
||||
|
||||
def test_output_shape(self):
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
out = self.attn(x, self.freqs)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
def test_cache_stores_compressed_kv(self):
|
||||
cache = {}
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||
assert "c_kv" in cache["mla0"]
|
||||
assert "k_rope" in cache["mla0"]
|
||||
# c_kv should have kv_lora_rank as last dim, not full K/V
|
||||
assert cache["mla0"]["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||||
|
||||
def test_cache_accumulates_across_steps(self):
|
||||
cache = {}
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||
first_len = cache["mla0"]["c_kv"].shape[1]
|
||||
self.attn(x, self.freqs, kv_cache=cache, cache_key="mla0")
|
||||
assert cache["mla0"]["c_kv"].shape[1] == first_len + T
|
||||
|
||||
def test_with_causal_mask(self):
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
mask = torch.triu(torch.full((1, 1, T, T), float("-inf")), diagonal=1)
|
||||
out = self.attn(x, self.freqs, mask=mask)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Expert (dense SwiGLU FFN)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExpert:
|
||||
def test_output_shape(self):
|
||||
expert = Expert(dim=64, expert_dim=32)
|
||||
x = torch.randn(B, T, 64)
|
||||
assert expert(x).shape == (B, T, 64)
|
||||
|
||||
def test_flat_input(self):
|
||||
expert = Expert(dim=32, expert_dim=16)
|
||||
x = torch.randn(5, 32)
|
||||
assert expert(x).shape == (5, 32)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MoEFFN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMoEFFN:
|
||||
def setup_method(self):
|
||||
self.cfg = gqa_cfg()
|
||||
self.moe = MoEFFN(self.cfg)
|
||||
|
||||
def test_output_shape(self):
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
assert self.moe(x).shape == (B, T, self.cfg.dim)
|
||||
|
||||
def test_router_bias_not_grad(self):
|
||||
# router_bias is a buffer, not a parameter
|
||||
param_names = {n for n, _ in self.moe.named_parameters()}
|
||||
assert "router_bias" not in param_names
|
||||
|
||||
def test_shared_experts_always_fire(self):
|
||||
# Zero out all routed experts; output should still be nonzero from shared
|
||||
for exp in self.moe.routed_experts:
|
||||
for p in exp.parameters():
|
||||
p.data.zero_()
|
||||
x = torch.randn(B, T, self.cfg.dim)
|
||||
out = self.moe(x)
|
||||
assert out.abs().sum() > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# loop_index_embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoopIndexEmbedding:
|
||||
def test_output_shape(self):
|
||||
h = torch.randn(B, T, 64)
|
||||
out = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||||
assert out.shape == h.shape
|
||||
|
||||
def test_different_iterations_differ(self):
|
||||
h = torch.zeros(1, 1, 64)
|
||||
out0 = loop_index_embedding(h, loop_t=0, loop_dim=8)
|
||||
out1 = loop_index_embedding(h, loop_t=1, loop_dim=8)
|
||||
assert not torch.allclose(out0, out1)
|
||||
|
||||
def test_only_first_dims_modified(self):
|
||||
h = torch.zeros(1, 1, 64)
|
||||
loop_dim = 8
|
||||
out = loop_index_embedding(h, loop_t=3, loop_dim=loop_dim)
|
||||
# channels beyond loop_dim should be unchanged (still 0)
|
||||
assert torch.all(out[..., loop_dim:] == 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoRAAdapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoRAAdapter:
|
||||
def setup_method(self):
|
||||
self.lora = LoRAAdapter(dim=64, rank=8, max_loops=10)
|
||||
|
||||
def test_output_shape(self):
|
||||
x = torch.randn(B, T, 64)
|
||||
out = self.lora(x, loop_t=0)
|
||||
assert out.shape == (B, T, 64)
|
||||
|
||||
def test_different_loops_differ(self):
|
||||
x = torch.randn(B, T, 64)
|
||||
out0 = self.lora(x, loop_t=0)
|
||||
out1 = self.lora(x, loop_t=1)
|
||||
assert not torch.allclose(out0, out1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TransformerBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTransformerBlock:
|
||||
def test_gqa_output_shape(self):
|
||||
cfg = gqa_cfg()
|
||||
block = TransformerBlock(cfg, use_moe=False)
|
||||
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||||
x = torch.randn(B, T, cfg.dim)
|
||||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||
|
||||
def test_mla_output_shape(self):
|
||||
cfg = mla_cfg()
|
||||
block = TransformerBlock(cfg, use_moe=False)
|
||||
freqs = precompute_rope_freqs(cfg.qk_rope_head_dim, cfg.max_seq_len)
|
||||
x = torch.randn(B, T, cfg.dim)
|
||||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||
|
||||
def test_moe_block_output_shape(self):
|
||||
cfg = gqa_cfg()
|
||||
block = TransformerBlock(cfg, use_moe=True)
|
||||
freqs = precompute_rope_freqs(cfg.dim // cfg.n_heads, cfg.max_seq_len)
|
||||
x = torch.randn(B, T, cfg.dim)
|
||||
assert block(x, freqs).shape == (B, T, cfg.dim)
|
||||
|
||||
def test_attn_type_selection(self):
|
||||
assert isinstance(TransformerBlock(gqa_cfg()).attn, GQAttention)
|
||||
assert isinstance(TransformerBlock(mla_cfg()).attn, MLAttention)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LTIInjection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLTIInjection:
|
||||
def setup_method(self):
|
||||
self.inj = LTIInjection(dim=64)
|
||||
|
||||
def test_output_shape(self):
|
||||
h = torch.randn(B, T, 64)
|
||||
e = torch.randn(B, T, 64)
|
||||
t = torch.randn(B, T, 64)
|
||||
assert self.inj(h, e, t).shape == (B, T, 64)
|
||||
|
||||
def test_spectral_radius_lt_1(self):
|
||||
A = self.inj.get_A()
|
||||
assert A.max().item() < 1.0
|
||||
|
||||
def test_spectral_radius_gt_0(self):
|
||||
A = self.inj.get_A()
|
||||
assert A.min().item() > 0.0
|
||||
|
||||
def test_spectral_radius_stable_after_large_grad_step(self):
|
||||
# Simulate an aggressive gradient update and verify stability holds
|
||||
opt = torch.optim.SGD(self.inj.parameters(), lr=1e3)
|
||||
h = torch.randn(B, T, 64)
|
||||
e = torch.randn(B, T, 64)
|
||||
t = torch.randn(B, T, 64)
|
||||
loss = self.inj(h, e, t).sum()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
A = self.inj.get_A()
|
||||
assert A.max().item() < 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ACTHalting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestACTHalting:
|
||||
def setup_method(self):
|
||||
self.act = ACTHalting(dim=64)
|
||||
|
||||
def test_output_shape(self):
|
||||
h = torch.randn(B, T, 64)
|
||||
p = self.act(h)
|
||||
assert p.shape == (B, T)
|
||||
|
||||
def test_values_in_01(self):
|
||||
h = torch.randn(B, T, 64)
|
||||
p = self.act(h)
|
||||
assert p.min().item() >= 0.0
|
||||
assert p.max().item() <= 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RecurrentBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecurrentBlock:
|
||||
def setup_method(self):
|
||||
self.cfg = gqa_cfg()
|
||||
self.block = RecurrentBlock(self.cfg)
|
||||
self.freqs = precompute_rope_freqs(
|
||||
self.cfg.dim // self.cfg.n_heads, self.cfg.max_seq_len
|
||||
)
|
||||
|
||||
def test_output_shape(self):
|
||||
h = torch.randn(B, T, self.cfg.dim)
|
||||
e = torch.randn(B, T, self.cfg.dim)
|
||||
out = self.block(h, e, self.freqs)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
def test_more_loops_changes_output(self):
|
||||
h = torch.randn(B, T, self.cfg.dim)
|
||||
e = torch.randn(B, T, self.cfg.dim)
|
||||
out1 = self.block(h.clone(), e.clone(), self.freqs, n_loops=1)
|
||||
out3 = self.block(h.clone(), e.clone(), self.freqs, n_loops=3)
|
||||
assert not torch.allclose(out1, out3)
|
||||
|
||||
def test_single_loop_runs(self):
|
||||
h = torch.randn(B, T, self.cfg.dim)
|
||||
e = torch.randn(B, T, self.cfg.dim)
|
||||
out = self.block(h, e, self.freqs, n_loops=1)
|
||||
assert out.shape == (B, T, self.cfg.dim)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenMythos — GQA mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenMythosGQA:
|
||||
def setup_method(self):
|
||||
self.cfg = gqa_cfg()
|
||||
self.model = OpenMythos(self.cfg)
|
||||
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||||
|
||||
def test_forward_shape(self):
|
||||
logits = self.model(self.ids)
|
||||
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||||
|
||||
def test_forward_no_nan(self):
|
||||
logits = self.model(self.ids)
|
||||
assert not torch.isnan(logits).any()
|
||||
|
||||
def test_generate_shape(self):
|
||||
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||||
assert out.shape == (B, T + 4)
|
||||
|
||||
def test_weight_tying(self):
|
||||
assert self.model.head.weight is self.model.embed.weight
|
||||
|
||||
def test_lti_spectral_radius(self):
|
||||
A = self.model.recurrent.injection.get_A()
|
||||
assert A.max().item() < 1.0
|
||||
|
||||
def test_depth_extrapolation_changes_output(self):
|
||||
# More loops at inference should produce different (ideally better) output
|
||||
logits_shallow = self.model(self.ids, n_loops=1)
|
||||
logits_deep = self.model(self.ids, n_loops=3)
|
||||
assert not torch.allclose(logits_shallow, logits_deep)
|
||||
|
||||
def test_kv_cache_generate_matches_no_cache(self):
|
||||
# Single-token generation with and without cache should agree
|
||||
torch.manual_seed(0)
|
||||
prompt = torch.randint(0, self.cfg.vocab_size, (1, T))
|
||||
with torch.no_grad():
|
||||
logits_no_cache = self.model(prompt, n_loops=2)[:, -1, :]
|
||||
cache = {}
|
||||
logits_cached = self.model(prompt, n_loops=2, kv_cache=cache)[:, -1, :]
|
||||
assert torch.allclose(logits_no_cache, logits_cached, atol=1e-4)
|
||||
|
||||
def test_single_token_forward(self):
|
||||
# Mask is None when T=1; should not crash
|
||||
single = torch.randint(0, self.cfg.vocab_size, (B, 1))
|
||||
logits = self.model(single)
|
||||
assert logits.shape == (B, 1, self.cfg.vocab_size)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenMythos — MLA mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenMythosMLА:
|
||||
def setup_method(self):
|
||||
self.cfg = mla_cfg()
|
||||
self.model = OpenMythos(self.cfg)
|
||||
self.ids = torch.randint(0, self.cfg.vocab_size, (B, T))
|
||||
|
||||
def test_forward_shape(self):
|
||||
logits = self.model(self.ids)
|
||||
assert logits.shape == (B, T, self.cfg.vocab_size)
|
||||
|
||||
def test_forward_no_nan(self):
|
||||
assert not torch.isnan(self.model(self.ids)).any()
|
||||
|
||||
def test_generate_shape(self):
|
||||
out = self.model.generate(self.ids, max_new_tokens=4, n_loops=2)
|
||||
assert out.shape == (B, T + 4)
|
||||
|
||||
def test_lti_spectral_radius(self):
|
||||
A = self.model.recurrent.injection.get_A()
|
||||
assert A.max().item() < 1.0
|
||||
|
||||
def test_mla_cache_is_compressed(self):
|
||||
# MLA cache should store c_kv (lora_rank), not full K/V (n_heads * head_dim)
|
||||
cache = {}
|
||||
with torch.no_grad():
|
||||
self.model(self.ids, kv_cache=cache)
|
||||
# find any MLA cache entry and check dimensions
|
||||
mla_entries = {k: v for k, v in cache.items() if "c_kv" in v}
|
||||
assert len(mla_entries) > 0
|
||||
for entry in mla_entries.values():
|
||||
assert entry["c_kv"].shape[-1] == self.cfg.kv_lora_rank
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GQA vs MLA: same config, different attn_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAttnTypeSwap:
|
||||
def test_gqa_and_mla_produce_different_outputs(self):
|
||||
cfg_gqa = gqa_cfg()
|
||||
cfg_mla = mla_cfg()
|
||||
ids = torch.randint(0, cfg_gqa.vocab_size, (B, T))
|
||||
logits_gqa = OpenMythos(cfg_gqa)(ids)
|
||||
logits_mla = OpenMythos(cfg_mla)(ids)
|
||||
# different architectures, different params → outputs must differ
|
||||
assert not torch.allclose(logits_gqa, logits_mla)
|
||||
|
||||
def test_both_modes_produce_valid_shapes(self):
|
||||
ids = torch.randint(0, 200, (B, T))
|
||||
for attn_type in ("gqa", "mla"):
|
||||
cfg = gqa_cfg(attn_type=attn_type)
|
||||
logits = OpenMythos(cfg)(ids)
|
||||
assert logits.shape == (B, T, cfg.vocab_size)
|
||||
|
||||
def test_mla_fewer_kv_cache_bytes(self):
|
||||
# MLA cache should be smaller than GQA cache for the same sequence
|
||||
ids = torch.randint(0, 200, (1, T))
|
||||
cache_gqa, cache_mla = {}, {}
|
||||
with torch.no_grad():
|
||||
OpenMythos(gqa_cfg())(ids, kv_cache=cache_gqa)
|
||||
OpenMythos(mla_cfg())(ids, kv_cache=cache_mla)
|
||||
|
||||
def cache_bytes(cache):
|
||||
return sum(
|
||||
t.numel() * t.element_size()
|
||||
for entry in cache.values()
|
||||
for t in entry.values()
|
||||
)
|
||||
|
||||
assert cache_bytes(cache_mla) < cache_bytes(cache_gqa)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "--verbose"])
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user