flash attn

This commit is contained in:
Kye Gomez 2026-04-22 12:15:37 -04:00
parent eae0f04b8e
commit 7d78ebec79
4 changed files with 58 additions and 16 deletions

View File

@ -40,6 +40,12 @@ pip install open-mythos
#uv pip install open-mythos #uv pip install open-mythos
``` ```
To enable Flash Attention 2 in `GQAttention` (requires CUDA and build tools):
```bash
pip install open-mythos[flash]
```
## Usage ## Usage
```python ```python

View File

@ -5,6 +5,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try:
from flash_attn import flash_attn_func
_HAS_FLASH_ATTN = True
except ImportError:
_HAS_FLASH_ATTN = False
@dataclass @dataclass
class MythosConfig: class MythosConfig:
@ -169,12 +176,17 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
class GQAttention(nn.Module): class GQAttention(nn.Module):
""" """
Grouped Query Attention (Ainslie et al., 2023). Grouped Query Attention (Ainslie et al., 2023) with Flash Attention 2 (Dao et al., 2023).
Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is
shared across n_heads // n_kv_heads query heads, reducing the KV cache size shared across n_heads // n_kv_heads query heads, reducing the KV cache size
by that factor while keeping full query expressiveness. by that factor while keeping full query expressiveness.
When flash-attn is installed, uses flash_attn_func which handles GQA natively
(no KV head expansion needed) and is IO-bound-optimal. Inputs are cast to
bfloat16 for flash_attn and restored to the original dtype afterward.
Falls back to manual scaled dot-product attention when flash-attn is absent.
RoPE is applied to both Q and K. K and V are stored in kv_cache after RoPE is applied to both Q and K. K and V are stored in kv_cache after
RoPE application so that cached values are already positionally encoded and RoPE application so that cached values are already positionally encoded and
do not need to be re-rotated on retrieval. do not need to be re-rotated on retrieval.
@ -195,7 +207,7 @@ class GQAttention(nn.Module):
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False) self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
self.attn_drop = nn.Dropout(cfg.dropout) self.dropout_p = cfg.dropout
def forward( def forward(
self, self,
@ -230,21 +242,35 @@ class GQAttention(nn.Module):
v = torch.cat([kv_cache[cache_key]["v"], v], dim=1) v = torch.cat([kv_cache[cache_key]["v"], v], dim=1)
kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()} kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()}
# expand KV to match Q heads if _HAS_FLASH_ATTN:
k = k.repeat_interleave(self.groups, dim=2) # flash_attn_func expects (B, T, H, head_dim) — GQA is handled natively
v = v.repeat_interleave(self.groups, dim=2) # (n_kv_heads < n_heads is supported without repeat_interleave).
# causal=True when mask is present (full-sequence prefill/training);
# causal=False for single-token decode where T=1 and mask is None.
orig_dtype = q.dtype
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
dropout_p = self.dropout_p if self.training else 0.0
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None))
out = out.to(orig_dtype).contiguous().view(B, T, -1)
else:
# Fallback: manual scaled dot-product with explicit KV head expansion.
k = k.repeat_interleave(self.groups, dim=2)
v = v.repeat_interleave(self.groups, dim=2)
q = q.transpose(1, 2) # (B, H, T, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = self.head_dim**-0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn = attn + mask
attn = F.dropout(
F.softmax(attn, dim=-1), p=self.dropout_p, training=self.training
)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
q = q.transpose(1, 2) # (B, H, T, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = self.head_dim**-0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn = attn + mask
attn = self.attn_drop(F.softmax(attn, dim=-1))
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out) return self.wo(out)

View File

@ -42,6 +42,13 @@ torch = "2.11.0"
transformers = ">=4.40.0" transformers = ">=4.40.0"
datasets = ">=2.18.0" datasets = ">=2.18.0"
[tool.poetry.extras]
flash = ["flash-attn"]
[tool.poetry.dependencies.flash-attn]
version = ">=2.8.3"
optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
black = ">=23.1,<27.0" black = ">=23.1,<27.0"

View File

@ -2,3 +2,6 @@ torch>=2.1.0
transformers>=4.40.0 transformers>=4.40.0
datasets>=2.18.0 datasets>=2.18.0
pytest>=7.0.0 pytest>=7.0.0
# optional — enables Flash Attention 2 in GQAttention (requires CUDA + build tools)
# flash-attn>=2.8.3