[fix][rope Every decode token was stuck at position 0, so <q_decoded, k_cached> lost the (n - m) term entirely]

This commit is contained in:
Kye Gomez 2026-04-20 08:19:14 -04:00
parent 537b116b3e
commit 18cca894dd
4 changed files with 235 additions and 20 deletions

View File

@ -147,14 +147,19 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
Args: Args:
x -- tensor of shape (B, T, H, head_dim); head_dim must be even x -- tensor of shape (B, T, H, head_dim); head_dim must be even
freqs_cis -- precomputed complex frequencies of shape (max_len, head_dim//2) freqs_cis -- precomputed complex frequencies of shape (T, head_dim//2),
already sliced to exactly the positions being processed
(caller is responsible for correct start_pos offset)
Returns: Returns:
Rotated tensor of the same shape and dtype as x Rotated tensor of the same shape and dtype as x
""" """
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis[: x.shape[1]].unsqueeze(0).unsqueeze(2) return (
return torch.view_as_real(xc * freqs_cis).flatten(-2).to(x.dtype) torch.view_as_real(xc * freqs_cis.unsqueeze(0).unsqueeze(2))
.flatten(-2)
.to(x.dtype)
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -936,6 +941,7 @@ class OpenMythos(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
n_loops: Optional[int] = None, n_loops: Optional[int] = None,
kv_cache: Optional[dict] = None, kv_cache: Optional[dict] = None,
start_pos: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass through Prelude Recurrent Block Coda. Forward pass through Prelude Recurrent Block Coda.
@ -946,17 +952,21 @@ class OpenMythos(nn.Module):
Increase at inference to extrapolate to harder problems. Increase at inference to extrapolate to harder problems.
kv_cache -- dict mutated in-place for autoregressive KV caching; kv_cache -- dict mutated in-place for autoregressive KV caching;
pass an empty dict {} and reuse across decode steps pass an empty dict {} and reuse across decode steps
start_pos -- index of the first token in input_ids within the full
sequence; used to select the correct RoPE frequencies
during incremental decoding (0 for prefill, prompt_len
for each subsequent decode step)
Returns: Returns:
Logits of shape (B, T, vocab_size) Logits of shape (B, T, vocab_size)
""" """
B, T = input_ids.shape T = input_ids.shape[1]
device = input_ids.device device = input_ids.device
x = self.embed(input_ids) x = self.embed(input_ids)
freqs_cis = ( freqs_cis = (
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
)[:T] )[start_pos : start_pos + T]
mask = self._causal_mask(T, device) if T > 1 else None mask = self._causal_mask(T, device) if T > 1 else None
for i, layer in enumerate(self.prelude): for i, layer in enumerate(self.prelude):
@ -1001,9 +1011,17 @@ class OpenMythos(nn.Module):
Token indices of shape (B, T + max_new_tokens) Token indices of shape (B, T + max_new_tokens)
""" """
kv_cache: dict = {} kv_cache: dict = {}
prompt_len = input_ids.shape[1]
for step in range(max_new_tokens): for step in range(max_new_tokens):
cur_ids = input_ids if step == 0 else input_ids[:, -1:] if step == 0:
logits = self.forward(cur_ids, n_loops=n_loops, kv_cache=kv_cache) cur_ids = input_ids
start_pos = 0
else:
cur_ids = input_ids[:, -1:]
start_pos = prompt_len + step - 1
logits = self.forward(
cur_ids, n_loops=n_loops, kv_cache=kv_cache, start_pos=start_pos
)
logits = logits[:, -1, :] / temperature logits = logits[:, -1, :] / temperature
if top_k > 0: if top_k > 0:
v, _ = logits.topk(top_k) v, _ = logits.topk(top_k)

View File

@ -96,20 +96,20 @@ class TestRoPE:
def test_apply_rope_shape(self): def test_apply_rope_shape(self):
freqs = precompute_rope_freqs(dim=16, max_len=32) freqs = precompute_rope_freqs(dim=16, max_len=32)
x = torch.randn(B, T, 4, 16) x = torch.randn(B, T, 4, 16)
out = apply_rope(x, freqs) out = apply_rope(x, freqs[:T])
assert out.shape == x.shape assert out.shape == x.shape
def test_apply_rope_preserves_norm(self): def test_apply_rope_preserves_norm(self):
# rotation is an isometry — norms must be unchanged # rotation is an isometry — norms must be unchanged
freqs = precompute_rope_freqs(dim=16, max_len=32) freqs = precompute_rope_freqs(dim=16, max_len=32)
x = torch.randn(B, T, 4, 16) x = torch.randn(B, T, 4, 16)
out = apply_rope(x, freqs) out = apply_rope(x, freqs[:T])
assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5) assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5)
def test_different_positions_differ(self): def test_different_positions_differ(self):
freqs = precompute_rope_freqs(dim=16, max_len=32) freqs = precompute_rope_freqs(dim=16, max_len=32)
x = torch.ones(1, 2, 1, 16) x = torch.ones(1, 2, 1, 16)
out = apply_rope(x, freqs) out = apply_rope(x, freqs[:2])
# position 0 and position 1 should produce different rotations # position 0 and position 1 should produce different rotations
assert not torch.allclose(out[0, 0], out[0, 1]) assert not torch.allclose(out[0, 0], out[0, 1])
@ -168,27 +168,27 @@ class TestRoPEExtended:
"""T=1 input uses only freqs[0] = 1+0j, so output must equal input.""" """T=1 input uses only freqs[0] = 1+0j, so output must equal input."""
freqs = precompute_rope_freqs(dim=16, max_len=8) freqs = precompute_rope_freqs(dim=16, max_len=8)
x = torch.randn(2, 1, 4, 16) x = torch.randn(2, 1, 4, 16)
out = apply_rope(x, freqs) out = apply_rope(x, freqs[:1])
assert torch.allclose(x, out, atol=1e-6) assert torch.allclose(x, out, atol=1e-6)
def test_dtype_float32_preserved(self): def test_dtype_float32_preserved(self):
freqs = precompute_rope_freqs(dim=16, max_len=16) freqs = precompute_rope_freqs(dim=16, max_len=16)
x = torch.randn(1, 4, 2, 16).float() x = torch.randn(1, 4, 2, 16).float()
assert apply_rope(x, freqs).dtype == torch.float32 assert apply_rope(x, freqs[:4]).dtype == torch.float32
def test_dtype_float16_preserved(self): def test_dtype_float16_preserved(self):
freqs = precompute_rope_freqs(dim=16, max_len=16) freqs = precompute_rope_freqs(dim=16, max_len=16)
x = torch.randn(1, 4, 2, 16).half() x = torch.randn(1, 4, 2, 16).half()
assert apply_rope(x, freqs).dtype == torch.float16 assert apply_rope(x, freqs[:4]).dtype == torch.float16
def test_inverse_rotation_recovers_input(self): def test_inverse_rotation_recovers_input(self):
"""Rotating by freqs then by conj(freqs) (inverse) must recover the original.""" """Rotating by freqs then by conj(freqs) (inverse) must recover the original."""
dim = 16 dim = 16
freqs = precompute_rope_freqs(dim=dim, max_len=8) freqs = precompute_rope_freqs(dim=dim, max_len=8)
x = torch.randn(2, 4, 3, dim) x = torch.randn(2, 4, 3, dim)
rotated = apply_rope(x, freqs) rotated = apply_rope(x, freqs[:4])
xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2)) xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2))
inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2) inv = freqs.conj()[:4].unsqueeze(0).unsqueeze(2)
recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype) recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype)
assert torch.allclose(x, recovered, atol=1e-5) assert torch.allclose(x, recovered, atol=1e-5)
@ -199,8 +199,8 @@ class TestRoPEExtended:
torch.manual_seed(7) torch.manual_seed(7)
x_a = torch.randn(1, 4, 2, dim) x_a = torch.randn(1, 4, 2, dim)
x_b = torch.randn(1, 4, 2, dim) x_b = torch.randn(1, 4, 2, dim)
solo = apply_rope(x_a, freqs) solo = apply_rope(x_a, freqs[:4])
batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1] batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs[:4])[:1]
assert torch.allclose(solo, batched, atol=1e-6) assert torch.allclose(solo, batched, atol=1e-6)
def test_head_independence(self): def test_head_independence(self):
@ -208,7 +208,7 @@ class TestRoPEExtended:
dim = 16 dim = 16
freqs = precompute_rope_freqs(dim=dim, max_len=8) freqs = precompute_rope_freqs(dim=dim, max_len=8)
x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous() x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous()
out = apply_rope(x, freqs) out = apply_rope(x, freqs[:4])
assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6) assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6)
assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6) assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6)
@ -227,7 +227,7 @@ class TestRoPEExtended:
"""Rotate tensor at a specific position by embedding it in a zero sequence.""" """Rotate tensor at a specific position by embedding it in a zero sequence."""
seq = torch.zeros(1, pos + 1, 1, dim) seq = torch.zeros(1, pos + 1, 1, dim)
seq[0, pos] = tensor[0, 0] seq[0, pos] = tensor[0, 0]
return apply_rope(seq, freqs)[:, pos : pos + 1] return apply_rope(seq, freqs[: pos + 1])[:, pos : pos + 1]
# Both pairs have relative offset n - m = 6 # Both pairs have relative offset n - m = 6
dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum() dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum()

