Router Examples

View as Markdown

For quick start instructions, see the Router README. This document provides further examples for using the Dynamo Router, including Python API usage, Kubernetes deployments, and custom routing patterns.

Table of Contents

Using KvPushRouter Python API

Instead of launching the KV Router via command line, you can create a KvPushRouter object directly in Python. This allows per-request routing configuration overrides.

[!Warning] Multiple Routers in Same Process: If you need to run multiple KvPushRouter instances for fault tolerance or load distribution, you must launch them in separate processes (e.g., using python -m dynamo.frontend with different ports). Creating multiple KvPushRouter objects in the same Python process is not supported - they share the same cancellation token from the component’s primary lease, so dropping one router will cancel all routers in that process. For in-process routing, use a single KvPushRouter instance.

Methods

The KvPushRouter provides the following methods:

  • generate(token_ids, model, ...): Route and execute a request, returning an async stream of responses. Automatically handles worker selection, state tracking, and lifecycle management.

  • best_worker(token_ids, router_config_override=None, request_id=None): Query which worker would be selected for given tokens. Returns (worker_id, dp_rank, overlap_blocks).

    • Without request_id: Query-only, doesn’t update router state
    • With request_id: Updates router state to track the request. Note: If used with request_id, you must call mark_prefill_complete() and free() at the appropriate lifecycle points to maintain accurate load tracking
  • get_potential_loads(token_ids): Get detailed load information for all workers, including potential prefill tokens and active decode blocks. Returns a list of load dictionaries.

  • mark_prefill_complete(request_id): Signal that a request has completed its prefill phase. Only used for manual lifecycle management when using best_worker() for manual routing instead of generate().

  • free(request_id): Signal that a request has completed and its resources should be released. Only used for manual lifecycle management when using best_worker() for manual routing instead of generate().

  • dump_events(): Dump all KV cache events from the router’s indexer as a JSON string. Useful for debugging and analysis.

Setup

First, launch your backend engines:

$python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf

Example Script

1import asyncio
2from dynamollm import DistributedRuntime, KvPushRouter, KvRouterConfig
3
4async def main():
5 # Get runtime and create endpoint
6 runtime = DistributedRuntime.detached()
7 namespace = runtime.namespace("dynamo")
8 component = namespace.component("backend")
9 endpoint = component.endpoint("generate")
10
11 # Create KV router
12 kv_router_config = KvRouterConfig()
13 router = KvPushRouter(
14 endpoint=endpoint,
15 block_size=16,
16 kv_router_config=kv_router_config
17 )
18
19 # Your input tokens
20 token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
21
22 # Generate with per-request routing override
23 stream = await router.generate(
24 token_ids=token_ids,
25 model="meta-llama/Llama-2-7b-hf",
26 stop_conditions={
27 "max_tokens": 20, # Generate exactly 20 tokens
28 "ignore_eos": True, # Don't stop at EOS token
29 },
30 sampling_options={
31 "temperature": 0.7,
32 "top_p": 0.9,
33 },
34 router_config_override={
35 "overlap_score_weight": 2.0, # Prioritize cache hits for this request
36 "router_temperature": 0.5, # Add routing randomness
37 }
38 )
39
40 # Collect generated tokens
41 generated_tokens = []
42 async for response in stream:
43 if isinstance(response, dict) and "token_ids" in response:
44 generated_tokens.extend(response["token_ids"])
45
46 print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
47
48if __name__ == "__main__":
49 asyncio.run(main())

K8s Examples

For basic Kubernetes deployment with the KV Router, see the Kubernetes Deployment section in the Quick Start guide.

Complete K8s Examples

For A/B Testing and Advanced K8s Setup: See the comprehensive KV Router A/B Benchmarking Guide for step-by-step instructions on deploying, configuring, and benchmarking the KV router in Kubernetes.

Example with Advanced Configuration

