From 18cca894ddefd4591b48c6c7554eb83f9cd6ee87 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 20 Apr 2026 08:19:14 -0400 Subject: [PATCH] [fix][rope Every decode token was stuck at position 0, so lost the (n - m) term entirely] --- open_mythos/main.py | 32 ++++-- test_main.py | 24 ++--- tests/test_rope_debug.py | 195 ++++++++++++++++++++++++++++++++++++ training/3b_fine_web_edu.py | 4 +- 4 files changed, 235 insertions(+), 20 deletions(-) create mode 100644 tests/test_rope_debug.py diff --git a/open_mythos/main.py b/open_mythos/main.py index 238eeed..3def4d6 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -147,14 +147,19 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: Args: x -- tensor of shape (B, T, H, head_dim); head_dim must be even - freqs_cis -- precomputed complex frequencies of shape (max_len, head_dim//2) + freqs_cis -- precomputed complex frequencies of shape (T, head_dim//2), + already sliced to exactly the positions being processed + (caller is responsible for correct start_pos offset) Returns: Rotated tensor of the same shape and dtype as x """ xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis[: x.shape[1]].unsqueeze(0).unsqueeze(2) - return torch.view_as_real(xc * freqs_cis).flatten(-2).to(x.dtype) + return ( + torch.view_as_real(xc * freqs_cis.unsqueeze(0).unsqueeze(2)) + .flatten(-2) + .to(x.dtype) + ) # --------------------------------------------------------------------------- @@ -936,6 +941,7 @@ class OpenMythos(nn.Module): input_ids: torch.Tensor, n_loops: Optional[int] = None, kv_cache: Optional[dict] = None, + start_pos: int = 0, ) -> torch.Tensor: """ Forward pass through Prelude → Recurrent Block → Coda. @@ -946,17 +952,21 @@ class OpenMythos(nn.Module): Increase at inference to extrapolate to harder problems. kv_cache -- dict mutated in-place for autoregressive KV caching; pass an empty dict {} and reuse across decode steps + start_pos -- index of the first token in input_ids within the full + sequence; used to select the correct RoPE frequencies + during incremental decoding (0 for prefill, prompt_len + for each subsequent decode step) Returns: Logits of shape (B, T, vocab_size) """ - B, T = input_ids.shape + T = input_ids.shape[1] device = input_ids.device x = self.embed(input_ids) freqs_cis = ( self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis - )[:T] + )[start_pos : start_pos + T] mask = self._causal_mask(T, device) if T > 1 else None for i, layer in enumerate(self.prelude): @@ -1001,9 +1011,17 @@ class OpenMythos(nn.Module): Token indices of shape (B, T + max_new_tokens) """ kv_cache: dict = {} + prompt_len = input_ids.shape[1] for step in range(max_new_tokens): - cur_ids = input_ids if step == 0 else input_ids[:, -1:] - logits = self.forward(cur_ids, n_loops=n_loops, kv_cache=kv_cache) + if step == 0: + cur_ids = input_ids + start_pos = 0 + else: + cur_ids = input_ids[:, -1:] + start_pos = prompt_len + step - 1 + logits = self.forward( + cur_ids, n_loops=n_loops, kv_cache=kv_cache, start_pos=start_pos + ) logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = logits.topk(top_k) diff --git a/test_main.py b/test_main.py index eac767f..c54c462 100644 --- a/test_main.py +++ b/test_main.py @@ -96,20 +96,20 @@ class TestRoPE: def test_apply_rope_shape(self): freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.randn(B, T, 4, 16) - out = apply_rope(x, freqs) + out = apply_rope(x, freqs[:T]) assert out.shape == x.shape def test_apply_rope_preserves_norm(self): # rotation is an isometry — norms must be unchanged freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.randn(B, T, 4, 16) - out = apply_rope(x, freqs) + out = apply_rope(x, freqs[:T]) assert torch.allclose(x.norm(dim=-1), out.norm(dim=-1), atol=1e-5) def test_different_positions_differ(self): freqs = precompute_rope_freqs(dim=16, max_len=32) x = torch.ones(1, 2, 1, 16) - out = apply_rope(x, freqs) + out = apply_rope(x, freqs[:2]) # position 0 and position 1 should produce different rotations assert not torch.allclose(out[0, 0], out[0, 1]) @@ -168,27 +168,27 @@ class TestRoPEExtended: """T=1 input uses only freqs[0] = 1+0j, so output must equal input.""" freqs = precompute_rope_freqs(dim=16, max_len=8) x = torch.randn(2, 1, 4, 16) - out = apply_rope(x, freqs) + out = apply_rope(x, freqs[:1]) assert torch.allclose(x, out, atol=1e-6) def test_dtype_float32_preserved(self): freqs = precompute_rope_freqs(dim=16, max_len=16) x = torch.randn(1, 4, 2, 16).float() - assert apply_rope(x, freqs).dtype == torch.float32 + assert apply_rope(x, freqs[:4]).dtype == torch.float32 def test_dtype_float16_preserved(self): freqs = precompute_rope_freqs(dim=16, max_len=16) x = torch.randn(1, 4, 2, 16).half() - assert apply_rope(x, freqs).dtype == torch.float16 + assert apply_rope(x, freqs[:4]).dtype == torch.float16 def test_inverse_rotation_recovers_input(self): """Rotating by freqs then by conj(freqs) (inverse) must recover the original.""" dim = 16 freqs = precompute_rope_freqs(dim=dim, max_len=8) x = torch.randn(2, 4, 3, dim) - rotated = apply_rope(x, freqs) + rotated = apply_rope(x, freqs[:4]) xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2)) - inv = freqs.conj()[: rotated.shape[1]].unsqueeze(0).unsqueeze(2) + inv = freqs.conj()[:4].unsqueeze(0).unsqueeze(2) recovered = torch.view_as_real(xc * inv).flatten(-2).to(x.dtype) assert torch.allclose(x, recovered, atol=1e-5) @@ -199,8 +199,8 @@ class TestRoPEExtended: torch.manual_seed(7) x_a = torch.randn(1, 4, 2, dim) x_b = torch.randn(1, 4, 2, dim) - solo = apply_rope(x_a, freqs) - batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs)[:1] + solo = apply_rope(x_a, freqs[:4]) + batched = apply_rope(torch.cat([x_a, x_b], dim=0), freqs[:4])[:1] assert torch.allclose(solo, batched, atol=1e-6) def test_head_independence(self): @@ -208,7 +208,7 @@ class TestRoPEExtended: dim = 16 freqs = precompute_rope_freqs(dim=dim, max_len=8) x = torch.randn(1, 4, 1, dim).expand(1, 4, 3, dim).contiguous() - out = apply_rope(x, freqs) + out = apply_rope(x, freqs[:4]) assert torch.allclose(out[:, :, 0], out[:, :, 1], atol=1e-6) assert torch.allclose(out[:, :, 1], out[:, :, 2], atol=1e-6) @@ -227,7 +227,7 @@ class TestRoPEExtended: """Rotate tensor at a specific position by embedding it in a zero sequence.""" seq = torch.zeros(1, pos + 1, 1, dim) seq[0, pos] = tensor[0, 0] - return apply_rope(seq, freqs)[:, pos : pos + 1] + return apply_rope(seq, freqs[: pos + 1])[:, pos : pos + 1] # Both pairs have relative offset n - m = 6 dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum() diff --git a/tests/test_rope_debug.py b/tests/test_rope_debug.py new file mode 100644 index 0000000..32e8487 --- /dev/null +++ b/tests/test_rope_debug.py @@ -0,0 +1,195 @@ +""" +Standalone RoPE debug test — logs tensor outputs and intermediate calculations +so you can visually verify correctness of precompute_rope_freqs and apply_rope. +""" + +import torch +from open_mythos.main import apply_rope, precompute_rope_freqs + +DIM = 8 +MAX_LEN = 6 +THETA = 500000.0 + + +def section(title: str) -> None: + print(f"\n{'=' * 60}") + print(f" {title}") + print("=" * 60) + + +def log(label: str, value) -> None: + print(f"\n[{label}]") + print(value) + + +# --------------------------------------------------------------------------- +# 1. precompute_rope_freqs — raw frequency table +# --------------------------------------------------------------------------- + +section("1. precompute_rope_freqs") + +freqs = precompute_rope_freqs(dim=DIM, max_len=MAX_LEN, theta=THETA) +log("freqs shape", freqs.shape) +log("freqs (complex)", freqs) +log("freqs.real", freqs.real) +log("freqs.imag", freqs.imag) +log("freqs magnitude (should be all 1.0)", freqs.abs()) +log("freqs angle (radians)", freqs.angle()) + +print("\nExpected: base frequencies (1 per dim pair)") +base = 1.0 / (THETA ** (torch.arange(0, DIM, 2, dtype=torch.float32) / DIM)) +log("base freqs", base) + +print("\nExpected: freqs[t, k].angle() == t * base[k]") +for t in range(MAX_LEN): + expected_angles = t * base + actual_angles = freqs[t].angle() + match = torch.allclose(actual_angles, expected_angles, atol=1e-5) + print(f" t={t}: angles match = {match} actual={actual_angles.tolist()}") + +# --------------------------------------------------------------------------- +# 2. Position 0 is identity (freqs[0] == 1+0j) +# --------------------------------------------------------------------------- + +section("2. freqs[0] is identity phasor (1+0j)") +log("freqs[0]", freqs[0]) +print(f" All magnitude=1: {torch.allclose(freqs[0].abs(), torch.ones(DIM // 2))}") +print(f" All angle=0: {torch.allclose(freqs[0].angle(), torch.zeros(DIM // 2))}") + +# --------------------------------------------------------------------------- +# 3. apply_rope — shape and dtype +# --------------------------------------------------------------------------- + +section("3. apply_rope — shape and dtype") + +B, T, H = 2, MAX_LEN, 3 +x = torch.randn(B, T, H, DIM) +log("input x shape", x.shape) + +out = apply_rope(x, freqs) +log("output shape", out.shape) +print(f" Shape preserved: {out.shape == x.shape}") + +# dtype float16 +x_half = x.half() +out_half = apply_rope(x_half, freqs) +print(f" float16 dtype preserved: {out_half.dtype == torch.float16}") + +# --------------------------------------------------------------------------- +# 4. apply_rope — isometry (norm preservation) +# --------------------------------------------------------------------------- + +section("4. apply_rope — norm preservation (isometry)") + +norms_in = x.norm(dim=-1) +norms_out = out.norm(dim=-1) +log("input norms (first batch item)", norms_in[0]) +log("output norms (first batch item)", norms_out[0]) +print( + f" Max absolute norm difference: {(norms_in - norms_out).abs().max().item():.2e}" +) +print( + f" Norms preserved (atol=1e-5): {torch.allclose(norms_in, norms_out, atol=1e-5)}" +) + +# --------------------------------------------------------------------------- +# 5. Position 0 is the identity transformation +# --------------------------------------------------------------------------- + +section("5. Position 0 is identity transformation") + +x1 = torch.randn(1, 1, 2, DIM) +out1 = apply_rope(x1, freqs[:1]) +log("input x[:,0]", x1[0, 0]) +log("output x[:,0]", out1[0, 0]) +log("diff (should be ~0)", (x1 - out1).abs()) +print(f" Identity at pos 0: {torch.allclose(x1, out1, atol=1e-6)}") + +# --------------------------------------------------------------------------- +# 6. Different positions produce different rotations +# --------------------------------------------------------------------------- + +section("6. Different positions produce different rotations") + +x2 = torch.ones(1, MAX_LEN, 1, DIM) +out2 = apply_rope(x2, freqs) +print("Per-position outputs (all input=1.0):") +for t in range(MAX_LEN): + print(f" pos={t}: {out2[0, t, 0].tolist()}") + +# --------------------------------------------------------------------------- +# 7. Inverse rotation recovers original +# --------------------------------------------------------------------------- + +section("7. Inverse rotation recovers original") + +x3 = torch.randn(2, T, H, DIM) +rotated = apply_rope(x3, freqs) +xc = torch.view_as_complex(rotated.float().reshape(*rotated.shape[:-1], -1, 2)) +inv_freqs = freqs.conj() +recovered = ( + torch.view_as_real(xc * inv_freqs.unsqueeze(0).unsqueeze(2)) + .flatten(-2) + .to(x3.dtype) +) +diff = (x3 - recovered).abs() +log("max recovery error", diff.max().item()) +print(f" Recovery succeeded (atol=1e-5): {torch.allclose(x3, recovered, atol=1e-5)}") + +# --------------------------------------------------------------------------- +# 8. start_pos correctness — the core generation bug fix +# --------------------------------------------------------------------------- + +section("8. start_pos correctness (generation bug)") + +print( + "Simulates: prompt of length 4, then decode step receives token at position 4.\n" + "Old buggy code: freqs[:1] → position 0 always.\n" + "Fixed code: freqs[4:5] → position 4." +) + +prompt_len = 4 +decode_token = torch.randn(1, 1, 1, DIM) + +out_buggy = apply_rope(decode_token, freqs[:1]) # wrong: always pos 0 +out_fixed = apply_rope( + decode_token, freqs[prompt_len : prompt_len + 1] +) # correct: pos 4 + +log("freqs[0] (pos 0)", freqs[0]) +log(f"freqs[{prompt_len}] (pos {prompt_len})", freqs[prompt_len]) +log("buggy output (pos 0 encoding)", out_buggy[0, 0, 0]) +log("fixed output (pos 4 encoding)", out_fixed[0, 0, 0]) +print(f" Outputs differ (they should): {not torch.allclose(out_buggy, out_fixed)}") + +# --------------------------------------------------------------------------- +# 9. Relative position property +# --------------------------------------------------------------------------- + +section("9. Relative position property: depends only on (n-m)") + +dim = 16 +max_len = 32 +freqs_big = precompute_rope_freqs(dim=dim, max_len=max_len, theta=THETA) +torch.manual_seed(42) +q = torch.randn(1, 1, 1, dim) +k = torch.randn(1, 1, 1, dim) + + +def rope_at(tensor, pos): + seq = torch.zeros(1, pos + 1, 1, dim) + seq[0, pos] = tensor[0, 0] + return apply_rope(seq, freqs_big[: pos + 1])[:, pos : pos + 1] + + +dot_3_9 = (rope_at(q, 3) * rope_at(k, 9)).sum() +dot_1_7 = (rope_at(q, 1) * rope_at(k, 7)).sum() +dot_0_6 = (rope_at(q, 0) * rope_at(k, 6)).sum() +print(f" : {dot_3_9.item():.6f}") +print(f" : {dot_1_7.item():.6f}") +print(f" : {dot_0_6.item():.6f}") +print( + f" All equal (offset=6): {torch.allclose(dot_3_9, dot_1_7, atol=1e-5) and torch.allclose(dot_1_7, dot_0_6, atol=1e-5)}" +) + +section("DONE — all checks complete") diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index 215381d..f9f20b8 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -113,7 +113,9 @@ def main(): master = rank == 0 if master: - print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}") + print( + f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}" + ) # ------------------------------------------------------------------ # Tokenizer