mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
[bugf][act-halting][gate halted positions from weight accumulation][bugf][moe-router-bias][stop
load balance bias gradient leak][bugf][act-cache-consistency][keep all loops with kv cache][bugf][lora-depth-extrapolation][clamp scale index beyond max loops][improvement][pyproject-version][bump version to 0 4 0]
This commit is contained in:
parent
7ba690797b
commit
289981ba01
@ -477,10 +477,14 @@ class MoEFFN(nn.Module):
|
|||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
flat = x.view(B * T, D)
|
flat = x.view(B * T, D)
|
||||||
|
|
||||||
# router — bias shifts logits for load balancing without touching loss
|
# Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the
|
||||||
logits = self.router(flat) + self.router_bias # (B*T, n_experts)
|
# selection of which experts fire so underused experts are picked more,
|
||||||
|
# but the gating weights come from unbiased softmax scores so the bias
|
||||||
|
# never shows up in the gradient.
|
||||||
|
logits = self.router(flat) # (B*T, n_experts), unbiased
|
||||||
scores = F.softmax(logits, dim=-1)
|
scores = F.softmax(logits, dim=-1)
|
||||||
topk_scores, topk_idx = scores.topk(self.topk, dim=-1)
|
_, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1)
|
||||||
|
topk_scores = scores.gather(-1, topk_idx)
|
||||||
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm
|
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm
|
||||||
|
|
||||||
# routed expert dispatch (token-level scatter)
|
# routed expert dispatch (token-level scatter)
|
||||||
@ -577,7 +581,12 @@ class LoRAAdapter(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Delta tensor of shape (B, T, dim) to be added to the block output
|
Delta tensor of shape (B, T, dim) to be added to the block output
|
||||||
"""
|
"""
|
||||||
s = self.scale(torch.tensor(loop_t, device=x.device)) # (rank,)
|
# Clamp for depth extrapolation: at inference n_loops can exceed the
|
||||||
|
# training max_loop_iters. Iterations beyond the trained range reuse
|
||||||
|
# the last learned per-loop scale rather than indexing out of range.
|
||||||
|
max_t = self.scale.num_embeddings - 1
|
||||||
|
t_idx = loop_t if loop_t <= max_t else max_t
|
||||||
|
s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,)
|
||||||
down = self.down(x) * s # (B, T, rank)
|
down = self.down(x) * s # (B, T, rank)
|
||||||
return down @ self.B # (B, T, dim)
|
return down @ self.B # (B, T, dim)
|
||||||
|
|
||||||
@ -829,19 +838,26 @@ class RecurrentBlock(nn.Module):
|
|||||||
still_running = ~halted
|
still_running = ~halted
|
||||||
|
|
||||||
# ACT remainder trick: once cumulative_p + p crosses threshold,
|
# ACT remainder trick: once cumulative_p + p crosses threshold,
|
||||||
# assign the remaining probability mass as the final weight
|
# assign the remaining probability mass as the final weight.
|
||||||
|
# Gate by still_running so halted positions contribute exactly
|
||||||
|
# once (on the halting step) and zero thereafter — otherwise
|
||||||
|
# threshold<1 leaves a non-zero remainder that leaks every step.
|
||||||
remainder = (1.0 - cumulative_p).clamp(min=0)
|
remainder = (1.0 - cumulative_p).clamp(min=0)
|
||||||
weight = torch.where(
|
weight = torch.where(
|
||||||
cumulative_p + p >= self.cfg.act_threshold,
|
cumulative_p + p >= self.cfg.act_threshold,
|
||||||
remainder,
|
remainder,
|
||||||
p,
|
p,
|
||||||
)
|
)
|
||||||
|
weight = weight * still_running.float()
|
||||||
h_out = h_out + weight.unsqueeze(-1) * h
|
h_out = h_out + weight.unsqueeze(-1) * h
|
||||||
|
|
||||||
cumulative_p = cumulative_p + p * still_running.float()
|
cumulative_p = cumulative_p + p * still_running.float()
|
||||||
halted = halted | (cumulative_p >= self.cfg.act_threshold)
|
halted = halted | (cumulative_p >= self.cfg.act_threshold)
|
||||||
|
|
||||||
if halted.all():
|
# Only short-circuit when there is no KV cache to keep consistent.
|
||||||
|
# With a cache, every loop depth must run on every forward pass so
|
||||||
|
# later decode steps find populated keys at every cache_key.
|
||||||
|
if halted.all() and kv_cache is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
return h_out
|
return h_out
|
||||||
|
|||||||
@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "open-mythos"
|
name = "open-mythos"
|
||||||
version = "0.3.0"
|
version = "0.4.0"
|
||||||
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
|
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
authors = ["Kye Gomez <kye@swarms.world>"]
|
authors = ["Kye Gomez <kye@swarms.world>"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user