From 537b116b3ede03cfe7f8888e84c50d58169b325b Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Sun, 19 Apr 2026 23:34:58 -0400 Subject: [PATCH] just use adam for now in training maybe add muon later --- training/3b_fine_web_edu.py | 152 +++++++++++++++++------------------- 1 file changed, 72 insertions(+), 80 deletions(-) diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index 601a2a0..215381d 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 """ -OpenMythos pretraining on FineWeb-Edu with Muon optimizer. +OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW. Single GPU: - python train.py + python training/3b_fine_web_edu.py -Multi-GPU (auto-detects GPU count): - torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py +Multi-GPU: + torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py """ import os @@ -15,38 +15,23 @@ import time import torch import torch.nn as nn import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + MixedPrecision, + FullStateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.utils.data import IterableDataset, DataLoader, get_worker_info from contextlib import nullcontext from datasets import load_dataset from open_mythos import OpenMythos +from open_mythos.main import TransformerBlock, RecurrentBlock from open_mythos.variants import mythos_3b from open_mythos.tokenizer import MythosTokenizer -from torch.optim import Muon - - -def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float): - """Muon for 2D weight matrices; AdamW for embeddings, norms, biases.""" - muon_params, adamw_params = [], [] - for name, p in model.named_parameters(): - if not p.requires_grad: - continue - if ( - p.ndim >= 2 - and "embed" not in name - and "norm" not in name - and "scale" not in name - ): - muon_params.append(p) - else: - adamw_params.append(p) - muon = Muon(muon_params, lr=muon_lr) - adamw = torch.optim.AdamW( - adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True - ) - return muon, adamw # --------------------------------------------------------------------------- @@ -67,7 +52,6 @@ class FineWebEduDataset(IterableDataset): num_workers = worker.num_workers if worker else 1 worker_id = worker.id if worker else 0 - # shard first by DDP rank, then by dataloader worker total_shards = self.world_size * num_workers shard_index = self.rank * num_workers + worker_id @@ -111,8 +95,7 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> def main(): # ------------------------------------------------------------------ - # Distributed init — works for single GPU (python train.py) - # and multi-GPU (torchrun --nproc_per_node=N train.py) + # Distributed init # ------------------------------------------------------------------ ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: @@ -130,10 +113,7 @@ def main(): master = rank == 0 if master: - n_gpu = torch.cuda.device_count() - print( - f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}" - ) + print(f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}") # ------------------------------------------------------------------ # Tokenizer @@ -148,14 +128,13 @@ def main(): # Hyperparameters # ------------------------------------------------------------------ seq_len = 2048 - micro_batch = 4 # sequences per GPU per grad-accum step - target_tokens = 30_000_000_000 # 30B token run + micro_batch = 4 + target_tokens = 30_000_000_000 grad_accum = max(1, 256 // (world_size * micro_batch)) global_batch_tok = world_size * micro_batch * grad_accum * seq_len total_steps = target_tokens // global_batch_tok warmup_steps = 2000 - muon_lr = 0.02 - adamw_lr = 3e-4 + lr = 3e-4 wd = 0.1 log_every = 10 ckpt_every = 1000 @@ -169,37 +148,52 @@ def main(): ) # ------------------------------------------------------------------ - # Model — override vocab_size to match tokenizer + # Model # ------------------------------------------------------------------ cfg = mythos_3b() cfg.vocab_size = vocab_size cfg.max_seq_len = seq_len - model = OpenMythos(cfg).to(device) + bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported() + amp_dtype = torch.bfloat16 if bf16_ok else torch.float16 - # Mixed precision - bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported() - amp_dtype = torch.bfloat16 if bf16_supported else torch.float16 - amp_ctx = ( - torch.amp.autocast(device_type="cuda", dtype=amp_dtype) - if "cuda" in device - else nullcontext() - ) - scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16)) + model = OpenMythos(cfg) + + if ddp: + mp_policy = MixedPrecision( + param_dtype=amp_dtype, + reduce_dtype=amp_dtype, + buffer_dtype=amp_dtype, + ) + wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock}) + model = FSDP( + model, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mp_policy, + auto_wrap_policy=wrap_policy, + device_id=local_rank, + ) + else: + model = model.to(device) + amp_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=amp_dtype) + if "cuda" in device + else nullcontext() + ) + + # FSDP handles its own mixed precision; only need autocast for single-GPU + amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined] if master: n_params = sum(p.numel() for p in model.parameters()) print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") - if ddp: - model = DDP(model, device_ids=[local_rank]) - - raw_model = model.module if ddp else model - # ------------------------------------------------------------------ - # Optimizers + # Optimizer # ------------------------------------------------------------------ - muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd) + optimizer = torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True + ) # ------------------------------------------------------------------ # Dataset + DataLoader @@ -219,15 +213,11 @@ def main(): step = 0 while step < total_steps: - cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1) - cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1) - for g in muon.param_groups: - g["lr"] = cur_muon_lr - for g in adamw.param_groups: - g["lr"] = cur_adamw_lr + cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) + for g in optimizer.param_groups: + g["lr"] = cur_lr - muon.zero_grad() - adamw.zero_grad() + optimizer.zero_grad() loss_accum = 0.0 for micro_step in range(grad_accum): @@ -237,9 +227,9 @@ def main(): data_iter = iter(loader) x, y = next(data_iter) - x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) + y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True) - # Defer DDP gradient sync until the last micro-step sync = ( nullcontext() if (not ddp or micro_step == grad_accum - 1) @@ -252,17 +242,11 @@ def main(): ) loss = loss / grad_accum - scaler.scale(loss).backward() + loss.backward() loss_accum += loss.item() - # Unscale, clip, step both optimizers - scaler.unscale_(muon) - scaler.unscale_(adamw) nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - scaler.step(muon) - scaler.step(adamw) - scaler.update() - + optimizer.step() step += 1 if master and step % log_every == 0: @@ -271,19 +255,27 @@ def main(): tokens_seen = step * global_batch_tok print( f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " - f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} " - f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen" + f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s " + f"| {tokens_seen / 1e9:.1f}B tokens seen" ) t0 = time.perf_counter() if master and step % ckpt_every == 0: path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + if ddp: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state = model.state_dict() + else: + model_state = model.state_dict() torch.save( { "step": step, - "model": raw_model.state_dict(), - "muon": muon.state_dict(), - "adamw": adamw.state_dict(), + "model": model_state, + "optimizer": optimizer.state_dict(), "cfg": cfg, "vocab_size": vocab_size, },