KV Event Publishing for Custom Engines

View as Markdown

This document explains how to implement KV event publishing for custom inference engines, enabling them to participate in Dynamo’s KV cache-aware routing.

Overview

The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions.

There are two main publishing pathways:

  1. Direct NATS publishing (KvEventPublisher) - Publishes events directly to NATS. Simplest approach for custom engines.
  2. ZMQ-based publishing - For engines with ZMQ event output (like vLLM). Uses a ZMQ publisher in the engine and ZmqKvEventPublisher to forward events to NATS.

Event Types

The KV cache supports three event types:

Event TypeDescriptionWhen to Publish
BlockStoredNew blocks added to cacheAfter KV cache allocation succeeds
BlockRemovedBlocks evicted from cacheWhen blocks are evicted or freed
AllBlocksClearedAll blocks removedOn cache reset or worker restart

Event Structure

Each event contains:

  • event_id: Monotonically increasing identifier per worker
  • dp_rank: Data parallel rank (0 if DP not enabled)
  • data: One of Stored, Removed, or Cleared

For BlockStored events:

  • token_ids: List of token IDs for the stored blocks
  • block_hashes: List of sequence block hashes from the engine’s block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
  • num_block_tokens: Number of tokens per block (should all equal kv_block_size)
  • parent_hash: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
  • lora_id: LoRA adapter ID (0 if not using LoRA)

For BlockRemoved events:

  • block_hashes: List of sequence block hashes being evicted

The KvEventPublisher class publishes events directly to NATS. This is the simplest approach for custom engines.

When to use:

  • Building a custom inference engine from scratch
  • Your engine doesn’t have a ZMQ-based event system
  • You want the simplest integration path

Basic Setup

1from dynamo.llm import KvEventPublisher
2
3class CustomEnginePublisher:
4 def __init__(self, component, worker_id: int, block_size: int, dp_rank: int = 0):
5 self.block_size = block_size
6 self.event_id = 0
7 self.kv_publisher = KvEventPublisher(
8 component=component,
9 worker_id=worker_id,
10 kv_block_size=block_size,
11 dp_rank=dp_rank,
12 enable_local_indexer=False,
13 )
14
15 def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
16 lora_id: int = 0, parent_hash: int | None = None):
17 """Call after KV cache blocks are allocated."""
18 self.event_id += 1
19 num_block_tokens = [self.block_size] * len(block_hashes)
20 self.kv_publisher.publish_stored(
21 event_id=self.event_id,
22 token_ids=token_ids,
23 num_block_tokens=num_block_tokens,
24 block_hashes=block_hashes,
25 lora_id=lora_id,
26 parent_hash=parent_hash,
27 )
28
29 def on_blocks_removed(self, block_hashes: list[int]):
30 """Call when KV cache blocks are evicted."""
31 self.event_id += 1
32 self.kv_publisher.publish_removed(event_id=self.event_id, block_hashes=block_hashes)

Integration with Your Engine

1from dynamo.llm import register_llm
2
3async def main():
4 # Register your engine with Dynamo
5 component, endpoint = await register_llm(
6 model="my-model",
7 generator=my_generate_fn,
8 )
9
10 # Initialize publisher
11 publisher = CustomEnginePublisher(
12 component=component,
13 worker_id=endpoint.connection_id(),
14 block_size=16, # Match your engine's block size
15 )
16
17 # Hook into your engine's cache events
18 def on_prefill_complete(request_id, token_ids, blocks):
19 block_hashes = [block.hash for block in blocks]
20 publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)
21
22 def on_cache_eviction(evicted_blocks):
23 block_hashes = [block.hash for block in evicted_blocks]
24 publisher.on_blocks_removed(block_hashes=block_hashes)

Option 2: ZMQ-based Publishing

For engines that publish events via ZMQ (like vLLM), this option uses two components that work together:

  1. ZMQ Publisher (in your engine) - Publishes events to a ZMQ socket
  2. ZmqKvEventPublisher (Dynamo binding) - Subscribes to ZMQ and forwards to NATS

When to use:

  • Your engine already has a ZMQ-based event system (like vLLM)
  • You’re integrating with a consolidator (like KVBM)
  • You want to decouple event publishing from your engine’s main loop

Part 1: ZMQ Subscriber (Dynamo Bindings)

If your engine already publishes to ZMQ, use ZmqKvEventPublisher to subscribe and forward to NATS:

1from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
2
3# Configure the ZMQ subscriber
4config = ZmqKvEventPublisherConfig(
5 worker_id=endpoint.connection_id(),
6 kv_block_size=block_size,
7 zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
8 zmq_topic="", # Subscribe to all topics
9 enable_local_indexer=False,
10)
11
12# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
13kv_publisher = ZmqKvEventPublisher(
14 component=component,
15 config=config,
16)

Part 2: ZMQ Publisher (Pure Python)

If your engine needs to publish to ZMQ (e.g., for consolidator integration), implement the ZMQ protocol:

1import zmq
2import msgpack
3import time
4
5class ZmqKvEventPublisher:
6 """Pure Python ZMQ publisher for KV events (vLLM-compatible format)."""
7
8 def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
9 self.kv_block_size = kv_block_size
10 self.topic = topic
11 self.ctx = zmq.Context()
12 self.socket = self.ctx.socket(zmq.PUB)
13 self.socket.bind(zmq_endpoint)
14 self.sequence = 0
15 self.data_parallel_rank = 0
16
17 def _to_signed_i64(self, value: int | None) -> int | None:
18 if value is None:
19 return None
20 return value - 0x10000000000000000 if value > 0x7FFFFFFFFFFFFFFF else value
21
22 def publish_stored(self, event_id: int, token_ids: list[int], num_block_tokens: list[int],
23 block_hashes: list[int], lora_id: int = 0, parent_hash: int | None = None):
24 event = {
25 "type": "BlockStored",
26 "block_hashes": [self._to_signed_i64(h) for h in block_hashes],
27 "parent_block_hash": self._to_signed_i64(parent_hash),
28 "token_ids": token_ids,
29 "block_size": self.kv_block_size,
30 "lora_id": lora_id if lora_id != 0 else None,
31 }
32 self._publish_event(event)
33
34 def publish_removed(self, event_id: int, block_hashes: list[int]):
35 event = {"type": "BlockRemoved", "block_hashes": [self._to_signed_i64(h) for h in block_hashes]}
36 self._publish_event(event)
37
38 def publish_all_cleared(self):
39 self._publish_event({"type": "AllBlocksCleared"})
40
41 def _publish_event(self, event: dict):
42 batch = [time.time(), [event], self.data_parallel_rank]
43 payload = msgpack.packb(batch, use_bin_type=True)
44 sequence_bytes = self.sequence.to_bytes(8, byteorder="big")
45 self.sequence += 1
46 self.socket.send_multipart([self.topic.encode(), sequence_bytes, payload])
47
48 def shutdown(self):
49 self.socket.close()
50 self.ctx.term()

ZMQ Wire Format

The ZMQ message format (compatible with vLLM):

FrameDescription
1Topic (empty string for all topics)
2Sequence number (8 bytes, big-endian)
3Msgpack payload: [timestamp, [events], dp_rank]

Each event in the payload is a dictionary with type field (BlockStored, BlockRemoved, or AllBlocksCleared).

Best Practices

  1. Event IDs must be monotonically increasing per worker (use a thread-safe counter)

  2. Block size must match your engine’s actual kv_block_size

  3. parent_hash is required for all blocks except the first in a sequence - it links blocks to enable prefix matching

See Also