AI Summary Hub

Transformers

Transformer architecture and self-attention mechanisms.

Definition

Transformers are neural architectures based on self-attention: each token attends to all others to compute contextual representations. They avoid recurrence and enable parallelization, scaling to very long sequences and large models (BERT, GPT, etc.).

They underpin modern LLMs and have been extended to multimodal and vision models. Encoder-only (BERT) and decoder-only (GPT) variants are most common today; the encoder-decoder layout remains used for sequence-to-sequence tasks.

The "Attention Is All You Need" paper (2017) introduced the transformer by removing the recurrent loop entirely and replacing it with scaled dot-product attention. This made training fully parallelizable, enabling models to be trained on far larger datasets than RNN-based predecessors. Positional encodings replace the implicit ordering of recurrence; residual connections and layer normalization stabilize gradient flow through many layers. These design choices, combined with the feed-forward sub-layer for per-position computation, form the fundamental building block that has scaled to hundreds of billions of parameters.

How it works

Self-attention mechanism

Attention: Input is projected into Query (Q), Key (K), and Value (V) matrices. Attention weights are computed as softmax(QK^T / sqrt(d_k)), then applied to V. Each token's output is a weighted combination of all tokens' values — capturing global context in one step.

Multi-head attention

Multi-head attention: Multiple attention heads run in parallel, each learning different relational patterns (syntax, coreference, semantics). Their outputs are concatenated and projected, giving the model richer representational capacity than a single attention head.

Encoder vs. decoder

Encoder-only (e.g. BERT): All tokens attend to all others (bidirectional). Best for understanding tasks. Decoder-only (e.g. GPT): Causal masking ensures each position only attends to past tokens, enabling autoregressive generation. Encoder-decoder: Used for tasks like translation where the input sequence is fully encoded before decoding the output.

When to use / When NOT to use

ScenarioUse transformers?Notes
NLP classification, NER, QAYesEncoder-only (BERT-style) is the default
Text generation, chat, codeYesDecoder-only (GPT-style) is the standard
Low-resource edge inferenceWith cautionDistilled or quantized variants recommended
Short sequences with clear localityWith cautionCNNs or RNNs may be more efficient
Sequence-to-sequence (translation)YesEncoder-decoder transformers excel here
Vision tasksYesVision Transformer (ViT) patches work well

Comparisons

AspectRNN / LSTMCNNTransformer
Long-range dependenciesModeratePoorExcellent
Parallelizable trainingNoYesYes
Context windowLimited by unrollingFixed receptive fieldConfigurable (up to 1M+ tokens)
Memory cost at inferenceLow (fixed state)LowHigh (KV cache grows with context)
State-of-the-art NLPNoNoYes

Pros and cons

ProsCons
Parallelizable, scalableHigh compute and memory
Strong at long-range dependenciesRequires large data
Unified architecture for many tasksInterpretability challenges
Pretrained models widely availableQuadratic attention cost with sequence length

Code examples

# Self-attention from scratch with PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o   = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.W_qkv(x).split(C, dim=2)
        q, k, v = [t.view(B, T, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        scale  = math.sqrt(self.d_k)
        scores = (q @ k.transpose(-2, -1)) / scale         # (B, heads, T, T)
        if causal:
            mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
            scores = scores.masked_fill(~mask, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        out = (weights @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)

# Test with a dummy batch
attn  = MultiHeadSelfAttention(d_model=64, num_heads=4)
x     = torch.randn(2, 10, 64)   # batch=2, seq_len=10, d_model=64
print(attn(x).shape)             # (2, 10, 64)

Practical resources

See also