[improvement][loguru-logging][replace print with loguru in training script][feat][ckpt-logging][add checkpoint start and success log events][docs][readme-optimizer][remove muon optimizer

reference][feat][train-requirements][add requirements txt to training folder]
This commit is contained in:
Kye Gomez 2026-04-20 08:25:00 -04:00
parent 18cca894dd
commit 7ba690797b
3 changed files with 15 additions and 9 deletions

View File

@ -151,7 +151,7 @@ Key design choices:
| Feature | Detail |
|---|---|
| Optimizer | Muon for 2D weight matrices, AdamW for embeddings/norms |
| Optimizer | AdamW |
| Dataset | `HuggingFaceFW/fineweb-edu` (`sample-10BT` by default, swap to `sample-100BT` or `default` for full run) |
| Tokenizer | `openai/gpt-oss-20b` via `MythosTokenizer` |
| Parallelism | PyTorch DDP via `torchrun`, sharded streaming dataset |

View File

@ -15,6 +15,7 @@ import time
import torch
import torch.nn as nn
import torch.distributed as dist
from loguru import logger
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
@ -113,7 +114,7 @@ def main():
master = rank == 0
if master:
print(
logger.info(
f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
)
@ -124,7 +125,7 @@ def main():
vocab_size = encoding.vocab_size
if master:
print(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}")
logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}")
# ------------------------------------------------------------------
# Hyperparameters
@ -144,8 +145,8 @@ def main():
dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run
if master:
print(
f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum}\n"
logger.info(
f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
)
@ -188,7 +189,7 @@ def main():
if master:
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
# ------------------------------------------------------------------
# Optimizer
@ -255,7 +256,7 @@ def main():
dt = time.perf_counter() - t0
tok_per_sec = global_batch_tok * log_every / dt
tokens_seen = step * global_batch_tok
print(
logger.info(
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
f"| {tokens_seen / 1e9:.1f}B tokens seen"
@ -264,6 +265,7 @@ def main():
if master and step % ckpt_every == 0:
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
logger.info(f"Saving checkpoint at step {step}{path}")
if ddp:
with FSDP.state_dict_type(
model,
@ -283,13 +285,13 @@ def main():
},
path,
)
print(f"Checkpoint saved → {path}")
logger.success(f"Checkpoint saved → {path}")
if ddp:
dist.destroy_process_group()
if master:
print("Training complete.")
logger.success("Training complete.")
if __name__ == "__main__":

View File

@ -0,0 +1,4 @@
torch>=2.11.0
datasets>=3.6.0
loguru>=0.7.3
open-mythos