mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
[bugf][act-halting][gate halted positions from weight accumulation][bugf][moe-router-bias][stop
load balance bias gradient leak][bugf][act-cache-consistency][keep all loops with kv cache][bugf][lora-depth-extrapolation][clamp scale index beyond max loops][improvement][pyproject-version][bump version to 0 4 0]
This commit is contained in:
parent
7ba690797b
commit
289981ba01
@ -477,10 +477,14 @@ class MoEFFN(nn.Module):
|
||||
B, T, D = x.shape
|
||||
flat = x.view(B * T, D)
|
||||
|
||||
# router — bias shifts logits for load balancing without touching loss
|
||||
logits = self.router(flat) + self.router_bias # (B*T, n_experts)
|
||||
# Aux-loss-free load balancing (DeepSeek-V3): the bias shifts only the
|
||||
# selection of which experts fire so underused experts are picked more,
|
||||
# but the gating weights come from unbiased softmax scores so the bias
|
||||
# never shows up in the gradient.
|
||||
logits = self.router(flat) # (B*T, n_experts), unbiased
|
||||
scores = F.softmax(logits, dim=-1)
|
||||
topk_scores, topk_idx = scores.topk(self.topk, dim=-1)
|
||||
_, topk_idx = (logits + self.router_bias).topk(self.topk, dim=-1)
|
||||
topk_scores = scores.gather(-1, topk_idx)
|
||||
topk_scores = topk_scores / topk_scores.sum(dim=-1, keepdim=True) # renorm
|
||||
|
||||
# routed expert dispatch (token-level scatter)
|
||||
@ -577,7 +581,12 @@ class LoRAAdapter(nn.Module):
|
||||
Returns:
|
||||
Delta tensor of shape (B, T, dim) to be added to the block output
|
||||
"""
|
||||
s = self.scale(torch.tensor(loop_t, device=x.device)) # (rank,)
|
||||
# Clamp for depth extrapolation: at inference n_loops can exceed the
|
||||
# training max_loop_iters. Iterations beyond the trained range reuse
|
||||
# the last learned per-loop scale rather than indexing out of range.
|
||||
max_t = self.scale.num_embeddings - 1
|
||||
t_idx = loop_t if loop_t <= max_t else max_t
|
||||
s = self.scale(torch.tensor(t_idx, device=x.device)) # (rank,)
|
||||
down = self.down(x) * s # (B, T, rank)
|
||||
return down @ self.B # (B, T, dim)
|
||||
|
||||
@ -829,19 +838,26 @@ class RecurrentBlock(nn.Module):
|
||||
still_running = ~halted
|
||||
|
||||
# ACT remainder trick: once cumulative_p + p crosses threshold,
|
||||
# assign the remaining probability mass as the final weight
|
||||
# assign the remaining probability mass as the final weight.
|
||||
# Gate by still_running so halted positions contribute exactly
|
||||
# once (on the halting step) and zero thereafter — otherwise
|
||||
# threshold<1 leaves a non-zero remainder that leaks every step.
|
||||
remainder = (1.0 - cumulative_p).clamp(min=0)
|
||||
weight = torch.where(
|
||||
cumulative_p + p >= self.cfg.act_threshold,
|
||||
remainder,
|
||||
p,
|
||||
)
|
||||
weight = weight * still_running.float()
|
||||
h_out = h_out + weight.unsqueeze(-1) * h
|
||||
|
||||
cumulative_p = cumulative_p + p * still_running.float()
|
||||
halted = halted | (cumulative_p >= self.cfg.act_threshold)
|
||||
|
||||
if halted.all():
|
||||
# Only short-circuit when there is no KV cache to keep consistent.
|
||||
# With a cache, every loop depth must run on every forward pass so
|
||||
# later decode steps find populated keys at every cache_key.
|
||||
if halted.all() and kv_cache is None:
|
||||
break
|
||||
|
||||
return h_out
|
||||
|
||||
@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "open-mythos"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
description = "OpenMythos — open-source theoretical reconstruction of the Claude Mythos Recurrent-Depth Transformer architecture"
|
||||
license = "MIT"
|
||||
authors = ["Kye Gomez <kye@swarms.world>"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user