mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
fix training
This commit is contained in:
parent
289981ba01
commit
eae0f04b8e
@ -41,7 +41,30 @@ from open_mythos.tokenizer import MythosTokenizer
|
|||||||
|
|
||||||
|
|
||||||
class FineWebEduDataset(IterableDataset):
|
class FineWebEduDataset(IterableDataset):
|
||||||
|
"""
|
||||||
|
Streaming FineWeb-Edu loader yielding fixed-length (input, target) pairs.
|
||||||
|
|
||||||
|
FineWeb-Edu is trillions of tokens, so `streaming=True` pulls shards on
|
||||||
|
demand instead of materializing to disk. Sharding is two-dimensional —
|
||||||
|
`world_size` ranks × `num_workers` DataLoader workers per rank — and each
|
||||||
|
`(rank, worker_id)` deterministically owns one shard of the global stream.
|
||||||
|
That gives disjoint coverage without any cross-process coordination.
|
||||||
|
|
||||||
|
Streaming datasets are not seekable, so a resumed run re-enters its shard
|
||||||
|
from the beginning. Acceptable at pretraining scale: the chance of
|
||||||
|
re-playing the same tokens before the run ends is negligible versus the
|
||||||
|
cost of a true resumable loader.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
|
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoding -- tokenizer exposing `.encode(str) -> list[int]`
|
||||||
|
seq_len -- context length; every yielded pair has this many tokens
|
||||||
|
subset -- FineWeb-Edu config name (e.g. "sample-10BT", "default")
|
||||||
|
rank -- global rank of this process within the distributed job
|
||||||
|
world_size -- total number of distributed processes
|
||||||
|
"""
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.subset = subset
|
self.subset = subset
|
||||||
@ -49,6 +72,16 @@ class FineWebEduDataset(IterableDataset):
|
|||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Yield `(input_ids, target_ids)` tensors of length `seq_len` forever.
|
||||||
|
|
||||||
|
Inputs and targets are shifted by one for next-token prediction —
|
||||||
|
`target[i] == input[i + 1]`. Documents are concatenated into a rolling
|
||||||
|
buffer and sliced into fixed-length chunks, packing short docs together
|
||||||
|
and splitting long ones. This keeps every step at the same shape,
|
||||||
|
which under FSDP avoids recompute from variable-length inputs and
|
||||||
|
removes the need for a pad-aware attention mask.
|
||||||
|
"""
|
||||||
worker = get_worker_info()
|
worker = get_worker_info()
|
||||||
num_workers = worker.num_workers if worker else 1
|
num_workers = worker.num_workers if worker else 1
|
||||||
worker_id = worker.id if worker else 0
|
worker_id = worker.id if worker else 0
|
||||||
@ -81,6 +114,32 @@ class FineWebEduDataset(IterableDataset):
|
|||||||
|
|
||||||
|
|
||||||
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
||||||
|
"""
|
||||||
|
Linear warmup → half-cosine decay to `min_lr`.
|
||||||
|
|
||||||
|
Standard language-model pretraining schedule. The warmup phase prevents
|
||||||
|
Adam's second-moment estimate from collapsing to a huge LR in the first
|
||||||
|
few steps when gradients are noisy. The cosine tail lets the model make
|
||||||
|
small, increasingly conservative updates near the end of training rather
|
||||||
|
than crashing to `min_lr` at a fixed step.
|
||||||
|
|
||||||
|
Behavior by region:
|
||||||
|
step < warmup → linear ramp 0 → max_lr
|
||||||
|
warmup ≤ step < total → cosine decay max_lr → min_lr
|
||||||
|
step ≥ total → clamped at min_lr (safety for
|
||||||
|
off-by-one step counters at the end
|
||||||
|
of training)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step -- current global optimizer step (0-indexed)
|
||||||
|
warmup -- number of warmup steps before cosine decay begins
|
||||||
|
total -- step at which the cosine reaches `min_lr`
|
||||||
|
max_lr -- peak learning rate reached at the end of warmup
|
||||||
|
min_lr -- floor learning rate at and after `total` steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar learning rate for this step.
|
||||||
|
"""
|
||||||
if step < warmup:
|
if step < warmup:
|
||||||
return max_lr * step / warmup
|
return max_lr * step / warmup
|
||||||
if step >= total:
|
if step >= total:
|
||||||
@ -89,12 +148,194 @@ def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) ->
|
|||||||
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
|
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Checkpointing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _list_ckpts(ckpt_dir: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return checkpoint paths in `ckpt_dir` sorted oldest → newest.
|
||||||
|
|
||||||
|
Relies on the zero-padded `step_{0000000}.pt` filename convention so
|
||||||
|
lexicographic sort matches chronological order. Changing the filename
|
||||||
|
format elsewhere without updating the pad width would silently break
|
||||||
|
both `keep_last` pruning and resume-latest on startup, since both pick
|
||||||
|
the last element of this list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_dir -- directory to scan; missing directory returns []
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of absolute paths to matching checkpoint files.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(ckpt_dir):
|
||||||
|
return []
|
||||||
|
return sorted(
|
||||||
|
os.path.join(ckpt_dir, f)
|
||||||
|
for f in os.listdir(ckpt_dir)
|
||||||
|
if f.startswith("step_") and f.endswith(".pt")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
step: int,
|
||||||
|
cfg,
|
||||||
|
vocab_size: int,
|
||||||
|
ckpt_dir: str,
|
||||||
|
ddp: bool,
|
||||||
|
master: bool,
|
||||||
|
keep_last: int = 3,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Gather full model + optimizer state, write atomically, prune old files.
|
||||||
|
|
||||||
|
Under FSDP both states are collected inside a single FULL_STATE_DICT
|
||||||
|
context so the optim-state tensors bind to fully-unsharded parameters;
|
||||||
|
mixing contexts between model and optimizer has caused silent divergence
|
||||||
|
on resume in past torch versions. The temp-file + os.replace write means
|
||||||
|
a kill mid-save leaves the previous checkpoint intact instead of a
|
||||||
|
truncated .pt file. Non-master ranks participate in the FSDP gather
|
||||||
|
(otherwise the collective would hang) but exit before touching disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model -- FSDP-wrapped (ddp=True) or raw (ddp=False) model
|
||||||
|
optimizer -- the optimizer whose state should round-trip with the model
|
||||||
|
step -- global step number; encoded zero-padded into the filename
|
||||||
|
cfg -- model config object; saved so downstream eval can
|
||||||
|
reconstruct the model without re-importing the variant
|
||||||
|
vocab_size -- tokenizer vocab size at train time; saved for sanity-check
|
||||||
|
on load against a (possibly updated) tokenizer
|
||||||
|
ckpt_dir -- directory to write into; created if missing
|
||||||
|
ddp -- True if FSDP path; False for single-GPU / CPU
|
||||||
|
master -- whether this rank writes to disk (rank 0 only)
|
||||||
|
keep_last -- number of most-recent checkpoints to retain; older ones
|
||||||
|
are unlinked after a successful write
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None. Writes to disk as a side effect on master rank.
|
||||||
|
"""
|
||||||
|
if ddp:
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model,
|
||||||
|
StateDictType.FULL_STATE_DICT,
|
||||||
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
||||||
|
):
|
||||||
|
model_state = model.state_dict()
|
||||||
|
optim_state = FSDP.optim_state_dict(model, optimizer)
|
||||||
|
else:
|
||||||
|
model_state = model.state_dict()
|
||||||
|
optim_state = optimizer.state_dict()
|
||||||
|
|
||||||
|
if not master:
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||||||
|
final_path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||||
|
tmp_path = final_path + ".tmp"
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"step": step,
|
||||||
|
"model": model_state,
|
||||||
|
"optimizer": optim_state,
|
||||||
|
"cfg": cfg,
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
},
|
||||||
|
tmp_path,
|
||||||
|
)
|
||||||
|
os.replace(tmp_path, final_path)
|
||||||
|
|
||||||
|
for old in _list_ckpts(ckpt_dir)[:-keep_last]:
|
||||||
|
try:
|
||||||
|
os.remove(old)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(f"Failed to prune old checkpoint {old}: {exc}")
|
||||||
|
|
||||||
|
logger.success(f"Checkpoint saved → {final_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, optimizer, path: str, ddp: bool) -> int:
|
||||||
|
"""
|
||||||
|
Restore model + optimizer from disk, returning the step to resume at.
|
||||||
|
|
||||||
|
Every rank reads the file (`rank0_only=False` on load) so FSDP has access
|
||||||
|
to the full state on each rank — the complement to the `rank0_only=True`
|
||||||
|
save path. Must mirror save's single-context pattern; splitting the model
|
||||||
|
and optimizer loads across two `state_dict_type` blocks has historically
|
||||||
|
produced optimizer state bound to the wrong shard shapes.
|
||||||
|
|
||||||
|
`weights_only=False` is required because the checkpoint contains the
|
||||||
|
pickled `cfg` dataclass — flip to `weights_only=True` only if you
|
||||||
|
separate config out.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model -- same FSDP-wrapped or raw model used during save
|
||||||
|
optimizer -- freshly constructed optimizer to be filled in-place
|
||||||
|
path -- absolute path to a `step_{N:07d}.pt` file produced by
|
||||||
|
`save_checkpoint`
|
||||||
|
ddp -- whether the model is FSDP-wrapped; must match the save run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The step number the checkpoint was taken at; the caller advances the
|
||||||
|
training loop from this value.
|
||||||
|
"""
|
||||||
|
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
if ddp:
|
||||||
|
with FSDP.state_dict_type(
|
||||||
|
model,
|
||||||
|
StateDictType.FULL_STATE_DICT,
|
||||||
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||||
|
):
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
optim_state = FSDP.optim_state_dict_to_load(
|
||||||
|
model=model,
|
||||||
|
optim=optimizer,
|
||||||
|
optim_state_dict=ckpt["optimizer"],
|
||||||
|
)
|
||||||
|
optimizer.load_state_dict(optim_state)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
|
||||||
|
return int(ckpt["step"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Main
|
# Main
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
End-to-end pretraining entry point.
|
||||||
|
|
||||||
|
Order matters: distributed init must run before any CUDA allocation, the
|
||||||
|
tokenizer must exist before the model is built (vocab_size flows into
|
||||||
|
cfg), and FSDP must wrap the model before the optimizer is constructed
|
||||||
|
(FSDP re-flattens parameters, so an optimizer built on the unwrapped
|
||||||
|
model would track stale param objects). Resume then loads state into the
|
||||||
|
already-constructed optimizer in-place.
|
||||||
|
|
||||||
|
Lifecycle:
|
||||||
|
1. Initialize torch.distributed (NCCL) if launched under torchrun.
|
||||||
|
2. Build tokenizer → derive vocab_size.
|
||||||
|
3. Construct OpenMythos with the 3B variant config.
|
||||||
|
4. Wrap in FSDP with FULL_SHARD + bf16/fp16 mixed precision (multi-GPU)
|
||||||
|
or move to device + autocast (single-GPU).
|
||||||
|
5. Build fused AdamW on (possibly sharded) parameters.
|
||||||
|
6. Resume from the latest checkpoint in `ckpt_dir` if one exists.
|
||||||
|
7. Stream FineWeb-Edu through grad-accumulation microbatches with
|
||||||
|
cosine LR schedule, per-step logging, and periodic checkpoints.
|
||||||
|
8. Write a final checkpoint if the last save wasn't aligned to
|
||||||
|
`ckpt_every`, then barrier + tear down the process group.
|
||||||
|
|
||||||
|
All hyperparameters are literal constants in this function by design —
|
||||||
|
pretraining runs are long-lived and each run pins exact settings; a
|
||||||
|
CLI/config layer is deliberately avoided to keep the file self-auditable.
|
||||||
|
"""
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Distributed init
|
# Distributed init
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -198,6 +439,22 @@ def main():
|
|||||||
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Resume from latest checkpoint (if any)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Streaming datasets are not resumable by position, so re-iterating from
|
||||||
|
# the beginning is accepted — at pretraining scale the loss of dataset
|
||||||
|
# position is negligible vs. the cost of discarded training steps.
|
||||||
|
start_step = 0
|
||||||
|
existing_ckpts = _list_ckpts(ckpt_dir)
|
||||||
|
if existing_ckpts:
|
||||||
|
latest = existing_ckpts[-1]
|
||||||
|
if master:
|
||||||
|
logger.info(f"Resuming from checkpoint: {latest}")
|
||||||
|
start_step = load_checkpoint(model, optimizer, latest, ddp)
|
||||||
|
if master:
|
||||||
|
logger.success(f"Resumed at step {start_step}")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Dataset + DataLoader
|
# Dataset + DataLoader
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -213,7 +470,7 @@ def main():
|
|||||||
model.train()
|
model.train()
|
||||||
data_iter = iter(loader)
|
data_iter = iter(loader)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
step = 0
|
step = start_step
|
||||||
|
|
||||||
while step < total_steps:
|
while step < total_steps:
|
||||||
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
||||||
@ -248,7 +505,13 @@ def main():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
loss_accum += loss.item()
|
loss_accum += loss.item()
|
||||||
|
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
# FSDP shards parameters, so `nn.utils.clip_grad_norm_` would clip
|
||||||
|
# against each rank's local norm and miss the cross-shard gather.
|
||||||
|
# FSDP.clip_grad_norm_ computes the true global norm and returns it.
|
||||||
|
if ddp:
|
||||||
|
grad_norm = model.clip_grad_norm_(1.0)
|
||||||
|
else:
|
||||||
|
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
@ -258,36 +521,26 @@ def main():
|
|||||||
tokens_seen = step * global_batch_tok
|
tokens_seen = step * global_batch_tok
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
||||||
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
|
f"| gnorm {float(grad_norm):.2f} | lr {cur_lr:.2e} "
|
||||||
|
f"| {tok_per_sec / 1e6:.2f}M tok/s "
|
||||||
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
||||||
)
|
)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
if master and step % ckpt_every == 0:
|
if step % ckpt_every == 0:
|
||||||
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
save_checkpoint(
|
||||||
logger.info(f"Saving checkpoint at step {step} → {path}")
|
model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master
|
||||||
if ddp:
|
|
||||||
with FSDP.state_dict_type(
|
|
||||||
model,
|
|
||||||
StateDictType.FULL_STATE_DICT,
|
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
|
||||||
):
|
|
||||||
model_state = model.state_dict()
|
|
||||||
else:
|
|
||||||
model_state = model.state_dict()
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"step": step,
|
|
||||||
"model": model_state,
|
|
||||||
"optimizer": optimizer.state_dict(),
|
|
||||||
"cfg": cfg,
|
|
||||||
"vocab_size": vocab_size,
|
|
||||||
},
|
|
||||||
path,
|
|
||||||
)
|
)
|
||||||
logger.success(f"Checkpoint saved → {path}")
|
|
||||||
|
# Final checkpoint — total_steps may not be divisible by ckpt_every, so
|
||||||
|
# without this the tail of the run is lost if the schedule doesn't align.
|
||||||
|
if step > start_step and step % ckpt_every != 0:
|
||||||
|
save_checkpoint(model, optimizer, step, cfg, vocab_size, ckpt_dir, ddp, master)
|
||||||
|
|
||||||
if ddp:
|
if ddp:
|
||||||
|
# Barrier so no rank exits while another is still finishing its
|
||||||
|
# checkpoint gather — avoids NCCL "process group destroyed" noise.
|
||||||
|
dist.barrier()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
if master:
|
if master:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user