195
tests/test_rope_debug.py Normal file
View File

@ -0,0 +1,195 @@
"""
Standalone RoPE debug test logs tensor outputs and intermediate calculations
so you can visually verify correctness of precompute_rope_freqs and apply_rope.
"""
import torch
from open_mythos.main import apply_rope, precompute_rope_freqs
DIM = 8
MAX_LEN = 6
THETA = 500000.0
def section(title: str) -> None:
print(f"\n{'=' * 60}")
print(f" {title}")
print("=" * 60)
def log(label: str, value) -> None:
print(f"\n[{label}]")
print(value)
# ---------------------------------------------------------------------------
# 1. precompute_rope_freqs — raw frequency table
# ---------------------------------------------------------------------------
section("1. precompute_rope_freqs")
freqs = precompute_rope_freqs(dim=DIM, max_len=MAX_LEN, theta=THETA)
log("freqs shape", freqs.shape)
log("freqs (complex)", freqs)
log("freqs.real", freqs.real)
log("freqs.imag", freqs.imag)
log("freqs magnitude (should be all 1.0)", freqs.abs())
log("freqs angle (radians)", freqs.angle())
print("\nExpected: base frequencies (1 per dim pair)")
base = 1.0 / (THETA ** (torch.arange(0, DIM, 2, dtype=torch.float32) / DIM))
log("base freqs", base)
print("\nExpected: freqs[t, k].angle() == t * base[k]")
for t in range(MAX_LEN):
expected_angles = t * base
actual_angles = freqs[t].angle()
match = torch.allclose(actual_angles, expected_angles, atol=1e-5)
print(f" t={t}: angles match = {match} actual={actual_angles.tolist()}")
# ---------------------------------------------------------------------------
# 2. Position 0 is identity (freqs[0] == 1+0j)
# ---------------------------------------------------------------------------
section("2. freqs[0] is identity phasor (1+0j)")
log("freqs[0]", freqs[0])
print(f" All magnitude=1: {torch.allclose(freqs[0].abs(), torch.ones(DIM // 2))}")
print(f" All angle=0: {torch.allclose(freqs[0].angle(), torch.zeros(DIM // 2))}")
# ---------------------------------------------------------------------------
# 3. apply_rope — shape and dtype
# ---------------------------------------------------------------------------
section("3. apply_rope — shape and dtype")
B, T, H = 2, MAX_LEN, 3
x = torch.randn(B, T, H, DIM)
log("input x shape", x.shape)
out = apply_rope(x, freqs)
log("output shape", out.shape)
print(f" Shape preserved: {out.shape == x.shape}")
# dtype float16
x_half = x.half()
out_half = apply_rope(x_half, freqs)
print(f" float16 dtype preserved: {out_half.dtype == torch.float16}")
# ---------------------------------------------------------------------------
# 4. apply_rope — isometry (norm preservation)
# ---------------------------------------------------------------------------
section("4. apply_rope — norm preservation (isometry)")
norms_in = x.norm(dim=-1)
norms_out = out.norm(dim=-1)
log("input norms (first batch item)", norms_in[0])
log("output norms (first batch item)", norms_out[0])
print(
f" Max absolute norm difference: {(norms_in - norms_out).abs().max().item():.2e}"
)
print(
f" Norms preserved (atol=1e-5): {torch.allclose(norms_in, norms_out, atol=1e-5)}"
)
# ---------------------------------------------------------------------------
# 5. Position 0 is the identity transformation
# ---------------------------------------------------------------------------
section("5. Position 0 is identity transformation")
x1 = torch.randn(1, 1, 2, DIM)
out1 = apply_rope(x1, freqs[:1])
log("input x[:,0]", x1[0, 0])
log("output x[:,0]", out1[0, 0])
log("diff (should be ~0)", (x1 - out1).abs())
print(f" Identity at pos 0: {torch.allclose(x1, out1, atol=1e-6)}")
# ---------------------------------------------------------------------------
# 6. Different positions produce different rotations
# ---------------------------------------------------------------------------
section("6. Different positions produce different rotations")
x2 = torch.ones(1, MAX_LEN, 1, DIM)
out2 = apply_rope(x2, freqs)
print("Per-position outputs (all input=1.0):")
for t in range(MAX_LEN):
print(f" pos={t}: {out2[0, t, 0].tolist()}")
# ---------------------------------------------------------------------------
# 7. Inverse rotation recovers original
# ---------------------------------------------------------------------------
section("7. Inverse rotation recovers original")
x3 = torch.randn(2, T, H, DIM)
rotated = apply_rope(x3, freqs)
xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2))
inv_freqs = freqs.conj()
recovered = (
torch.view_as_real(xc * inv_freqs.unsqueeze(0).unsqueeze(2))
.flatten(-2)
.to(x3.dtype)
)
diff = (x3 - recovered).abs()
log("max recovery error", diff.max().item())
print(f" Recovery succeeded (atol=1e-5): {torch.allclose(x3, recovered, atol=1e-5)}")
# ---------------------------------------------------------------------------
# 8. start_pos correctness — the core generation bug fix
# ---------------------------------------------------------------------------
section("8. start_pos correctness (generation bug)")
print(
"Simulates: prompt of length 4, then decode step receives token at position 4.\n"
"Old buggy code: freqs[:1] → position 0 always.\n"
"Fixed code: freqs[4:5] → position 4."
)
prompt_len = 4
decode_token = torch.randn(1, 1, 1, DIM)
out_buggy = apply_rope(decode_token, freqs[:1]) # wrong: always pos 0
out_fixed = apply_rope(
decode_token, freqs[prompt_len : prompt_len + 1]
) # correct: pos 4
log("freqs[0] (pos 0)", freqs[0])
log(f"freqs[{prompt_len}] (pos {prompt_len})", freqs[prompt_len])
log("buggy output (pos 0 encoding)", out_buggy[0, 0, 0])
log("fixed output (pos 4 encoding)", out_fixed[0, 0, 0])
print(f" Outputs differ (they should): {not torch.allclose(out_buggy, out_fixed)}")
# ---------------------------------------------------------------------------
# 9. Relative position property
# ---------------------------------------------------------------------------
section("9. Relative position property: <RoPE(q,m), RoPE(k,n)> depends only on (n-m)")
dim = 16
max_len = 32
freqs_big = precompute_rope_freqs(dim=dim, max_len=max_len, theta=THETA)
torch.manual_seed(42)
q = torch.randn(1, 1, 1, dim)
k = torch.randn(1, 1, 1, dim)
def rope_at(tensor, pos):
seq = torch.zeros(1, pos + 1, 1, dim)
seq[0, pos] = tensor[0, 0]
return apply_rope(seq, freqs_big[: pos + 1])[:, pos : pos + 1]
dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum()
dot_1_7 = (rope_at(q, 1) * rope_at(k, 7)).sum()
dot_0_6 = (rope_at(q, 0) * rope_at(k, 6)).sum()
print(f" <RoPE(q,3), RoPE(k,9)>: {dot_3_9.item():.6f}")
print(f" <RoPE(q,1), RoPE(k,7)>: {dot_1_7.item():.6f}")
print(f" <RoPE(q,0), RoPE(k,6)>: {dot_0_6.item():.6f}")
print(
f" All equal (offset=6): {torch.allclose(dot_3_9, dot_1_7, atol=1e-5) and torch.allclose(dot_1_7, dot_0_6, atol=1e-5)}"
)
section("DONE — all checks complete")

View File

@ -113,7 +113,9 @@ def main():
master = rank == 0 master = rank == 0
if master: if master:
print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}") print(
f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Tokenizer # Tokenizer