mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 09:33:27 +02:00
flash attn
This commit is contained in:
parent
eae0f04b8e
commit
7d78ebec79
@ -40,6 +40,12 @@ pip install open-mythos
|
|||||||
#uv pip install open-mythos
|
#uv pip install open-mythos
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To enable Flash Attention 2 in `GQAttention` (requires CUDA and build tools):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install open-mythos[flash]
|
||||||
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@ -5,6 +5,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
|
||||||
|
_HAS_FLASH_ATTN = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_FLASH_ATTN = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MythosConfig:
|
class MythosConfig:
|
||||||
@ -169,12 +176,17 @@ def apply_rope(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
class GQAttention(nn.Module):
|
class GQAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Grouped Query Attention (Ainslie et al., 2023).
|
Grouped Query Attention (Ainslie et al., 2023) with Flash Attention 2 (Dao et al., 2023).
|
||||||
|
|
||||||
Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is
|
Uses fewer KV heads than Q heads (n_kv_heads < n_heads). Each KV head is
|
||||||
shared across n_heads // n_kv_heads query heads, reducing the KV cache size
|
shared across n_heads // n_kv_heads query heads, reducing the KV cache size
|
||||||
by that factor while keeping full query expressiveness.
|
by that factor while keeping full query expressiveness.
|
||||||
|
|
||||||
|
When flash-attn is installed, uses flash_attn_func which handles GQA natively
|
||||||
|
(no KV head expansion needed) and is IO-bound-optimal. Inputs are cast to
|
||||||
|
bfloat16 for flash_attn and restored to the original dtype afterward.
|
||||||
|
Falls back to manual scaled dot-product attention when flash-attn is absent.
|
||||||
|
|
||||||
RoPE is applied to both Q and K. K and V are stored in kv_cache after
|
RoPE is applied to both Q and K. K and V are stored in kv_cache after
|
||||||
RoPE application so that cached values are already positionally encoded and
|
RoPE application so that cached values are already positionally encoded and
|
||||||
do not need to be re-rotated on retrieval.
|
do not need to be re-rotated on retrieval.
|
||||||
@ -195,7 +207,7 @@ class GQAttention(nn.Module):
|
|||||||
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
self.wk = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
self.wv = nn.Linear(cfg.dim, cfg.n_kv_heads * self.head_dim, bias=False)
|
||||||
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
|
self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.dim, bias=False)
|
||||||
self.attn_drop = nn.Dropout(cfg.dropout)
|
self.dropout_p = cfg.dropout
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -230,21 +242,35 @@ class GQAttention(nn.Module):
|
|||||||
v = torch.cat([kv_cache[cache_key]["v"], v], dim=1)
|
v = torch.cat([kv_cache[cache_key]["v"], v], dim=1)
|
||||||
kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()}
|
kv_cache[cache_key] = {"k": k.detach(), "v": v.detach()}
|
||||||
|
|
||||||
# expand KV to match Q heads
|
if _HAS_FLASH_ATTN:
|
||||||
k = k.repeat_interleave(self.groups, dim=2)
|
# flash_attn_func expects (B, T, H, head_dim) — GQA is handled natively
|
||||||
v = v.repeat_interleave(self.groups, dim=2)
|
# (n_kv_heads < n_heads is supported without repeat_interleave).
|
||||||
|
# causal=True when mask is present (full-sequence prefill/training);
|
||||||
|
# causal=False for single-token decode where T=1 and mask is None.
|
||||||
|
orig_dtype = q.dtype
|
||||||
|
q = q.to(torch.bfloat16)
|
||||||
|
k = k.to(torch.bfloat16)
|
||||||
|
v = v.to(torch.bfloat16)
|
||||||
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
|
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=(mask is not None))
|
||||||
|
out = out.to(orig_dtype).contiguous().view(B, T, -1)
|
||||||
|
else:
|
||||||
|
# Fallback: manual scaled dot-product with explicit KV head expansion.
|
||||||
|
k = k.repeat_interleave(self.groups, dim=2)
|
||||||
|
v = v.repeat_interleave(self.groups, dim=2)
|
||||||
|
q = q.transpose(1, 2) # (B, H, T, head_dim)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
scale = self.head_dim**-0.5
|
||||||
|
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||||
|
if mask is not None:
|
||||||
|
attn = attn + mask
|
||||||
|
attn = F.dropout(
|
||||||
|
F.softmax(attn, dim=-1), p=self.dropout_p, training=self.training
|
||||||
|
)
|
||||||
|
out = torch.matmul(attn, v)
|
||||||
|
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
||||||
|
|
||||||
q = q.transpose(1, 2) # (B, H, T, head_dim)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
scale = self.head_dim**-0.5
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
|
|
||||||
if mask is not None:
|
|
||||||
attn = attn + mask
|
|
||||||
attn = self.attn_drop(F.softmax(attn, dim=-1))
|
|
||||||
out = torch.matmul(attn, v)
|
|
||||||
out = out.transpose(1, 2).contiguous().view(B, T, -1)
|
|
||||||
return self.wo(out)
|
return self.wo(out)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,13 @@ torch = "2.11.0"
|
|||||||
transformers = ">=4.40.0"
|
transformers = ">=4.40.0"
|
||||||
datasets = ">=2.18.0"
|
datasets = ">=2.18.0"
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
flash = ["flash-attn"]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies.flash-attn]
|
||||||
|
version = ">=2.8.3"
|
||||||
|
optional = true
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
black = ">=23.1,<27.0"
|
black = ">=23.1,<27.0"
|
||||||
|
|||||||
@ -2,3 +2,6 @@ torch>=2.1.0
|
|||||||
transformers>=4.40.0
|
transformers>=4.40.0
|
||||||
datasets>=2.18.0
|
datasets>=2.18.0
|
||||||
pytest>=7.0.0
|
pytest>=7.0.0
|
||||||
|
|
||||||
|
# optional — enables Flash Attention 2 in GQAttention (requires CUDA + build tools)
|
||||||
|
# flash-attn>=2.8.3
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user