mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
just use adam for now in training maybe add muon later
This commit is contained in:
parent
137cd8832e
commit
537b116b3e
@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenMythos pretraining on FineWeb-Edu with Muon optimizer.
|
||||
OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW.
|
||||
|
||||
Single GPU:
|
||||
python train.py
|
||||
python training/3b_fine_web_edu.py
|
||||
|
||||
Multi-GPU (auto-detects GPU count):
|
||||
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py
|
||||
Multi-GPU:
|
||||
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -15,38 +15,23 @@ import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
MixedPrecision,
|
||||
FullStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
|
||||
from contextlib import nullcontext
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from open_mythos import OpenMythos
|
||||
from open_mythos.main import TransformerBlock, RecurrentBlock
|
||||
from open_mythos.variants import mythos_3b
|
||||
from open_mythos.tokenizer import MythosTokenizer
|
||||
from torch.optim import Muon
|
||||
|
||||
|
||||
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
|
||||
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
|
||||
muon_params, adamw_params = [], []
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if (
|
||||
p.ndim >= 2
|
||||
and "embed" not in name
|
||||
and "norm" not in name
|
||||
and "scale" not in name
|
||||
):
|
||||
muon_params.append(p)
|
||||
else:
|
||||
adamw_params.append(p)
|
||||
muon = Muon(muon_params, lr=muon_lr)
|
||||
adamw = torch.optim.AdamW(
|
||||
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||
)
|
||||
return muon, adamw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -67,7 +52,6 @@ class FineWebEduDataset(IterableDataset):
|
||||
num_workers = worker.num_workers if worker else 1
|
||||
worker_id = worker.id if worker else 0
|
||||
|
||||
# shard first by DDP rank, then by dataloader worker
|
||||
total_shards = self.world_size * num_workers
|
||||
shard_index = self.rank * num_workers + worker_id
|
||||
|
||||
@ -111,8 +95,7 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------------
|
||||
# Distributed init — works for single GPU (python train.py)
|
||||
# and multi-GPU (torchrun --nproc_per_node=N train.py)
|
||||
# Distributed init
|
||||
# ------------------------------------------------------------------
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||
if ddp:
|
||||
@ -130,10 +113,7 @@ def main():
|
||||
master = rank == 0
|
||||
|
||||
if master:
|
||||
n_gpu = torch.cuda.device_count()
|
||||
print(
|
||||
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
|
||||
)
|
||||
print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
@ -148,14 +128,13 @@ def main():
|
||||
# Hyperparameters
|
||||
# ------------------------------------------------------------------
|
||||
seq_len = 2048
|
||||
micro_batch = 4 # sequences per GPU per grad-accum step
|
||||
target_tokens = 30_000_000_000 # 30B token run
|
||||
micro_batch = 4
|
||||
target_tokens = 30_000_000_000
|
||||
grad_accum = max(1, 256 // (world_size * micro_batch))
|
||||
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
|
||||
total_steps = target_tokens // global_batch_tok
|
||||
warmup_steps = 2000
|
||||
muon_lr = 0.02
|
||||
adamw_lr = 3e-4
|
||||
lr = 3e-4
|
||||
wd = 0.1
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
@ -169,37 +148,52 @@ def main():
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model — override vocab_size to match tokenizer
|
||||
# Model
|
||||
# ------------------------------------------------------------------
|
||||
cfg = mythos_3b()
|
||||
cfg.vocab_size = vocab_size
|
||||
cfg.max_seq_len = seq_len
|
||||
|
||||
model = OpenMythos(cfg).to(device)
|
||||
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
||||
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||
|
||||
# Mixed precision
|
||||
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
||||
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16
|
||||
model = OpenMythos(cfg)
|
||||
|
||||
if ddp:
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=amp_dtype,
|
||||
reduce_dtype=amp_dtype,
|
||||
buffer_dtype=amp_dtype,
|
||||
)
|
||||
wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock})
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
mixed_precision=mp_policy,
|
||||
auto_wrap_policy=wrap_policy,
|
||||
device_id=local_rank,
|
||||
)
|
||||
else:
|
||||
model = model.to(device)
|
||||
amp_ctx = (
|
||||
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
|
||||
if "cuda" in device
|
||||
else nullcontext()
|
||||
)
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
|
||||
|
||||
# FSDP handles its own mixed precision; only need autocast for single-GPU
|
||||
amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined]
|
||||
|
||||
if master:
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
|
||||
|
||||
if ddp:
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
raw_model = model.module if ddp else model
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optimizers
|
||||
# Optimizer
|
||||
# ------------------------------------------------------------------
|
||||
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dataset + DataLoader
|
||||
@ -219,15 +213,11 @@ def main():
|
||||
step = 0
|
||||
|
||||
while step < total_steps:
|
||||
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1)
|
||||
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1)
|
||||
for g in muon.param_groups:
|
||||
g["lr"] = cur_muon_lr
|
||||
for g in adamw.param_groups:
|
||||
g["lr"] = cur_adamw_lr
|
||||
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
||||
for g in optimizer.param_groups:
|
||||
g["lr"] = cur_lr
|
||||
|
||||
muon.zero_grad()
|
||||
adamw.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
loss_accum = 0.0
|
||||
|
||||
for micro_step in range(grad_accum):
|
||||
@ -237,9 +227,9 @@ def main():
|
||||
data_iter = iter(loader)
|
||||
x, y = next(data_iter)
|
||||
|
||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
|
||||
x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
||||
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
||||
|
||||
# Defer DDP gradient sync until the last micro-step
|
||||
sync = (
|
||||
nullcontext()
|
||||
if (not ddp or micro_step == grad_accum - 1)
|
||||
@ -252,17 +242,11 @@ def main():
|
||||
)
|
||||
loss = loss / grad_accum
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
loss.backward()
|
||||
loss_accum += loss.item()
|
||||
|
||||
# Unscale, clip, step both optimizers
|
||||
scaler.unscale_(muon)
|
||||
scaler.unscale_(adamw)
|
||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
scaler.step(muon)
|
||||
scaler.step(adamw)
|
||||
scaler.update()
|
||||
|
||||
optimizer.step()
|
||||
step += 1
|
||||
|
||||
if master and step % log_every == 0:
|
||||
@ -271,19 +255,27 @@ def main():
|
||||
tokens_seen = step * global_batch_tok
|
||||
print(
|
||||
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
||||
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} "
|
||||
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen"
|
||||
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
|
||||
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if master and step % ckpt_every == 0:
|
||||
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||
if ddp:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||
):
|
||||
model_state = model.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
torch.save(
|
||||
{
|
||||
"step": step,
|
||||
"model": raw_model.state_dict(),
|
||||
"muon": muon.state_dict(),
|
||||
"adamw": adamw.state_dict(),
|
||||
"model": model_state,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"cfg": cfg,
|
||||
"vocab_size": vocab_size,
|
||||
},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user