mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
just use adam for now in training maybe add muon later
This commit is contained in:
parent
137cd8832e
commit
537b116b3e
@ -1,12 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
OpenMythos pretraining on FineWeb-Edu with Muon optimizer.
|
OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW.
|
||||||
|
|
||||||
Single GPU:
|
Single GPU:
|
||||||
python train.py
|
python training/3b_fine_web_edu.py
|
||||||
|
|
||||||
Multi-GPU (auto-detects GPU count):
|
Multi-GPU:
|
||||||
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py
|
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -15,38 +15,23 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.distributed.fsdp import (
|
||||||
|
FullyShardedDataParallel as FSDP,
|
||||||
|
ShardingStrategy,
|
||||||
|
MixedPrecision,
|
||||||
|
FullStateDictConfig,
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||||
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
|
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from open_mythos import OpenMythos
|
from open_mythos import OpenMythos
|
||||||
|
from open_mythos.main import TransformerBlock, RecurrentBlock
|
||||||
from open_mythos.variants import mythos_3b
|
from open_mythos.variants import mythos_3b
|
||||||
from open_mythos.tokenizer import MythosTokenizer
|
from open_mythos.tokenizer import MythosTokenizer
|
||||||
from torch.optim import Muon
|
|
||||||
|
|
||||||
|
|
||||||
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
|
|
||||||
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
|
|
||||||
muon_params, adamw_params = [], []
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if not p.requires_grad:
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
p.ndim >= 2
|
|
||||||
and "embed" not in name
|
|
||||||
and "norm" not in name
|
|
||||||
and "scale" not in name
|
|
||||||
):
|
|
||||||
muon_params.append(p)
|
|
||||||
else:
|
|
||||||
adamw_params.append(p)
|
|
||||||
muon = Muon(muon_params, lr=muon_lr)
|
|
||||||
adamw = torch.optim.AdamW(
|
|
||||||
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
|
||||||
)
|
|
||||||
return muon, adamw
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -67,7 +52,6 @@ class FineWebEduDataset(IterableDataset):
|
|||||||
num_workers = worker.num_workers if worker else 1
|
num_workers = worker.num_workers if worker else 1
|
||||||
worker_id = worker.id if worker else 0
|
worker_id = worker.id if worker else 0
|
||||||
|
|
||||||
# shard first by DDP rank, then by dataloader worker
|
|
||||||
total_shards = self.world_size * num_workers
|
total_shards = self.world_size * num_workers
|
||||||
shard_index = self.rank * num_workers + worker_id
|
shard_index = self.rank * num_workers + worker_id
|
||||||
|
|
||||||
@ -111,8 +95,7 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Distributed init — works for single GPU (python train.py)
|
# Distributed init
|
||||||
# and multi-GPU (torchrun --nproc_per_node=N train.py)
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||||
if ddp:
|
if ddp:
|
||||||
@ -130,10 +113,7 @@ def main():
|
|||||||
master = rank == 0
|
master = rank == 0
|
||||||
|
|
||||||
if master:
|
if master:
|
||||||
n_gpu = torch.cuda.device_count()
|
print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}")
|
||||||
print(
|
|
||||||
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
@ -148,14 +128,13 @@ def main():
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
seq_len = 2048
|
seq_len = 2048
|
||||||
micro_batch = 4 # sequences per GPU per grad-accum step
|
micro_batch = 4
|
||||||
target_tokens = 30_000_000_000 # 30B token run
|
target_tokens = 30_000_000_000
|
||||||
grad_accum = max(1, 256 // (world_size * micro_batch))
|
grad_accum = max(1, 256 // (world_size * micro_batch))
|
||||||
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
|
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
|
||||||
total_steps = target_tokens // global_batch_tok
|
total_steps = target_tokens // global_batch_tok
|
||||||
warmup_steps = 2000
|
warmup_steps = 2000
|
||||||
muon_lr = 0.02
|
lr = 3e-4
|
||||||
adamw_lr = 3e-4
|
|
||||||
wd = 0.1
|
wd = 0.1
|
||||||
log_every = 10
|
log_every = 10
|
||||||
ckpt_every = 1000
|
ckpt_every = 1000
|
||||||
@ -169,37 +148,52 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Model — override vocab_size to match tokenizer
|
# Model
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
cfg = mythos_3b()
|
cfg = mythos_3b()
|
||||||
cfg.vocab_size = vocab_size
|
cfg.vocab_size = vocab_size
|
||||||
cfg.max_seq_len = seq_len
|
cfg.max_seq_len = seq_len
|
||||||
|
|
||||||
model = OpenMythos(cfg).to(device)
|
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
||||||
|
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||||
|
|
||||||
# Mixed precision
|
model = OpenMythos(cfg)
|
||||||
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
|
||||||
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16
|
if ddp:
|
||||||
|
mp_policy = MixedPrecision(
|
||||||
|
param_dtype=amp_dtype,
|
||||||
|
reduce_dtype=amp_dtype,
|
||||||
|
buffer_dtype=amp_dtype,
|
||||||
|
)
|
||||||
|
wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock})
|
||||||
|
model = FSDP(
|
||||||
|
model,
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
mixed_precision=mp_policy,
|
||||||
|
auto_wrap_policy=wrap_policy,
|
||||||
|
device_id=local_rank,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = model.to(device)
|
||||||
amp_ctx = (
|
amp_ctx = (
|
||||||
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
|
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
|
||||||
if "cuda" in device
|
if "cuda" in device
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
|
|
||||||
|
# FSDP handles its own mixed precision; only need autocast for single-GPU
|
||||||
|
amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined]
|
||||||
|
|
||||||
if master:
|
if master:
|
||||||
n_params = sum(p.numel() for p in model.parameters())
|
n_params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
|
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
|
||||||
|
|
||||||
if ddp:
|
|
||||||
model = DDP(model, device_ids=[local_rank])
|
|
||||||
|
|
||||||
raw_model = model.module if ddp else model
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Optimizers
|
# Optimizer
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd)
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Dataset + DataLoader
|
# Dataset + DataLoader
|
||||||
@ -219,15 +213,11 @@ def main():
|
|||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
while step < total_steps:
|
while step < total_steps:
|
||||||
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1)
|
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
||||||
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1)
|
for g in optimizer.param_groups:
|
||||||
for g in muon.param_groups:
|
g["lr"] = cur_lr
|
||||||
g["lr"] = cur_muon_lr
|
|
||||||
for g in adamw.param_groups:
|
|
||||||
g["lr"] = cur_adamw_lr
|
|
||||||
|
|
||||||
muon.zero_grad()
|
optimizer.zero_grad()
|
||||||
adamw.zero_grad()
|
|
||||||
loss_accum = 0.0
|
loss_accum = 0.0
|
||||||
|
|
||||||
for micro_step in range(grad_accum):
|
for micro_step in range(grad_accum):
|
||||||
@ -237,9 +227,9 @@ def main():
|
|||||||
data_iter = iter(loader)
|
data_iter = iter(loader)
|
||||||
x, y = next(data_iter)
|
x, y = next(data_iter)
|
||||||
|
|
||||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
|
x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
||||||
|
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
||||||
|
|
||||||
# Defer DDP gradient sync until the last micro-step
|
|
||||||
sync = (
|
sync = (
|
||||||
nullcontext()
|
nullcontext()
|
||||||
if (not ddp or micro_step == grad_accum - 1)
|
if (not ddp or micro_step == grad_accum - 1)
|
||||||
@ -252,17 +242,11 @@ def main():
|
|||||||
)
|
)
|
||||||
loss = loss / grad_accum
|
loss = loss / grad_accum
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
loss.backward()
|
||||||
loss_accum += loss.item()
|
loss_accum += loss.item()
|
||||||
|
|
||||||
# Unscale, clip, step both optimizers
|
|
||||||
scaler.unscale_(muon)
|
|
||||||
scaler.unscale_(adamw)
|
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
scaler.step(muon)
|
optimizer.step()
|
||||||
scaler.step(adamw)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
if master and step % log_every == 0:
|
if master and step % log_every == 0:
|
||||||
@ -271,19 +255,27 @@ def main():
|
|||||||
tokens_seen = step * global_batch_tok
|
tokens_seen = step * global_batch_tok
|
||||||
print(
|
print(
|
||||||
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
||||||
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} "
|
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
|
||||||
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen"
|
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
||||||
)
|
)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
if master and step % ckpt_every == 0:
|
if master and step % ckpt_every == 0:
|
||||||
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||||
|
if ddp:
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model,
|
||||||
|
StateDictType.FULL_STATE_DICT,
|
||||||
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||||
|
):
|
||||||
|
model_state = model.state_dict()
|
||||||
|
else:
|
||||||
|
model_state = model.state_dict()
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
"step": step,
|
"step": step,
|
||||||
"model": raw_model.state_dict(),
|
"model": model_state,
|
||||||
"muon": muon.state_dict(),
|
"optimizer": optimizer.state_dict(),
|
||||||
"adamw": adamw.state_dict(),
|
|
||||||
"cfg": cfg,
|
"cfg": cfg,
|
||||||
"vocab_size": vocab_size,
|
"vocab_size": vocab_size,
|
||||||
},
|
},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user