mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
flash attn
This commit is contained in:
parent
eae0f04b8e
commit
7d78ebec79
@ -40,6 +40,12 @@ pip install open-mythos
|
||||
#uv pip install open-mythos
|
||||
```
|
||||
|
||||
To enable Flash Attention 2 in `GQAttention` (requires CUDA and build tools):
|
||||
|
||||
```bash
|
||||
pip install open-mythos[flash]
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
|
||||
@ -5,6 +5,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
_HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
_HAS_FLASH_ATTN = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MythosConfig:
|
||||
@ -169,12 +176,17 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
class GQAttention(nn.Module):
|
||||
"""
|
||||
Grouped Query Attention (Ainslie et al., 2023).
|
||||
Grouped Query Attention (Ainslie et al., 2023) with Flash Attention 2 (Dao et al., 2023).
|
||||
|
||||
Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is
|
||||
shared across n_heads // n_kv_heads query heads, reducing the KV cache size
|
||||
by that factor while keeping full query expressiveness.
|
||||
|
||||
When flash-attn is installed, uses flash_attn_func which handles GQA natively
|
||||
(no KV head expansion needed) and is IO-bound-optimal. Inputs are cast to
|
||||
bfloat16 for flash_attn and restored to the original dtype afterward.
|
||||
Falls back to manual scaled dot-product attention when flash-attn is absent.
|
||||
|
||||
RoPE is applied to both Q and K. K and V are stored in kv_cache after
|
||||
RoPE application so that cached values are already positionally encoded and
|
||||
do not need to be re-rotated on retrieval.
|
||||
@ -195,7 +207,7 @@ class GQAttention(nn.Module):
|
||||
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
|
||||
self.attn_drop = nn.Dropout(cfg.dropout)
|
||||
self.dropout_p = cfg.dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -230,21 +242,35 @@ class GQAttention(nn.Module):
|
||||
v = torch.cat([kv_cache[cache_key]["v"], v], dim=1)
|
||||
kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()}
|
||||
|
||||
# expand KV to match Q heads
|
||||
k = k.repeat_interleave(self.groups, dim=2)
|
||||
v = v.repeat_interleave(self.groups, dim=2)
|
||||
if _HAS_FLASH_ATTN:
|
||||
# flash_attn_func expects (B, T, H, head_dim) — GQA is handled natively
|
||||
# (n_kv_heads < n_heads is supported without repeat_interleave).
|
||||
# causal=True when mask is present (full-sequence prefill/training);
|
||||
# causal=False for single-token decode where T=1 and mask is None.
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.bfloat16)
|
||||
k = k.to(torch.bfloat16)
|
||||
v = v.to(torch.bfloat16)
|
||||
dropout_p = self.dropout_p if self.training else 0.0
|
||||
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None))
|
||||
out = out.to(orig_dtype).contiguous().view(B, T, -1)
|
||||
else:
|
||||
# Fallback: manual scaled dot-product with explicit KV head expansion.
|
||||
k = k.repeat_interleave(self.groups, dim=2)
|
||||
v = v.repeat_interleave(self.groups, dim=2)
|
||||
q = q.transpose(1, 2) # (B, H, T, head_dim)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
scale = self.head_dim**-0.5
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||
if mask is not None:
|
||||
attn = attn + mask
|
||||
attn = F.dropout(
|
||||
F.softmax(attn, dim=-1), p=self.dropout_p, training=self.training
|
||||
)
|
||||
out = torch.matmul(attn, v)
|
||||
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
||||
q = q.transpose(1, 2) # (B, H, T, head_dim)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
scale = self.head_dim**-0.5
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||
if mask is not None:
|
||||
attn = attn + mask
|
||||
attn = self.attn_drop(F.softmax(attn, dim=-1))
|
||||
out = torch.matmul(attn, v)
|
||||
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
|
||||
@ -42,6 +42,13 @@ torch = "2.11.0"
|
||||
transformers = ">=4.40.0"
|
||||
datasets = ">=2.18.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
flash = ["flash-attn"]
|
||||
|
||||
[tool.poetry.dependencies.flash-attn]
|
||||
version = ">=2.8.3"
|
||||
optional = true
|
||||
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
black = ">=23.1,<27.0"
|
||||
|
||||
@ -2,3 +2,6 @@ torch>=2.1.0
|
||||
transformers>=4.40.0
|
||||
datasets>=2.18.0
|
||||
pytest>=7.0.0
|
||||
|
||||
# optional — enables Flash Attention 2 in GQAttention (requires CUDA + build tools)
|
||||
# flash-attn>=2.8.3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user