mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
fix training
This commit is contained in:
parent
289981ba01
commit
eae0f04b8e
@ -41,7 +41,30 @@ from open_mythos.tokenizer import MythosTokenizer
|
||||
|
||||
|
||||
class FineWebEduDataset(IterableDataset):
|
||||
"""
|
||||
Streaming FineWeb-Edu loader yielding fixed-length (input, target) pairs.
|
||||
|
||||
FineWeb-Edu is trillions of tokens, so `streaming=True` pulls shards on
|
||||
demand instead of materializing to disk. Sharding is two-dimensional —
|
||||
`world_size` ranks × `num_workers` DataLoader workers per rank — and each
|
||||
`(rank, worker_id)` deterministically owns one shard of the global stream.
|
||||
That gives disjoint coverage without any cross-process coordination.
|
||||
|
||||
Streaming datasets are not seekable, so a resumed run re-enters its shard
|
||||
from the beginning. Acceptable at pretraining scale: the chance of
|
||||
re-playing the same tokens before the run ends is negligible versus the
|
||||
cost of a true resumable loader.
|
||||
"""
|
||||
|
||||
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
|
||||
"""
|
||||
Args:
|
||||
encoding -- tokenizer exposing `.encode(str) -> list[int]`
|
||||
seq_len -- context length; every yielded pair has this many tokens
|
||||
subset -- FineWeb-Edu config name (e.g. "sample-10BT", "default")
|
||||
rank -- global rank of this process within the distributed job
|
||||
world_size -- total number of distributed processes
|
||||
"""
|
||||
self.encoding = encoding
|
||||
self.seq_len = seq_len
|
||||
self.subset = subset
|
||||
@ -49,6 +72,16 @@ class FineWebEduDataset(IterableDataset):
|
||||
self.world_size = world_size
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Yield `(input_ids, target_ids)` tensors of length `seq_len` forever.
|
||||
|
||||
Inputs and targets are shifted by one for next-token prediction —
|
||||
`target[i] == input[i + 1]`. Documents are concatenated into a rolling
|
||||
buffer and sliced into fixed-length chunks, packing short docs together
|
||||
and splitting long ones. This keeps every step at the same shape,
|
||||
which under FSDP avoids recompute from variable-length inputs and
|
||||
removes the need for a pad-aware attention mask.
|
||||
"""
|
||||
worker = get_worker_info()
|
||||
num_workers = worker.num_workers if worker else 1
|
||||
worker_id = worker.id if worker else 0
|
||||
@ -81,6 +114,32 @@ class FineWebEduDataset(IterableDataset):
|
||||
|
||||
|
||||
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
||||
"""
|
||||
Linear warmup → half-cosine decay to `min_lr`.
|
||||
|
||||
Standard language-model pretraining schedule. The warmup phase prevents
|
||||
Adam's second-moment estimate from collapsing to a huge LR in the first
|
||||
few steps when gradients are noisy. The cosine tail lets the model make
|
||||
small, increasingly conservative updates near the end of training rather
|
||||
than crashing to `min_lr` at a fixed step.
|
||||
|
||||
Behavior by region:
|
||||
step < warmup → linear ramp 0 → max_lr
|
||||
warmup ≤ step < total → cosine decay max_lr → min_lr
|
||||
step ≥ total → clamped at min_lr (safety for
|
||||
off-by-one step counters at the end
|
||||
of training)
|
||||
|
||||
Args:
|
||||
step -- current global optimizer step (0-indexed)
|
||||
warmup -- number of warmup steps before cosine decay begins
|
||||
total -- step at which the cosine reaches `min_lr`
|
||||
max_lr -- peak learning rate reached at the end of warmup
|
||||
min_lr -- floor learning rate at and after `total` steps
|
||||
|
||||
Returns:
|
||||
Scalar learning rate for this step.
|
||||
"""
|
||||
if step < warmup:
|
||||
return max_lr * step / warmup
|
||||
if step >= total:
|
||||
@ -89,12 +148,194 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpointing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _list_ckpts(ckpt_dir: str) -> list[str]:
|
||||
"""
|
||||
Return checkpoint paths in `ckpt_dir` sorted oldest → newest.
|
||||
|
||||
Relies on the zero-padded `step_{0000000}.pt` filename convention so
|
||||
lexicographic sort matches chronological order. Changing the filename
|
||||
format elsewhere without updating the pad width would silently break
|
||||
both `keep_last` pruning and resume-latest on startup, since both pick
|
||||
the last element of this list.
|
||||
|
||||
Args:
|
||||
ckpt_dir -- directory to scan; missing directory returns []
|
||||
|
||||
Returns:
|
||||
Sorted list of absolute paths to matching checkpoint files.
|
||||
"""
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
return []
|
||||
return sorted(
|
||||
os.path.join(ckpt_dir, f)
|
||||
for f in os.listdir(ckpt_dir)
|
||||
if f.startswith("step_") and f.endswith(".pt")
|
||||
)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
step: int,
|
||||
cfg,
|
||||
vocab_size: int,
|
||||
ckpt_dir: str,
|
||||
ddp: bool,
|
||||
master: bool,
|
||||
keep_last: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
Gather full model + optimizer state, write atomically, prune old files.
|
||||
|
||||
Under FSDP both states are collected inside a single FULL_STATE_DICT
|
||||
context so the optim-state tensors bind to fully-unsharded parameters;
|
||||
mixing contexts between model and optimizer has caused silent divergence
|
||||
on resume in past torch versions. The temp-file + os.replace write means
|
||||
a kill mid-save leaves the previous checkpoint intact instead of a
|
||||
truncated .pt file. Non-master ranks participate in the FSDP gather
|
||||
(otherwise the collective would hang) but exit before touching disk.
|
||||
|
||||
Args:
|
||||
model -- FSDP-wrapped (ddp=True) or raw (ddp=False) model
|
||||
optimizer -- the optimizer whose state should round-trip with the model
|
||||
step -- global step number; encoded zero-padded into the filename
|
||||
cfg -- model config object; saved so downstream eval can
|
||||
reconstruct the model without re-importing the variant
|
||||
vocab_size -- tokenizer vocab size at train time; saved for sanity-check
|
||||
on load against a (possibly updated) tokenizer
|
||||
ckpt_dir -- directory to write into; created if missing
|
||||
ddp -- True if FSDP path; False for single-GPU / CPU
|
||||
master -- whether this rank writes to disk (rank 0 only)
|
||||
keep_last -- number of most-recent checkpoints to retain; older ones
|
||||
are unlinked after a successful write
|
||||
|
||||
Returns:
|
||||
None. Writes to disk as a side effect on master rank.
|
||||
"""
|
||||
if ddp:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||
):
|
||||
model_state = model.state_dict()
|
||||
optim_state = FSDP.optim_state_dict(model, optimizer)
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
|
||||
if not master:
|
||||
return
|
||||
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||
tmp_path = final_path + ".tmp"
|
||||
torch.save(
|
||||
{
|
||||
"step": step,
|
||||
"model": model_state,
|
||||
"optimizer": optim_state,
|
||||
"cfg": cfg,
|
||||
"vocab_size": vocab_size,
|
||||
},
|
||||
tmp_path,
|
||||
)
|
||||
os.replace(tmp_path, final_path)
|
||||
|
||||
for old in _list_ckpts(ckpt_dir)[:-keep_last]:
|
||||
try:
|
||||
os.remove(old)
|
||||
except OSError as exc:
|
||||
logger.warning(f"Failed to prune old checkpoint {old}: {exc}")
|
||||
|
||||
logger.success(f"Checkpoint saved → {final_path}")
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int:
|
||||
"""
|
||||
Restore model + optimizer from disk, returning the step to resume at.
|
||||
|
||||
Every rank reads the file (`rank0_only=False` on load) so FSDP has access
|
||||
to the full state on each rank — the complement to the `rank0_only=True`
|
||||
save path. Must mirror save's single-context pattern; splitting the model
|
||||
and optimizer loads across two `state_dict_type` blocks has historically
|
||||
produced optimizer state bound to the wrong shard shapes.
|
||||
|
||||
`weights_only=False` is required because the checkpoint contains the
|
||||
pickled `cfg` dataclass — flip to `weights_only=True` only if you
|
||||
separate config out.
|
||||
|
||||
Args:
|
||||
model -- same FSDP-wrapped or raw model used during save
|
||||
optimizer -- freshly constructed optimizer to be filled in-place
|
||||
path -- absolute path to a `step_{N:07d}.pt` file produced by
|
||||
`save_checkpoint`
|
||||
ddp -- whether the model is FSDP-wrapped; must match the save run
|
||||
|
||||
Returns:
|
||||
The step number the checkpoint was taken at; the caller advances the
|
||||
training loop from this value.
|
||||
"""
|
||||
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
||||
|
||||
if ddp:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||
):
|
||||
model.load_state_dict(ckpt["model"])
|
||||
optim_state = FSDP.optim_state_dict_to_load(
|
||||
model=model,
|
||||
optim=optimizer,
|
||||
optim_state_dict=ckpt["optimizer"],
|
||||
)
|
||||
optimizer.load_state_dict(optim_state)
|
||||
else:
|
||||
model.load_state_dict(ckpt["model"])
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
|
||||
return int(ckpt["step"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
End-to-end pretraining entry point.
|
||||
|
||||
Order matters: distributed init must run before any CUDA allocation, the
|
||||
tokenizer must exist before the model is built (vocab_size flows into
|
||||
cfg), and FSDP must wrap the model before the optimizer is constructed
|
||||
(FSDP re-flattens parameters, so an optimizer built on the unwrapped
|
||||
model would track stale param objects). Resume then loads state into the
|
||||
already-constructed optimizer in-place.
|
||||
|
||||
Lifecycle:
|
||||
1. Initialize torch.distributed (NCCL) if launched under torchrun.
|
||||
2. Build tokenizer → derive vocab_size.
|
||||
3. Construct OpenMythos with the 3B variant config.
|
||||
4. Wrap in FSDP with FULL_SHARD + bf16/fp16 mixed precision (multi-GPU)
|
||||
or move to device + autocast (single-GPU).
|
||||
5. Build fused AdamW on (possibly sharded) parameters.
|
||||
6. Resume from the latest checkpoint in `ckpt_dir` if one exists.
|
||||
7. Stream FineWeb-Edu through grad-accumulation microbatches with
|
||||
cosine LR schedule, per-step logging, and periodic checkpoints.
|
||||
8. Write a final checkpoint if the last save wasn't aligned to
|
||||
`ckpt_every`, then barrier + tear down the process group.
|
||||
|
||||
All hyperparameters are literal constants in this function by design —
|
||||
pretraining runs are long-lived and each run pins exact settings; a
|
||||
CLI/config layer is deliberately avoided to keep the file self-auditable.
|
||||
"""
|
||||
# ------------------------------------------------------------------
|
||||
# Distributed init
|
||||
# ------------------------------------------------------------------
|
||||
@ -198,6 +439,22 @@ def main():
|
||||
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Resume from latest checkpoint (if any)
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming datasets are not resumable by position, so re-iterating from
|
||||
# the beginning is accepted — at pretraining scale the loss of dataset
|
||||
# position is negligible vs. the cost of discarded training steps.
|
||||
start_step = 0
|
||||
existing_ckpts = _list_ckpts(ckpt_dir)
|
||||
if existing_ckpts:
|
||||
latest = existing_ckpts[-1]
|
||||
if master:
|
||||
logger.info(f"Resuming from checkpoint: {latest}")
|
||||
start_step = load_checkpoint(model, optimizer, latest, ddp)
|
||||
if master:
|
||||
logger.success(f"Resumed at step {start_step}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dataset + DataLoader
|
||||
# ------------------------------------------------------------------
|
||||
@ -213,7 +470,7 @@ def main():
|
||||
model.train()
|
||||
data_iter = iter(loader)
|
||||
t0 = time.perf_counter()
|
||||
step = 0
|
||||
step = start_step
|
||||
|
||||
while step < total_steps:
|
||||
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
||||
@ -248,7 +505,13 @@ def main():
|
||||
loss.backward()
|
||||
loss_accum += loss.item()
|
||||
|
||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
# FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip
|
||||
# against each rank's local norm and miss the cross-shard gather.
|
||||
# FSDP.clip_grad_norm_ computes the true global norm and returns it.
|
||||
if ddp:
|
||||
grad_norm = model.clip_grad_norm_(1.0)
|
||||
else:
|
||||
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
step += 1
|
||||
|
||||
@ -258,36 +521,26 @@ def main():
|
||||
tokens_seen = step * global_batch_tok
|
||||
logger.info(
|
||||
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
||||
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
|
||||
f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
|
||||
f"| {tok_per_sec / 1e6:.2f}M tok/s "
|
||||
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if master and step % ckpt_every == 0:
|
||||
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||
logger.info(f"Saving checkpoint at step {step} → {path}")
|
||||
if ddp:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||
):
|
||||
model_state = model.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
torch.save(
|
||||
{
|
||||
"step": step,
|
||||
"model": model_state,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"cfg": cfg,
|
||||
"vocab_size": vocab_size,
|
||||
},
|
||||
path,
|
||||
if step % ckpt_every == 0:
|
||||
save_checkpoint(
|
||||
model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master
|
||||
)
|
||||
logger.success(f"Checkpoint saved → {path}")
|
||||
|
||||
# Final checkpoint — total_steps may not be divisible by ckpt_every, so
|
||||
# without this the tail of the run is lost if the schedule doesn't align.
|
||||
if step > start_step and step % ckpt_every != 0:
|
||||
save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master)
|
||||
|
||||
if ddp:
|
||||
# Barrier so no rank exits while another is still finishing its
|
||||
# checkpoint gather — avoids NCCL "process group destroyed" noise.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
if master:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user