1apiVersion: nvidia.com/v1alpha1
2kind: DynamoGraphDeployment
3metadata:
4 name: my-deployment
5spec:
6 services:
7 Frontend:
8 dynamoNamespace: my-namespace
9 componentType: frontend
10 replicas: 1
11 envs:
12 - name: DYN_ROUTER_MODE
13 value: kv
14 - name: DYN_ROUTER_TEMPERATURE
15 value: "0.5" # Add some randomness to prevent worker saturation
16 - name: DYN_KV_OVERLAP_SCORE_WEIGHT
17 value: "1.5" # Prioritize TTFT over ITL
18 - name: DYN_KV_CACHE_BLOCK_SIZE
19 value: "16"
20 extraPodSpec:
21 mainContainer:
22 image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.6.0

Alternative: Using Command Args in K8s

You can also pass CLI arguments directly in the container command:

1extraPodSpec:
2 mainContainer:
3 image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.6.0
4 command:
5 - /bin/sh
6 - -c
7 args:
8 - "python3 -m dynamo.frontend --router-mode kv --router-temperature 0.5 --http-port 8000"

Recommendation: Use environment variables for easier configuration management and consistency with Dynamo’s K8s patterns.

Routing Patterns

The KvPushRouter supports multiple usage patterns depending on your control requirements:

Call generate() directly and let the router handle everything:

1stream = await router.generate(token_ids=tokens, model="model-name")
  • Best for: Most use cases
  • Router automatically: Selects best worker, updates state, routes request, tracks lifecycle

2. Manual State Management (Advanced)

Use best_worker(request_id=...) to select and track, then manage the request yourself:

1worker_id, _dp_rank, overlap = await router.best_worker(tokens, request_id="req-123")
2response = await client.generate(tokens, request_id="req-123")
3# await anext(response) # Get first token
4await router.mark_prefill_complete("req-123") # After first token
5# async for _ in response: # Continue generating
6# ...
7await router.free("req-123") # After completion
  • Best for: Custom request handling with router state tracking
  • Requires: Calling mark_prefill_complete() and free() at correct lifecycle points
  • Caution: Incorrect lifecycle management degrades load balancing accuracy

3. Hierarchical Router Probing

Query without state updates, then route through a chosen router:

1# Probe multiple routers without updating state
2worker_id_1, dp_rank, overlap_1 = await router_1.best_worker(tokens) # No request_id
3worker_id_2, dp_rank, overlap_2 = await router_2.best_worker(tokens)
4
5# Pick the best router based on results
6chosen_router = router_1 if overlap_1 > overlap_2 else router_2
7stream = await chosen_router.generate(tokens, model="model-name", worker_id=worker_id)
  • Best for: Multi-tier deployments (e.g., Envoy Gateway routing to multiple router groups)
  • Advantage: Query multiple routers before committing to one

4. Custom Load-Based Routing

Use get_potential_loads() to implement custom routing logic:

1loads = await router.get_potential_loads(tokens)
2# Apply custom logic (e.g., weighted scoring, constraints)
3best_worker = min(loads, key=lambda x: custom_cost_fn(x))
4stream = await router.generate(tokens, model="model-name", worker_id=best_worker['worker_id'])
  • Best for: Custom optimization strategies beyond the built-in cost function
  • Advantage: Full control over worker selection logic
  • See also: Detailed example below in “Custom Routing Example: Minimizing TTFT”

All patterns support router_config_override to adjust routing behavior per-request without recreating the router.

Custom Routing Example: Minimizing TTFT

Here’s an example of using get_potential_loads() to implement custom routing that minimizes Time To First Token (TTFT) by selecting the worker with the least prefill work:

1import asyncio
2from dynamo.llm import DistributedRuntime, KvPushRouter, KvRouterConfig
3
4async def minimize_ttft_routing():
5 # Setup router
6 runtime = DistributedRuntime.detached()
7 namespace = runtime.namespace("dynamo")
8 component = namespace.component("backend")
9 endpoint = component.endpoint("generate")
10
11 router = KvPushRouter(
12 endpoint=endpoint,
13 block_size=16,
14 kv_router_config=KvRouterConfig()
15 )
16
17 # Your input tokens
18 token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
19
20 # Get potential loads for all workers
21 potential_loads = await router.get_potential_loads(token_ids)
22
23 # Find worker with minimum prefill tokens (best for TTFT)
24 best_worker = min(potential_loads, key=lambda x: x['potential_prefill_tokens'])
25
26 print(f"Worker loads: {potential_loads}")
27 print(f"Selected worker {best_worker['worker_id']} with {best_worker['potential_prefill_tokens']} prefill tokens")
28
29 # Route directly to the selected worker
30 stream = await router.generate(
31 token_ids=token_ids,
32 model="meta-llama/Llama-2-7b-hf",
33 worker_id=best_worker['worker_id'], # Force routing to optimal worker
34 stop_conditions={"max_tokens": 20}
35 )
36
37 # Process response
38 async for response in stream:
39 if isinstance(response, dict) and "token_ids" in response:
40 print(f"Generated tokens: {response['token_ids']}")
41
42if __name__ == "__main__":
43 asyncio.run(minimize_ttft_routing())

