Sampler & Token Generation
After the transformer computes logits (unnormalized log-probabilities over the vocabulary), the sampler converts them into the next token. This step implements temperature, top-p, top-k, penalties, and all other sampling strategies.
Sampling Pipeline
Sampler.forward() — Step by Step
Source: vllm/v1/sample/sampler.py
Runs on GPU, batched over all sequences in one call:
[num_decode_tokens, vocab_size]logprobs > 0: log_softmax(logits) → keep top-N per token. Stored for user inspection.repetition_penalty (if > 1) or multiplied (if < 1).logit -= presence_penalty for any token seen ≥ 1 time.logit -= frequency_penalty × count proportional to occurrence count.
logits /= temperature. Low temp → peaked distribution (more deterministic). High temp → flat distribution (more creative).probs = softmax(logits), then torch.multinomial(probs, 1). For greedy (temp=0), use argmax.Code: Sampler.forward()
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor, # [num_decode_seqs, vocab_size]
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# Gather per-request sampling parameters as batched tensors
temperatures = sampling_metadata.temperature # [num_seqs]
top_ps = sampling_metadata.top_p
top_ks = sampling_metadata.top_k
# Step 3-4: Apply penalties
logits = self._apply_penalties(logits, sampling_metadata)
# Step 5: Temperature scaling
# Zero temperature → greedy (handled separately)
is_greedy = temperatures == 0
logits[~is_greedy] /= temperatures[~is_greedy].unsqueeze(1)
# Step 6: Top-K
if (top_ks != vocab_size).any():
logits = _apply_top_k(logits, top_ks)
# Step 7: Top-P (nucleus)
if (top_ps < 1.0).any():
logits = _apply_top_p(logits, top_ps)
# Step 8: Sample
probs = torch.softmax(logits, dim=-1)
sampled = torch.multinomial(probs, num_samples=1).squeeze(1)
# Greedy override
sampled[is_greedy] = logits[is_greedy].argmax(dim=-1)
return SamplerOutput(sampled_tokens=sampled, logprobs=...)
Sampling Strategies
Temperature, Top-P, Top-K Explained
| Strategy | Config | Behavior |
|---|---|---|
| Greedy | temperature=0 | Always pick the highest-probability token. Deterministic, no creativity. |
| Temperature | temperature=0.8 | Soften the distribution. 1.0 = unchanged, <1 = sharper, >1 = flatter. |
| Top-K | top_k=50 | Only consider the 50 most likely tokens at each step. |
| Nucleus (top-p) | top_p=0.9 | Consider the smallest set of tokens covering 90% cumulative probability. |
| Beam search | best_of=4, use_beam_search=True | Keep 4 partial sequences, pick the highest joint probability completion. |
# Top-P nucleus sampling (implementation sketch)
def _apply_top_p(logits, top_p):
# Sort logits descending
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, -1), dim=-1)
# Remove tokens once cumulative prob exceeds top_p
sorted_mask = cumulative_probs - softmax(sorted_logits) > top_p
sorted_logits[sorted_mask] = -inf
# Scatter back to original order
logits.scatter_(1, sorted_indices, sorted_logits)
return logits
SamplingMetadata
Source: vllm/v1/sample/metadata.py
Per-sequence sampling configuration as GPU tensors. Constructed by the model runner from SamplingParams before calling the sampler:
@dataclass
class SamplingMetadata:
# Per-sequence sampling config (broadcast from SamplingParams)
temperature: torch.Tensor # [num_seqs]
top_p: torch.Tensor
top_k: torch.Tensor
min_p: torch.Tensor
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
# Token IDs for penalty computation (which tokens appeared so far)
output_token_ids: list[list[int]]
prompt_token_ids: list[list[int]]
# Which token IDs to compute logprobs for (if requested)
logprob_token_ids: list[list[int]] | None
# Seeds for reproducible sampling
generators: list[torch.Generator | None]
Structured Output & Constrained Decoding
Logit Processors & Grammar Constraints
vLLM supports structured generation (JSON, grammar, regex) by masking invalid tokens at each step. The integration uses xgrammar or guidance.
# Structured output via SamplingParams
from vllm.sampling_params import GuidedDecodingParams
params = SamplingParams(
guided_decoding=GuidedDecodingParams(
json={"type": "object", "properties": {"name": {"type": "string"}}}
)
)
# At each step, invalid tokens get logit = -inf
# Only tokens that could legally appear next in valid JSON are kept
The StructuredOutputManager in the engine core maintains a finite-state machine per request. Before sampling, it generates a logit bias mask marking which tokens are valid at the current grammar state. This mask is applied to logits before temperature scaling.
Speculative Decoding
Rejection Sampler
Source: vllm/v1/sample/rejection_sampler.py
Speculative decoding uses a small draft model to propose several tokens ahead, then verifies them with the large target model in one pass. This can achieve near-target-model quality at draft-model speed.
# Speculative decoding flow:
# 1. Draft model proposes K tokens [t1, t2, t3, t4, t5]
# with draft probabilities [q1, q2, q3, q4, q5]
#
# 2. Target model computes probabilities for all K+1 positions
# [p1, p2, p3, p4, p5, p6] in one forward pass
#
# 3. Rejection sampler accepts/rejects each draft token:
for i in range(K):
accept_prob = min(1.0, target_probs[i] / draft_probs[i])
if random() < accept_prob:
accept token t_i
else:
# Reject: sample from corrected distribution
corrected = max(0, target_probs[i] - draft_probs[i])
sample from corrected / sum(corrected)
break # stop at first rejection
When the draft model and target model agree (draft_prob ≈ target_prob), acceptance rate is high and throughput approaches K × target model decode speed. vLLM supports n-gram, EAGLE, and DFlash draft models.
Stop Conditions
When to Stop Generating
After each sampled token, the scheduler checks stop conditions:
| Condition | FinishReason | Trigger |
|---|---|---|
| EOS token sampled | STOP | Sampled token ∈ stop_token_ids (or EOS by default) |
| Stop string match | STOP | Detokenized output contains a string from stop |
| Max tokens reached | LENGTH | len(output_token_ids) >= max_tokens |
| Max model len | LENGTH | Prompt + output ≥ max_model_len |
| Client abort | ABORT | HTTP disconnect / abort_request() called |
Stop-string detection happens in the output processor (CPU-side), not in the sampler (GPU-side). The token sequence is detokenized and checked against stop strings after each iteration.