mirror of
https://github.com/kyegomez/OpenMythos.git
synced 2026-05-02 17:43:27 +02:00
[bugf][lora-b-init][fix zero-init B making adapter always output zero][bugf][lti-get-a][fix 0 times
inf NaN in log space computation][improvement][rope-theta-test][exclude degenerate dim0 from theta angle comparison]
This commit is contained in:
parent
53f786afda
commit
806a8da1d6
@ -564,7 +564,7 @@ class LoRAAdapter(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.down = nn.Linear(dim, rank, bias=False) # shared A: dim → rank
|
||||
self.B = nn.Parameter(torch.zeros(rank, dim)) # shared B: rank → dim
|
||||
self.B = nn.Parameter(torch.randn(rank, dim) * 0.02) # shared B: rank → dim
|
||||
self.scale = nn.Embedding(max_loops, rank) # per-loop element-wise scale
|
||||
|
||||
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
|
||||
@ -678,9 +678,10 @@ class LTIInjection(nn.Module):
|
||||
1-D tensor of shape (dim,) with all values strictly in (0, 1),
|
||||
guaranteeing ρ(A) < 1 regardless of learned parameter values.
|
||||
"""
|
||||
A_c = -torch.exp(self.log_A) # continuous diagonal, always < 0
|
||||
dt = torch.exp(self.log_dt) # always > 0
|
||||
return torch.exp(dt * A_c) # ZOH: values in (0, 1)
|
||||
# Compute in log space to avoid 0 * inf = NaN when log_dt → -∞, log_A → +∞.
|
||||
# dt * A_c = -exp(log_dt) * exp(log_A) = -exp(log_dt + log_A)
|
||||
# Clamp keeps the product finite in float32 for any gradient step size.
|
||||
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
|
||||
|
||||
def forward(
|
||||
self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor
|
||||
|
||||
@ -153,11 +153,15 @@ class TestRoPEExtended:
|
||||
assert torch.allclose(freqs.imag, expected.imag, atol=1e-6)
|
||||
|
||||
def test_higher_theta_produces_smaller_angles(self):
|
||||
"""Larger theta → slower frequency decay → smaller rotation angle per step."""
|
||||
"""Larger theta → slower frequency decay → smaller rotation angle per step.
|
||||
|
||||
Index 0 (dim_i=0) is excluded: its frequency is 1/(theta^0)=1 for any theta,
|
||||
so the comparison is not meaningful there.
|
||||
"""
|
||||
dim, max_len = 16, 8
|
||||
freqs_fast = precompute_rope_freqs(dim=dim, max_len=max_len, theta=100.0)
|
||||
freqs_slow = precompute_rope_freqs(dim=dim, max_len=max_len, theta=500000.0)
|
||||
assert (freqs_fast[1].angle().abs() > freqs_slow[1].angle().abs()).all()
|
||||
assert (freqs_fast[1, 1:].angle().abs() > freqs_slow[1, 1:].angle().abs()).all()
|
||||
|
||||
def test_default_theta_matches_explicit(self):
|
||||
"""Omitting theta must equal passing theta=500000.0."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user