mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
use gpt-oss tokenizer because it's a great tokenizer
This commit is contained in:
parent
5cfef742b5
commit
97bc414977
@ -15,6 +15,7 @@ from open_mythos.main import (
|
||||
apply_rope,
|
||||
loop_index_embedding,
|
||||
)
|
||||
from open_mythos.tokenizer import MythosTokenizer
|
||||
from open_mythos.variants import (
|
||||
mythos_1b,
|
||||
mythos_3b,
|
||||
@ -48,4 +49,7 @@ __all__ = [
|
||||
"mythos_100b",
|
||||
"mythos_500b",
|
||||
"mythos_1t",
|
||||
"load_tokenizer",
|
||||
"get_vocab_size",
|
||||
"MythosTokenizer",
|
||||
]
|
||||
|
||||
@ -1,9 +1,3 @@
|
||||
"""
|
||||
OpenMythos v1 — Recurrent-Depth Transformer
|
||||
Architecture: Prelude → [Looped Recurrent Block]×T → Coda
|
||||
MoE FFN (DeepSeek-style), GQA or MLA, RoPE, RMSNorm, KV cache, LTI-stable injection, ACT halting
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -633,7 +627,9 @@ class TransformerBlock(nn.Module):
|
||||
Returns:
|
||||
Output tensor of shape (B, T, dim)
|
||||
"""
|
||||
x = x + self.resid_drop(self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key))
|
||||
x = x + self.resid_drop(
|
||||
self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key)
|
||||
)
|
||||
x = x + self.resid_drop(self.ffn(self.ffn_norm(x)))
|
||||
return x
|
||||
|
||||
|
||||
64
open_mythos/tokenizer.py
Normal file
64
open_mythos/tokenizer.py
Normal file
@ -0,0 +1,64 @@
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
DEFAULT_MODEL_ID = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
class MythosTokenizer:
|
||||
"""
|
||||
HuggingFace tokenizer wrapper for OpenMythos.
|
||||
|
||||
Args:
|
||||
model_id (str): The HuggingFace model ID or path to use with AutoTokenizer.
|
||||
Defaults to "openai/gpt-oss-20b".
|
||||
|
||||
Attributes:
|
||||
tokenizer: An instance of HuggingFace's AutoTokenizer.
|
||||
|
||||
Example:
|
||||
>>> tok = MythosTokenizer()
|
||||
>>> ids = tok.encode("Hello world")
|
||||
>>> s = tok.decode(ids)
|
||||
"""
|
||||
|
||||
def __init__(self, model_id: str = DEFAULT_MODEL_ID):
|
||||
"""
|
||||
Initialize the MythosTokenizer.
|
||||
|
||||
Args:
|
||||
model_id (str): HuggingFace model identifier or path to tokenizer files.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""
|
||||
Return the size of the tokenizer vocabulary.
|
||||
|
||||
Returns:
|
||||
int: The number of unique tokens in the tokenizer vocabulary.
|
||||
"""
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""
|
||||
Encode input text into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text (str): The input text string to tokenize.
|
||||
|
||||
Returns:
|
||||
list[int]: List of integer token IDs representing the input text.
|
||||
"""
|
||||
return self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
def decode(self, token_ids: list[int]) -> str:
|
||||
"""
|
||||
Decode a list of token IDs back into a text string.
|
||||
|
||||
Args:
|
||||
token_ids (list[int]): A list of integer token IDs to decode.
|
||||
|
||||
Returns:
|
||||
str: Decoded string representation of the token IDs.
|
||||
"""
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
302
train.py
Normal file
302
train.py
Normal file
@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenMythos pretraining on FineWeb-Edu with Muon optimizer.
|
||||
|
||||
Single GPU:
|
||||
python train.py
|
||||
|
||||
Multi-GPU (auto-detects GPU count):
|
||||
torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
|
||||
from contextlib import nullcontext
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from open_mythos import OpenMythos
|
||||
from open_mythos.variants import mythos_3b
|
||||
from open_mythos.tokenizer import MythosTokenizer
|
||||
from torch.optim import Muon
|
||||
|
||||
|
||||
def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float):
|
||||
"""Muon for 2D weight matrices; AdamW for embeddings, norms, biases."""
|
||||
muon_params, adamw_params = [], []
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if (
|
||||
p.ndim >= 2
|
||||
and "embed" not in name
|
||||
and "norm" not in name
|
||||
and "scale" not in name
|
||||
):
|
||||
muon_params.append(p)
|
||||
else:
|
||||
adamw_params.append(p)
|
||||
muon = Muon(muon_params, lr=muon_lr)
|
||||
adamw = torch.optim.AdamW(
|
||||
adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True
|
||||
)
|
||||
return muon, adamw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FineWebEduDataset(IterableDataset):
|
||||
def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int):
|
||||
self.encoding = encoding
|
||||
self.seq_len = seq_len
|
||||
self.subset = subset
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
def __iter__(self):
|
||||
worker = get_worker_info()
|
||||
num_workers = worker.num_workers if worker else 1
|
||||
worker_id = worker.id if worker else 0
|
||||
|
||||
# shard first by DDP rank, then by dataloader worker
|
||||
total_shards = self.world_size * num_workers
|
||||
shard_index = self.rank * num_workers + worker_id
|
||||
|
||||
ds = load_dataset(
|
||||
"HuggingFaceFW/fineweb-edu",
|
||||
name=self.subset,
|
||||
split="train",
|
||||
streaming=True,
|
||||
).shard(num_shards=total_shards, index=shard_index)
|
||||
|
||||
buf = []
|
||||
for sample in ds:
|
||||
buf.extend(self.encoding.encode(sample["text"]))
|
||||
while len(buf) >= self.seq_len + 1:
|
||||
chunk = buf[: self.seq_len + 1]
|
||||
buf = buf[self.seq_len + 1 :]
|
||||
yield (
|
||||
torch.tensor(chunk[:-1], dtype=torch.long),
|
||||
torch.tensor(chunk[1:], dtype=torch.long),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LR schedule: linear warmup → cosine decay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
||||
if step < warmup:
|
||||
return max_lr * step / warmup
|
||||
if step >= total:
|
||||
return min_lr
|
||||
decay = (step - warmup) / (total - warmup)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------------
|
||||
# Distributed init — works for single GPU (python train.py)
|
||||
# and multi-GPU (torchrun --nproc_per_node=N train.py)
|
||||
# ------------------------------------------------------------------
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1
|
||||
if ddp:
|
||||
dist.init_process_group("nccl")
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
device = f"cuda:{local_rank}"
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
rank = local_rank = 0
|
||||
world_size = 1
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
master = rank == 0
|
||||
|
||||
if master:
|
||||
n_gpu = torch.cuda.device_count()
|
||||
print(
|
||||
f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
# ------------------------------------------------------------------
|
||||
encoding = MythosTokenizer()
|
||||
vocab_size = encoding.vocab_size
|
||||
|
||||
if master:
|
||||
print(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hyperparameters
|
||||
# ------------------------------------------------------------------
|
||||
seq_len = 2048
|
||||
micro_batch = 4 # sequences per GPU per grad-accum step
|
||||
target_tokens = 30_000_000_000 # 30B token run
|
||||
grad_accum = max(1, 256 // (world_size * micro_batch))
|
||||
global_batch_tok = world_size * micro_batch * grad_accum * seq_len
|
||||
total_steps = target_tokens // global_batch_tok
|
||||
warmup_steps = 2000
|
||||
muon_lr = 0.02
|
||||
adamw_lr = 3e-4
|
||||
wd = 0.1
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
ckpt_dir = "checkpoints"
|
||||
dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run
|
||||
|
||||
if master:
|
||||
print(
|
||||
f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum}\n"
|
||||
f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model — override vocab_size to match tokenizer
|
||||
# ------------------------------------------------------------------
|
||||
cfg = mythos_3b()
|
||||
cfg.vocab_size = vocab_size
|
||||
cfg.max_seq_len = seq_len
|
||||
|
||||
model = OpenMythos(cfg).to(device)
|
||||
|
||||
# Mixed precision
|
||||
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
||||
amp_dtype = torch.bfloat16 if bf16_supported else torch.float16
|
||||
amp_ctx = (
|
||||
torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
|
||||
if "cuda" in device
|
||||
else nullcontext()
|
||||
)
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
|
||||
|
||||
if master:
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}")
|
||||
|
||||
if ddp:
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
raw_model = model.module if ddp else model
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optimizers
|
||||
# ------------------------------------------------------------------
|
||||
muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dataset + DataLoader
|
||||
# ------------------------------------------------------------------
|
||||
dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size)
|
||||
loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop
|
||||
# ------------------------------------------------------------------
|
||||
if master:
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
|
||||
model.train()
|
||||
data_iter = iter(loader)
|
||||
t0 = time.perf_counter()
|
||||
step = 0
|
||||
|
||||
while step < total_steps:
|
||||
cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1)
|
||||
cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1)
|
||||
for g in muon.param_groups:
|
||||
g["lr"] = cur_muon_lr
|
||||
for g in adamw.param_groups:
|
||||
g["lr"] = cur_adamw_lr
|
||||
|
||||
muon.zero_grad()
|
||||
adamw.zero_grad()
|
||||
loss_accum = 0.0
|
||||
|
||||
for micro_step in range(grad_accum):
|
||||
try:
|
||||
x, y = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(loader)
|
||||
x, y = next(data_iter)
|
||||
|
||||
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
|
||||
|
||||
# Defer DDP gradient sync until the last micro-step
|
||||
sync = (
|
||||
nullcontext()
|
||||
if (not ddp or micro_step == grad_accum - 1)
|
||||
else model.no_sync()
|
||||
)
|
||||
with sync, amp_ctx:
|
||||
logits = model(x)
|
||||
loss = nn.functional.cross_entropy(
|
||||
logits.view(-1, vocab_size), y.view(-1)
|
||||
)
|
||||
loss = loss / grad_accum
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
loss_accum += loss.item()
|
||||
|
||||
# Unscale, clip, step both optimizers
|
||||
scaler.unscale_(muon)
|
||||
scaler.unscale_(adamw)
|
||||
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
scaler.step(muon)
|
||||
scaler.step(adamw)
|
||||
scaler.update()
|
||||
|
||||
step += 1
|
||||
|
||||
if master and step % log_every == 0:
|
||||
dt = time.perf_counter() - t0
|
||||
tok_per_sec = global_batch_tok * log_every / dt
|
||||
tokens_seen = step * global_batch_tok
|
||||
print(
|
||||
f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} "
|
||||
f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} "
|
||||
f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen"
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if master and step % ckpt_every == 0:
|
||||
path = os.path.join(ckpt_dir, f"step_{step:07d}.pt")
|
||||
torch.save(
|
||||
{
|
||||
"step": step,
|
||||
"model": raw_model.state_dict(),
|
||||
"muon": muon.state_dict(),
|
||||
"adamw": adamw.state_dict(),
|
||||
"cfg": cfg,
|
||||
"vocab_size": vocab_size,
|
||||
},
|
||||
path,
|
||||
)
|
||||
print(f"Checkpoint saved → {path}")
|
||||
|
||||
if ddp:
|
||||
dist.destroy_process_group()
|
||||
|
||||
if master:
|
||||
print("Training complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user