AI Summary Hub

Infrastructure

Hardware and systems for training and serving AI: GPUs, TPUs, clusters.

Definition

AI infrastructure encompasses the hardware, networking, and software systems required to train and deploy large machine learning models at scale. On the hardware side, this means GPUs (NVIDIA H100/A100, consumer RTX series), TPUs (Google's custom AI accelerators), and emerging inference-specific chips (AWS Inferentia, Groq LPU). On the software side, it includes distributed training frameworks, job schedulers, model serving stacks, and observability tooling.

The scale of infrastructure needed is driven primarily by LLMs and large vision models. Training a frontier model may require thousands of GPUs running for weeks, demanding careful attention to inter-node networking (NVLink, InfiniBand), storage I/O (parallel file systems like Lustre, cloud object storage with high-bandwidth connectors), and fault tolerance (automatic checkpointing, preemption handling). Serving these trained models efficiently requires different hardware optimizations: quantization, speculative decoding, and continuous batching reduce per-token costs and latency.

Frameworks such as PyTorch, JAX, and TensorFlow provide the programming model for expressing neural network computations; infrastructure provides the substrate. Cloud providers (AWS, GCP, Azure) offer managed AI infrastructure (SageMaker, Vertex AI, Azure ML) that handles cluster provisioning, job scheduling, and experiment tracking, while on-premises deployments use orchestrators like SLURM or Kubernetes with GPU device plugins.

How it works

Training pipeline

Serving pipeline

Key concepts

Data parallelism — replicate model across all devices; split data across devices; synchronize gradients. Model parallelism — split model layers across devices; necessary when model doesn't fit on a single GPU. Pipeline parallelism — split model into stages across devices; overlap computation and communication. Continuous batching — dynamically group concurrent inference requests to maximize GPU utilization. KV cache — cache attention key/value tensors between tokens to avoid recomputation.

When to use / When NOT to use

ScenarioInvest in dedicated infrastructureUse cloud / managed services
Training proprietary frontier modelsYes — cost and control at scale
Regulated environments (data sovereignty)Yes — on-prem guarantees data residency
Occasional fine-tuning or inferenceCloud spot instances or managed APIs are cheaper
Serving public-facing models at variable loadAutoscaling cloud serving is easier to manage
Research with frequent GPU needsCloud reserved instances or academic clusters suffice

Pros and cons

ProsCons
Full control over hardware, data, and securityHigh capital and operational cost for on-prem clusters
Predictable cost at high utilizationRequires expertise in distributed systems and MLOps
Lowest latency when co-located with servicesOver-provisioning risk if workloads fluctuate
No egress costs or API rate limitsGPU supply constraints and long procurement lead times

Code examples

# PyTorch DistributedDataParallel (DDP) training — minimal example
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def train(rank: int, world_size: int):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    model = MyModel().to(rank)
    model = DDP(model, device_ids=[rank])          # wrap for distributed sync

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=64, sampler=sampler)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    for epoch in range(10):
        sampler.set_epoch(epoch)                   # ensure different shuffles
        for x, y in loader:
            x, y = x.to(rank), y.to(rank)
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    dist.destroy_process_group()

if __name__ == "__main__":
    mp.spawn(train, args=(torch.cuda.device_count(),), nprocs=torch.cuda.device_count())

Practical resources

See also