Python Example
Included in the static site for reading. Run locally from the source repository when needed.
"""
第26章:Transformer Block 简化实现
Chapter 26: Simplified Transformer Block
"""
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=64, num_heads=4):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(out)
class TransformerBlock(nn.Module):
def __init__(self, d_model=64, num_heads=4, d_ff=256, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Self-Attention + Residual + LayerNorm
attn_out = self.attn(x)
x = self.norm1(x + self.dropout(attn_out))
# FFN + Residual + LayerNorm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
# 测试
x = torch.randn(2, 10, 64) # batch=2, seq=10, dim=64
block = TransformerBlock(d_model=64, num_heads=4)
out = block(x)
print(f"输入: {x.shape}")
print(f"输出: {out.shape}")
print("Transformer Block = Attention + FFN + Residual + LayerNorm")