disclaimer and rope tests

This commit is contained in:
Kye Gomez 2026-04-18 23:44:31 -04:00
parent 14a806c2a1
commit 53f786afda
2 changed files with 135 additions and 0 deletions

View File

@ -1,5 +1,7 @@
# OpenMythos
> **Disclaimer:** OpenMythos is an independent, community-driven theoretical reconstruction based solely on publicly available research and speculation. It is not affiliated with, endorsed by, or connected to Anthropic or any of their proprietary systems.
OpenMythos is an open-source, theoretical implementation of the Claude Mythos model. It implements a Recurrent-Depth Transformer (RDT) with three stages: **Prelude** (transformer blocks), a looped **Recurrent Block** (up to `max_loop_iters`), and a final **Coda**. Attention is switchable between MLA and GQA, and the feed-forward uses a sparse MoE with routed and shared experts ideal for exploring compute-adaptive, depth-variable reasoning.

View File

@ -121,6 +121,139 @@ class TestRoPE:
assert not torch.allclose(out[0, 0], out[0, 1])
# ---------------------------------------------------------------------------
# RoPE extended — correctness invariants
# ---------------------------------------------------------------------------
class TestRoPEExtended:
"""Comprehensive correctness tests for precompute_rope_freqs and apply_rope."""
# --- precompute_rope_freqs ---
def test_position_zero_is_unit_phasor(self):
"""freqs[0] must be all 1+0j (angle = 0 * freq = 0 for every pair)."""
freqs = precompute_rope_freqs(dim=16, max_len=8)
expected = torch.ones(8, dtype=torch.complex64)
assert torch.allclose(freqs[0], expected, atol=1e-6)
def test_all_phasors_have_unit_magnitude(self):
"""Every phasor magnitude must be 1 — RoPE is an isometric rotation."""
freqs = precompute_rope_freqs(dim=16, max_len=32)
assert torch.allclose(freqs.abs(), torch.ones_like(freqs.abs()), atol=1e-6)
def test_angles_equal_outer_product(self):
"""freqs[t, k].angle() must equal t × base_freq[k] for all t, k."""
dim, max_len, theta = 8, 6, 500000.0
freqs = precompute_rope_freqs(dim=dim, max_len=max_len, theta=theta)
base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(max_len, dtype=torch.float32)
expected = torch.polar(torch.ones(max_len, dim // 2), torch.outer(t, base))
assert torch.allclose(freqs.real, expected.real, atol=1e-6)
assert torch.allclose(freqs.imag, expected.imag, atol=1e-6)
def test_higher_theta_produces_smaller_angles(self):
"""Larger theta → slower frequency decay → smaller rotation angle per step."""
dim, max_len = 16, 8
freqs_fast = precompute_rope_freqs(dim=dim, max_len=max_len, theta=100.0)
freqs_slow = precompute_rope_freqs(dim=dim, max_len=max_len, theta=500000.0)
assert (freqs_fast[1].angle().abs() > freqs_slow[1].angle().abs()).all()
def test_default_theta_matches_explicit(self):
"""Omitting theta must equal passing theta=500000.0."""
f1 = precompute_rope_freqs(16, 8)
f2 = precompute_rope_freqs(16, 8, theta=500000.0)
assert torch.allclose(f1.real, f2.real) and torch.allclose(f1.imag, f2.imag)
# --- apply_rope ---
def test_position_zero_is_identity(self):
"""T=1 input uses only freqs[0] = 1+0j, so output must equal input."""
freqs = precompute_rope_freqs(dim=16, max_len=8)
x = torch.randn(2, 1, 4, 16)
out = apply_rope(x, freqs)
assert torch.allclose(x, out, atol=1e-6)
def test_dtype_float32_preserved(self):
freqs = precompute_rope_freqs(dim=16, max_len=16)
x = torch.randn(1, 4, 2, 16).float()
assert apply_rope(x, freqs).dtype == torch.float32
def test_dtype_float16_preserved(self):
freqs = precompute_rope_freqs(dim=16, max_len=16)
x = torch.randn(1, 4, 2, 16).half()
assert apply_rope(x, freqs).dtype == torch.float16
def test_inverse_rotation_recovers_input(self):
"""Rotating by freqs then by conj(freqs) (inverse) must recover the original."""
dim = 16
freqs = precompute_rope_freqs(dim=dim, max_len=8)
x = torch.randn(2, 4, 3, dim)
rotated = apply_rope(x, freqs)
xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2))
inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2)
recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype)
assert torch.allclose(x, recovered, atol=1e-5)
def test_batch_independence(self):
"""Output for one batch item must not depend on other items in the batch."""
dim = 16
freqs = precompute_rope_freqs(dim=dim, max_len=16)
torch.manual_seed(7)
x_a = torch.randn(1, 4, 2, dim)
x_b = torch.randn(1, 4, 2, dim)
solo = apply_rope(x_a, freqs)
batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1]
assert torch.allclose(solo, batched, atol=1e-6)
def test_head_independence(self):
"""All heads at the same position must receive identical rotations."""
dim = 16
freqs = precompute_rope_freqs(dim=dim, max_len=8)
x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous()
out = apply_rope(x, freqs)
assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6)
assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6)
def test_relative_position_property(self):
"""
Core RoPE invariant: <RoPE(q,m), RoPE(k,n)> depends only on (n-m).
Two pairs with the same offset must produce the same dot product.
"""
dim, max_len = 16, 32
freqs = precompute_rope_freqs(dim=dim, max_len=max_len)
torch.manual_seed(42)
q = torch.randn(1, 1, 1, dim)
k = torch.randn(1, 1, 1, dim)
def rope_at(tensor, pos):
"""Rotate tensor at a specific position by embedding it in a zero sequence."""
seq = torch.zeros(1, pos + 1, 1, dim)
seq[0, pos] = tensor[0, 0]
return apply_rope(seq, freqs)[:, pos : pos + 1]
# Both pairs have relative offset n - m = 6
dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum()
dot_1_7 = (rope_at(q, 1) * rope_at(k, 7)).sum()
assert torch.allclose(dot_3_9, dot_1_7, atol=1e-5)
def test_max_len_boundary(self):
"""apply_rope must handle T == max_len without error or NaN."""
max_len = 10
freqs = precompute_rope_freqs(dim=8, max_len=max_len)
x = torch.randn(1, max_len, 2, 8)
out = apply_rope(x, freqs)
assert out.shape == x.shape
assert not torch.isnan(out).any()
def test_exceeds_max_len_raises(self):
"""apply_rope must raise RuntimeError when T > max_len."""
freqs = precompute_rope_freqs(dim=8, max_len=4)
x = torch.randn(1, 8, 2, 8) # T=8 > max_len=4
with pytest.raises(RuntimeError):
apply_rope(x, freqs)
# ---------------------------------------------------------------------------
# GQAttention
# ---------------------------------------------------------------------------