just use adam for now in training maybe add muon later

This commit is contained in:
Kye Gomez 2026-04-19 23:34:58 -04:00
parent 137cd8832e
commit 537b116b3e

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3
"""
OpenMythos pretraining on FineWeb-Edu with Muon optimizer.
OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW.
Single GPU:
python train.py
python training/3b_fine_web_edu.py
Multi-GPU (auto-detects GPU count):
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py
Multi-GPU:
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
"""
import os
@ -15,38 +15,23 @@ import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from contextlib import nullcontext
from datasets import load_dataset
from open_mythos import OpenMythos
from open_mythos.main import TransformerBlock, RecurrentBlock
from open_mythos.variants import mythos_3b
from open_mythos.tokenizer import MythosTokenizer
from torch.optim import Muon
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
muon_params, adamw_params = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if (
p.ndim >= 2
and "embed" not in name
and "norm" not in name
and "scale" not in name
):
muon_params.append(p)
else:
adamw_params.append(p)
muon = Muon(muon_params, lr=muon_lr)
adamw = torch.optim.AdamW(
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
return muon, adamw
# ---------------------------------------------------------------------------
@ -67,7 +52,6 @@ class FineWebEduDataset(IterableDataset):
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0
# shard first by DDP rank, then by dataloader worker
total_shards = self.world_size * num_workers
shard_index = self.rank * num_workers + worker_id
@ -111,8 +95,7 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
def main():
# ------------------------------------------------------------------
# Distributed init — works for single GPU (python train.py)
# and multi-GPU (torchrun --nproc_per_node=N train.py)
# Distributed init
# ------------------------------------------------------------------
ddp = int(os.environ.get("RANK", -1)) != -1
if ddp:
@ -130,10 +113,7 @@ def main():
master = rank == 0
if master:
n_gpu = torch.cuda.device_count()
print(
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
)
print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}")
# ------------------------------------------------------------------
# Tokenizer
@ -148,14 +128,13 @@ def main():
# Hyperparameters
# ------------------------------------------------------------------
seq_len = 2048
micro_batch = 4 # sequences per GPU per grad-accum step
target_tokens = 30_000_000_000 # 30B token run
micro_batch = 4
target_tokens = 30_000_000_000
grad_accum = max(1, 256 // (world_size * micro_batch))
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
total_steps = target_tokens // global_batch_tok
warmup_steps = 2000
muon_lr = 0.02
adamw_lr = 3e-4
lr = 3e-4
wd = 0.1
log_every = 10
ckpt_every = 1000
@ -169,37 +148,52 @@ def main():
)
# ------------------------------------------------------------------
# Model — override vocab_size to match tokenizer
# Model
# ------------------------------------------------------------------
cfg = mythos_3b()
cfg.vocab_size = vocab_size
cfg.max_seq_len = seq_len
model = OpenMythos(cfg).to(device)
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
# Mixed precision
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16
amp_ctx = (
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
if "cuda" in device
else nullcontext()
)
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
model = OpenMythos(cfg)
if ddp:
mp_policy = MixedPrecision(
param_dtype=amp_dtype,
reduce_dtype=amp_dtype,
buffer_dtype=amp_dtype,
)
wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock})
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=wrap_policy,
device_id=local_rank,
)
else:
model = model.to(device)
amp_ctx = (
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
if "cuda" in device
else nullcontext()
)
# FSDP handles its own mixed precision; only need autocast for single-GPU
amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined]
if master:
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
if ddp:
model = DDP(model, device_ids=[local_rank])
raw_model = model.module if ddp else model
# ------------------------------------------------------------------
# Optimizers
# Optimizer
# ------------------------------------------------------------------
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd)
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
# ------------------------------------------------------------------
# Dataset + DataLoader
@ -219,15 +213,11 @@ def main():
step = 0
while step < total_steps:
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1)
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1)
for g in muon.param_groups:
g["lr"] = cur_muon_lr
for g in adamw.param_groups:
g["lr"] = cur_adamw_lr
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
for g in optimizer.param_groups:
g["lr"] = cur_lr
muon.zero_grad()
adamw.zero_grad()
optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum):
@ -237,9 +227,9 @@ def main():
data_iter = iter(loader)
x, y = next(data_iter)
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
# Defer DDP gradient sync until the last micro-step
sync = (
nullcontext()
if (not ddp or micro_step == grad_accum - 1)
@ -252,17 +242,11 @@ def main():
)
loss = loss / grad_accum
scaler.scale(loss).backward()
loss.backward()
loss_accum += loss.item()
# Unscale, clip, step both optimizers
scaler.unscale_(muon)
scaler.unscale_(adamw)
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(muon)
scaler.step(adamw)
scaler.update()
optimizer.step()
step += 1
if master and step % log_every == 0:
@ -271,19 +255,27 @@ def main():
tokens_seen = step * global_batch_tok
print(
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} "
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen"
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
f"| {tokens_seen / 1e9:.1f}B tokens seen"
)
t0 = time.perf_counter()
if master and step % ckpt_every == 0:
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
if ddp:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
else:
model_state = model.state_dict()
torch.save(
{
"step": step,
"model": raw_model.state_dict(),
"muon": muon.state_dict(),
"adamw": adamw.state_dict(),
"model": model_state,
"optimizer": optimizer.state_dict(),
"cfg": cfg,
"vocab_size": vocab_size,
},