SGLang for Agentic Workloads

Priority scheduling, KV cache eviction policies, and cache pinning for multi-turn agentic serving

View as Markdown

SGLang for Agentic Workloads

This guide covers SGLang-specific configuration for agentic serving with Dynamo. It explains which SGLang engine flags to enable, how Dynamo’s agent hints map to SGLang behavior, and how to use experimental cache pinning to protect KV cache for high-value conversations.

Overview

Agentic workloads (tool-calling loops, multi-turn reasoning, code generation pipelines) have different performance characteristics than batch inference:

  • Prefix-heavy: Successive turns share a growing conversation prefix. KV cache reuse is critical for low TTFT.
  • Priority-sensitive: Some requests (user-facing agent turns) matter more than background tasks.
  • Long-lived: Conversations span minutes to hours. Cache eviction under memory pressure can destroy accumulated KV state.

Dynamo’s agent hints give the router per-request metadata. SGLang’s engine flags control how that metadata affects scheduling and eviction on the worker.

SGLang Engine Flags

Priority Scheduling

Enable priority-based scheduling so the engine respects the priority value from nvext.agent_hints.priority:

$python -m dynamo.sglang \
> --model-path <model> \
> --enable-priority-scheduling \
> --schedule-low-priority-values-first \
> ...
FlagDescription
--enable-priority-schedulingEnables priority-based request scheduling instead of FCFS.
--schedule-low-priority-values-firstInverts priority ordering so lower values are scheduled first (matches vLLM convention). Without this flag, higher values = higher priority.

When priority scheduling is enabled, the engine uses the priority field from nvext.agent_hints to order requests in its internal queue. Requests with higher effective priority are scheduled before lower-priority ones. Ties are broken by arrival time.

Priority-Based KV Cache Eviction

By default, SGLang evicts radix tree nodes using LRU. You can switch to priority-based eviction so that low-priority cache entries are evicted before high-priority ones:

$python -m dynamo.sglang \
> --model-path <model> \
> --radix-eviction-policy priority \
> ...
FlagValuesDefaultDescription
--radix-eviction-policylru, prioritylruEviction strategy for the GPU radix cache. priority uses a heap ordered by the request’s priority value.

This does not require HiCache. It controls GPU-only radix tree eviction. When the GPU KV cache is full:

  • lru: Evicts the least recently used leaf nodes first.
  • priority: Evicts lowest-priority leaf nodes first. Nodes with equal priority fall back to LRU ordering.

Interaction with HiCache

When both --radix-eviction-policy priority and --enable-hierarchical-cache are enabled, priority affects eviction at both tiers:

EventBehavior
GPU fullLow-priority nodes are evicted (demoted to host) first. With write_through, all nodes survive on host — priority only affects demotion order.
Host fullLow-priority nodes are deleted from host first. High-priority nodes survive longer. Pinned nodes are skipped entirely.

The practical impact depends on your write policy. With write_through, GPU eviction is just a demotion — the real deletion happens at host eviction, which is where priority ordering matters most.

How Agent Hints Map to SGLang

Dynamo’s nvext.agent_hints fields are consumed by the router and forwarded to SGLang workers. Here is how each hint interacts with the SGLang engine:

Agent HintRouter BehaviorSGLang Engine Behavior
priorityNo routing effect (forwarded to engine)Queue ordering when --enable-priority-scheduling is set. Also affects radix cache eviction order when --radix-eviction-policy priority is set.
latency_sensitivityShifts request earlier in router queue (requires --router-queue-threshold)No direct engine effect.
oslOutput block tracking for routing decisions (requires --router-track-output-blocks)No direct engine effect.
speculative_prefillAfter response completes, sends a max_tokens=1 prefill to warm the KV cache for the predicted next turn.SGLang processes the prefill request normally, populating the radix cache.

Example: Agentic Request with Hints

1from openai import OpenAI
2
3client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
4
5response = client.chat.completions.create(
6 model="Qwen/Qwen3-14B-FP8",
7 messages=[
8 {"role": "system", "content": "You are a coding assistant."},
9 {"role": "user", "content": "Write a Python function to parse CSV files."},
10 ],
11 stream=True,
12 extra_body={
13 "nvext": {
14 "agent_hints": {
15 "priority": 10,
16 "latency_sensitivity": 2.0,
17 "speculative_prefill": True,
18 "osl": 512
19 }
20 }
21 }
22)
23
24for chunk in response:
25 if chunk.choices[0].delta.content:
26 print(chunk.choices[0].delta.content, end="")

Cache Pinning (Experimental)

Cache pinning is experimental and available on development branches only. The API may change.

Required PRs:

Cache pinning lets you explicitly protect KV cache for high-value conversation prefixes. When a request includes nvext.cache_control, the router fires a pin_prefix call to the SGLang worker after generation completes. Pinned nodes resist eviction for the specified TTL — even under memory pressure, they are retained (demoted to host memory with HiCache rather than deleted).

