mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
disclaimer and rope tests
This commit is contained in:
parent
14a806c2a1
commit
53f786afda
@ -1,5 +1,7 @@
|
||||
# OpenMythos
|
||||
|
||||
> **Disclaimer:** OpenMythos is an independent, community-driven theoretical reconstruction based solely on publicly available research and speculation. It is not affiliated with, endorsed by, or connected to Anthropic or any of their proprietary systems.
|
||||
|
||||
OpenMythos is an open-source, theoretical implementation of the Claude Mythos model. It implements a Recurrent-Depth Transformer (RDT) with three stages: **Prelude** (transformer blocks), a looped **Recurrent Block** (up to `max_loop_iters`), and a final **Coda**. Attention is switchable between MLA and GQA, and the feed-forward uses a sparse MoE with routed and shared experts ideal for exploring compute-adaptive, depth-variable reasoning.
|
||||
|
||||
|
||||
|
||||
133
test_main.py
133
test_main.py
@ -121,6 +121,139 @@ class TestRoPE:
|
||||
assert not torch.allclose(out[0, 0], out[0, 1])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE extended — correctness invariants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoPEExtended:
|
||||
"""Comprehensive correctness tests for precompute_rope_freqs and apply_rope."""
|
||||
|
||||
# --- precompute_rope_freqs ---
|
||||
|
||||
def test_position_zero_is_unit_phasor(self):
|
||||
"""freqs[0] must be all 1+0j (angle = 0 * freq = 0 for every pair)."""
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=8)
|
||||
expected = torch.ones(8, dtype=torch.complex64)
|
||||
assert torch.allclose(freqs[0], expected, atol=1e-6)
|
||||
|
||||
def test_all_phasors_have_unit_magnitude(self):
|
||||
"""Every phasor magnitude must be 1 — RoPE is an isometric rotation."""
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=32)
|
||||
assert torch.allclose(freqs.abs(), torch.ones_like(freqs.abs()), atol=1e-6)
|
||||
|
||||
def test_angles_equal_outer_product(self):
|
||||
"""freqs[t, k].angle() must equal t × base_freq[k] for all t, k."""
|
||||
dim, max_len, theta = 8, 6, 500000.0
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=max_len, theta=theta)
|
||||
base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
t = torch.arange(max_len, dtype=torch.float32)
|
||||
expected = torch.polar(torch.ones(max_len, dim // 2), torch.outer(t, base))
|
||||
assert torch.allclose(freqs.real, expected.real, atol=1e-6)
|
||||
assert torch.allclose(freqs.imag, expected.imag, atol=1e-6)
|
||||
|
||||
def test_higher_theta_produces_smaller_angles(self):
|
||||
"""Larger theta → slower frequency decay → smaller rotation angle per step."""
|
||||
dim, max_len = 16, 8
|
||||
freqs_fast = precompute_rope_freqs(dim=dim, max_len=max_len, theta=100.0)
|
||||
freqs_slow = precompute_rope_freqs(dim=dim, max_len=max_len, theta=500000.0)
|
||||
assert (freqs_fast[1].angle().abs() > freqs_slow[1].angle().abs()).all()
|
||||
|
||||
def test_default_theta_matches_explicit(self):
|
||||
"""Omitting theta must equal passing theta=500000.0."""
|
||||
f1 = precompute_rope_freqs(16, 8)
|
||||
f2 = precompute_rope_freqs(16, 8, theta=500000.0)
|
||||
assert torch.allclose(f1.real, f2.real) and torch.allclose(f1.imag, f2.imag)
|
||||
|
||||
# --- apply_rope ---
|
||||
|
||||
def test_position_zero_is_identity(self):
|
||||
"""T=1 input uses only freqs[0] = 1+0j, so output must equal input."""
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=8)
|
||||
x = torch.randn(2, 1, 4, 16)
|
||||
out = apply_rope(x, freqs)
|
||||
assert torch.allclose(x, out, atol=1e-6)
|
||||
|
||||
def test_dtype_float32_preserved(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=16)
|
||||
x = torch.randn(1, 4, 2, 16).float()
|
||||
assert apply_rope(x, freqs).dtype == torch.float32
|
||||
|
||||
def test_dtype_float16_preserved(self):
|
||||
freqs = precompute_rope_freqs(dim=16, max_len=16)
|
||||
x = torch.randn(1, 4, 2, 16).half()
|
||||
assert apply_rope(x, freqs).dtype == torch.float16
|
||||
|
||||
def test_inverse_rotation_recovers_input(self):
|
||||
"""Rotating by freqs then by conj(freqs) (inverse) must recover the original."""
|
||||
dim = 16
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=8)
|
||||
x = torch.randn(2, 4, 3, dim)
|
||||
rotated = apply_rope(x, freqs)
|
||||
xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2))
|
||||
inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2)
|
||||
recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype)
|
||||
assert torch.allclose(x, recovered, atol=1e-5)
|
||||
|
||||
def test_batch_independence(self):
|
||||
"""Output for one batch item must not depend on other items in the batch."""
|
||||
dim = 16
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=16)
|
||||
torch.manual_seed(7)
|
||||
x_a = torch.randn(1, 4, 2, dim)
|
||||
x_b = torch.randn(1, 4, 2, dim)
|
||||
solo = apply_rope(x_a, freqs)
|
||||
batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1]
|
||||
assert torch.allclose(solo, batched, atol=1e-6)
|
||||
|
||||
def test_head_independence(self):
|
||||
"""All heads at the same position must receive identical rotations."""
|
||||
dim = 16
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=8)
|
||||
x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous()
|
||||
out = apply_rope(x, freqs)
|
||||
assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6)
|
||||
assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6)
|
||||
|
||||
def test_relative_position_property(self):
|
||||
"""
|
||||
Core RoPE invariant: <RoPE(q,m), RoPE(k,n)> depends only on (n-m).
|
||||
Two pairs with the same offset must produce the same dot product.
|
||||
"""
|
||||
dim, max_len = 16, 32
|
||||
freqs = precompute_rope_freqs(dim=dim, max_len=max_len)
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(1, 1, 1, dim)
|
||||
k = torch.randn(1, 1, 1, dim)
|
||||
|
||||
def rope_at(tensor, pos):
|
||||
"""Rotate tensor at a specific position by embedding it in a zero sequence."""
|
||||
seq = torch.zeros(1, pos + 1, 1, dim)
|
||||
seq[0, pos] = tensor[0, 0]
|
||||
return apply_rope(seq, freqs)[:, pos : pos + 1]
|
||||
|
||||
# Both pairs have relative offset n - m = 6
|
||||
dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum()
|
||||
dot_1_7 = (rope_at(q, 1) * rope_at(k, 7)).sum()
|
||||
assert torch.allclose(dot_3_9, dot_1_7, atol=1e-5)
|
||||
|
||||
def test_max_len_boundary(self):
|
||||
"""apply_rope must handle T == max_len without error or NaN."""
|
||||
max_len = 10
|
||||
freqs = precompute_rope_freqs(dim=8, max_len=max_len)
|
||||
x = torch.randn(1, max_len, 2, 8)
|
||||
out = apply_rope(x, freqs)
|
||||
assert out.shape == x.shape
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
def test_exceeds_max_len_raises(self):
|
||||
"""apply_rope must raise RuntimeError when T > max_len."""
|
||||
freqs = precompute_rope_freqs(dim=8, max_len=4)
|
||||
x = torch.randn(1, 8, 2, 8) # T=8 > max_len=4
|
||||
with pytest.raises(RuntimeError):
|
||||
apply_rope(x, freqs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GQAttention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user