Avoiding hash collisions in prefix cache

Prefix caching uses hash() to hash blocks of tokens for reuse. Due to hash collisions, context data may be leaked, which can be quite critical.

I want to replace hash() by hashlib.sha256() to remove this vulnerability. As performance is key, I compared the two and see an increase of 4ms for hashing 50k input tokens (block size 16, 3125 blocks, Apple M2, Python 3.12).

  • hash(): 0.83 ms (mean error 0.011ms)
  • sha256(): 3.64 ms (mean error 0.001ms)

The impact on output tokens is even less, as a single generation create only a few tokens (e.g., for 1000 tokens response time would increase by about 0.1 ms).

So overall a minor increase in latency.

wdyt?

Thanks for bringing up this topic and the proposed change sounds reasonable. As you pointed out, the performance is the key, so it would be good to back up with end-to-end benchmarking results.

Also cc @ywang96 @Woosuk

Try xxhash for high-performance hashing (it operates at RAM speed limit).

Link to the PR: Use SHA256 instead of hash() in prefix caching by dr75 · Pull Request #15297 · vllm-project/vllm · GitHub

I did some benchmarking (exclude xxhash because it’s not cryptographically secure):

Benchmark hashing 3125 blocks x 16 tokens = 50000 tokens

Baseline
hash(): 0.33636ms

Hash functions without serialization
sha256: 2.01066ms
blake2b: 1.81043ms
blake3: 2.42824ms

Hash functions with serialization
sha256 with orjson: 3.02108ms
sha256 with pickle: 3.99483ms
sha256 with msgpack: 4.71474ms
sha256 with str: 5.76562ms
sha256 with json: 10.25009ms
==================================================
blake2b with orjson: 2.64448ms
blake2b with pickle: 3.72862ms
blake2b with msgpack: 4.55445ms
blake2b with str: 5.69848ms
blake2b with json: 9.73685ms
==================================================
blake3 with orjson: 3.33751ms
blake3 with pickle: 4.37640ms
blake3 with msgpack: 5.01523ms
blake3 with str: 6.37801ms
blake3 with json: 12.58258ms
==================================================

Looks like hashlib.blake2b + orjson is the fastest safe solution.

Benchmark script:

import hashlib
import pickle
import json
import msgpack
import orjson
from blake3 import blake3
import time
from typing import Callable, Any

def timer(fn: Callable, key: str, n_block: int):
    tic = time.perf_counter()
    for _ in range(n_block):
        fn()
    toc = time.perf_counter()
    print(f"{key}: {(toc - tic) * 1000:.5f}ms")

def serialize_pickle(data: Any) -> bytes:
    return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)

def serialize_orjson(data: Any) -> bytes:
    return orjson.dumps(data, option=orjson.OPT_SORT_KEYS)

def serialize_msgpack(data: Any) -> bytes:
    return msgpack.packb(data)

def serialize_str(data: Any) -> bytes:
    return str(data).encode("utf-8")

def serialize_json(data: Any) -> bytes:
    return json.dumps(data).encode("utf-8")

def get_hash_fn(method: str = "builtin", serialize_method: str = "none"):
    parent_hash = hash("None")
    token_ids = tuple(range(16))
    extra_keys = ("hash1", "hash2")
    data = (parent_hash, token_ids, extra_keys)
    bytes_data = serialize_pickle(data)
    
    if method == "builtin":
        return lambda: hash(data)
    
    # Serialization methods
    serializers = {
        "pickle": serialize_pickle,
        "orjson": serialize_orjson,
        "msgpack": serialize_msgpack,
        "str": serialize_str,
        "json": serialize_json,
        "none": lambda _: bytes_data,
    }
    
    # Hash methods
    hash_functions = {
        "sha256": hashlib.sha256,
        "blake2b": hashlib.blake2b,
        "blake3": blake3,
    }
    
    if serialize_method not in serializers:
        raise ValueError(f"Unknown serialization method: {serialize_method}")
    if method not in hash_functions:
        raise ValueError(f"Unknown hash method: {method}")
    
    serialize_fn = serializers[serialize_method]
    hash_fn = hash_functions[method]

    def _hash_fn():
        inp_bytes = serialize_fn(data)
        return int.from_bytes(hash_fn(inp_bytes).digest(), byteorder="big")
    return _hash_fn

