From eae0f04b8e0540f876f7da7f34089a6d11c1fc6e Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 20 Apr 2026 09:43:25 -0400 Subject: [PATCH] fix training --- training/3b_fine_web_edu.py | 303 +++++++++++++++++++++++++++++++++--- 1 file changed, 278 insertions(+), 25 deletions(-) diff --git a/training/3b_fine_web_edu.py b/training/3b_fine_web_edu.py index 92603f3..e980302 100644 --- a/training/3b_fine_web_edu.py +++ b/training/3b_fine_web_edu.py @@ -41,7 +41,30 @@ from open_mythos.tokenizer import MythosTokenizer class FineWebEduDataset(IterableDataset): + """ + Streaming FineWeb-Edu loader yielding fixed-length (input, target) pairs. + + FineWeb-Edu is trillions of tokens, so `streaming=True` pulls shards on + demand instead of materializing to disk. Sharding is two-dimensional — + `world_size` ranks × `num_workers` DataLoader workers per rank — and each + `(rank, worker_id)` deterministically owns one shard of the global stream. + That gives disjoint coverage without any cross-process coordination. + + Streaming datasets are not seekable, so a resumed run re-enters its shard + from the beginning. Acceptable at pretraining scale: the chance of + re-playing the same tokens before the run ends is negligible versus the + cost of a true resumable loader. + """ + def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int): + """ + Args: + encoding -- tokenizer exposing `.encode(str) -> list[int]` + seq_len -- context length; every yielded pair has this many tokens + subset -- FineWeb-Edu config name (e.g. "sample-10BT", "default") + rank -- global rank of this process within the distributed job + world_size -- total number of distributed processes + """ self.encoding = encoding self.seq_len = seq_len self.subset = subset @@ -49,6 +72,16 @@ class FineWebEduDataset(IterableDataset): self.world_size = world_size def __iter__(self): + """ + Yield `(input_ids, target_ids)` tensors of length `seq_len` forever. + + Inputs and targets are shifted by one for next-token prediction — + `target[i] == input[i + 1]`. Documents are concatenated into a rolling + buffer and sliced into fixed-length chunks, packing short docs together + and splitting long ones. This keeps every step at the same shape, + which under FSDP avoids recompute from variable-length inputs and + removes the need for a pad-aware attention mask. + """ worker = get_worker_info() num_workers = worker.num_workers if worker else 1 worker_id = worker.id if worker else 0 @@ -81,6 +114,32 @@ class FineWebEduDataset(IterableDataset): def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: + """ + Linear warmup → half-cosine decay to `min_lr`. + + Standard language-model pretraining schedule. The warmup phase prevents + Adam's second-moment estimate from collapsing to a huge LR in the first + few steps when gradients are noisy. The cosine tail lets the model make + small, increasingly conservative updates near the end of training rather + than crashing to `min_lr` at a fixed step. + + Behavior by region: + step < warmup → linear ramp 0 → max_lr + warmup ≤ step < total → cosine decay max_lr → min_lr + step ≥ total → clamped at min_lr (safety for + off-by-one step counters at the end + of training) + + Args: + step -- current global optimizer step (0-indexed) + warmup -- number of warmup steps before cosine decay begins + total -- step at which the cosine reaches `min_lr` + max_lr -- peak learning rate reached at the end of warmup + min_lr -- floor learning rate at and after `total` steps + + Returns: + Scalar learning rate for this step. + """ if step < warmup: return max_lr * step / warmup if step >= total: @@ -89,12 +148,194 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- + + +def _list_ckpts(ckpt_dir: str) -> list[str]: + """ + Return checkpoint paths in `ckpt_dir` sorted oldest → newest. + + Relies on the zero-padded `step_{0000000}.pt` filename convention so + lexicographic sort matches chronological order. Changing the filename + format elsewhere without updating the pad width would silently break + both `keep_last` pruning and resume-latest on startup, since both pick + the last element of this list. + + Args: + ckpt_dir -- directory to scan; missing directory returns [] + + Returns: + Sorted list of absolute paths to matching checkpoint files. + """ + if not os.path.isdir(ckpt_dir): + return [] + return sorted( + os.path.join(ckpt_dir, f) + for f in os.listdir(ckpt_dir) + if f.startswith("step_") and f.endswith(".pt") + ) + + +def save_checkpoint( + model, + optimizer, + step: int, + cfg, + vocab_size: int, + ckpt_dir: str, + ddp: bool, + master: bool, + keep_last: int = 3, +) -> None: + """ + Gather full model + optimizer state, write atomically, prune old files. + + Under FSDP both states are collected inside a single FULL_STATE_DICT + context so the optim-state tensors bind to fully-unsharded parameters; + mixing contexts between model and optimizer has caused silent divergence + on resume in past torch versions. The temp-file + os.replace write means + a kill mid-save leaves the previous checkpoint intact instead of a + truncated .pt file. Non-master ranks participate in the FSDP gather + (otherwise the collective would hang) but exit before touching disk. + + Args: + model -- FSDP-wrapped (ddp=True) or raw (ddp=False) model + optimizer -- the optimizer whose state should round-trip with the model + step -- global step number; encoded zero-padded into the filename + cfg -- model config object; saved so downstream eval can + reconstruct the model without re-importing the variant + vocab_size -- tokenizer vocab size at train time; saved for sanity-check + on load against a (possibly updated) tokenizer + ckpt_dir -- directory to write into; created if missing + ddp -- True if FSDP path; False for single-GPU / CPU + master -- whether this rank writes to disk (rank 0 only) + keep_last -- number of most-recent checkpoints to retain; older ones + are unlinked after a successful write + + Returns: + None. Writes to disk as a side effect on master rank. + """ + if ddp: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state = model.state_dict() + optim_state = FSDP.optim_state_dict(model, optimizer) + else: + model_state = model.state_dict() + optim_state = optimizer.state_dict() + + if not master: + return + + os.makedirs(ckpt_dir, exist_ok=True) + final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + tmp_path = final_path + ".tmp" + torch.save( + { + "step": step, + "model": model_state, + "optimizer": optim_state, + "cfg": cfg, + "vocab_size": vocab_size, + }, + tmp_path, + ) + os.replace(tmp_path, final_path) + + for old in _list_ckpts(ckpt_dir)[:-keep_last]: + try: + os.remove(old) + except OSError as exc: + logger.warning(f"Failed to prune old checkpoint {old}: {exc}") + + logger.success(f"Checkpoint saved → {final_path}") + + +def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int: + """ + Restore model + optimizer from disk, returning the step to resume at. + + Every rank reads the file (`rank0_only=False` on load) so FSDP has access + to the full state on each rank — the complement to the `rank0_only=True` + save path. Must mirror save's single-context pattern; splitting the model + and optimizer loads across two `state_dict_type` blocks has historically + produced optimizer state bound to the wrong shard shapes. + + `weights_only=False` is required because the checkpoint contains the + pickled `cfg` dataclass — flip to `weights_only=True` only if you + separate config out. + + Args: + model -- same FSDP-wrapped or raw model used during save + optimizer -- freshly constructed optimizer to be filled in-place + path -- absolute path to a `step_{N:07d}.pt` file produced by + `save_checkpoint` + ddp -- whether the model is FSDP-wrapped; must match the save run + + Returns: + The step number the checkpoint was taken at; the caller advances the + training loop from this value. + """ + ckpt = torch.load(path, map_location="cpu", weights_only=False) + + if ddp: + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=False), + ): + model.load_state_dict(ckpt["model"]) + optim_state = FSDP.optim_state_dict_to_load( + model=model, + optim=optimizer, + optim_state_dict=ckpt["optimizer"], + ) + optimizer.load_state_dict(optim_state) + else: + model.load_state_dict(ckpt["model"]) + optimizer.load_state_dict(ckpt["optimizer"]) + + return int(ckpt["step"]) + + # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): + """ + End-to-end pretraining entry point. + + Order matters: distributed init must run before any CUDA allocation, the + tokenizer must exist before the model is built (vocab_size flows into + cfg), and FSDP must wrap the model before the optimizer is constructed + (FSDP re-flattens parameters, so an optimizer built on the unwrapped + model would track stale param objects). Resume then loads state into the + already-constructed optimizer in-place. + + Lifecycle: + 1. Initialize torch.distributed (NCCL) if launched under torchrun. + 2. Build tokenizer → derive vocab_size. + 3. Construct OpenMythos with the 3B variant config. + 4. Wrap in FSDP with FULL_SHARD + bf16/fp16 mixed precision (multi-GPU) + or move to device + autocast (single-GPU). + 5. Build fused AdamW on (possibly sharded) parameters. + 6. Resume from the latest checkpoint in `ckpt_dir` if one exists. + 7. Stream FineWeb-Edu through grad-accumulation microbatches with + cosine LR schedule, per-step logging, and periodic checkpoints. + 8. Write a final checkpoint if the last save wasn't aligned to + `ckpt_every`, then barrier + tear down the process group. + + All hyperparameters are literal constants in this function by design — + pretraining runs are long-lived and each run pins exact settings; a + CLI/config layer is deliberately avoided to keep the file self-auditable. + """ # ------------------------------------------------------------------ # Distributed init # ------------------------------------------------------------------ @@ -198,6 +439,22 @@ def main(): model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True ) + # ------------------------------------------------------------------ + # Resume from latest checkpoint (if any) + # ------------------------------------------------------------------ + # Streaming datasets are not resumable by position, so re-iterating from + # the beginning is accepted — at pretraining scale the loss of dataset + # position is negligible vs. the cost of discarded training steps. + start_step = 0 + existing_ckpts = _list_ckpts(ckpt_dir) + if existing_ckpts: + latest = existing_ckpts[-1] + if master: + logger.info(f"Resuming from checkpoint: {latest}") + start_step = load_checkpoint(model, optimizer, latest, ddp) + if master: + logger.success(f"Resumed at step {start_step}") + # ------------------------------------------------------------------ # Dataset + DataLoader # ------------------------------------------------------------------ @@ -213,7 +470,7 @@ def main(): model.train() data_iter = iter(loader) t0 = time.perf_counter() - step = 0 + step = start_step while step < total_steps: cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1) @@ -248,7 +505,13 @@ def main(): loss.backward() loss_accum += loss.item() - nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip + # against each rank's local norm and miss the cross-shard gather. + # FSDP.clip_grad_norm_ computes the true global norm and returns it. + if ddp: + grad_norm = model.clip_grad_norm_(1.0) + else: + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() step += 1 @@ -258,36 +521,26 @@ def main(): tokens_seen = step * global_batch_tok logger.info( f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " - f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s " + f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s " f"| {tokens_seen / 1e9:.1f}B tokens seen" ) t0 = time.perf_counter() - if master and step % ckpt_every == 0: - path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") - logger.info(f"Saving checkpoint at step {step} → {path}") - if ddp: - with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - model_state = model.state_dict() - else: - model_state = model.state_dict() - torch.save( - { - "step": step, - "model": model_state, - "optimizer": optimizer.state_dict(), - "cfg": cfg, - "vocab_size": vocab_size, - }, - path, + if step % ckpt_every == 0: + save_checkpoint( + model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master ) - logger.success(f"Checkpoint saved → {path}") + + # Final checkpoint — total_steps may not be divisible by ckpt_every, so + # without this the tail of the run is lost if the schedule doesn't align. + if step > start_step and step % ckpt_every != 0: + save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master) if ddp: + # Barrier so no rank exits while another is still finishing its + # checkpoint gather — avoids NCCL "process group destroyed" noise. + dist.barrier() dist.destroy_process_group() if master: