KV Events for Custom Engines

View as Markdown

KV Event Publishing for Custom Engines

This document explains how to implement KV event publishing for custom inference engines, enabling them to participate in Dynamo’s KV cache-aware routing.

Overview

The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions.

Events are published over the Dynamo event plane, a transport-agnostic pub/sub layer that supports both NATS and ZMQ backends (see Event Plane for details). The KvEventPublisher binding handles all transport concerns — your engine code does not interact with the event plane directly.

KvEventPublisher supports two publishing modes:

  1. Direct publishing — Your engine calls publish_stored() / publish_removed() to push events directly over the event plane. Simplest approach for custom engines.
  2. ZMQ relay — For engines that emit raw KV events over a ZMQ socket (like vLLM and SGLang). The publisher subscribes to the ZMQ endpoint and relays events to the event plane automatically.

Event Types

The KV cache supports three event types:

Event TypeDescriptionWhen to Publish
BlockStoredNew blocks added to cacheAfter KV cache allocation succeeds
BlockRemovedBlocks evicted from cacheWhen blocks are evicted or freed
AllBlocksClearedAll blocks removedOn cache reset or worker restart

Event Structure

Each event contains:

  • event_id: Monotonically increasing identifier per worker (managed internally by the publisher)
  • dp_rank: Data parallel rank (0 if DP not enabled)
  • data: One of Stored, Removed, or Cleared

For BlockStored events:

  • token_ids: List of token IDs for the stored blocks
  • block_hashes: List of sequence block hashes from the engine’s block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
  • num_block_tokens: Number of tokens per block (should all equal kv_block_size)
  • parent_hash: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
  • lora_name: LoRA adapter name string (omit or None for base model). When set, the adapter name is incorporated into block hash computation so that blocks for different LoRA adapters (or the base model) are never conflated.

For BlockRemoved events:

  • block_hashes: List of sequence block hashes being evicted

Call publish_stored() and publish_removed() directly from your engine code. The publisher handles event IDs, serialization, and transport.

When to use:

  • Building a custom inference engine from scratch
  • Your engine doesn’t have a ZMQ-based event system
  • You want the simplest integration path

Basic Setup

1from dynamo.llm import KvEventPublisher
2
3class CustomEnginePublisher:
4 def __init__(self, component, block_size: int, dp_rank: int = 0):
5 self.block_size = block_size
6 self.kv_publisher = KvEventPublisher(
7 component=component,
8 kv_block_size=block_size,
9 dp_rank=dp_rank,
10 )
11
12 def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
13 parent_hash: int | None = None,
14 lora_name: str | None = None):
15 """Call after KV cache blocks are allocated."""
16 num_block_tokens = [self.block_size] * len(block_hashes)
17 self.kv_publisher.publish_stored(
18 token_ids=token_ids,
19 num_block_tokens=num_block_tokens,
20 block_hashes=block_hashes,
21 parent_hash=parent_hash,
22 lora_name=lora_name,
23 )
24
25 def on_blocks_removed(self, block_hashes: list[int]):
26 """Call when KV cache blocks are evicted."""
27 self.kv_publisher.publish_removed(block_hashes=block_hashes)

Integration with Your Engine

1from dynamo.llm import register_model
2
3async def main():
4 component, endpoint = await register_model(
5 model="my-model",
6 generator=my_generate_fn,
7 )
8
9 publisher = CustomEnginePublisher(
10 component=component,
11 block_size=16, # Match your engine's block size
12 )
13
14 def on_prefill_complete(request_id, token_ids, blocks):
15 block_hashes = [block.hash for block in blocks]
16 publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)
17
18 def on_cache_eviction(evicted_blocks):
19 block_hashes = [block.hash for block in evicted_blocks]
20 publisher.on_blocks_removed(block_hashes=block_hashes)

ZMQ Relay (For Engines with Raw KV Events)

For engines that already publish raw KV events over a ZMQ socket (like vLLM and SGLang), use the same KvEventPublisher with a zmq_endpoint. The publisher subscribes to the ZMQ socket and relays events to the event plane automatically.

When to use:

  • Your engine already publishes KV events via ZMQ (like vLLM or SGLang)
  • You want to decouple event publishing from your engine’s main loop

Setup

Pass zmq_endpoint (and optional zmq_topic) to the same KvEventPublisher:

1from dynamo.llm import KvEventPublisher
2
3kv_publisher = KvEventPublisher(
4 component=component,
5 kv_block_size=block_size,
6 zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
7 zmq_topic="", # Subscribe to all topics
8)

No further calls to publish_stored() / publish_removed() are needed — the publisher reads events from the ZMQ socket and forwards them automatically.

ZMQ Wire Format

The ZMQ message format (compatible with vLLM / SGLang):

FrameDescription
1Topic (empty string for all topics)
2Sequence number (8 bytes, big-endian)
3Msgpack payload: [timestamp, [events], dp_rank]

Each event in the payload is a dictionary with a type field (BlockStored, BlockRemoved, or AllBlocksCleared).

For BlockStored:

1{
2 "type": "BlockStored",
3 "block_hashes": [signed_i64, ...], # Sequence block hashes
4 "parent_block_hash": signed_i64 | None, # Parent hash
5 "token_ids": [int, ...], # Token IDs
6 "block_size": int, # Tokens per block
7 "lora_name": str | None, # LoRA adapter name
8}

For BlockRemoved:

1{
2 "type": "BlockRemoved",
3 "block_hashes": [signed_i64, ...],
4}

For AllBlocksCleared:

1{"type": "AllBlocksCleared"}

API Reference

KvEventPublisher

1KvEventPublisher(
2 component: Component,
3 kv_block_size: int,
4 dp_rank: int = 0,
5 enable_local_indexer: bool = False,
6 zmq_endpoint: str | None = None, # Set for relay mode
7 zmq_topic: str | None = None, # Defaults to "" when zmq_endpoint is set
8)
ParameterDescription
componentThe Dynamo component this publisher belongs to
kv_block_sizeNumber of tokens per block (must be > 0, must match your engine)
dp_rankData parallel rank (defaults to 0)
enable_local_indexerEnable a worker-local KV indexer for direct overlap queries
zmq_endpointZMQ endpoint to subscribe to for relay mode (e.g. "tcp://127.0.0.1:5557")
zmq_topicZMQ topic filter (defaults to "" = all topics)

publish_stored()

1publish_stored(
2 token_ids: list[int],
3 num_block_tokens: list[int],
4 block_hashes: list[int],
5 parent_hash: int | None = None,
6 block_mm_infos: list[dict | None] | None = None,
7 lora_name: str | None = None,
8)

Publish a block-stored event. Event IDs are managed internally. When lora_name is provided, the adapter name is mixed into block hash computation so blocks cached under different adapters produce distinct hashes.

publish_removed()

1publish_removed(block_hashes: list[int])

Publish a block-removed event. Event IDs are managed internally.

shutdown()

1shutdown()

Stop background tasks (ZMQ listener, event forwarding).

Best Practices

  1. kv_block_size must match your engine’s actual block size.

  2. parent_hash is required for all blocks except the first in a sequence — it links blocks to enable prefix matching.

  3. Block hashes are signed 64-bit integers in the Python API. The publisher handles conversion internally.

  4. Event ordering is automatic — the publisher assigns monotonically increasing event IDs. You do not need to track event IDs yourself.

See Also

  • Event Plane: Transport options (NATS, ZMQ) and configuration
  • Router Guide: Configuration, tuning, and production setup
  • Router Design: Architecture details and event transport modes