just use adam for now in training maybe add muon later

This commit is contained in:
Kye Gomez 2026-04-19 23:34:58 -04:00
parent 137cd8832e
commit 537b116b3e

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
OpenMythos pretraining on FineWeb-Edu with Muon optimizer. OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW.
Single GPU: Single GPU:
python train.py python training/3b_fine_web_edu.py
Multi-GPU (auto-detects GPU count): Multi-GPU:
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
""" """
import os import os
@ -15,38 +15,23 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.utils.data import IterableDataset, DataLoader, get_worker_info from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from contextlib import nullcontext from contextlib import nullcontext
from datasets import load_dataset from datasets import load_dataset
from open_mythos import OpenMythos from open_mythos import OpenMythos
from open_mythos.main import TransformerBlock, RecurrentBlock
from open_mythos.variants import mythos_3b from open_mythos.variants import mythos_3b
from open_mythos.tokenizer import MythosTokenizer from open_mythos.tokenizer import MythosTokenizer
from torch.optim import Muon
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
muon_params, adamw_params = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if (
p.ndim >= 2
and "embed" not in name
and "norm" not in name
and "scale" not in name
):
muon_params.append(p)
else:
adamw_params.append(p)
muon = Muon(muon_params, lr=muon_lr)
adamw = torch.optim.AdamW(
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
return muon, adamw
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -67,7 +52,6 @@ class FineWebEduDataset(IterableDataset):
num_workers = worker.num_workers if worker else 1 num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0 worker_id = worker.id if worker else 0
# shard first by DDP rank, then by dataloader worker
total_shards = self.world_size * num_workers total_shards = self.world_size * num_workers
shard_index = self.rank * num_workers + worker_id shard_index = self.rank * num_workers + worker_id
@ -111,8 +95,7 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
def main(): def main():
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Distributed init — works for single GPU (python train.py) # Distributed init
# and multi-GPU (torchrun --nproc_per_node=N train.py)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
ddp = int(os.environ.get("RANK", -1)) != -1 ddp = int(os.environ.get("RANK", -1)) != -1
if ddp: if ddp:
@ -130,10 +113,7 @@ def main():
master = rank == 0 master = rank == 0
if master: if master:
n_gpu = torch.cuda.device_count() print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}")
print(
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Tokenizer # Tokenizer
@ -148,14 +128,13 @@ def main():
# Hyperparameters # Hyperparameters
# ------------------------------------------------------------------ # ------------------------------------------------------------------
seq_len = 2048 seq_len = 2048
micro_batch = 4 # sequences per GPU per grad-accum step micro_batch = 4
target_tokens = 30_000_000_000 # 30B token run target_tokens = 30_000_000_000
grad_accum = max(1, 256 // (world_size * micro_batch)) grad_accum = max(1, 256 // (world_size * micro_batch))
global_batch_tok = world_size * micro_batch * grad_accum * seq_len global_batch_tok = world_size * micro_batch * grad_accum * seq_len
total_steps = target_tokens // global_batch_tok total_steps = target_tokens // global_batch_tok
warmup_steps = 2000 warmup_steps = 2000
muon_lr = 0.02 lr = 3e-4
adamw_lr = 3e-4
wd = 0.1 wd = 0.1
log_every = 10 log_every = 10
ckpt_every = 1000 ckpt_every = 1000
@ -169,37 +148,52 @@ def main():
) )
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Model — override vocab_size to match tokenizer # Model
# ------------------------------------------------------------------ # ------------------------------------------------------------------
cfg = mythos_3b() cfg = mythos_3b()
cfg.vocab_size = vocab_size cfg.vocab_size = vocab_size
cfg.max_seq_len = seq_len cfg.max_seq_len = seq_len
model = OpenMythos(cfg).to(device) bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
# Mixed precision model = OpenMythos(cfg)
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16 if ddp:
mp_policy = MixedPrecision(
param_dtype=amp_dtype,
reduce_dtype=amp_dtype,
buffer_dtype=amp_dtype,
)
wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock})
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=wrap_policy,
device_id=local_rank,
)
else:
model = model.to(device)
amp_ctx = ( amp_ctx = (
torch.amp.autocast(device_type="cuda", dtype=amp_dtype) torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
if "cuda" in device if "cuda" in device
else nullcontext() else nullcontext()
) )
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
# FSDP handles its own mixed precision; only need autocast for single-GPU
amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined]
if master: if master:
n_params = sum(p.numel() for p in model.parameters()) n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
if ddp:
model = DDP(model, device_ids=[local_rank])
raw_model = model.module if ddp else model
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Optimizers # Optimizer
# ------------------------------------------------------------------ # ------------------------------------------------------------------
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd) optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Dataset + DataLoader # Dataset + DataLoader
@ -219,15 +213,11 @@ def main():
step = 0 step = 0
while step < total_steps: while step < total_steps:
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1) cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1) for g in optimizer.param_groups:
for g in muon.param_groups: g["lr"] = cur_lr
g["lr"] = cur_muon_lr
for g in adamw.param_groups:
g["lr"] = cur_adamw_lr
muon.zero_grad() optimizer.zero_grad()
adamw.zero_grad()
loss_accum = 0.0 loss_accum = 0.0
for micro_step in range(grad_accum): for micro_step in range(grad_accum):
@ -237,9 +227,9 @@ def main():
data_iter = iter(loader) data_iter = iter(loader)
x, y = next(data_iter) x, y = next(data_iter)
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
# Defer DDP gradient sync until the last micro-step
sync = ( sync = (
nullcontext() nullcontext()
if (not ddp or micro_step == grad_accum - 1) if (not ddp or micro_step == grad_accum - 1)
@ -252,17 +242,11 @@ def main():
) )
loss = loss / grad_accum loss = loss / grad_accum
scaler.scale(loss).backward() loss.backward()
loss_accum += loss.item() loss_accum += loss.item()
# Unscale, clip, step both optimizers
scaler.unscale_(muon)
scaler.unscale_(adamw)
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(muon) optimizer.step()
scaler.step(adamw)
scaler.update()
step += 1 step += 1
if master and step % log_every == 0: if master and step % log_every == 0:
@ -271,19 +255,27 @@ def main():
tokens_seen = step * global_batch_tok tokens_seen = step * global_batch_tok
print( print(
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} " f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen" f"| {tokens_seen / 1e9:.1f}B tokens seen"
) )
t0 = time.perf_counter() t0 = time.perf_counter()
if master and step % ckpt_every == 0: if master and step % ckpt_every == 0:
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
if ddp:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
else:
model_state = model.state_dict()
torch.save( torch.save(
{ {
"step": step, "step": step,
"model": raw_model.state_dict(), "model": model_state,
"muon": muon.state_dict(), "optimizer": optimizer.state_dict(),
"adamw": adamw.state_dict(),
"cfg": cfg, "cfg": cfg,
"vocab_size": vocab_size, "vocab_size": vocab_size,
}, },