import torch from open_mythos.main import OpenMythos, MythosConfig attn_type = "mla" # or "gqa" base = { "vocab_size": 1000, "dim": 256, "n_heads": 8, "max_seq_len": 128, "max_loop_iters": 4, "prelude_layers": 1, "coda_layers": 1, "n_experts": 8, "n_shared_experts": 1, "n_experts_per_tok": 2, "expert_dim": 64, "lora_rank": 8, "attn_type": attn_type, } if attn_type == "gqa": cfg = MythosConfig(**base, n_kv_heads=2) else: cfg = MythosConfig( **base, n_kv_heads=8, kv_lora_rank=32, q_lora_rank=64, qk_rope_head_dim=16, qk_nope_head_dim=16, v_head_dim=16, ) model = OpenMythos(cfg) total = sum(p.numel() for p in model.parameters()) print(f"\n[{attn_type.upper()}] Parameters: {total:,}") ids = torch.randint(0, cfg.vocab_size, (2, 16)) logits = model(ids, n_loops=4) print(f"[{attn_type.upper()}] Logits shape: {logits.shape}") out = model.generate(ids, max_new_tokens=8, n_loops=8) print(f"[{attn_type.upper()}] Generated shape: {out.shape}") A = model.recurrent.injection.get_A() print( f"[{attn_type.upper()}] Spectral radius ρ(A) max: {A.max().item():.4f} (must be < 1)" )