def main():
    n_block = 3125
    print(f"Benchmark hashing {n_block} blocks x 16 tokens = {n_block * 16} tokens")
    
    print("\nBaseline")
    timer(get_hash_fn(method="builtin"), "hash()", n_block)
    
    print("\nHash functions without serialization")
    for method in ["sha256", "blake2b", "blake3"]:
        timer(get_hash_fn(method=method, serialize_method="none"), method, n_block)

    print("\nHash functions with serialization")
    for method in ["sha256", "blake2b", "blake3"]:
        for serialize_method in ["orjson", "pickle", "msgpack", "str", "json"]:
            timer(get_hash_fn(method=method, serialize_method=serialize_method), f"{method} with {serialize_method}", n_block)
        print("=" * 50)

if __name__ == "__main__":
    main()
2 Likes

Actually orjson does not work with int > 64 bit. So pickle seems the best alternative to me.

Also attaching the code I used (results in PR), which hashes the same data types as in vLLM.

Also added blake2b

import hashlib
import math
import pickle
import random
import time
from typing import List, Optional
import blake3

import hashlib
import pickle
import json
import marshal

def sha256_pickle(input) -> int:
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
    return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big")

def sha256_json(input) -> int:
    input_str = json.dumps(input, sort_keys=True)
    input_bytes = input_str.encode("utf-8")
    return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big")

def sha256_str(input) -> int:
    input_bytes = str(input).encode("utf-8")
    return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big")

def sha256_repr(input) -> int:
    input_bytes = repr(input).encode("utf-8")
    return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big")

def sha256_marshal(input) -> int:
    input_bytes = marshal.dumps(input)
    return int.from_bytes(hashlib.sha256(input_bytes).digest(), byteorder="big")

def blake3_pickle(input) -> int:
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
    return int.from_bytes(blake3.blake3(input_bytes).digest(), byteorder="big")

def blake2b_pickle(input) -> int:
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
    return int.from_bytes(hashlib.blake2b(input_bytes).digest(), byteorder="big")

# for performance comparison only
def hash_pickle(input) -> int:
    key = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
    return hash(key)

hash_strategies = {
    "sha256_pickle": sha256_pickle,
    "sha256_json": sha256_json,
    "sha256_str": sha256_str,
    "sha256_repr": sha256_repr,
    "sha256_marshal": sha256_marshal,
    "hash": hash,
    "hash_pickle": hash_pickle,
    "blake3_pickle": blake3_pickle,
    "blake2b_pickle": blake2b_pickle,
}

_block_hash = None
_none_hash = None

def init_hash(hash_name: str):
    global _block_hash
    global _none_hash

    if hash_name in hash_strategies:
        _block_hash = hash_strategies[hash_name]
    else:
        raise ValueError(f"Unknown hash function: {hash_name}")

    none_key = tuple(int(ord(c)) for c in 'None')
    _none_hash = _block_hash(none_key)

def compute_hash(is_first_block: bool,
         prev_block_hash: Optional[int],
         cur_block_token_ids: List[int],
         extra_hash: Optional[int] = None) -> int:
    if is_first_block and prev_block_hash is None:
        prev_block_hash = _none_hash
    return _block_hash((int(is_first_block), prev_block_hash, *cur_block_token_ids,
                    int(extra_hash)))

def split_tokens(tokens: List[int], block_size: int) -> List[List[int]]:
    return [tokens[i:i + block_size] for i in range(0, len(tokens), block_size)]

