diff --git a/open_mythos/main.py b/open_mythos/main.py index 3def4d6..10de093 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -477,10 +477,14 @@ class MoEFFN(nn.Module): B, T, D = x.shape flat = x.view(B * T, D) - # router — bias shifts logits for load balancing without touching loss - logits = self.router(flat) + self.router_bias # (B*T, n_experts) + # Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the + # selection of which experts fire so underused experts are picked more, + # but the gating weights come from unbiased softmax scores so the bias + # never shows up in the gradient. + logits = self.router(flat) # (B*T, n_experts), unbiased scores = F.softmax(logits, dim=-1) - topk_scores, topk_idx = scores.topk(self.topk, dim=-1) + _, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1) + topk_scores = scores.gather(-1, topk_idx) topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm # routed expert dispatch (token-level scatter) @@ -577,7 +581,12 @@ class LoRAAdapter(nn.Module): Returns: Delta tensor of shape (B, T, dim) to be added to the block output """ - s = self.scale(torch.tensor(loop_t, device=x.device)) # (rank,) + # Clamp for depth extrapolation: at inference n_loops can exceed the + # training max_loop_iters. Iterations beyond the trained range reuse + # the last learned per-loop scale rather than indexing out of range. + max_t = self.scale.num_embeddings - 1 + t_idx = loop_t if loop_t <= max_t else max_t + s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,) down = self.down(x) * s # (B, T, rank) return down @ self.B # (B, T, dim) @@ -829,19 +838,26 @@ class RecurrentBlock(nn.Module): still_running = ~halted # ACT remainder trick: once cumulative_p + p crosses threshold, - # assign the remaining probability mass as the final weight + # assign the remaining probability mass as the final weight. + # Gate by still_running so halted positions contribute exactly + # once (on the halting step) and zero thereafter — otherwise + # threshold<1 leaves a non-zero remainder that leaks every step. remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where( cumulative_p + p >= self.cfg.act_threshold, remainder, p, ) + weight = weight * still_running.float() h_out = h_out + weight.unsqueeze(-1) * h cumulative_p = cumulative_p + p * still_running.float() halted = halted | (cumulative_p >= self.cfg.act_threshold) - if halted.all(): + # Only short-circuit when there is no KV cache to keep consistent. + # With a cache, every loop depth must run on every forward pass so + # later decode steps find populated keys at every cache_key. + if halted.all() and kv_cache is None: break return h_out diff --git a/pyproject.toml b/pyproject.toml index 5ff5a74..1d9f720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "open-mythos" -version = "0.3.0" +version = "0.4.0" description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture" license = "MIT" authors = ["Kye Gomez "]