How It Works

  1. The client includes nvext.cache_control with a TTL in the request.
  2. The Dynamo preprocessor extracts the TTL and attaches it to routing hints.
  3. The router routes the request normally and records the token IDs in a PinState.
  4. After the response stream completes, the router spawns a fire-and-forget pin_prefix RPC to the worker that served the request.
  5. The worker walks the radix tree along the token sequence and pins each node, setting pin_expiry and acquiring a host_ref_counter hold that prevents eviction.
  6. When TTL expires, the pin is cleared and the node becomes eligible for normal eviction.

Enabling Cache Pinning

Frontend flag:

$python -m dynamo.frontend \
> --router-mode kv \
> --enable-cache-control \
> ...
FlagDescription
--enable-cache-controlEnables cache control (PIN with TTL). Creates a cache_control service mesh client and fires pin_prefix after generation for requests with nvext.cache_control. Requires --router-mode=kv.

SGLang worker: The worker receives PIN requests via its cache_control service mesh endpoint. You must set the SGLANG_HICACHE_MAX_PINNED_RATIO environment variable to a non-zero value — pinning is disabled by default.

Environment VariableTypeDefaultDescription
SGLANG_HICACHE_MAX_PINNED_RATIOfloat0.0Max fraction of cache tokens that can be pinned. Must be in [0, 1). 0 disables pinning entirely.

HiCache is required (--enable-hierarchical-cache). Without it, the scheduler rejects PIN requests. For best results, use write_through so that pinned nodes demote to host memory instead of being deleted when GPU memory fills:

$SGLANG_HICACHE_MAX_PINNED_RATIO=0.1 python -m dynamo.sglang \
> --model-path Qwen/Qwen3-14B-FP8 \
> --enable-hierarchical-cache \
> --hicache-ratio 2.0 \
> --hicache-write-policy write_through \
> ...

Request Format

Include cache_control as a top-level field in nvext:

1{
2 "model": "Qwen/Qwen3-14B-FP8",
3 "messages": [
4 {"role": "system", "content": "You are a helpful assistant."},
5 {"role": "user", "content": "Explain quantum computing."}
6 ],
7 "nvext": {
8 "cache_control": {
9 "type": "ephemeral",
10 "ttl": "1h"
11 }
12 }
13}
FieldTypeDescription
cache_control.typestringCurrently only "ephemeral" is supported.
cache_control.ttlstringTTL as integer seconds ("600") or shorthand ("5m", "1h"). Clamped to [300, 3600] seconds. Unrecognized strings default to 300s.

Python Example

1from openai import OpenAI
2
3client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
4
5# First turn -- pin the conversation prefix for 1 hour
6response = client.chat.completions.create(
7 model="Qwen/Qwen3-14B-FP8",
8 messages=[
9 {"role": "system", "content": system_prompt},
10 {"role": "user", "content": "Analyze this codebase and suggest improvements."},
11 ],
12 stream=True,
13 extra_body={
14 "nvext": {
15 "cache_control": {
16 "type": "ephemeral",
17 "ttl": "1h"
18 }
19 }
20 }
21)
22
23# Collect the assistant reply
24assistant_response = ""
25for chunk in response:
26 if chunk.choices[0].delta.content:
27 assistant_response += chunk.choices[0].delta.content
28
29# Later turns reuse the pinned prefix -- even after heavy load from
30# other requests, the KV cache for this conversation is preserved.
31response = client.chat.completions.create(
32 model="Qwen/Qwen3-14B-FP8",
33 messages=[
34 {"role": "system", "content": system_prompt},
35 {"role": "user", "content": "Analyze this codebase and suggest improvements."},
36 {"role": "assistant", "content": assistant_response},
37 {"role": "user", "content": "Now focus on the database layer."},
38 ],
39 stream=True,
40 extra_body={
41 "nvext": {
42 "cache_control": {
43 "type": "ephemeral",
44 "ttl": "1h"
45 }
46 }
47 }
48)

Verifying Cache Hits

The response includes prompt_tokens_details.cached_tokens in the usage object when --enable-cache-report is set on the SGLang worker:

1{
2 "usage": {
3 "prompt_tokens": 2048,
4 "completion_tokens": 150,
5 "prompt_tokens_details": {
6 "cached_tokens": 1920
7 }
8 }
9}

A high cached_tokens / prompt_tokens ratio on subsequent turns confirms that the pinned prefix was preserved.

Limitations

  • Pinning disabled by default: SGLANG_HICACHE_MAX_PINNED_RATIO defaults to 0.0. You must set it to a non-zero value (e.g., 0.1) or all PIN requests will be rejected.
  • HiCache required: The scheduler rejects PIN requests unless --enable-hierarchical-cache is set.
  • TTL clamping: Values are clamped to [300, 3600] seconds. You cannot pin for less than 5 minutes or more than 1 hour.
  • Pin budget: Pinned tokens consume a budget controlled by SGLANG_HICACHE_MAX_PINNED_RATIO (fraction of host pool capacity). Requests exceeding this budget are rejected.
  • No priority on pinned nodes: pin_prefix does not set a priority on the radix tree nodes. All pinned nodes have equal eviction priority and fall back to LRU ordering among themselves when host memory fills.
  • Requires stack restart for A/B testing: Pins persist in cache across benchmark runs. When comparing pinned vs. unpinned performance, restart the full stack between phases to avoid false cache hits.

See Also