Efficient Attention Mechanisms for Long-Context Transformers
Abstract
Standard self-attention in transformers scales quadratically with sequence length, limiting their application to long documents. We survey recent approaches to achieve linear or near-linear complexity while maintaining model quality.
Background
Standard Attention Complexity
Traditional self-attention:
Attention(Q, K, V) = softmax(QK^T / √d)V
Complexity: O(n²d) where n is sequence length
Memory: O(n²)
This becomes prohibitive for sequences longer than a few thousand tokens.
Efficient Attention Methods
1. Sparse Attention
Approach: Limit attention to subset of positions
Variants:
- Local attention: Fixed window around each position
- Strided attention: Every k-th position
- Global attention: Special tokens attend to all
Examples: Longformer, BigBird
Complexity: O(n·k·d) where k << n
2. Low-Rank Approximation
Approach: Factor attention matrix into low-rank components
Method: Linformer
Attention ≈ softmax(QK_proj^T / √d)V_proj
Where K_proj and V_proj have reduced dimensionality.
Complexity: O(n·k·d) where k is projection dimension
3. Kernel Methods
Approach: Replace softmax with kernel function
Example: Performer
Attention(Q,K,V) ≈ φ(Q)(φ(K)^T V)
Where φ is random feature map
Complexity: O(n·d²)
Benefits:
- True linear complexity
- Maintains global attention pattern
4. Recurrent Memory
Approach: Compress past into fixed-size memory
Examples:
- Transformer-XL: Segment-level recurrence
- Compressive Transformer: Hierarchical memory
Trade-off: Some loss of long-range dependencies
Experimental Results
Task Performance
| Model | WikiText-103 | LRA Average | Speed (tokens/sec) |
|---|---|---|---|
| Standard | 18.3 | - | 1000 |
| Longformer | 18.5 | 53.2 | 3500 |
| Performer | 19.1 | 51.8 | 4200 |
| BigBird | 18.7 | 54.1 | 3200 |
Context Length Scaling
Efficient methods enable processing of:
- 16K tokens: Document-level understanding
- 64K tokens: Book chapters
- 1M+ tokens: Full books (Gemini, Claude)
Future Directions
Hybrid Approaches
Combining multiple techniques:
- Local + global attention (Longformer style)
- Sparse + low-rank (Luna)
Hardware Optimization
- FlashAttention: Memory-efficient implementation
- Sparse matrix hardware support
Dynamic Sparsity
Learning which positions to attend to rather than fixed patterns.
Practical Considerations
When to use each:
- Sparse attention: When local context is sufficient
- Low-rank: When computational budget is tight
- Kernel methods: When true O(n) scaling needed
- Recurrent: When streaming/incremental processing required
Code Example
class EfficientAttention(nn.Module):
def __init__(self, d_model, window_size=256):
super().__init__()
self.window_size = window_size
def forward(self, q, k, v, mask=None):
# Local windowed attention
batch, seq_len, d = q.shape
# Create banded attention mask
attention_mask = self._create_window_mask(seq_len)
# Standard attention with mask
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores.masked_fill(attention_mask == 0, -1e9)
attn = F.softmax(scores / math.sqrt(d), dim=-1)
return torch.matmul(attn, v)
Conclusion
Efficient attention mechanisms are crucial for scaling transformers to longer contexts. The field is rapidly evolving, with new methods achieving better quality-efficiency trade-offs. The choice of method depends on specific use case requirements.
References
- Beltagy et al., "Longformer: The Long-Document Transformer" (2020)
- Zaheer et al., "Big Bird: Transformers for Longer Sequences" (2020)
- Wang et al., "Linformer: Self-Attention with Linear Complexity" (2020)
- Choromanski et al., "Rethinking Attention with Performers" (2021)
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)