Simplified HashFormer
Basic HashFormer implementation using PyTorch. This code shows how to implement hash-based approximate attention to reduce computational complexity.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple
import math
class LSHAttention(nn.Module):
"""Locality-Sensitive Hashing Attention for HashFormer"""
def __init__(self, d_model: int, n_heads: int, n_buckets: int = 64,
n_rounds: int = 4, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_buckets = n_buckets
self.n_rounds = n_rounds
self.d_k = d_model // n_heads
# Linear projections for Q, K, V
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model)
# Random projection matrices for LSH
self.register_buffer('random_projections',
torch.randn(n_rounds, self.d_k, n_buckets // 2))
self.dropout = nn.Dropout(dropout)
self.scale = 1.0 / math.sqrt(self.d_k)
def hash_vectors(self, vecs: torch.Tensor) -> torch.Tensor:
"""Apply LSH to group similar vectors"""
batch_size, seq_len, n_heads, d_k = vecs.shape
# Reshape for hashing: [batch_size * seq_len * n_heads, d_k]
vecs_flat = vecs.reshape(-1, d_k)
# Apply random projections
projections = torch.matmul(vecs_flat.unsqueeze(1),
self.random_projections) # [B*L*H, n_rounds, n_buckets//2]
# Create hash codes
hash_codes = torch.where(projections > 0, 1, 0)
# Convert binary codes to bucket indices
bucket_indices = torch.sum(hash_codes *
torch.pow(2, torch.arange(self.n_buckets // 2,
device=vecs.device)), dim=-1)
return bucket_indices.reshape(batch_size, seq_len, n_heads, self.n_rounds)
def create_attention_mask(self, bucket_indices: torch.Tensor) -> torch.Tensor:
"""Create attention mask based on hash buckets"""
batch_size, seq_len, n_heads, n_rounds = bucket_indices.shape
# Create mask for each round
masks = []
for r in range(n_rounds):
# Get bucket indices for this round
buckets = bucket_indices[:, :, :, r] # [B, L, H]
# Create pairwise comparison
buckets_i = buckets.unsqueeze(2) # [B, L, 1, H]
buckets_j = buckets.unsqueeze(1) # [B, 1, L, H]
# Tokens attend to others in same bucket
mask = (buckets_i == buckets_j).float() # [B, L, L, H]
masks.append(mask)
# Combine masks from all rounds (union)
combined_mask = torch.stack(masks, dim=-1).max(dim=-1)[0]
return combined_mask.permute(0, 3, 1, 2) # [B, H, L, L]
def forward(self, x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len, d_model = x.shape
# Linear projections
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k)
# Apply LSH to queries and keys
q_buckets = self.hash_vectors(Q)
k_buckets = self.hash_vectors(K)
# Create hash-based attention mask
hash_mask = self.create_attention_mask(q_buckets)
# Reshape for attention computation
Q = Q.permute(0, 2, 1, 3) # [B, H, L, d_k]
K = K.permute(0, 2, 1, 3) # [B, H, L, d_k]
V = V.permute(0, 2, 1, 3) # [B, H, L, d_k]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
# Apply hash mask (only attend within same buckets)
scores = scores.masked_fill(hash_mask == 0, float('-inf'))
# Apply additional attention mask if provided
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# Softmax and dropout
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, V)
# Reshape and apply output projection
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(batch_size, seq_len, d_model)
output = self.w_o(attn_output)
return output
class HashFormerBlock(nn.Module):
"""Single HashFormer transformer block"""
def __init__(self, d_model: int, n_heads: int, d_ff: int,
n_buckets: int = 64, dropout: float = 0.1):
super().__init__()
self.attention = LSHAttention(d_model, n_heads, n_buckets, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Self-attention with residual connection
attn_output = self.attention(self.norm1(x), attention_mask)
x = x + attn_output
# Feed-forward with residual connection
ffn_output = self.ffn(self.norm2(x))
x = x + ffn_output
return x
class HashFormer(nn.Module):
"""Complete HashFormer model"""
def __init__(self, vocab_size: int, d_model: int = 512, n_heads: int = 8,
n_layers: int = 6, d_ff: int = 2048, max_length: int = 512,
n_buckets: int = 64, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.max_length = max_length
# Embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_length, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
HashFormerBlock(d_model, n_heads, d_ff, n_buckets, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
# Create position indices
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Embeddings
token_embeds = self.token_embedding(input_ids)
pos_embeds = self.position_embedding(positions)
x = self.dropout(token_embeds + pos_embeds)
# Apply transformer blocks
for block in self.blocks:
x = block(x, attention_mask)
return self.norm(x)
# Exemplo de uso
def main():
# Configuração do modelo
model = HashFormer(
vocab_size=50000,
d_model=512,
n_heads=8,
n_layers=6,
d_ff=2048,
n_buckets=32, # Menos buckets = mais aproximação
dropout=0.1
)
# Dados de exemplo
batch_size, seq_len = 4, 1024
input_ids = torch.randint(0, 50000, (batch_size, seq_len))
# Forward pass
output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"Memory usage significantly reduced compared to standard Transformer!")
# Comparar complexidade
print(f"\nComplexity comparison:")
print(f"Standard Attention: O({seq_len}²) = O({seq_len**2:,})")
print(f"HashFormer Attention: O({seq_len} * log({seq_len})) = O({seq_len * int(np.log2(seq_len)):,})")
print(f"Speedup: ~{(seq_len**2) // (seq_len * int(np.log2(seq_len)))}x")
if __name__ == "__main__":
main()