This approach gives you complete control over routing decisions, allowing you to optimize for different metrics based on your specific requirements. As some examples:

  • Minimize TTFT: Select worker with lowest potential_prefill_tokens
  • Maximize cache reuse: Use best_worker() which considers both prefill and decode loads
  • Balance load: Consider both potential_prefill_tokens and potential_decode_blocks together

See Router Design for architecture details and the cost function algorithm.

KV Event Publishing for Custom Engines

The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions. There are two main publishing pathways: direct NATS publishing (KvEventPublisher) which publishes events directly to NATS and is the simplest approach for custom engines, and ZMQ-based publishing for engines with ZMQ event output (like vLLM) which uses a ZMQ publisher in the engine and ZmqKvEventPublisher to forward events to NATS.

Event Types

The KV cache supports three event types:

Event TypeDescriptionWhen to Publish
BlockStoredNew blocks added to cacheAfter KV cache allocation succeeds
BlockRemovedBlocks evicted from cacheWhen blocks are evicted or freed
AllBlocksClearedAll blocks removedOn cache reset or worker restart

Event Structure

Each event contains:

  • event_id: Monotonically increasing identifier per worker
  • dp_rank: Data parallel rank (0 if DP not enabled)
  • data: One of Stored, Removed, or Cleared

For BlockStored events:

  • token_ids: List of token IDs for the stored blocks
  • block_hashes: List of sequence block hashes from the engine’s block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
  • num_block_tokens: Number of tokens per block (should all equal kv_block_size)
  • parent_hash: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
  • lora_id: LoRA adapter ID (0 if not using LoRA)

For BlockRemoved events:

  • block_hashes: List of sequence block hashes being evicted

The KvEventPublisher class publishes events directly to NATS. This is the simplest approach for custom engines.

When to use:

  • Building a custom inference engine from scratch
  • Your engine doesn’t have a ZMQ-based event system
  • You want the simplest integration path

Basic Setup

1from dynamo.llm import KvEventPublisher
2
3class CustomEnginePublisher:
4 def __init__(self, component, worker_id: int, block_size: int, dp_rank: int = 0):
5 self.block_size = block_size
6 self.event_id = 0
7 self.kv_publisher = KvEventPublisher(
8 component=component,
9 worker_id=worker_id,
10 kv_block_size=block_size,
11 dp_rank=dp_rank,
12 enable_local_indexer=False,
13 )
14
15 def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
16 lora_id: int = 0, parent_hash: int | None = None):
17 """Call after KV cache blocks are allocated."""
18 self.event_id += 1
19 num_block_tokens = [self.block_size] * len(block_hashes)
20 self.kv_publisher.publish_stored(
21 event_id=self.event_id,
22 token_ids=token_ids,
23 num_block_tokens=num_block_tokens,
24 block_hashes=block_hashes,
25 lora_id=lora_id,
26 parent_hash=parent_hash,
27 )
28
29 def on_blocks_removed(self, block_hashes: list[int]):
30 """Call when KV cache blocks are evicted."""
31 self.event_id += 1
32 self.kv_publisher.publish_removed(event_id=self.event_id, block_hashes=block_hashes)

Integration with Your Engine

1from dynamo.llm import register_llm
2
3async def main():
4 # Register your engine with Dynamo
5 component, endpoint = await register_llm(
6 model="my-model",
7 generator=my_generate_fn,
8 )
9
10 # Initialize publisher
11 publisher = CustomEnginePublisher(
12 component=component,
13 worker_id=endpoint.connection_id(),
14 block_size=16, # Match your engine's block size
15 )
16
17 # Hook into your engine's cache events
18 def on_prefill_complete(request_id, token_ids, blocks):
19 block_hashes = [block.hash for block in blocks]
20 publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)
21
22 def on_cache_eviction(evicted_blocks):
23 block_hashes = [block.hash for block in evicted_blocks]
24 publisher.on_blocks_removed(block_hashes=block_hashes)

