[bugf][lora-b-init][fix zero-init B making adapter always output zero][bugf][lti-get-a][fix 0 times

inf NaN in log space computation][improvement][rope-theta-test][exclude degenerate dim0 from theta
  angle comparison]
This commit is contained in:
Kye Gomez 2026-04-18 23:51:01 -04:00
parent 53f786afda
commit 806a8da1d6
2 changed files with 11 additions and 6 deletions

View File

@ -564,7 +564,7 @@ class LoRAAdapter(nn.Module):
"""
super().__init__()
self.down = nn.Linear(dim, rank, bias=False) # shared A: dim → rank
self.B = nn.Parameter(torch.zeros(rank, dim)) # shared B: rank → dim
self.B = nn.Parameter(torch.randn(rank, dim) * 0.02) # shared B: rank → dim
self.scale = nn.Embedding(max_loops, rank) # per-loop element-wise scale
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
@ -678,9 +678,10 @@ class LTIInjection(nn.Module):
1-D tensor of shape (dim,) with all values strictly in (0, 1),
guaranteeing ρ(A) < 1 regardless of learned parameter values.
"""
A_c = -torch.exp(self.log_A) # continuous diagonal, always < 0
dt = torch.exp(self.log_dt) # always > 0
return torch.exp(dt * A_c) # ZOH: values in (0, 1)
# Compute in log space to avoid 0 * inf = NaN when log_dt → -∞, log_A → +∞.
# dt * A_c = -exp(log_dt) * exp(log_A) = -exp(log_dt + log_A)
# Clamp keeps the product finite in float32 for any gradient step size.
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
def forward(
self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor

View File

@ -153,11 +153,15 @@ class TestRoPEExtended:
assert torch.allclose(freqs.imag, expected.imag, atol=1e-6)
def test_higher_theta_produces_smaller_angles(self):
"""Larger theta → slower frequency decay → smaller rotation angle per step."""
"""Larger theta → slower frequency decay → smaller rotation angle per step.
Index 0 (dim_i=0) is excluded: its frequency is 1/(theta^0)=1 for any theta,
so the comparison is not meaningful there.
"""
dim, max_len = 16, 8
freqs_fast = precompute_rope_freqs(dim=dim, max_len=max_len, theta=100.0)
freqs_slow = precompute_rope_freqs(dim=dim, max_len=max_len, theta=500000.0)
assert (freqs_fast[1].angle().abs() > freqs_slow[1].angle().abs()).all()
assert (freqs_fast[1, 1:].angle().abs() > freqs_slow[1, 1:].angle().abs()).all()
def test_default_theta_matches_explicit(self):
"""Omitting theta must equal passing theta=500000.0."""