From 97bc41497753d3c8a5d925f73c703a0948c99e4e Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Sun, 19 Apr 2026 22:28:09 -0400 Subject: [PATCH] use gpt-oss tokenizer because it's a great tokenizer --- open_mythos/__init__.py | 4 + open_mythos/main.py | 10 +- open_mythos/tokenizer.py | 64 +++++++++ train.py | 302 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 373 insertions(+), 7 deletions(-) create mode 100644 open_mythos/tokenizer.py create mode 100644 train.py diff --git a/open_mythos/__init__.py b/open_mythos/__init__.py index 52ad2fe..9afe52d 100644 --- a/open_mythos/__init__.py +++ b/open_mythos/__init__.py @@ -15,6 +15,7 @@ from open_mythos.main import ( apply_rope, loop_index_embedding, ) +from open_mythos.tokenizer import MythosTokenizer from open_mythos.variants import ( mythos_1b, mythos_3b, @@ -48,4 +49,7 @@ __all__ = [ "mythos_100b", "mythos_500b", "mythos_1t", + "load_tokenizer", + "get_vocab_size", + "MythosTokenizer", ] diff --git a/open_mythos/main.py b/open_mythos/main.py index e59f5f0..238eeed 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -1,9 +1,3 @@ -""" -OpenMythos v1 — Recurrent-Depth Transformer -Architecture: Prelude → [Looped Recurrent Block]×T → Coda -MoE FFN (DeepSeek-style), GQA or MLA, RoPE, RMSNorm, KV cache, LTI-stable injection, ACT halting -""" - from dataclasses import dataclass from typing import Optional @@ -633,7 +627,9 @@ class TransformerBlock(nn.Module): Returns: Output tensor of shape (B, T, dim) """ - x = x + self.resid_drop(self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key)) + x = x + self.resid_drop( + self.attn(self.attn_norm(x), freqs_cis, mask, kv_cache, cache_key) + ) x = x + self.resid_drop(self.ffn(self.ffn_norm(x))) return x diff --git a/open_mythos/tokenizer.py b/open_mythos/tokenizer.py new file mode 100644 index 0000000..fadb3a5 --- /dev/null +++ b/open_mythos/tokenizer.py @@ -0,0 +1,64 @@ +from transformers import AutoTokenizer + +DEFAULT_MODEL_ID = "openai/gpt-oss-20b" + + +class MythosTokenizer: + """ + HuggingFace tokenizer wrapper for OpenMythos. + + Args: + model_id (str): The HuggingFace model ID or path to use with AutoTokenizer. + Defaults to "openai/gpt-oss-20b". + + Attributes: + tokenizer: An instance of HuggingFace's AutoTokenizer. + + Example: + >>> tok = MythosTokenizer() + >>> ids = tok.encode("Hello world") + >>> s = tok.decode(ids) + """ + + def __init__(self, model_id: str = DEFAULT_MODEL_ID): + """ + Initialize the MythosTokenizer. + + Args: + model_id (str): HuggingFace model identifier or path to tokenizer files. + """ + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + @property + def vocab_size(self) -> int: + """ + Return the size of the tokenizer vocabulary. + + Returns: + int: The number of unique tokens in the tokenizer vocabulary. + """ + return self.tokenizer.vocab_size + + def encode(self, text: str) -> list[int]: + """ + Encode input text into a list of token IDs. + + Args: + text (str): The input text string to tokenize. + + Returns: + list[int]: List of integer token IDs representing the input text. + """ + return self.tokenizer.encode(text, add_special_tokens=False) + + def decode(self, token_ids: list[int]) -> str: + """ + Decode a list of token IDs back into a text string. + + Args: + token_ids (list[int]): A list of integer token IDs to decode. + + Returns: + str: Decoded string representation of the token IDs. + """ + return self.tokenizer.decode(token_ids, skip_special_tokens=True) diff --git a/train.py b/train.py new file mode 100644 index 0000000..601a2a0 --- /dev/null +++ b/train.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +""" +OpenMythos pretraining on FineWeb-Edu with Muon optimizer. + +Single GPU: + python train.py + +Multi-GPU (auto-detects GPU count): + torchrun --nproc_per_node=$(python -c "import torch; print(torch.cuda.device_count())") train.py +""" + +import os +import math +import time +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import IterableDataset, DataLoader, get_worker_info +from contextlib import nullcontext + +from datasets import load_dataset + +from open_mythos import OpenMythos +from open_mythos.variants import mythos_3b +from open_mythos.tokenizer import MythosTokenizer +from torch.optim import Muon + + +def build_optimizers(model: nn.Module, muon_lr: float, adamw_lr: float, wd: float): + """Muon for 2D weight matrices; AdamW for embeddings, norms, biases.""" + muon_params, adamw_params = [], [] + for name, p in model.named_parameters(): + if not p.requires_grad: + continue + if ( + p.ndim >= 2 + and "embed" not in name + and "norm" not in name + and "scale" not in name + ): + muon_params.append(p) + else: + adamw_params.append(p) + muon = Muon(muon_params, lr=muon_lr) + adamw = torch.optim.AdamW( + adamw_params, lr=adamw_lr, weight_decay=wd, betas=(0.9, 0.95), fused=True + ) + return muon, adamw + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + + +class FineWebEduDataset(IterableDataset): + def __init__(self, encoding, seq_len: int, subset: str, rank: int, world_size: int): + self.encoding = encoding + self.seq_len = seq_len + self.subset = subset + self.rank = rank + self.world_size = world_size + + def __iter__(self): + worker = get_worker_info() + num_workers = worker.num_workers if worker else 1 + worker_id = worker.id if worker else 0 + + # shard first by DDP rank, then by dataloader worker + total_shards = self.world_size * num_workers + shard_index = self.rank * num_workers + worker_id + + ds = load_dataset( + "HuggingFaceFW/fineweb-edu", + name=self.subset, + split="train", + streaming=True, + ).shard(num_shards=total_shards, index=shard_index) + + buf = [] + for sample in ds: + buf.extend(self.encoding.encode(sample["text"])) + while len(buf) >= self.seq_len + 1: + chunk = buf[: self.seq_len + 1] + buf = buf[self.seq_len + 1 :] + yield ( + torch.tensor(chunk[:-1], dtype=torch.long), + torch.tensor(chunk[1:], dtype=torch.long), + ) + + +# --------------------------------------------------------------------------- +# LR schedule: linear warmup → cosine decay +# --------------------------------------------------------------------------- + + +def get_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: + if step < warmup: + return max_lr * step / warmup + if step >= total: + return min_lr + decay = (step - warmup) / (total - warmup) + return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * decay)) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + # ------------------------------------------------------------------ + # Distributed init — works for single GPU (python train.py) + # and multi-GPU (torchrun --nproc_per_node=N train.py) + # ------------------------------------------------------------------ + ddp = int(os.environ.get("RANK", -1)) != -1 + if ddp: + dist.init_process_group("nccl") + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = f"cuda:{local_rank}" + torch.cuda.set_device(device) + else: + rank = local_rank = 0 + world_size = 1 + device = "cuda" if torch.cuda.is_available() else "cpu" + + master = rank == 0 + + if master: + n_gpu = torch.cuda.device_count() + print( + f"GPUs detected: {n_gpu} | World size: {world_size} | Device: {device}" + ) + + # ------------------------------------------------------------------ + # Tokenizer + # ------------------------------------------------------------------ + encoding = MythosTokenizer() + vocab_size = encoding.vocab_size + + if master: + print(f"Tokenizer: gpt-oss-20b | Vocab size: {vocab_size:,}") + + # ------------------------------------------------------------------ + # Hyperparameters + # ------------------------------------------------------------------ + seq_len = 2048 + micro_batch = 4 # sequences per GPU per grad-accum step + target_tokens = 30_000_000_000 # 30B token run + grad_accum = max(1, 256 // (world_size * micro_batch)) + global_batch_tok = world_size * micro_batch * grad_accum * seq_len + total_steps = target_tokens // global_batch_tok + warmup_steps = 2000 + muon_lr = 0.02 + adamw_lr = 3e-4 + wd = 0.1 + log_every = 10 + ckpt_every = 1000 + ckpt_dir = "checkpoints" + dataset_subset = "sample-10BT" # → sample-100BT or "default" for full run + + if master: + print( + f"seq_len={seq_len} | micro_batch={micro_batch} | grad_accum={grad_accum}\n" + f"global_batch_tokens={global_batch_tok:,} | total_steps={total_steps:,}" + ) + + # ------------------------------------------------------------------ + # Model — override vocab_size to match tokenizer + # ------------------------------------------------------------------ + cfg = mythos_3b() + cfg.vocab_size = vocab_size + cfg.max_seq_len = seq_len + + model = OpenMythos(cfg).to(device) + + # Mixed precision + bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported() + amp_dtype = torch.bfloat16 if bf16_supported else torch.float16 + amp_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=amp_dtype) + if "cuda" in device + else nullcontext() + ) + scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16)) + + if master: + n_params = sum(p.numel() for p in model.parameters()) + print(f"Parameters: {n_params:,} | AMP dtype: {amp_dtype}") + + if ddp: + model = DDP(model, device_ids=[local_rank]) + + raw_model = model.module if ddp else model + + # ------------------------------------------------------------------ + # Optimizers + # ------------------------------------------------------------------ + muon, adamw = build_optimizers(raw_model, muon_lr, adamw_lr, wd) + + # ------------------------------------------------------------------ + # Dataset + DataLoader + # ------------------------------------------------------------------ + dataset = FineWebEduDataset(encoding, seq_len, dataset_subset, rank, world_size) + loader = DataLoader(dataset, batch_size=micro_batch, num_workers=4, pin_memory=True) + + # ------------------------------------------------------------------ + # Training loop + # ------------------------------------------------------------------ + if master: + os.makedirs(ckpt_dir, exist_ok=True) + + model.train() + data_iter = iter(loader) + t0 = time.perf_counter() + step = 0 + + while step < total_steps: + cur_muon_lr = get_lr(step, warmup_steps, total_steps, muon_lr, muon_lr * 0.1) + cur_adamw_lr = get_lr(step, warmup_steps, total_steps, adamw_lr, adamw_lr * 0.1) + for g in muon.param_groups: + g["lr"] = cur_muon_lr + for g in adamw.param_groups: + g["lr"] = cur_adamw_lr + + muon.zero_grad() + adamw.zero_grad() + loss_accum = 0.0 + + for micro_step in range(grad_accum): + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(loader) + x, y = next(data_iter) + + x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + # Defer DDP gradient sync until the last micro-step + sync = ( + nullcontext() + if (not ddp or micro_step == grad_accum - 1) + else model.no_sync() + ) + with sync, amp_ctx: + logits = model(x) + loss = nn.functional.cross_entropy( + logits.view(-1, vocab_size), y.view(-1) + ) + loss = loss / grad_accum + + scaler.scale(loss).backward() + loss_accum += loss.item() + + # Unscale, clip, step both optimizers + scaler.unscale_(muon) + scaler.unscale_(adamw) + nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(muon) + scaler.step(adamw) + scaler.update() + + step += 1 + + if master and step % log_every == 0: + dt = time.perf_counter() - t0 + tok_per_sec = global_batch_tok * log_every / dt + tokens_seen = step * global_batch_tok + print( + f"step {step:6d}/{total_steps} | loss {loss_accum:.4f} " + f"| muon_lr {cur_muon_lr:.2e} | adamw_lr {cur_adamw_lr:.2e} " + f"| {tok_per_sec / 1e6:.2f}M tok/s | {tokens_seen / 1e9:.1f}B tokens seen" + ) + t0 = time.perf_counter() + + if master and step % ckpt_every == 0: + path = os.path.join(ckpt_dir, f"step_{step:07d}.pt") + torch.save( + { + "step": step, + "model": raw_model.state_dict(), + "muon": muon.state_dict(), + "adamw": adamw.state_dict(), + "cfg": cfg, + "vocab_size": vocab_size, + }, + path, + ) + print(f"Checkpoint saved → {path}") + + if ddp: + dist.destroy_process_group() + + if master: + print("Training complete.") + + +if __name__ == "__main__": + main()