Attention Layer & Backends
Attention is the hottest path in the model. vLLM separates the logical attention layer (what it computes) from the backend (how it computes it). The backend is chosen based on hardware and phase (prefill vs. decode).
References: FlashAttention | FlashAttention-2 | FlashInfer
Attention Basics
Multi-Head Attention with KV Cache
In standard attention, Q, K, V all come from the current input. In autoregressive decoding, only Q is new (the current token); K and V come from the KV cache (all prior tokens).
# Attention layer in vLLM model (e.g., LlamaAttention)
class LlamaAttention(nn.Module):
def forward(self, hidden_states, kv_cache, attn_metadata):
# Project to Q, K, V
q = self.q_proj(hidden_states) # [seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states)
# Apply rotary positional embeddings to Q, K
q, k = self.rotary_emb(positions, q, k)
# Paged attention — reads from and writes to kv_cache blocks
output = self.attn(q, k, v, kv_cache, attn_metadata)
return self.o_proj(output) # [seq_len, hidden_dim]
Grouped Query Attention (GQA)
Modern models (Llama 3, Gemma, Mistral) use GQA: fewer KV heads than Q heads. For example, Llama-3-8B has 32 Q heads but only 8 KV heads. This reduces KV cache size by 4× while maintaining most model quality.
Attention Class — Logical Layer
Source: vllm/model_executor/layers/attention/attention.py
The Attention class is platform-agnostic. It delegates to a backend via the AttentionBackend interface:
class Attention(nn.Module, AttentionLayerBase):
def __init__(self, num_heads, head_size, scale, ...):
self.backend = get_attn_backend(
num_heads, head_size, ..., vllm_config
)
self.impl = self.backend.get_impl_cls()()
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_dim]
key: torch.Tensor, # [num_tokens, num_kv_heads * head_dim]
value: torch.Tensor,
kv_cache: torch.Tensor, # the physical GPU cache tensor
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Write new K, V into the kv_cache at positions from attn_metadata
# Then compute attention over all cached positions
return self.impl.forward(query, key, value, kv_cache, attn_metadata)
Attention Backends
Backend Selection
Source: vllm/v1/attention/backend.py
| Backend | Best For | Notes |
|---|---|---|
FlashInferBackend |
NVIDIA A100/H100, decode | Fastest decode via paged KV; supports ragged batches natively |
FlashAttentionBackend |
NVIDIA, prefill | Optimal for long-context prefill; uses tiling to stay in SRAM |
XFormersBackend |
Older NVIDIA, AMD | Fallback; uses xformers memory-efficient attention |
TritonBackend |
AMD ROCm | Pure Triton kernel; portable across GPU architectures |
TRTLLMBackend |
NVIDIA, TensorRT | TensorRT-LLM kernels for highest throughput on Hopper |
vLLM can use different backends for prefill and decode within the same model. FlashAttention for prefill (better for long sequences) and FlashInfer for decode (better for short queries over long KV).
FlashInfer — Paged Decode
FlashInfer is purpose-built for paged KV cache inference. Its key advantage over standard FlashAttention is native support for non-contiguous (paged) K/V memory via block tables.
# FlashInfer decode kernel (conceptual):
# For each query token q_i (one per decoding request):
# For each logical KV block b_j in block_table[i]:
# phys_block = block_table[i][b_j]
# k_slice = kv_cache[phys_block, :, :] # tokens in this block
# v_slice = kv_cache[phys_block, :, :]
# partial_attn += softmax(q_i · k_slice^T) · v_slice
# output[i] = partial_attn (flash-style: online softmax)
FlashInfer uses a workspace-based API where the block tables and sequence lengths are pre-registered in a BatchDecodeWithPagedKVCacheWrapper, allowing the kernel to be replayed with low overhead each iteration.
# vLLM's use of FlashInfer
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
wrapper.plan(
indptr=seq_indptr,
indices=block_tables_flat,
last_page_len=last_page_lengths,
num_qo_heads=num_q_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=block_size,
)
output = wrapper.run(query, kv_cache)
AttentionMetadata
Per-Iteration Attention State
Source: vllm/v1/attention/backend.py
Constructed by _prepare_inputs() in the model runner. Carries all per-batch metadata needed by the attention kernel:
@dataclass
class AttentionMetadata:
# Which tokens are in prefill vs. decode phase
num_prefill_tokens: int
num_decode_tokens: int
# Sequence lengths (including cache)
seq_lens: list[int] # total context length per seq
query_lens: list[int] # how many query tokens per seq
# Block table: [num_seqs, max_blocks_per_seq]
# Maps logical block index → physical GPU block ID
block_tables: torch.Tensor
# For prefill: causal mask indices
# For decode: page sizes, indptr arrays for FlashInfer
# Slot mapping: where to write new K/V values in the cache
slot_mapping: torch.Tensor # [num_tokens] → flat cache index
Slot Mapping
When new tokens arrive (prefill or decode), their K and V values must be written into the physical KV cache at the right position. The slot mapping translates each input token's logical position to a flat index into the physical cache tensor:
# Example: 2 blocks of size 4, request at position 6
# logical position 6 → block 1 (positions 4-7) → offset 2
# physical block ID = block_table[req][1] = 7
# slot = 7 * block_size + 2 = 7 * 4 + 2 = 30
slot_mapping[token_idx] = 30 # write K/V here in the flat cache
Prefill vs. Decode Paths
The attention kernel behaves differently depending on the phase:
Q, K, V all from the prompt. Causal mask (can't attend to future tokens). Compute-bound. Uses FlashAttention (tiled SRAM computation). Sequence length can be thousands of tokens.
Q is one new token per sequence. K, V come from the cache. Bandwidth-bound (need to read all cached K/V). Uses FlashInfer (paged gather). query_len=1 per sequence.
# In the model runner, the batch contains BOTH:
# - Prefill tokens (contiguous block at the front of input_ids)
# - Decode tokens (one token per decoding request, at the end)
#
# input_ids = [pref0, pref1, pref2, pref3, dec0, dec1]
# ← prefill (4 tok) → ← decode (2 tok) →
#
# Attention handles them separately:
if num_prefill_tokens > 0:
prefill_output = flash_attn(q_prefill, k_prefill, v_prefill, ...)
if num_decode_tokens > 0:
decode_output = flashinfer_paged_decode(q_decode, kv_cache, block_tables, ...)