mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
readme][improvement][init-sort][sort imports alphabetically in init][improvement][version-bump][bump version to 0.3.0][feat][tokenizer-class][add MythosTokenizer to init exports][feat][test-tokenizer][add tokenizer test suite with printed output]
1135 lines
44 KiB
Python
1135 lines
44 KiB
Python
"""
|
||
Mixture-of-Depths Attention (MoDA) + DeepSeek Mixture-of-Experts FFN
|
||
======================================================================
|
||
Paper (attention): "Mixture-of-Depths Attention" arXiv 2603.15619
|
||
Paper (MoE): "DeepSeekMoE: Towards Ultimate Expert Specialization
|
||
in Mixture-of-Experts Language Models" arXiv 2401.06066
|
||
Reference impl (V3): https://github.com/deepseek-ai/DeepSeek-V3
|
||
|
||
Architecture
|
||
------------
|
||
This file fuses two independent architectural improvements:
|
||
|
||
1. **MoDA** — each attention head jointly attends to current-layer sequence
|
||
KV pairs (causal) *and* depth KV pairs from all preceding layers at the
|
||
same token position, under a single softmax.
|
||
|
||
2. **DeepSeek MoE** (replaces the dense SwiGLU FFN in every block):
|
||
* K_s *shared experts* — always activated, capture common knowledge.
|
||
* N_r *routed experts* — sparse; top-K activated per token.
|
||
* Fine-grained expert segmentation: each expert has a small hidden dim
|
||
(≈ dense_hidden / m) so that activating more experts keeps FLOPs
|
||
constant while improving specialisation.
|
||
* Expert-level balance loss prevents routing collapse.
|
||
|
||
Gate routing (faithful to DeepSeek-V3 model.py)
|
||
------------------------------------------------
|
||
scores = softmax(x W^T) # or sigmoid for V3 style
|
||
original = scores # saved for weight computation
|
||
[optional] scores += bias # V3 aux-loss-free routing
|
||
[optional] group-limited masking # V3 device-group routing
|
||
indices = topk(scores, K)
|
||
weights = original[indices] # un-biased original scores
|
||
[sigmoid] weights /= sum(weights) # re-normalise for sigmoid gating
|
||
weights *= route_scale
|
||
|
||
Balance loss (DeepSeekMoE §3.3, used when training without V3 bias routing)
|
||
L_ExpBal = Σ_i f_i · P_i
|
||
f_i = (N_r / (K · T)) · #{tokens routing to i} (normalised frequency)
|
||
P_i = (1/T) Σ_t s_{i,t} (mean soft gate score)
|
||
|
||
Memory note
|
||
-----------
|
||
MoDA's unified attention has O(T·L) combined KV length. For long sequences
|
||
use the Triton kernel from https://github.com/hustvl/MoDA.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
from dataclasses import dataclass
|
||
from typing import List, Optional, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
@dataclass
|
||
class MoDAConfig:
|
||
"""Configuration for a MoDA + DeepSeek-MoE decoder-only language model.
|
||
|
||
Attention (MoDA)
|
||
----------------
|
||
vocab_size: Vocabulary size.
|
||
d_model: Hidden dimension (must equal n_heads_q * head_dim).
|
||
n_layers: Number of transformer layers.
|
||
n_heads_q: Query heads.
|
||
n_heads_kv: Key/value heads for GQA (must divide n_heads_q).
|
||
head_dim: Per-head dimension (usually d_model // n_heads_q).
|
||
max_seq_len: Maximum sequence length for the RoPE cache.
|
||
rope_base: RoPE frequency base.
|
||
attn_dropout: Attention dropout (0 for inference).
|
||
norm_eps: RMSNorm epsilon.
|
||
|
||
MoE FFN (DeepSeekMoE / DeepSeek-V3 style)
|
||
------------------------------------------
|
||
n_shared_experts: Always-active shared experts (K_s). Capture common
|
||
knowledge; excluded from routing and balance loss.
|
||
n_routed_experts: Total pool of routed experts (N_r).
|
||
n_activated_experts: Top-K selected from routed experts per token (K').
|
||
expert_hidden_dim: Per-expert intermediate dimension.
|
||
Set to dense_ffn_hidden / m where m is the
|
||
fine-grained segmentation factor so that total
|
||
activated FLOPs match a dense FFN:
|
||
(n_shared + n_activated) × expert_hidden ≈ dense_hidden
|
||
moe_balance_alpha: Weight of the expert-level balance loss. Set to
|
||
0.0 to disable (e.g. when using V3 bias routing).
|
||
moe_score_func: "softmax" (DeepSeekMoE / V2) or "sigmoid" (V3).
|
||
moe_n_groups: Number of expert groups for group-limited routing
|
||
(V3 uses 8; set 1 to disable, default).
|
||
moe_topk_groups: Number of groups a token may route to
|
||
(V3 uses 3; set 1 to disable, default).
|
||
moe_route_scale: Scalar multiplied onto the selected gate weights
|
||
after normalisation (V3 uses 2.5446; default 1.0).
|
||
|
||
Defaults approximate the DeepSeekMoE 2B configuration scaled to
|
||
d_model = 2048, keeping per-token FLOPs equal to a dense SwiGLU with
|
||
hidden_dim = 5 632 (≈ 8/3 × 2048):
|
||
(n_shared + n_activated) × expert_hidden = (2+6) × 704 = 5 632.
|
||
"""
|
||
|
||
# ---- Transformer / MoDA ----
|
||
vocab_size: int = 32_000
|
||
d_model: int = 2048
|
||
n_layers: int = 24
|
||
n_heads_q: int = 16
|
||
n_heads_kv: int = 8
|
||
head_dim: int = 128
|
||
max_seq_len: int = 4_096
|
||
rope_base: float = 10_000.0
|
||
attn_dropout: float = 0.0
|
||
norm_eps: float = 1e-6
|
||
|
||
# ---- DeepSeek MoE FFN ----
|
||
n_shared_experts: int = 2 # K_s
|
||
n_routed_experts: int = 64 # N_r
|
||
n_activated_experts: int = 6 # K' top-K from routed pool
|
||
expert_hidden_dim: int = 704 # per-expert intermediate dim
|
||
moe_balance_alpha: float = 0.001 # balance-loss weight (0 = disabled)
|
||
moe_score_func: str = "softmax" # "softmax" | "sigmoid"
|
||
moe_n_groups: int = 1 # expert groups (1 = no grouping)
|
||
moe_topk_groups: int = 1 # groups to route into (1 = no limit)
|
||
moe_route_scale: float = 1.0 # gate-weight scale factor
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Primitives
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RMSNorm(nn.Module):
|
||
"""Root Mean Square Layer Normalization (no bias, no mean subtraction)."""
|
||
|
||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||
"""Create an RMSNorm layer.
|
||
|
||
Args:
|
||
dim: Feature dimension to normalise over (the last axis of input).
|
||
eps: Stability constant added before the reciprocal square-root to
|
||
prevent division by zero when the RMS is near zero.
|
||
"""
|
||
super().__init__()
|
||
self.eps = eps
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""Normalise *x* by its root-mean-square and apply a learnable scale.
|
||
|
||
Args:
|
||
x: Input tensor of arbitrary shape ``[..., dim]``.
|
||
|
||
Returns:
|
||
Normalised tensor, same shape as *x*.
|
||
"""
|
||
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
||
return x * rms * self.weight
|
||
|
||
|
||
class RotaryEmbedding(nn.Module):
|
||
"""Rotary Position Embedding (RoPE) with lazy cache extension.
|
||
|
||
Args:
|
||
dim: Per-head dimension (head_dim).
|
||
max_seq_len: Initial cache size.
|
||
base: Frequency base (default 10 000).
|
||
"""
|
||
|
||
def __init__(
|
||
self, dim: int, max_seq_len: int = 8_192, base: float = 10_000.0
|
||
) -> None:
|
||
"""Initialise RoPE and pre-compute the cos/sin cache.
|
||
|
||
Args:
|
||
dim: Per-head dimension. Must be even.
|
||
max_seq_len: Number of positions to cache on construction. The
|
||
cache doubles automatically when a longer sequence
|
||
is encountered.
|
||
base: Frequency base θ. Higher values slow the rotation
|
||
rate, extending effective context length.
|
||
"""
|
||
super().__init__()
|
||
self.dim = dim
|
||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||
self._build_cache(max_seq_len)
|
||
|
||
def _build_cache(self, seq_len: int) -> None:
|
||
"""Pre-compute and register ``_cos`` / ``_sin`` buffers up to *seq_len*.
|
||
|
||
Called once at init and again (doubling capacity) whenever ``forward``
|
||
is asked for a sequence longer than the current cache.
|
||
|
||
Args:
|
||
seq_len: Number of positions to pre-compute.
|
||
"""
|
||
t = torch.arange(
|
||
seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||
)
|
||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [T, dim/2]
|
||
emb = torch.cat([freqs, freqs], dim=-1) # [T, dim]
|
||
self.register_buffer("_cos", emb.cos()[None, None], persistent=False)
|
||
self.register_buffer("_sin", emb.sin()[None, None], persistent=False)
|
||
|
||
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Return cached (cos, sin) tables for the first *seq_len* positions.
|
||
|
||
Args:
|
||
seq_len: Number of positions required.
|
||
|
||
Returns:
|
||
Tuple of ``(cos, sin)``, each shaped ``[1, 1, seq_len, dim]``,
|
||
ready to broadcast with ``[B, H, T, dim]`` query / key tensors.
|
||
"""
|
||
if seq_len > self._cos.shape[2]:
|
||
self._build_cache(seq_len * 2)
|
||
return self._cos[:, :, :seq_len], self._sin[:, :, :seq_len]
|
||
|
||
|
||
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||
"""Return *x* with its last dimension split and swapped with negation.
|
||
|
||
Given ``x = [x₁, x₂]`` (each half of the last dim), returns
|
||
``[-x₂, x₁]``. Combined with the cos/sin multiply in
|
||
:func:`apply_rotary_emb` this implements the 2-D rotation matrix
|
||
that defines RoPE.
|
||
|
||
Args:
|
||
x: Tensor with an even-sized last dimension.
|
||
|
||
Returns:
|
||
Rotated tensor with the same shape as *x*.
|
||
"""
|
||
half = x.shape[-1] // 2
|
||
return torch.cat([-x[..., half:], x[..., :half]], dim=-1)
|
||
|
||
|
||
def apply_rotary_emb(
|
||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||
) -> torch.Tensor:
|
||
"""Apply Rotary Position Embeddings in-place to query or key tensors.
|
||
|
||
Implements ``x_rot = x * cos + rotate_half(x) * sin``, which is
|
||
equivalent to multiplying each consecutive pair of dimensions by a
|
||
2-D rotation matrix whose angle depends on the sequence position and
|
||
the dimension's frequency.
|
||
|
||
Args:
|
||
x: Query or key tensor, shape ``[B, H, T, d]``.
|
||
cos: Pre-computed cosines, shape ``[1, 1, T, d]``.
|
||
sin: Pre-computed sines, shape ``[1, 1, T, d]``.
|
||
|
||
Returns:
|
||
Rotated tensor with the same shape and dtype as *x*.
|
||
"""
|
||
return x * cos + _rotate_half(x) * sin
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DeepSeek MoE FFN
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class DeepSeekExpert(nn.Module):
|
||
"""Single fine-grained SwiGLU expert.
|
||
|
||
Faithful to DeepSeek-V3 ``Expert``:
|
||
output = w2( SiLU(w1(x)) ⊙ w3(x) )
|
||
|
||
where w1 is the gate projection, w3 the up-projection, w2 the
|
||
down-projection — identical to a SwiGLU FFN at smaller hidden dim.
|
||
|
||
Args:
|
||
d_model: Input / output dimension.
|
||
hidden_dim: Expert intermediate dimension (≪ dense FFN hidden_dim).
|
||
"""
|
||
|
||
def __init__(self, d_model: int, hidden_dim: int) -> None:
|
||
"""Create a single fine-grained SwiGLU expert.
|
||
|
||
Args:
|
||
d_model: Token hidden dimension (input and output size).
|
||
hidden_dim: Expert intermediate dimension. Typically much
|
||
smaller than the dense FFN hidden dim — set to
|
||
``dense_hidden / m`` where *m* is the fine-grained
|
||
segmentation factor so total activated FLOPs match
|
||
the dense baseline.
|
||
"""
|
||
super().__init__()
|
||
self.w1 = nn.Linear(d_model, hidden_dim, bias=False) # gate
|
||
self.w3 = nn.Linear(d_model, hidden_dim, bias=False) # up
|
||
self.w2 = nn.Linear(hidden_dim, d_model, bias=False) # down
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""Compute ``w2( SiLU(w1(x)) ⊙ w3(x) )``.
|
||
|
||
Args:
|
||
x: Token features assigned to this expert, shape
|
||
``[num_assigned_tokens, d_model]``.
|
||
|
||
Returns:
|
||
Expert output, shape ``[num_assigned_tokens, d_model]``.
|
||
"""
|
||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||
|
||
|
||
class DeepSeekGate(nn.Module):
|
||
"""Token-to-expert routing gate.
|
||
|
||
Faithful to DeepSeek-V3 ``Gate`` (minus distributed sharding).
|
||
|
||
Routing algorithm
|
||
~~~~~~~~~~~~~~~~~
|
||
1. ``scores = softmax(x W^T)`` or ``sigmoid(x W^T)``
|
||
2. ``original_scores = scores`` (saved — will be used for gate weights)
|
||
3. [optional] ``scores += bias`` (V3 aux-loss-free bias routing)
|
||
4. [optional] Group-limited masking:
|
||
- reshape scores → [T, n_groups, experts_per_group]
|
||
- keep only top-``topk_groups`` groups per token
|
||
- mask the rest to −∞
|
||
5. ``indices = topk(scores, K')`` (routing decision)
|
||
6. ``weights = original_scores[indices]`` (un-biased weights)
|
||
7. [sigmoid only] ``weights /= sum(weights)`` (re-normalise)
|
||
8. ``weights *= route_scale``
|
||
|
||
The ``original_scores`` (full distribution, before bias/masking) are also
|
||
returned so the MoE layer can compute the expert-level balance loss.
|
||
|
||
Args:
|
||
d_model: Token hidden dimension.
|
||
n_routed_experts: Total routed expert pool size (N_r).
|
||
n_activated: Top-K experts to select (K').
|
||
score_func: ``"softmax"`` or ``"sigmoid"``.
|
||
n_groups: Number of expert groups (1 = disabled).
|
||
topk_groups: Groups a token may route to (1 = disabled).
|
||
route_scale: Scalar applied to final gate weights.
|
||
use_bias: If True, add a learnable per-expert bias used only
|
||
for the routing decision (V3 aux-loss-free scheme).
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
d_model: int,
|
||
n_routed_experts: int,
|
||
n_activated: int,
|
||
score_func: str = "softmax",
|
||
n_groups: int = 1,
|
||
topk_groups: int = 1,
|
||
route_scale: float = 1.0,
|
||
use_bias: bool = False,
|
||
) -> None:
|
||
"""Create the routing gate.
|
||
|
||
Args:
|
||
d_model: Token hidden dimension.
|
||
n_routed_experts: Total number of routed experts in the pool (N_r).
|
||
n_activated: How many experts to select per token (K').
|
||
score_func: Affinity function — ``"softmax"`` (original
|
||
DeepSeekMoE / V2) or ``"sigmoid"`` (V3).
|
||
n_groups: Number of expert groups for device-limited
|
||
routing. Set to 1 to disable (default).
|
||
topk_groups: Number of groups each token may route into.
|
||
Set to 1 to disable (default).
|
||
route_scale: Scalar multiplied onto gate weights after
|
||
optional sigmoid normalisation (V3 uses 2.5446;
|
||
default 1.0 leaves weights unchanged).
|
||
use_bias: If ``True``, initialise a learnable per-expert
|
||
float32 bias added to routing scores only (not
|
||
gate weights). Enables the V3 aux-loss-free
|
||
load-balancing scheme where the bias is adjusted
|
||
outside the optimizer to track expert loads.
|
||
"""
|
||
super().__init__()
|
||
self.n_routed_experts = n_routed_experts
|
||
self.n_activated = n_activated
|
||
self.score_func = score_func
|
||
self.n_groups = n_groups
|
||
self.topk_groups = topk_groups
|
||
self.route_scale = route_scale
|
||
|
||
# Gating projection: [N_r, D]
|
||
self.weight = nn.Parameter(torch.empty(n_routed_experts, d_model))
|
||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||
|
||
# Optional per-expert routing bias (V3 aux-loss-free load balancing).
|
||
# Updated outside the optimizer by monitoring expert loads — not trained
|
||
# through the balance loss. Initialised to zero.
|
||
self.bias: Optional[nn.Parameter] = (
|
||
nn.Parameter(torch.zeros(n_routed_experts, dtype=torch.float32))
|
||
if use_bias
|
||
else None
|
||
)
|
||
|
||
def forward(
|
||
self, x: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Compute routing weights and expert indices.
|
||
|
||
Args:
|
||
x: ``[num_tokens, d_model]`` (flattened B × T).
|
||
|
||
Returns:
|
||
weights: ``[num_tokens, K']`` gate weights (dtype = x.dtype).
|
||
indices: ``[num_tokens, K']`` selected expert indices (int64).
|
||
original_scores: ``[num_tokens, N_r]`` full soft scores (float32),
|
||
used by :class:`DeepSeekMoE` for the balance loss.
|
||
"""
|
||
# Affinity logits
|
||
logits = F.linear(x, self.weight) # [T, N_r]
|
||
|
||
if self.score_func == "softmax":
|
||
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
||
else: # sigmoid (V3)
|
||
scores = logits.sigmoid().to(torch.float32)
|
||
|
||
original_scores = scores # un-biased; used for weights + balance loss
|
||
|
||
# Routing scores (may differ from original_scores if bias is active)
|
||
routing = scores
|
||
if self.bias is not None:
|
||
routing = routing + self.bias
|
||
|
||
# Group-limited routing (V3 device-group constraint)
|
||
if self.n_groups > 1:
|
||
# [T, n_groups, experts_per_group]
|
||
g = routing.view(x.size(0), self.n_groups, -1)
|
||
if self.bias is None:
|
||
group_scores = g.amax(dim=-1) # [T, G]
|
||
else:
|
||
# Top-2 sum per group (V3 heuristic)
|
||
group_scores = g.topk(2, dim=-1)[0].sum(dim=-1)
|
||
_, top_groups = group_scores.topk(self.topk_groups, dim=-1) # [T, topk_g]
|
||
mask = torch.ones(
|
||
x.size(0), self.n_groups, dtype=torch.bool, device=x.device
|
||
).scatter_(
|
||
1, top_groups, False
|
||
) # True = masked out
|
||
routing = g.masked_fill(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
||
|
||
# Top-K selection (on routing scores which may include bias / group mask)
|
||
_, indices = routing.topk(self.n_activated, dim=-1) # [T, K']
|
||
|
||
# Gate weights from original (un-biased) scores
|
||
weights = original_scores.gather(1, indices) # [T, K']
|
||
|
||
if self.score_func == "sigmoid":
|
||
weights = weights / weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||
|
||
weights = (weights * self.route_scale).to(x.dtype)
|
||
return weights, indices, original_scores
|
||
|
||
|
||
class DeepSeekMoE(nn.Module):
|
||
"""DeepSeek Mixture-of-Experts layer — drop-in replacement for a dense FFN.
|
||
|
||
Combines shared experts (always active) and routed experts (sparse top-K)
|
||
exactly as in DeepSeek-V3 ``MoE``, adapted for single-device training
|
||
(no ColumnParallel/RowParallel, no all_reduce).
|
||
|
||
Forward pass
|
||
~~~~~~~~~~~~
|
||
::
|
||
|
||
x_flat = x.view(-1, D) # [B*T, D]
|
||
|
||
# Shared path (always executed)
|
||
z = shared_experts(x_flat) # [B*T, D]
|
||
|
||
# Routed path (sparse)
|
||
weights, indices, scores = gate(x_flat) # [B*T, K'], [B*T, K'], [B*T, N_r]
|
||
y = zeros_like(x_flat)
|
||
for each expert i:
|
||
toks = tokens that selected expert i
|
||
y[toks] += experts[i](x_flat[toks]) * weights[toks, rank_of_i]
|
||
|
||
output = (y + z).view(B, T, D)
|
||
|
||
# Training: expert-level balance loss (DeepSeekMoE §3.3)
|
||
L_ExpBal = Σ_i f_i · P_i
|
||
f_i = (N_r / (K' · T)) · #{tokens → expert i}
|
||
P_i = mean_t(scores_{t,i})
|
||
|
||
Args:
|
||
cfg: :class:`MoDAConfig` instance.
|
||
"""
|
||
|
||
def __init__(self, cfg: MoDAConfig) -> None:
|
||
"""Build the MoE layer from a :class:`MoDAConfig`.
|
||
|
||
Constructs:
|
||
* ``shared_experts`` — one dense SwiGLU FFN with hidden dimension
|
||
``n_shared_experts × expert_hidden_dim``.
|
||
* ``gate`` — :class:`DeepSeekGate` for top-K routing.
|
||
* ``experts`` — ``nn.ModuleList`` of ``n_routed_experts``
|
||
:class:`DeepSeekExpert` instances, each with ``expert_hidden_dim``
|
||
intermediate units.
|
||
|
||
Args:
|
||
cfg: Model configuration. The relevant fields are
|
||
``n_shared_experts``, ``n_routed_experts``,
|
||
``n_activated_experts``, ``expert_hidden_dim``,
|
||
``moe_balance_alpha``, ``moe_score_func``,
|
||
``moe_n_groups``, ``moe_topk_groups``, and
|
||
``moe_route_scale``.
|
||
"""
|
||
super().__init__()
|
||
self.d_model = cfg.d_model
|
||
self.n_routed_experts = cfg.n_routed_experts
|
||
self.n_activated_experts = cfg.n_activated_experts
|
||
self.moe_balance_alpha = cfg.moe_balance_alpha
|
||
|
||
# Shared experts: single dense SwiGLU with hidden = K_s × expert_hidden
|
||
# (matches DeepSeek-V3's ``MLP(dim, n_shared_experts * moe_inter_dim)``)
|
||
shared_hidden = cfg.n_shared_experts * cfg.expert_hidden_dim
|
||
self.shared_experts = _SharedFFN(cfg.d_model, shared_hidden)
|
||
|
||
# Routing gate
|
||
self.gate = DeepSeekGate(
|
||
d_model=cfg.d_model,
|
||
n_routed_experts=cfg.n_routed_experts,
|
||
n_activated=cfg.n_activated_experts,
|
||
score_func=cfg.moe_score_func,
|
||
n_groups=cfg.moe_n_groups,
|
||
topk_groups=cfg.moe_topk_groups,
|
||
route_scale=cfg.moe_route_scale,
|
||
use_bias=False,
|
||
)
|
||
|
||
# Routed experts pool
|
||
self.experts = nn.ModuleList(
|
||
[
|
||
DeepSeekExpert(cfg.d_model, cfg.expert_hidden_dim)
|
||
for _ in range(cfg.n_routed_experts)
|
||
]
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
"""Run the MoE layer.
|
||
|
||
Args:
|
||
x: ``[B, T, D]`` hidden states.
|
||
|
||
Returns:
|
||
output: ``[B, T, D]`` updated hidden states.
|
||
balance_loss: Scalar expert-level balance loss (during training),
|
||
or ``None`` during inference.
|
||
"""
|
||
shape = x.shape
|
||
x_flat = x.view(-1, self.d_model) # [T_tot, D]
|
||
n_tokens = x_flat.size(0)
|
||
|
||
# ---- Shared experts (all tokens) ---------------------------------
|
||
z = self.shared_experts(x_flat) # [T_tot, D]
|
||
|
||
# ---- Routed experts (sparse) -------------------------------------
|
||
weights, indices, scores = self.gate(x_flat)
|
||
# weights: [T_tot, K'], indices: [T_tot, K'], scores: [T_tot, N_r]
|
||
|
||
y = torch.zeros_like(x_flat)
|
||
|
||
# Dispatch: for each expert compute on its assigned tokens
|
||
# (token-major loop matches DeepSeek-V3's reference implementation)
|
||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts)
|
||
for i, expert in enumerate(self.experts):
|
||
if counts[i].item() == 0:
|
||
continue
|
||
tok_idx, rank_in_k = torch.where(
|
||
indices == i
|
||
) # which tokens & which K slot
|
||
y[tok_idx] += expert(x_flat[tok_idx]) * weights[tok_idx, rank_in_k, None]
|
||
|
||
output = (y + z).view(shape)
|
||
|
||
# ---- Expert-level balance loss (DeepSeekMoE §3.3) ----------------
|
||
balance_loss: Optional[torch.Tensor] = None
|
||
if self.training and self.moe_balance_alpha > 0.0:
|
||
balance_loss = self._balance_loss(indices, scores, n_tokens)
|
||
|
||
return output, balance_loss
|
||
|
||
def _balance_loss(
|
||
self,
|
||
indices: torch.Tensor, # [T, K'] int64
|
||
scores: torch.Tensor, # [T, N_r] float32 (full distribution)
|
||
n_tokens: int,
|
||
) -> torch.Tensor:
|
||
"""Compute the expert-level balance loss (DeepSeekMoE §3.3).
|
||
|
||
Penalises routing imbalance by encouraging the model to spread tokens
|
||
evenly across experts. Only the soft-score term ``P_i`` receives a
|
||
gradient; the hard-count term ``f_i`` is non-differentiable and acts
|
||
as a fixed weighting coefficient.
|
||
|
||
::
|
||
|
||
f_i = (N_r / (K' × T)) × #{tokens routed to expert i}
|
||
P_i = (1/T) Σ_t scores[t, i]
|
||
L = Σ_i f_i · P_i
|
||
|
||
For perfect balance ``f_i = 1`` for all *i* and ``L = Σ P_i = 1``
|
||
(softmax) or some constant (sigmoid). Overloaded experts produce
|
||
large ``f_i``, pushing their mean score ``P_i`` up via the gradient
|
||
and thereby attracting more tokens — stabilising load over training.
|
||
|
||
Args:
|
||
indices: ``[T, K']`` int64 — expert indices selected per token.
|
||
scores: ``[T, N_r]`` float32 — full soft distribution from the
|
||
gate (before top-K selection), used for ``P_i``.
|
||
n_tokens: Total number of tokens in the batch (``B × T``).
|
||
|
||
Returns:
|
||
Scalar balance loss tensor.
|
||
"""
|
||
Nr, K = self.n_routed_experts, self.n_activated_experts
|
||
|
||
# Routing counts per expert (non-differentiable)
|
||
counts = torch.zeros(Nr, dtype=torch.float32, device=indices.device)
|
||
counts.scatter_add_(
|
||
0,
|
||
indices.flatten(),
|
||
torch.ones(indices.numel(), dtype=torch.float32, device=indices.device),
|
||
)
|
||
f = counts * (Nr / (K * n_tokens)) # normalised frequency [N_r]
|
||
|
||
# Mean soft gate score per expert (differentiable through softmax/sigmoid)
|
||
P = scores.mean(dim=0) # [N_r]
|
||
|
||
# f is derived from hard top-K → no gradient; gradient flows through P only
|
||
return (f * P).sum()
|
||
|
||
|
||
class _SharedFFN(nn.Module):
|
||
"""Dense SwiGLU FFN used for the always-active shared experts.
|
||
|
||
Mirrors :class:`DeepSeekExpert` but is a single larger MLP whose
|
||
``hidden_dim`` equals ``n_shared_experts × expert_hidden_dim``. This
|
||
matches DeepSeek-V3's ``MLP(dim, n_shared_experts * moe_inter_dim)``.
|
||
|
||
Not part of the public API — instantiated only by :class:`DeepSeekMoE`.
|
||
"""
|
||
|
||
def __init__(self, d_model: int, hidden_dim: int) -> None:
|
||
"""Create the shared-expert FFN.
|
||
|
||
Args:
|
||
d_model: Token hidden dimension (input and output).
|
||
hidden_dim: Combined intermediate size for all shared experts
|
||
(``n_shared_experts × expert_hidden_dim``).
|
||
"""
|
||
super().__init__()
|
||
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
|
||
self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
|
||
self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""Apply the shared SwiGLU FFN to every token.
|
||
|
||
Args:
|
||
x: Flattened token features, shape ``[B*T, d_model]``.
|
||
|
||
Returns:
|
||
Transformed features, shape ``[B*T, d_model]``.
|
||
"""
|
||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MoDA Attention (unchanged from base file)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class MoDAAttention(nn.Module):
|
||
"""Mixture-of-Depths Attention — read side.
|
||
|
||
Each query jointly attends (single softmax) to:
|
||
* Sequence KVs at the current layer (causal GQA).
|
||
* Depth KVs from all preceding layers at the *same* token position.
|
||
|
||
Depth cache entries are written externally by :class:`MoDABlock` from
|
||
the full block output X_l^out (after the MoE FFN).
|
||
|
||
Args:
|
||
cfg: :class:`MoDAConfig` instance.
|
||
"""
|
||
|
||
def __init__(self, cfg: MoDAConfig) -> None:
|
||
"""Build the MoDA attention module.
|
||
|
||
Creates four projection matrices (Q, K, V, O) sized for GQA and
|
||
stores the attention scale and dropout rate.
|
||
|
||
Args:
|
||
cfg: Model configuration. Must satisfy
|
||
``n_heads_q % n_heads_kv == 0`` (GQA requirement).
|
||
|
||
Raises:
|
||
ValueError: If ``n_heads_q`` is not divisible by ``n_heads_kv``.
|
||
"""
|
||
super().__init__()
|
||
if cfg.n_heads_q % cfg.n_heads_kv != 0:
|
||
raise ValueError(
|
||
f"n_heads_q ({cfg.n_heads_q}) must be divisible by "
|
||
f"n_heads_kv ({cfg.n_heads_kv}) for GQA."
|
||
)
|
||
|
||
self.n_heads_q = cfg.n_heads_q
|
||
self.n_heads_kv = cfg.n_heads_kv
|
||
self.head_dim = cfg.head_dim
|
||
self.gqa_group = cfg.n_heads_q // cfg.n_heads_kv
|
||
self.scale = cfg.head_dim**-0.5
|
||
self.dropout = cfg.attn_dropout
|
||
|
||
inner_q = cfg.n_heads_q * cfg.head_dim
|
||
inner_kv = cfg.n_heads_kv * cfg.head_dim
|
||
|
||
self.q_proj = nn.Linear(cfg.d_model, inner_q, bias=False)
|
||
self.k_proj = nn.Linear(cfg.d_model, inner_kv, bias=False)
|
||
self.v_proj = nn.Linear(cfg.d_model, inner_kv, bias=False)
|
||
self.o_proj = nn.Linear(inner_q, cfg.d_model, bias=False)
|
||
|
||
def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
|
||
"""Repeat KV heads along dim 1 to match the number of query heads.
|
||
|
||
With GQA group size G, each KV head is shared by G query heads.
|
||
``repeat_interleave(G, dim=1)`` produces the correct interleaved
|
||
expansion so that query head ``h`` is paired with KV head ``h // G``.
|
||
|
||
Args:
|
||
kv: Key or value tensor whose dim 1 is the KV-head axis.
|
||
Supported shapes: ``[B, Hk, T, d]`` (sequence) and
|
||
``[B, Hk, T, L, d]`` (depth stack).
|
||
|
||
Returns:
|
||
Tensor with dim 1 expanded from ``Hk`` to ``Hq = Hk × G``.
|
||
Returns *kv* unchanged when ``gqa_group == 1``.
|
||
"""
|
||
if self.gqa_group == 1:
|
||
return kv
|
||
return kv.repeat_interleave(self.gqa_group, dim=1)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
depth_k_cache: List[torch.Tensor],
|
||
depth_v_cache: List[torch.Tensor],
|
||
cos: torch.Tensor,
|
||
sin: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
"""Compute MoDA attention output.
|
||
|
||
Args:
|
||
x: ``[B, T, D]`` input hidden states.
|
||
depth_k_cache: ``L`` tensors each ``[B, Hk, T, d]`` — depth keys.
|
||
depth_v_cache: Matching depth values.
|
||
cos/sin: RoPE tables ``[1, 1, T, d]``.
|
||
|
||
Returns:
|
||
``[B, T, D]`` output hidden states.
|
||
"""
|
||
B, T, D = x.shape
|
||
Hq, Hk, d = self.n_heads_q, self.n_heads_kv, self.head_dim
|
||
|
||
Q = self.q_proj(x).view(B, T, Hq, d).transpose(1, 2)
|
||
K = self.k_proj(x).view(B, T, Hk, d).transpose(1, 2)
|
||
V = self.v_proj(x).view(B, T, Hk, d).transpose(1, 2)
|
||
|
||
Q = apply_rotary_emb(Q, cos, sin)
|
||
K = apply_rotary_emb(K, cos, sin)
|
||
|
||
K_e = self._expand_kv(K)
|
||
V_e = self._expand_kv(V)
|
||
|
||
L = len(depth_k_cache)
|
||
|
||
if L == 0:
|
||
out = F.scaled_dot_product_attention(
|
||
Q,
|
||
K_e,
|
||
V_e,
|
||
is_causal=True,
|
||
dropout_p=self.dropout if self.training else 0.0,
|
||
scale=self.scale,
|
||
)
|
||
else:
|
||
# Sequence logits [B, Hq, T, T] with causal mask
|
||
seq_logits = torch.matmul(Q, K_e.transpose(-2, -1)) * self.scale
|
||
causal_mask = torch.triu(
|
||
torch.full((T, T), float("-inf"), device=x.device, dtype=Q.dtype),
|
||
diagonal=1,
|
||
)
|
||
seq_logits = seq_logits + causal_mask
|
||
|
||
# Depth KVs: [B, Hk, L, T, d] → [B, Hk, T, L, d]
|
||
K_depth = torch.stack(depth_k_cache, dim=2).permute(0, 1, 3, 2, 4)
|
||
V_depth = torch.stack(depth_v_cache, dim=2).permute(0, 1, 3, 2, 4)
|
||
K_depth_e = self._expand_kv(K_depth)
|
||
V_depth_e = self._expand_kv(V_depth)
|
||
|
||
# Depth logits [B, Hq, T, L]
|
||
depth_logits = torch.einsum("bhid,bhild->bhil", Q, K_depth_e) * self.scale
|
||
|
||
# Unified softmax over T + L positions
|
||
combined = torch.cat([seq_logits, depth_logits], dim=-1)
|
||
weights = F.softmax(combined, dim=-1)
|
||
if self.training and self.dropout > 0.0:
|
||
weights = F.dropout(weights, p=self.dropout)
|
||
|
||
seq_contrib = torch.matmul(weights[:, :, :, :T], V_e)
|
||
depth_contrib = torch.einsum(
|
||
"bhil,bhild->bhid", weights[:, :, :, T:], V_depth_e
|
||
)
|
||
out = seq_contrib + depth_contrib
|
||
|
||
out = out.transpose(1, 2).reshape(B, T, Hq * d)
|
||
return self.o_proj(out)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MoDA Transformer Block
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class MoDABlock(nn.Module):
|
||
"""Single MoDA + DeepSeek-MoE transformer block.
|
||
|
||
Structure (post-norm, per MoDA paper recommendation):
|
||
|
||
.. code-block::
|
||
|
||
x ──► Attention ──► + ──► RMSNorm ──► x_mid
|
||
x ↑ (residual)
|
||
x_mid ──► MoE ──► + ──► RMSNorm ──► x_out
|
||
x_mid ↑ (residual)
|
||
x_out ──► W_K^W ──► k_write } appended to MoDA depth KV cache
|
||
└─► W_V^W ──► v_write } by MoDAModel for the next layer
|
||
|
||
The MoE layer also returns an optional expert-level balance loss scalar
|
||
which is propagated up to :class:`MoDAModel` for inclusion in the total
|
||
training loss.
|
||
|
||
Args:
|
||
cfg: :class:`MoDAConfig` instance.
|
||
"""
|
||
|
||
def __init__(self, cfg: MoDAConfig) -> None:
|
||
"""Build one MoDA + MoE transformer block.
|
||
|
||
Constructs and wires together:
|
||
* ``attn`` — :class:`MoDAAttention` (depth-aware GQA).
|
||
* ``moe`` — :class:`DeepSeekMoE` (shared + routed experts).
|
||
* ``norm_attn`` / ``norm_ffn`` — post-sublayer :class:`RMSNorm`.
|
||
* ``k_write`` / ``v_write`` — depth-cache write projections
|
||
``D → n_heads_kv × head_dim``.
|
||
|
||
Args:
|
||
cfg: Model configuration.
|
||
"""
|
||
super().__init__()
|
||
inner_kv = cfg.n_heads_kv * cfg.head_dim
|
||
|
||
self.attn = MoDAAttention(cfg)
|
||
self.moe = DeepSeekMoE(cfg)
|
||
self.norm_attn = RMSNorm(cfg.d_model, cfg.norm_eps)
|
||
self.norm_ffn = RMSNorm(cfg.d_model, cfg.norm_eps)
|
||
|
||
# MoDA depth-cache write projections: K_l = X_l^out W_K^W, V_l = X_l^out W_V^W
|
||
self.k_write = nn.Linear(cfg.d_model, inner_kv, bias=False)
|
||
self.v_write = nn.Linear(cfg.d_model, inner_kv, bias=False)
|
||
|
||
self._n_heads_kv = cfg.n_heads_kv
|
||
self._head_dim = cfg.head_dim
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
depth_k_cache: List[torch.Tensor],
|
||
depth_v_cache: List[torch.Tensor],
|
||
cos: torch.Tensor,
|
||
sin: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||
"""Run one MoDA + MoE transformer block.
|
||
|
||
Args:
|
||
x: ``[B, T, D]`` input hidden states.
|
||
depth_k_cache: Depth keys from all preceding layers, each ``[B, Hk, T, d]``.
|
||
depth_v_cache: Matching depth values.
|
||
cos/sin: RoPE tables ``[1, 1, T, d]``.
|
||
|
||
Returns:
|
||
x_out: ``[B, T, D]`` updated hidden states.
|
||
k_write: ``[B, Hk, T, d]`` depth cache key for this layer.
|
||
v_write: ``[B, Hk, T, d]`` depth cache value for this layer.
|
||
balance_loss: Scalar expert-level balance loss, or ``None`` at inference.
|
||
"""
|
||
B, T, _ = x.shape
|
||
|
||
# Post-norm attention sub-layer
|
||
x = self.norm_attn(x + self.attn(x, depth_k_cache, depth_v_cache, cos, sin))
|
||
|
||
# Post-norm MoE sub-layer
|
||
moe_out, balance_loss = self.moe(x)
|
||
x = self.norm_ffn(x + moe_out)
|
||
|
||
# Depth write projections from X_l^out (full block output, after MoE)
|
||
k_write = (
|
||
self.k_write(x).view(B, T, self._n_heads_kv, self._head_dim).transpose(1, 2)
|
||
)
|
||
v_write = (
|
||
self.v_write(x).view(B, T, self._n_heads_kv, self._head_dim).transpose(1, 2)
|
||
)
|
||
|
||
# RoPE on k_write for positional consistency during future depth reads
|
||
k_write = apply_rotary_emb(k_write, cos, sin)
|
||
|
||
return x, k_write, v_write, balance_loss
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Full MoDA + MoE Language Model
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class MoDAModel(nn.Module):
|
||
"""Decoder-only LM with Mixture-of-Depths Attention and DeepSeek MoE FFN.
|
||
|
||
Loss = LM cross-entropy + moe_balance_alpha × mean(per-layer balance losses)
|
||
|
||
The depth KV cache is a local list inside :meth:`forward` — never stored
|
||
on ``self``, so autograd is clean across independent forward calls.
|
||
|
||
Args:
|
||
cfg: :class:`MoDAConfig` specifying the full model.
|
||
"""
|
||
|
||
def __init__(self, cfg: MoDAConfig) -> None:
|
||
"""Build the full MoDA + MoE language model.
|
||
|
||
Constructs the token embedding, RoPE, all transformer blocks, a final
|
||
RMSNorm, and the language-model head. The embedding and LM-head
|
||
weights are tied so they share the same parameter.
|
||
|
||
Args:
|
||
cfg: :class:`MoDAConfig` that fully specifies the model.
|
||
"""
|
||
super().__init__()
|
||
self.cfg = cfg
|
||
|
||
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
||
self.rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_base)
|
||
self.blocks = nn.ModuleList([MoDABlock(cfg) for _ in range(cfg.n_layers)])
|
||
self.norm_out = RMSNorm(cfg.d_model, cfg.norm_eps)
|
||
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
||
|
||
self.lm_head.weight = self.embed.weight # weight tying
|
||
|
||
self._init_weights()
|
||
|
||
def _init_weights(self) -> None:
|
||
"""Apply GPT-style weight initialisation to every sub-module.
|
||
|
||
* :class:`nn.Linear` and :class:`nn.Embedding` weights are drawn from
|
||
``Normal(0, 0.02)`` — the standard initialisation used by GPT-2 and
|
||
most subsequent transformer implementations.
|
||
* :class:`DeepSeekGate` weight matrices are re-initialised with
|
||
``kaiming_uniform`` (fan-in) to match the default ``nn.Linear``
|
||
init and avoid the Normal distribution being too narrow for a matrix
|
||
used without a subsequent non-linearity.
|
||
|
||
Called automatically at the end of :meth:`__init__`.
|
||
"""
|
||
for m in self.modules():
|
||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
||
elif isinstance(m, DeepSeekGate):
|
||
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
labels: Optional[torch.Tensor] = None,
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
"""Run the full MoDA + MoE language model.
|
||
|
||
Args:
|
||
input_ids: ``[B, T]`` token indices.
|
||
labels: ``[B, T]`` targets for LM loss; -100 positions ignored.
|
||
|
||
Returns:
|
||
logits: ``[B, T, vocab_size]``.
|
||
loss: ``lm_loss + balance_loss`` if labels provided, else ``None``.
|
||
"""
|
||
B, T = input_ids.shape
|
||
if T > self.cfg.max_seq_len:
|
||
raise ValueError(
|
||
f"Sequence length {T} exceeds max_seq_len={self.cfg.max_seq_len}."
|
||
)
|
||
|
||
x = self.embed(input_ids)
|
||
cos, sin = self.rope(T)
|
||
|
||
depth_k_cache: List[torch.Tensor] = []
|
||
depth_v_cache: List[torch.Tensor] = []
|
||
balance_losses: List[torch.Tensor] = []
|
||
|
||
for block in self.blocks:
|
||
x, k_write, v_write, bal = block(x, depth_k_cache, depth_v_cache, cos, sin)
|
||
depth_k_cache.append(k_write)
|
||
depth_v_cache.append(v_write)
|
||
if bal is not None:
|
||
balance_losses.append(bal)
|
||
|
||
x = self.norm_out(x)
|
||
logits = self.lm_head(x)
|
||
|
||
loss: Optional[torch.Tensor] = None
|
||
if labels is not None:
|
||
lm_loss = F.cross_entropy(
|
||
logits.view(-1, self.cfg.vocab_size),
|
||
labels.view(-1),
|
||
ignore_index=-100,
|
||
)
|
||
if balance_losses and self.cfg.moe_balance_alpha > 0.0:
|
||
avg_balance = torch.stack(balance_losses).mean()
|
||
loss = lm_loss + self.cfg.moe_balance_alpha * avg_balance
|
||
else:
|
||
loss = lm_loss
|
||
|
||
return logits, loss
|
||
|
||
def num_parameters(self, trainable_only: bool = False) -> int:
|
||
"""Count the total number of scalar parameters in the model.
|
||
|
||
Args:
|
||
trainable_only: If ``True``, count only parameters with
|
||
``requires_grad=True``, excluding frozen layers.
|
||
|
||
Returns:
|
||
Integer parameter count.
|
||
"""
|
||
params = (
|
||
self.parameters()
|
||
if not trainable_only
|
||
else (p for p in self.parameters() if p.requires_grad)
|
||
)
|
||
return sum(p.numel() for p in params)
|
||
|
||
def extra_repr(self) -> str:
|
||
"""Return a one-line summary string shown inside ``repr(model)``.
|
||
|
||
Displayed by PyTorch's ``__repr__`` directly after the class name,
|
||
before the sub-module tree.
|
||
|
||
Returns:
|
||
Human-readable string listing key model dimensions and the total
|
||
parameter count.
|
||
"""
|
||
c = self.cfg
|
||
return (
|
||
f"vocab={c.vocab_size}, d_model={c.d_model}, layers={c.n_layers}, "
|
||
f"heads={c.n_heads_q}/{c.n_heads_kv} (GQA), "
|
||
f"experts=shared×{c.n_shared_experts}+routed×{c.n_routed_experts}"
|
||
f"(top-{c.n_activated_experts}), "
|
||
f"params={self.num_parameters():,}"
|
||
)
|
||
|
||
|
||
# # ---------------------------------------------------------------------------
|
||
# # Smoke test
|
||
# # ---------------------------------------------------------------------------
|
||
|
||
# if __name__ == "__main__":
|
||
# torch.manual_seed(42)
|
||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
# print(f"Device: {device}")
|
||
|
||
# # Tiny config: 4 layers, 8 routed experts, top-2
|
||
# cfg = MoDAConfig(
|
||
# vocab_size=512,
|
||
# d_model=128,
|
||
# n_layers=4,
|
||
# n_heads_q=4,
|
||
# n_heads_kv=2,
|
||
# head_dim=32,
|
||
# max_seq_len=64,
|
||
# # MoE: 2 shared + 8 routed, activate top-2
|
||
# # (2+2)*64 = 256 ≈ equivalent to dense SwiGLU hidden~256
|
||
# n_shared_experts=2,
|
||
# n_routed_experts=8,
|
||
# n_activated_experts=2,
|
||
# expert_hidden_dim=64,
|
||
# moe_balance_alpha=0.01,
|
||
# moe_score_func="softmax",
|
||
# )
|
||
|
||
# model = MoDAModel(cfg).to(device)
|
||
# print(f"Parameters: {model.num_parameters():,}")
|
||
# print(model)
|
||
|
||
# B, T = 2, 32
|
||
# input_ids = torch.randint(0, cfg.vocab_size, (B, T), device=device)
|
||
# labels = torch.randint(0, cfg.vocab_size, (B, T), device=device)
|
||
|
||
# logits, loss = model(input_ids, labels)
|
||
# assert logits.shape == (B, T, cfg.vocab_size)
|
||
# print(f"Logits shape : {logits.shape}")
|
||
# print(f"Loss (LM + balance): {loss.item():.4f}")
|
||
|
||
# loss.backward()
|
||
|
||
# # Verify gradients
|
||
# last_writes = {
|
||
# f"blocks.{cfg.n_layers - 1}.k_write.weight",
|
||
# f"blocks.{cfg.n_layers - 1}.v_write.weight",
|
||
# }
|
||
# missing = [
|
||
# name
|
||
# for name, p in model.named_parameters()
|
||
# if p.grad is None and name not in last_writes
|
||
# ]
|
||
# if missing:
|
||
# print(f"WARNING — unexpected missing gradients: {missing}")
|
||
# else:
|
||
# print("All parameters received gradients (excluding last-block writes).")
|
||
|
||
# # Spot-check: MoE gate weights must receive gradients (through balance loss P_i)
|
||
# gate0_grad = model.blocks[0].moe.gate.weight.grad
|
||
# assert gate0_grad is not None, "blocks[0].moe.gate.weight has no gradient!"
|
||
# print(f"blocks[0].moe.gate.weight grad norm : {gate0_grad.norm().item():.6f}")
|
||
|
||
# # Spot-check: depth write projections gradient flows from layer ≥ 1 depth reads
|
||
# k0_grad = model.blocks[0].k_write.weight.grad
|
||
# assert k0_grad is not None, "blocks[0].k_write.weight has no gradient!"
|
||
# print(f"blocks[0].k_write.weight grad norm : {k0_grad.norm().item():.6f}")
|
||
|
||
# print("Smoke test passed.")
|