Option 2: ZMQ-based Publishing

For engines that publish events via ZMQ (like vLLM), this option uses two components that work together:

  1. ZMQ Publisher (in your engine) - Publishes events to a ZMQ socket
  2. ZmqKvEventPublisher (Dynamo binding) - Subscribes to ZMQ and forwards to NATS

When to use:

  • Your engine already has a ZMQ-based event system (like vLLM)
  • You’re integrating with a consolidator (like KVBM)
  • You want to decouple event publishing from your engine’s main loop

Part 1: ZMQ Subscriber (Dynamo Bindings)

If your engine already publishes to ZMQ, use ZmqKvEventPublisher to subscribe and forward to NATS:

1from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
2
3# Configure the ZMQ subscriber
4config = ZmqKvEventPublisherConfig(
5 worker_id=endpoint.connection_id(),
6 kv_block_size=block_size,
7 zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
8 zmq_topic="", # Subscribe to all topics
9 enable_local_indexer=False,
10)
11
12# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
13kv_publisher = ZmqKvEventPublisher(
14 component=component,
15 config=config,
16)

Part 2: ZMQ Publisher (Pure Python)

If your engine needs to publish to ZMQ (e.g., for consolidator integration), implement the ZMQ protocol:

1import zmq
2import msgpack
3import time
4
5class ZmqKvEventPublisher:
6 """Pure Python ZMQ publisher for KV events (vLLM-compatible format)."""
7
8 def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
9 self.kv_block_size = kv_block_size
10 self.topic = topic
11 self.ctx = zmq.Context()
12 self.socket = self.ctx.socket(zmq.PUB)
13 self.socket.bind(zmq_endpoint)
14 self.sequence = 0
15 self.data_parallel_rank = 0
16
17 def _to_signed_i64(self, value: int | None) -> int | None:
18 if value is None:
19 return None
20 return value - 0x10000000000000000 if value > 0x7FFFFFFFFFFFFFFF else value
21
22 def publish_stored(self, event_id: int, token_ids: list[int], num_block_tokens: list[int],
23 block_hashes: list[int], lora_id: int = 0, parent_hash: int | None = None):
24 event = {
25 "type": "BlockStored",
26 "block_hashes": [self._to_signed_i64(h) for h in block_hashes],
27 "parent_block_hash": self._to_signed_i64(parent_hash),
28 "token_ids": token_ids,
29 "block_size": self.kv_block_size,
30 "lora_id": lora_id if lora_id != 0 else None,
31 }
32 self._publish_event(event)
33
34 def publish_removed(self, event_id: int, block_hashes: list[int]):
35 event = {"type": "BlockRemoved", "block_hashes": [self._to_signed_i64(h) for h in block_hashes]}
36 self._publish_event(event)
37
38 def publish_all_cleared(self):
39 self._publish_event({"type": "AllBlocksCleared"})
40
41 def _publish_event(self, event: dict):
42 batch = [time.time(), [event], self.data_parallel_rank]
43 payload = msgpack.packb(batch, use_bin_type=True)
44 sequence_bytes = self.sequence.to_bytes(8, byteorder="big")
45 self.sequence += 1
46 self.socket.send_multipart([self.topic.encode(), sequence_bytes, payload])
47
48 def shutdown(self):
49 self.socket.close()
50 self.ctx.term()

ZMQ Wire Format

The ZMQ message format (compatible with vLLM):

FrameDescription
1Topic (empty string for all topics)
2Sequence number (8 bytes, big-endian)
3Msgpack payload: [timestamp, [events], dp_rank]

Each event in the payload is a dictionary with type field (BlockStored, BlockRemoved, or AllBlocksCleared).

Best Practices

  1. Event IDs must be monotonically increasing per worker (use a thread-safe counter)

  2. Block size must match your engine’s actual kv_block_size

  3. parent_hash is required for all blocks except the first in a sequence - it links blocks to enable prefix matching

See Also