def test(hash_name: str):
    init_hash(hash_name)
    random.seed(0)
    num_tokens = 50_000

    # warmup
    h = compute_hash(False, _none_hash, (1, 2, 3, 4), 123)

    t = []
    h = _none_hash
    for i in range(50):
        token_ids = [random.randint(0, 100_000) for _ in range(num_tokens)]
        chain = split_tokens(token_ids, 16)
        is_first_block = i == 0

        t0 = time.time()
        for chain_block in chain:
            h = compute_hash(is_first_block, h, chain_block, 123)
            assert h is not None

        t.append((time.time() - t0) * 1000*1000)

    t_mean = sum(t) / len(t)
    t_std = math.sqrt(sum((t - t_mean)**2 for t in t) / len(t))
    t_mean_error = t_std / math.sqrt(len(t))

    print(f"{hash_name}, {len(chain)} blocks")
    print(f"\tmean: {t_mean / 1000:.2f}ms")
    print(f"\tstd: {t_std / 1000:.2f}ms")
    print(f"\tmean error: {t_mean_error:.2f}us")
    print(f"\tper block: {t_mean / len(chain):.3f}us")
    return t_mean/1000

functions = hash_strategies.keys()
t = {f:test(f) for f in functions}

print(f"| Method           | Time (ms) | Overhead (ms) |")
print(f"|------------------|-----------|---------------|")
print(f"| hash             | {t['hash']:9.2f} | {t['hash'] - t['hash']:12.2f}  |")
print(f"| hash+pickle      | {t['hash_pickle']:9.2f} | {t['hash_pickle'] - t['hash']:12.2f}  |")
print(f"| blake2b+pickle   | {t['blake2b_pickle']:9.2f} | {t['blake2b_pickle'] - t['hash']:12.2f}  |")
print(f"| blake3+pickle    | {t['blake3_pickle']:9.2f} | {t['blake3_pickle'] - t['hash']:12.2f}  |")
print(f"| sha256+pickle    | {t['sha256_pickle']:9.2f} | {t['sha256_pickle'] - t['hash']:12.2f}  |")
print(f"| sha256+json      | {t['sha256_json']:9.2f} | {t['sha256_json'] - t['hash']:12.2f}  |")
print(f"| sha256+str       | {t['sha256_str']:9.2f} | {t['sha256_str'] - t['hash']:12.2f}  |")
print(f"| sha256+repr      | {t['sha256_repr']:9.2f} | {t['sha256_repr'] - t['hash']:12.2f}  |")
print(f"| sha256+marshal   | {t['sha256_marshal']:9.2f} | {t['sha256_marshal'] - t['hash']:12.2f}  |")
print(f"|(sha256 vs. hash) | ({t['sha256_pickle']:8.2f})| ({t['sha256_pickle'] - t['hash_pickle']:11.2f}) |")

Apple M2

Method Time (ms) Overhead (ms)
hash 0.68 0.00
hash+pickle 1.94 1.26
blake2b+pickle 3.83 3.15
blake3+pickle 4.14 3.46
sha256+pickle 3.58 2.90
sha256+json 8.85 8.17
sha256+str 5.62 4.94
sha256+repr 5.64 4.96
sha256+marshal 4.90 4.21
(sha256 vs. hash) ( 3.58) ( 1.64)

Intel Xeon @ 2.20GHz

Method Time (ms) Overhead (ms)
hash 1.53 0.00
hash+pickle 4.42 2.89
blake2b+pickle 7.30 5.77
blake3+pickle 7.53 6.00
sha256+pickle 7.78 6.25
sha256+json 20.79 19.25
sha256+str 12.81 11.28
sha256+repr 12.90 11.37
sha256+marshal 11.12 9.59
(sha256 vs. hash) ( 7.78) ( 3.36)

Attaching end-to-end measurement for SHA256 (v1_sha256) vs. hash() (v1_main).

  • hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
  • Intel Xeon @ 2.2GhZ
  • Nvidia L4
  • Cache miss: initial request with document
  • Cache hit: follow up request with same document and additional query