[bugf][act-halting][gate halted positions from weight accumulation][bugf][moe-router-bias][stop

load balance bias gradient leak][bugf][act-cache-consistency][keep all loops with kv
  cache][bugf][lora-depth-extrapolation][clamp scale index beyond max
  loops][improvement][pyproject-version][bump version to 0 4 0]
This commit is contained in:
Kye Gomez 2026-04-20 09:17:43 -04:00
parent 7ba690797b
commit 289981ba01
2 changed files with 23 additions and 7 deletions

View File

@ -477,10 +477,14 @@ class MoEFFN(nn.Module):
B, T, D = x.shape B, T, D = x.shape
flat = x.view(B * T, D) flat = x.view(B * T, D)
# router — bias shifts logits for load balancing without touching loss # Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the
logits = self.router(flat) + self.router_bias # (B*T, n_experts) # selection of which experts fire so underused experts are picked more,
# but the gating weights come from unbiased softmax scores so the bias
# never shows up in the gradient.
logits = self.router(flat) # (B*T, n_experts), unbiased
scores = F.softmax(logits, dim=-1) scores = F.softmax(logits, dim=-1)
topk_scores, topk_idx = scores.topk(self.topk, dim=-1) _, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1)
topk_scores = scores.gather(-1, topk_idx)
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm
# routed expert dispatch (token-level scatter) # routed expert dispatch (token-level scatter)
@ -577,7 +581,12 @@ class LoRAAdapter(nn.Module):
Returns: Returns:
Delta tensor of shape (B, T, dim) to be added to the block output Delta tensor of shape (B, T, dim) to be added to the block output
""" """
s = self.scale(torch.tensor(loop_t, device=x.device)) # (rank,) # Clamp for depth extrapolation: at inference n_loops can exceed the
# training max_loop_iters. Iterations beyond the trained range reuse
# the last learned per-loop scale rather than indexing out of range.
max_t = self.scale.num_embeddings - 1
t_idx = loop_t if loop_t <= max_t else max_t
s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,)
down = self.down(x) * s # (B, T, rank) down = self.down(x) * s # (B, T, rank)
return down @ self.B # (B, T, dim) return down @ self.B # (B, T, dim)
@ -829,19 +838,26 @@ class RecurrentBlock(nn.Module):
still_running = ~halted still_running = ~halted
# ACT remainder trick: once cumulative_p + p crosses threshold, # ACT remainder trick: once cumulative_p + p crosses threshold,
# assign the remaining probability mass as the final weight # assign the remaining probability mass as the final weight.
# Gate by still_running so halted positions contribute exactly
# once (on the halting step) and zero thereafter — otherwise
# threshold<1 leaves a non-zero remainder that leaks every step.
remainder = (1.0 - cumulative_p).clamp(min=0) remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where( weight = torch.where(
cumulative_p + p >= self.cfg.act_threshold, cumulative_p + p >= self.cfg.act_threshold,
remainder, remainder,
p, p,
) )
weight = weight * still_running.float()
h_out = h_out + weight.unsqueeze(-1) * h h_out = h_out + weight.unsqueeze(-1) * h
cumulative_p = cumulative_p + p * still_running.float() cumulative_p = cumulative_p + p * still_running.float()
halted = halted | (cumulative_p >= self.cfg.act_threshold) halted = halted | (cumulative_p >= self.cfg.act_threshold)
if halted.all(): # Only short-circuit when there is no KV cache to keep consistent.
# With a cache, every loop depth must run on every forward pass so
# later decode steps find populated keys at every cache_key.
if halted.all() and kv_cache is None:
break break
return h_out return h_out

View File

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "open-mythos" name = "open-mythos"
version = "0.3.0" version = "0.4.0"
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture" description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@swarms.world>"] authors = ["Kye Gomez <kye@swarms.world>"]