mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
299 lines
9.8 KiB
Python
299 lines
9.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
OpenMythos pretraining on FineWeb-Edu with FSDP + AdamW.
|
|
|
|
Single GPU:
|
|
python training/3b_fine_web_edu.py
|
|
|
|
Multi-GPU:
|
|
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") training/3b_fine_web_edu.py
|
|
"""
|
|
|
|
import os
|
|
import math
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
from loguru import logger
|
|
from torch.distributed.fsdp import (
|
|
FullyShardedDataParallel as FSDP,
|
|
ShardingStrategy,
|
|
MixedPrecision,
|
|
FullStateDictConfig,
|
|
StateDictType,
|
|
)
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
|
|
from contextlib import nullcontext
|
|
|
|
from datasets import load_dataset
|
|
|
|
from open_mythos import OpenMythos
|
|
from open_mythos.main import TransformerBlock, RecurrentBlock
|
|
from open_mythos.variants import mythos_3b
|
|
from open_mythos.tokenizer import MythosTokenizer
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataset
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FineWebEduDataset(IterableDataset):
|
|
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
|
|
self.encoding = encoding
|
|
self.seq_len = seq_len
|
|
self.subset = subset
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
|
|
def __iter__(self):
|
|
worker = get_worker_info()
|
|
num_workers = worker.num_workers if worker else 1
|
|
worker_id = worker.id if worker else 0
|
|
|
|
total_shards = self.world_size * num_workers
|
|
shard_index = self.rank * num_workers + worker_id
|
|
|
|
ds = load_dataset(
|
|
"HuggingFaceFW/fineweb-edu",
|
|
name=self.subset,
|
|
split="train",
|
|
streaming=True,
|
|
).shard(num_shards=total_shards, index=shard_index)
|
|
|
|
buf = []
|
|
for sample in ds:
|
|
buf.extend(self.encoding.encode(sample["text"]))
|
|
while len(buf) >= self.seq_len + 1:
|
|
chunk = buf[: self.seq_len + 1]
|
|
buf = buf[self.seq_len + 1 :]
|
|
yield (
|
|
torch.tensor(chunk[:-1], dtype=torch.long),
|
|
torch.tensor(chunk[1:], dtype=torch.long),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LR schedule: linear warmup → cosine decay
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
|
if step < warmup:
|
|
return max_lr * step / warmup
|
|
if step >= total:
|
|
return min_lr
|
|
decay = (step - warmup) / (total - warmup)
|
|
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main():
|
|
# ------------------------------------------------------------------
|
|
# Distributed init
|
|
# ------------------------------------------------------------------
|
|
ddp = int(os.environ.get("RANK", -1)) != -1
|
|
if ddp:
|
|
dist.init_process_group("nccl")
|
|
rank = int(os.environ["RANK"])
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
device = f"cuda:{local_rank}"
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
rank = local_rank = 0
|
|
world_size = 1
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
master = rank == 0
|
|
|
|
if master:
|
|
logger.info(
|
|
f"GPUs: {torch.cuda.device_count()} | World size: {world_size} | Device: {device}"
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Tokenizer
|
|
# ------------------------------------------------------------------
|
|
encoding = MythosTokenizer()
|
|
vocab_size = encoding.vocab_size
|
|
|
|
if master:
|
|
logger.info(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Hyperparameters
|
|
# ------------------------------------------------------------------
|
|
seq_len = 2048
|
|
micro_batch = 4
|
|
target_tokens = 30_000_000_000
|
|
grad_accum = max(1, 256 // (world_size * micro_batch))
|
|
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
|
|
total_steps = target_tokens // global_batch_tok
|
|
warmup_steps = 2000
|
|
lr = 3e-4
|
|
wd = 0.1
|
|
log_every = 10
|
|
ckpt_every = 1000
|
|
ckpt_dir = "checkpoints"
|
|
dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run
|
|
|
|
if master:
|
|
logger.info(
|
|
f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum} | "
|
|
f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Model
|
|
# ------------------------------------------------------------------
|
|
cfg = mythos_3b()
|
|
cfg.vocab_size = vocab_size
|
|
cfg.max_seq_len = seq_len
|
|
|
|
bf16_ok = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
|
amp_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
|
|
|
model = OpenMythos(cfg)
|
|
|
|
if ddp:
|
|
mp_policy = MixedPrecision(
|
|
param_dtype=amp_dtype,
|
|
reduce_dtype=amp_dtype,
|
|
buffer_dtype=amp_dtype,
|
|
)
|
|
wrap_policy = ModuleWrapPolicy({TransformerBlock, RecurrentBlock})
|
|
model = FSDP(
|
|
model,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
mixed_precision=mp_policy,
|
|
auto_wrap_policy=wrap_policy,
|
|
device_id=local_rank,
|
|
)
|
|
else:
|
|
model = model.to(device)
|
|
amp_ctx = (
|
|
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
|
|
if "cuda" in device
|
|
else nullcontext()
|
|
)
|
|
|
|
# FSDP handles its own mixed precision; only need autocast for single-GPU
|
|
amp_ctx = nullcontext() if ddp else amp_ctx # type: ignore[possibly-undefined]
|
|
|
|
if master:
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
logger.info(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Optimizer
|
|
# ------------------------------------------------------------------
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Dataset + DataLoader
|
|
# ------------------------------------------------------------------
|
|
dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size)
|
|
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Training loop
|
|
# ------------------------------------------------------------------
|
|
if master:
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
|
|
model.train()
|
|
data_iter = iter(loader)
|
|
t0 = time.perf_counter()
|
|
step = 0
|
|
|
|
while step < total_steps:
|
|
cur_lr = get_lr(step, warmup_steps, total_steps, lr, lr * 0.1)
|
|
for g in optimizer.param_groups:
|
|
g["lr"] = cur_lr
|
|
|
|
optimizer.zero_grad()
|
|
loss_accum = 0.0
|
|
|
|
for micro_step in range(grad_accum):
|
|
try:
|
|
x, y = next(data_iter)
|
|
except StopIteration:
|
|
data_iter = iter(loader)
|
|
x, y = next(data_iter)
|
|
|
|
x = x.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
|
y = y.to(device if not ddp else f"cuda:{local_rank}", non_blocking=True)
|
|
|
|
sync = (
|
|
nullcontext()
|
|
if (not ddp or micro_step == grad_accum - 1)
|
|
else model.no_sync()
|
|
)
|
|
with sync, amp_ctx:
|
|
logits = model(x)
|
|
loss = nn.functional.cross_entropy(
|
|
logits.view(-1, vocab_size), y.view(-1)
|
|
)
|
|
loss = loss / grad_accum
|
|
|
|
loss.backward()
|
|
loss_accum += loss.item()
|
|
|
|
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
step += 1
|
|
|
|
if master and step % log_every == 0:
|
|
dt = time.perf_counter() - t0
|
|
tok_per_sec = global_batch_tok * log_every / dt
|
|
tokens_seen = step * global_batch_tok
|
|
logger.info(
|
|
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
|
f"| lr {cur_lr:.2e} | {tok_per_sec / 1e6:.2f}M tok/s "
|
|
f"| {tokens_seen / 1e9:.1f}B tokens seen"
|
|
)
|
|
t0 = time.perf_counter()
|
|
|
|
if master and step % ckpt_every == 0:
|
|
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
|
logger.info(f"Saving checkpoint at step {step} → {path}")
|
|
if ddp:
|
|
with FSDP.state_dict_type(
|
|
model,
|
|
StateDictType.FULL_STATE_DICT,
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
|
):
|
|
model_state = model.state_dict()
|
|
else:
|
|
model_state = model.state_dict()
|
|
torch.save(
|
|
{
|
|
"step": step,
|
|
"model": model_state,
|
|
"optimizer": optimizer.state_dict(),
|
|
"cfg": cfg,
|
|
"vocab_size": vocab_size,
|
|
},
|
|
path,
|
|
)
|
|
logger.success(f"Checkpoint saved → {path}")
|
|
|
|
if ddp:
|
|
dist.destroy_process_group()
|
|
|
|
if master:
|
|
logger.success("Training complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|