OpenMythos/training/3b_fine_web_edu.py
Kye Gomez 5ffb897dcf [feat][training-script][add 3b fineweb-edu training
script][feat][tokenizer][add MythosTokenizer class with encode
  decode][improvement][deps][add transformers and datasets
  dependencies][docs][readme-training][add training section with run
  commands][improvement][pyproject][pin torch and add new deps]
2026-04-19 22:48:30 -04:00

303 lines
10 KiB
Python

#!/usr/bin/env python3
"""
OpenMythos pretraining on FineWeb-Edu with Muon optimizer.
Single GPU:
python train.py
Multi-GPU (auto-detects GPU count):
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py
"""
import os
import math
import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
from contextlib import nullcontext
from datasets import load_dataset
from open_mythos import OpenMythos
from open_mythos.variants import mythos_3b
from open_mythos.tokenizer import MythosTokenizer
from torch.optim import Muon
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
muon_params, adamw_params = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if (
p.ndim >= 2
and "embed" not in name
and "norm" not in name
and "scale" not in name
):
muon_params.append(p)
else:
adamw_params.append(p)
muon = Muon(muon_params, lr=muon_lr)
adamw = torch.optim.AdamW(
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
)
return muon, adamw
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class FineWebEduDataset(IterableDataset):
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
self.encoding = encoding
self.seq_len = seq_len
self.subset = subset
self.rank = rank
self.world_size = world_size
def __iter__(self):
worker = get_worker_info()
num_workers = worker.num_workers if worker else 1
worker_id = worker.id if worker else 0
# shard first by DDP rank, then by dataloader worker
total_shards = self.world_size * num_workers
shard_index = self.rank * num_workers + worker_id
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name=self.subset,
split="train",
streaming=True,
).shard(num_shards=total_shards, index=shard_index)
buf = []
for sample in ds:
buf.extend(self.encoding.encode(sample["text"]))
while len(buf) >= self.seq_len + 1:
chunk = buf[: self.seq_len + 1]
buf = buf[self.seq_len + 1 :]
yield (
torch.tensor(chunk[:-1], dtype=torch.long),
torch.tensor(chunk[1:], dtype=torch.long),
)
# ---------------------------------------------------------------------------
# LR schedule: linear warmup → cosine decay
# ---------------------------------------------------------------------------
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
if step < warmup:
return max_lr * step / warmup
if step >= total:
return min_lr
decay = (step - warmup) / (total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
# ------------------------------------------------------------------
# Distributed init — works for single GPU (python train.py)
# and multi-GPU (torchrun --nproc_per_node=N train.py)
# ------------------------------------------------------------------
ddp = int(os.environ.get("RANK", -1)) != -1
if ddp:
dist.init_process_group("nccl")
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
else:
rank = local_rank = 0
world_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
master = rank == 0
if master:
n_gpu = torch.cuda.device_count()
print(
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
)
# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
encoding = MythosTokenizer()
vocab_size = encoding.vocab_size
if master:
print(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}")
# ------------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------------
seq_len = 2048
micro_batch = 4 # sequences per GPU per grad-accum step
target_tokens = 30_000_000_000 # 30B token run
grad_accum = max(1, 256 // (world_size * micro_batch))
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
total_steps = target_tokens // global_batch_tok
warmup_steps = 2000
muon_lr = 0.02
adamw_lr = 3e-4
wd = 0.1
log_every = 10
ckpt_every = 1000
ckpt_dir = "checkpoints"
dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run
if master:
print(
f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum}\n"
f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
)
# ------------------------------------------------------------------
# Model — override vocab_size to match tokenizer
# ------------------------------------------------------------------
cfg = mythos_3b()
cfg.vocab_size = vocab_size
cfg.max_seq_len = seq_len
model = OpenMythos(cfg).to(device)
# Mixed precision
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16
amp_ctx = (
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
if "cuda" in device
else nullcontext()
)
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
if master:
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
if ddp:
model = DDP(model, device_ids=[local_rank])
raw_model = model.module if ddp else model
# ------------------------------------------------------------------
# Optimizers
# ------------------------------------------------------------------
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd)
# ------------------------------------------------------------------
# Dataset + DataLoader
# ------------------------------------------------------------------
dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size)
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True)
# ------------------------------------------------------------------
# Training loop
# ------------------------------------------------------------------
if master:
os.makedirs(ckpt_dir, exist_ok=True)
model.train()
data_iter = iter(loader)
t0 = time.perf_counter()
step = 0
while step < total_steps:
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1)
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1)
for g in muon.param_groups:
g["lr"] = cur_muon_lr
for g in adamw.param_groups:
g["lr"] = cur_adamw_lr
muon.zero_grad()
adamw.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum):
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(loader)
x, y = next(data_iter)
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
# Defer DDP gradient sync until the last micro-step
sync = (
nullcontext()
if (not ddp or micro_step == grad_accum - 1)
else model.no_sync()
)
with sync, amp_ctx:
logits = model(x)
loss = nn.functional.cross_entropy(
logits.view(-1, vocab_size), y.view(-1)
)
loss = loss / grad_accum
scaler.scale(loss).backward()
loss_accum += loss.item()
# Unscale, clip, step both optimizers
scaler.unscale_(muon)
scaler.unscale_(adamw)
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(muon)
scaler.step(adamw)
scaler.update()
step += 1
if master and step % log_every == 0:
dt = time.perf_counter() - t0
tok_per_sec = global_batch_tok * log_every / dt
tokens_seen = step * global_batch_tok
print(
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} "
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen"
)
t0 = time.perf_counter()
if master and step % ckpt_every == 0:
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
torch.save(
{
"step": step,
"model": raw_model.state_dict(),
"muon": muon.state_dict(),
"adamw": adamw.state_dict(),
"cfg": cfg,
"vocab_size": vocab_size,
},
path,
)
print(f"Checkpoint saved → {path}")
if ddp:
dist.destroy_process_group()
if master:
print("Training complete.")
if __name__ == "__main__":
main()