diff --git a/README.md b/README.md index afc5517..4e8e2df 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,12 @@ pip install open-mythos #uv pip install open-mythos ``` +To enable Flash Attention 2 in `GQAttention` (requires CUDA and build tools): + +```bash +pip install open-mythos[flash] +``` + ## Usage ```python diff --git a/open_mythos/main.py b/open_mythos/main.py index 10de093..98ca58c 100644 --- a/open_mythos/main.py +++ b/open_mythos/main.py @@ -5,6 +5,13 @@ import torch import torch.nn as nn import torch.nn.functional as F +try: + from flash_attn import flash_attn_func + + _HAS_FLASH_ATTN = True +except ImportError: + _HAS_FLASH_ATTN = False + @dataclass class MythosConfig: @@ -169,12 +176,17 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: class GQAttention(nn.Module): """ - Grouped Query Attention (Ainslie et al., 2023). + Grouped Query Attention (Ainslie et al., 2023) with Flash Attention 2 (Dao et al., 2023). Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is shared across n_heads // n_kv_heads query heads, reducing the KV cache size by that factor while keeping full query expressiveness. + When flash-attn is installed, uses flash_attn_func which handles GQA natively + (no KV head expansion needed) and is IO-bound-optimal. Inputs are cast to + bfloat16 for flash_attn and restored to the original dtype afterward. + Falls back to manual scaled dot-product attention when flash-attn is absent. + RoPE is applied to both Q and K. K and V are stored in kv_cache after RoPE application so that cached values are already positionally encoded and do not need to be re-rotated on retrieval. @@ -195,7 +207,7 @@ class GQAttention(nn.Module): self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False) - self.attn_drop = nn.Dropout(cfg.dropout) + self.dropout_p = cfg.dropout def forward( self, @@ -230,21 +242,35 @@ class GQAttention(nn.Module): v = torch.cat([kv_cache[cache_key]["v"], v], dim=1) kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()} - # expand KV to match Q heads - k = k.repeat_interleave(self.groups, dim=2) - v = v.repeat_interleave(self.groups, dim=2) + if _HAS_FLASH_ATTN: + # flash_attn_func expects (B, T, H, head_dim) — GQA is handled natively + # (n_kv_heads < n_heads is supported without repeat_interleave). + # causal=True when mask is present (full-sequence prefill/training); + # causal=False for single-token decode where T=1 and mask is None. + orig_dtype = q.dtype + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + dropout_p = self.dropout_p if self.training else 0.0 + out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None)) + out = out.to(orig_dtype).contiguous().view(B, T, -1) + else: + # Fallback: manual scaled dot-product with explicit KV head expansion. + k = k.repeat_interleave(self.groups, dim=2) + v = v.repeat_interleave(self.groups, dim=2) + q = q.transpose(1, 2) # (B, H, T, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + scale = self.head_dim**-0.5 + attn = torch.matmul(q, k.transpose(-2, -1)) * scale + if mask is not None: + attn = attn + mask + attn = F.dropout( + F.softmax(attn, dim=-1), p=self.dropout_p, training=self.training + ) + out = torch.matmul(attn, v) + out = out.transpose(1, 2).contiguous().view(B, T, -1) - q = q.transpose(1, 2) # (B, H, T, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - scale = self.head_dim**-0.5 - attn = torch.matmul(q, k.transpose(-2, -1)) * scale - if mask is not None: - attn = attn + mask - attn = self.attn_drop(F.softmax(attn, dim=-1)) - out = torch.matmul(attn, v) - out = out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(out) diff --git a/pyproject.toml b/pyproject.toml index 1d9f720..ef6a3b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,13 @@ torch = "2.11.0" transformers = ">=4.40.0" datasets = ">=2.18.0" +[tool.poetry.extras] +flash = ["flash-attn"] + +[tool.poetry.dependencies.flash-attn] +version = ">=2.8.3" +optional = true + [tool.poetry.group.lint.dependencies] black = ">=23.1,<27.0" diff --git a/requirements.txt b/requirements.txt index 3b01619..580fdf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,6 @@ torch>=2.1.0 transformers>=4.40.0 datasets>=2.18.0 pytest>=7.0.0 + +# optional — enables Flash Attention 2 in GQAttention (requires CUDA + build tools) +# flash-attn>=2.8.3