Skip to content

Instantly share code, notes, and snippets.

@surya501
Created March 13, 2026 05:45
Show Gist options
  • Select an option

  • Save surya501/3becd5e9b8aa870ed0afb6b912188841 to your computer and use it in GitHub Desktop.

Select an option

Save surya501/3becd5e9b8aa870ed0afb6b912188841 to your computer and use it in GitHub Desktop.
LanceDB vector store: ingest, search, benchmark & visual search (PEP 723 single-file script)
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "h5py",
# "lancedb",
# "numpy",
# "pillow",
# "tqdm",
# ]
# ///
"""
LanceDB vector store: ingest, search, and benchmark.
Single-file PEP 723 script — run directly with `uv run`, no setup needed.
Dependencies are declared in the inline metadata above and resolved
automatically by uv.
Ingestion is a one-time offline step. Run it locally (or in CI) to produce
the LanceDB store directory, then bake that directory into the Docker image.
The server only needs the reusable API section (marked below) to query it.
Subcommands:
# Ingest 117K vectors from H5 + CSV into LanceDB (one-time):
uv run lancedb_tool.py ingest --copies 1
# Ingest 825K vectors (7 synthetic copies, for benchmarking only):
uv run lancedb_tool.py ingest --copies 7
# Quick ingest test (1000 vectors):
uv run lancedb_tool.py ingest --copies 1 --limit 1000
# Search latency benchmark (sequential + async gather):
uv run lancedb_tool.py bench --n 100 --top-k 10
# Visual search — saves composite PNGs to results/:
uv run lancedb_tool.py visual --n 3 --top-k 10
Expects:
mount/embeddings/embeddings.h5 — H5 file with "keys" and "vectors" datasets
mount/stamp_database.csv — CSV with image, country, shape, aspect_ratio
mount/cropped/ — PNG crops named by UUID (for visual mode)
The reusable API section (marked below) can be copied directly into a
FastAPI/Flask project. Everything else is CLI/benchmark tooling.
"""
from __future__ import annotations
import argparse
import asyncio
import csv
import functools
import gc
import random
import shutil
import sqlite3
import time
from dataclasses import dataclass
from pathlib import Path
import h5py
import lancedb
import numpy as np
from tqdm import tqdm
# ===================================================================
# REUSABLE API — copy this section into your FastAPI / Flask project
#
# Functions:
# get_table() → async; call once at startup, reuse the handle
# search_single() → async; one query, optional .where() pre-filter
# search_batch() → async; N queries via asyncio.gather
# build_filter_params() → build FilterParams from a metadata dict
# build_where_clause() → SQL WHERE string for .where() pre-filter
#
# Data types:
# FilterParams, SearchResult
#
# For Flask (sync), wrap calls with asyncio.run():
# table = asyncio.run(get_table())
# results = asyncio.run(search_single(table, vec))
# ===================================================================
LANCE_PATH = Path("mount/embeddings/lancedb_store_800k")
TABLE_NAME = "stamp_embeddings"
ASPECT_RATIO_TOL = 0.34
@dataclass
class FilterParams:
country: str
valid_shapes: set[str]
ar_min: float
ar_max: float
@dataclass
class SearchResult:
uid: str
distance: float
metadata: dict
async def get_table(path: Path = LANCE_PATH) -> lancedb.table.AsyncTable:
"""Open LanceDB table via the async API.
Typical startup cost: 2-15ms. Store the returned table in app state
(e.g. FastAPI lifespan / Flask global) and reuse across requests.
"""
db = await lancedb.connect_async(str(path))
return await db.open_table(TABLE_NAME)
def _valid_shapes_for(shape: str) -> set[str]:
"""Map a query shape to the set of compatible target shapes."""
shape = shape.upper()
if shape == "HR":
return {"HR", "S"}
if shape == "VR":
return {"VR", "S"}
return {"HR", "VR", "S"}
def _sql_escape(s: str) -> str:
"""Escape single quotes for SQL string literals."""
return s.replace("'", "''")
def build_where_clause(params: FilterParams) -> str:
"""Build a SQL WHERE clause for LanceDB .where() pre-filter."""
country = _sql_escape(params.country)
shapes_sql = ", ".join(f"'{_sql_escape(s)}'" for s in params.valid_shapes)
return (
f"country = '{country}' "
f"AND shape IN ({shapes_sql}) "
f"AND aspect_ratio >= {params.ar_min} "
f"AND aspect_ratio <= {params.ar_max}"
)
def build_filter_params(meta: dict) -> FilterParams:
"""Build FilterParams from a metadata dict (country, shape, aspect_ratio).
Use in API handlers to construct filters from request data without a
DB round-trip.
"""
ar = meta.get("aspect_ratio", 0.0)
return FilterParams(
country=meta.get("country", ""),
valid_shapes=_valid_shapes_for(meta.get("shape", "")),
ar_min=ar * (1 - ASPECT_RATIO_TOL),
ar_max=ar * (1 + ASPECT_RATIO_TOL),
)
async def search_single(
table: lancedb.table.AsyncTable,
query_vec: list[float],
params: FilterParams | None = None,
top_k: int = 10,
) -> list[SearchResult]:
"""Single vector search, optionally pre-filtered via .where().
Uses vector_search() which is synchronous-to-build, async-to-execute,
avoiding an extra await on query construction.
"""
query = table.vector_search(query_vec).limit(top_k)
if params is not None:
query = query.where(build_where_clause(params))
results = (await query.to_arrow()).to_pylist()
return [
SearchResult(
uid=r["id"],
distance=r["_distance"],
metadata={
"country": r["country"],
"shape": r["shape"],
"aspect_ratio": r["aspect_ratio"],
} if params is not None else {},
)
for r in results
]
async def search_batch(
table: lancedb.table.AsyncTable,
query_vecs: list[list[float]],
params_list: list[FilterParams | None],
top_k: int = 10,
) -> list[list[SearchResult]]:
"""Run N queries concurrently via asyncio.gather.
The Rust engine parallelises internally, so gather mainly keeps the
event loop responsive rather than adding raw throughput.
"""
return await asyncio.gather(*[
search_single(table, vec, params, top_k)
for vec, params in zip(query_vecs, params_list)
])
# ===================================================================
# END REUSABLE API
# ===================================================================
# ===================================================================
# TIMING DECORATOR — for ad-hoc profiling
# ===================================================================
def timed(label: str | None = None):
"""Decorator that prints wall-clock time for sync or async functions."""
def decorator(fn): # type: ignore[return]
name = label or fn.__name__
if asyncio.iscoroutinefunction(fn):
@functools.wraps(fn)
async def async_wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
t0 = time.perf_counter()
result = await fn(*args, **kwargs)
print(f"[{name}] {(time.perf_counter()-t0)*1000:.0f}ms")
return result
return async_wrapper
@functools.wraps(fn)
def sync_wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
t0 = time.perf_counter()
result = fn(*args, **kwargs)
print(f"[{name}] {(time.perf_counter()-t0)*1000:.0f}ms")
return result
return sync_wrapper
return decorator
# ===================================================================
# INGEST
# ===================================================================
@dataclass(frozen=True)
class IngestConfig:
h5_path: Path = Path("mount/embeddings/embeddings.h5")
csv_path: Path = Path("mount/stamp_database.csv")
lance_path: Path = LANCE_PATH
table_name: str = TABLE_NAME
batch_size: int = 500
offset_block: int = 300
CFG = IngestConfig()
def _setup_sqlite(csv_path: Path) -> sqlite3.Connection:
"""Build an in-memory SQLite DB from the stamp CSV for fast batch lookups."""
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute(
"CREATE TABLE stamps ("
" image TEXT PRIMARY KEY,"
" country TEXT,"
" shape TEXT,"
" aspect_ratio REAL"
")"
)
with open(csv_path, newline="", encoding="utf-8") as fh:
reader = csv.DictReader(fh)
conn.executemany(
"INSERT OR IGNORE INTO stamps (image, country, shape, aspect_ratio) "
"VALUES (:image, :country, :shape, :aspect_ratio)",
reader,
)
conn.commit()
return conn
def _get_metadata_batch(conn: sqlite3.Connection, uids: list[str]) -> list[dict]:
"""Fetch metadata for a batch of UIDs from the in-memory SQLite cache."""
placeholders = ",".join("?" for _ in uids)
rows = conn.execute(
f"SELECT image, country, shape, aspect_ratio FROM stamps "
f"WHERE image IN ({placeholders})",
uids,
).fetchall()
lookup = {r["image"]: dict(r) for r in rows}
return [
lookup.get(uid, {"country": "", "shape": "", "aspect_ratio": 0.0})
for uid in uids
]
def _apply_offset(
raw: np.ndarray, start_idx: int, offset_start: int, offset_end: int,
) -> int:
"""Apply mean/2 offset to vectors in the overlap window. Returns count modified."""
lo = max(offset_start - start_idx, 0)
hi = min(offset_end - start_idx, len(raw))
if lo >= hi:
return 0
means = raw[lo:hi].mean(axis=1, keepdims=True)
raw[lo:hi] += means / 2.0
return hi - lo
def ingest(num_copies: int = 7, limit: int | None = None) -> None:
"""Stream H5 vectors + CSV metadata into LanceDB.
The --copies flag creates synthetic duplicates for benchmarking at scale.
Copy 0 = original vectors (normalized). Copies 1-N apply a small offset
to a sliding window of 300 vectors, then re-normalize — producing near-
duplicates that stress the ANN index. IDs are suffixed with -0 through -N.
For production use, set --copies 1 (no duplicates).
"""
if CFG.lance_path.exists():
import subprocess
subprocess.run(["xattr", "-rc", str(CFG.lance_path)],
capture_output=True)
shutil.rmtree(CFG.lance_path)
print(f"Removed existing {CFG.lance_path}")
db = lancedb.connect(str(CFG.lance_path))
print("Building in-memory SQLite metadata cache...")
conn = _setup_sqlite(CFG.csv_path)
table: lancedb.table.Table | None = None
with h5py.File(str(CFG.h5_path), "r") as f:
keys_ds: h5py.Dataset = f["keys"] # type: ignore[assignment]
vecs_ds: h5py.Dataset = f["vectors"] # type: ignore[assignment]
n = keys_ds.shape[0]
if limit is not None:
n = min(n, limit)
dim = vecs_ds.shape[1]
total = n * num_copies
print(f"Ingesting {n:,} x {num_copies} = {total:,} vectors ({dim}-dim)")
t0 = time.perf_counter()
pbar = tqdm(total=total, unit="vec", desc="Ingesting", dynamic_ncols=True)
for copy in range(num_copies):
if copy == 0:
offset_start, offset_end = -1, -1
else:
offset_start = (copy - 1) * CFG.offset_block
offset_end = offset_start + CFG.offset_block
offset_label = (f"offset [{offset_start}:{offset_end}]"
if offset_start >= 0 else "original")
pbar.set_postfix_str(f"copy {copy}/{num_copies-1} ({offset_label})")
for start in range(0, n, CFG.batch_size):
end = min(start + CFG.batch_size, n)
batch_len = end - start
orig_ids = [k.decode("utf-8") for k in keys_ds[start:end]]
ids = [f"{uid}-{copy}" for uid in orig_ids]
raw: np.ndarray = vecs_ds[start:end].astype(np.float32)
if offset_start >= 0:
_apply_offset(raw, start, offset_start, offset_end)
norms = np.linalg.norm(raw, axis=1, keepdims=True)
norms = np.where(norms > 0, norms, 1.0)
vectors = raw / norms
metas = _get_metadata_batch(conn, orig_ids)
records = [
{
"id": ids[i],
"vector": vectors[i].tolist(),
"country": metas[i].get("country", ""),
"shape": metas[i].get("shape", ""),
"aspect_ratio": metas[i].get("aspect_ratio", 0.0),
}
for i in range(batch_len)
]
if table is None:
table = db.create_table(
CFG.table_name, data=records, mode="overwrite",
)
else:
table.add(records)
pbar.update(batch_len)
del ids, orig_ids, raw, vectors, metas, records
gc.collect()
pbar.close()
conn.close()
total_time = time.perf_counter() - t0
print(f"Done: {total:,} vectors in {total_time:.1f}s "
f"({total / total_time:.0f} vec/s)")
# Create indexes
assert table is not None, "No data ingested"
print("\nCreating vector index (IVF_HNSW_SQ, dot product)...")
t0 = time.perf_counter()
table.create_index(metric="dot", index_type="IVF_HNSW_SQ")
print(f" Vector index created in {time.perf_counter()-t0:.1f}s")
for col in ("country", "shape", "aspect_ratio"):
t0 = time.perf_counter()
table.create_scalar_index(col)
print(f" Scalar index on '{col}' created in {time.perf_counter()-t0:.1f}s")
print("All indexes created.")
# ===================================================================
# BENCHMARK
# ===================================================================
H5_PATH = Path("mount/embeddings/embeddings.h5")
async def _lookup_filter_params(
table: lancedb.table.AsyncTable, key: str,
) -> FilterParams:
"""Look up metadata for a key and build filter params (benchmark helper)."""
rows = (
await table.query()
.where(f"id = '{_sql_escape(key)}-0'")
.limit(1)
.to_arrow()
).to_pylist()
meta = rows[0] if rows else {"country": "", "shape": "", "aspect_ratio": 0.0}
return build_filter_params(meta)
async def benchmark(n_queries: int = 100, top_k: int = 10) -> None:
"""Latency benchmark: unfiltered vs native-filtered, all async."""
t_start = time.perf_counter()
t0 = time.perf_counter()
table = await get_table()
count = await table.count_rows()
if count == 0:
print("Table empty. Run ingest first.")
return
print(f"Table: {count:,} vectors (opened in {(time.perf_counter()-t0)*1000:.0f}ms)")
print(f"Running {n_queries} queries (top_k={top_k}, async)...\n")
# Load random query vectors from H5
t0 = time.perf_counter()
with h5py.File(str(H5_PATH), "r") as f:
vecs_ds_: h5py.Dataset = f["vectors"] # type: ignore[assignment]
keys_ds_: h5py.Dataset = f["keys"] # type: ignore[assignment]
total = vecs_ds_.shape[0]
indices = sorted(random.sample(range(total), n_queries))
raw = vecs_ds_[indices].astype(np.float32)
norms = np.linalg.norm(raw, axis=1, keepdims=True)
query_vecs = raw / np.where(norms > 0, norms, 1.0)
query_keys = [keys_ds_[i].decode("utf-8") for i in indices]
print(f"Loaded {n_queries} query vectors in {(time.perf_counter()-t0)*1000:.0f}ms")
vecs_as_lists = [query_vecs[i].tolist() for i in range(n_queries)]
# Warm up
n_warmup = min(10, n_queries)
t0 = time.perf_counter()
print(f"Warming up ({n_warmup} queries)...", end=" ", flush=True)
for i in range(n_warmup):
await search_single(table, vecs_as_lists[i], top_k=top_k)
print(f"done ({(time.perf_counter()-t0)*1000:.0f}ms)")
# Phase 1: Unfiltered (sequential)
print("Phase 1/3: Unfiltered search...", end=" ", flush=True)
lats_unfilt: list[float] = []
for i in range(n_queries):
t0 = time.perf_counter()
await search_single(table, vecs_as_lists[i], top_k=top_k)
lats_unfilt.append((time.perf_counter() - t0) * 1000)
if (i + 1) % max(1, n_queries // 5) == 0:
print(f"{i+1}/{n_queries}", end=" ", flush=True)
print("done")
# Phase 2: Native-filtered (sequential)
print("Phase 2/3: Native-filtered search...", end=" ", flush=True)
lats_filtered: list[float] = []
params_list: list[FilterParams] = []
for i in range(n_queries):
params = await _lookup_filter_params(table, query_keys[i])
params_list.append(params)
t0 = time.perf_counter()
await search_single(table, vecs_as_lists[i], params, top_k)
lats_filtered.append((time.perf_counter() - t0) * 1000)
if (i + 1) % max(1, n_queries // 5) == 0:
print(f"{i+1}/{n_queries}", end=" ", flush=True)
print("done")
# Phase 3: Concurrent batch via asyncio.gather
print("Phase 3/3: Concurrent gather (all queries at once)...", end=" ", flush=True)
t0 = time.perf_counter()
await search_batch(table, vecs_as_lists, [None] * n_queries, top_k)
gather_unfilt = (time.perf_counter() - t0) * 1000
t0 = time.perf_counter()
await search_batch(table, vecs_as_lists, params_list, top_k)
gather_filt = (time.perf_counter() - t0) * 1000
print("done")
def stats(lats: list[float]) -> str:
s = sorted(lats)
return (f"mean={np.mean(s):>8.1f} "
f"median={np.median(s):>8.1f} "
f"p95={s[int(0.95 * len(s))]:>8.1f}")
print(f"\n{'Unfiltered top-' + str(top_k):<30} {stats(lats_unfilt)}")
print(f"{'Native .where() filter':<30} {stats(lats_filtered)}")
print(f"\n{'Gather unfiltered':<30} "
f"total={gather_unfilt:>8.1f}ms "
f"per-query={gather_unfilt / n_queries:>8.1f}ms")
print(f"{'Gather filtered':<30} "
f"total={gather_filt:>8.1f}ms "
f"per-query={gather_filt / n_queries:>8.1f}ms")
print(f"\nTotal wall time: {(time.perf_counter()-t_start)*1000:.0f}ms")
# ===================================================================
# VISUAL — filtered search with composite image output
# ===================================================================
CROPS_DIR = Path("mount/cropped")
RESULTS_DIR = Path("results")
THUMB_H = 200
def _load_crop(uid: str) -> "Image.Image | None":
"""Load a crop image by UUID, or None if not found."""
from PIL import Image
path = CROPS_DIR / f"{uid}.png"
return Image.open(path) if path.exists() else None
def _make_composite(
query_uid: str,
query_meta: dict,
results: list[dict],
out_path: Path,
) -> None:
"""Build a composite PNG: query on left, top-k results on right."""
from PIL import Image, ImageDraw
q_img = _load_crop(query_uid)
if q_img is None:
print(f" Query image not found: {query_uid}")
return
r_imgs = []
for r in results:
rid = r["id"].rsplit("-", 1)[0]
img = _load_crop(rid)
if img:
r_imgs.append((img, r))
if not r_imgs:
print(" No result images found")
return
def scale(img: "Image.Image", h: int) -> "Image.Image":
w = int(img.width * h / img.height)
return img.resize((w, h), Image.LANCZOS)
q_thumb = scale(q_img, THUMB_H * 2)
r_thumbs = [(scale(img, THUMB_H), r) for img, r in r_imgs]
gap = 10
label_h = 20
cols = min(5, len(r_thumbs))
rows_n = (len(r_thumbs) + cols - 1) // cols
max_r_w = max(t[0].width for t in r_thumbs)
grid_w = cols * (max_r_w + gap) - gap
grid_h = rows_n * (THUMB_H + label_h + gap) - gap
total_w = q_thumb.width + gap * 3 + grid_w + gap
total_h = max(q_thumb.height + label_h + gap, grid_h + label_h + gap) + gap * 2
canvas = Image.new("RGB", (total_w, total_h), (255, 255, 255))
draw = ImageDraw.Draw(canvas)
# Query image (left)
qx, qy = gap, gap + label_h
canvas.paste(q_thumb, (qx, qy))
draw.text((qx, gap), f"QUERY: {query_uid[:8]}...", fill=(0, 0, 200))
meta_label = (f"{query_meta.get('country', '')[:30]} | "
f"{query_meta.get('shape', '')} | "
f"ar={query_meta.get('aspect_ratio', 0):.2f}")
draw.text((qx, qy + q_thumb.height + 2), meta_label, fill=(100, 100, 100))
# Results grid (right)
rx_start = q_thumb.width + gap * 3
for idx, (thumb, r) in enumerate(r_thumbs):
col = idx % cols
row = idx // cols
x = rx_start + col * (max_r_w + gap)
y = gap + label_h + row * (THUMB_H + label_h + gap)
canvas.paste(thumb, (x, y))
draw.text((x, y - label_h + 4),
f"#{idx+1} d={r['_distance']:.3f}", fill=(200, 0, 0))
r_meta = (f"{r.get('country', '')[:20]} | "
f"{r.get('shape', '')} | "
f"ar={r.get('aspect_ratio', 0):.2f}")
draw.text((x, y + THUMB_H + 2), r_meta, fill=(100, 100, 100))
canvas.save(out_path)
print(f" Saved: {out_path} ({canvas.width}x{canvas.height})")
async def visual(n_queries: int = 3, top_k: int = 10) -> None:
"""Run filtered searches and save composite result images."""
RESULTS_DIR.mkdir(exist_ok=True)
table = await get_table()
count = await table.count_rows()
print(f"Table: {count:,} vectors")
# Pick queries with distinct countries
sample = (await table.query().limit(500).to_arrow()).to_pylist()
random.shuffle(sample)
seen: set[str] = set()
queries: list[dict] = []
for r in sample:
c = r["country"]
if c not in seen:
seen.add(c)
queries.append(r)
if len(queries) >= n_queries:
break
for qi, q in enumerate(queries):
uid = q["id"].rsplit("-", 1)[0]
print(f"\nQuery {qi+1}: {uid} ({q['country']})")
params = build_filter_params(q)
where = build_where_clause(params)
results = (
await table.vector_search(q["vector"])
.where(where).limit(top_k).to_arrow()
).to_pylist()
print(f" {len(results)} filtered results")
out_path = RESULTS_DIR / f"search_q{qi+1}_{uid[:8]}.png"
_make_composite(uid, q, results, out_path)
# ===================================================================
# CLI
# ===================================================================
def main() -> None:
parser = argparse.ArgumentParser(
description="LanceDB: ingest vectors and run search benchmarks",
)
sub = parser.add_subparsers(dest="command", required=True)
p_ingest = sub.add_parser("ingest", help="Ingest vectors from H5 + CSV")
p_ingest.add_argument("--copies", type=int, default=7,
help="Number of copies (1 = original only)")
p_ingest.add_argument("--limit", type=int, default=None,
help="Max vectors per copy (for testing)")
p_bench = sub.add_parser("bench", help="Run search benchmark")
p_bench.add_argument("--n", type=int, default=100,
help="Number of queries")
p_bench.add_argument("--top-k", type=int, default=10,
help="Results per query")
p_visual = sub.add_parser("visual", help="Visual search with composite images")
p_visual.add_argument("--n", type=int, default=3,
help="Number of queries (distinct countries)")
p_visual.add_argument("--top-k", type=int, default=10,
help="Results per query")
args = parser.parse_args()
if args.command == "ingest":
ingest(num_copies=args.copies, limit=args.limit)
elif args.command == "bench":
asyncio.run(benchmark(n_queries=args.n, top_k=args.top_k))
elif args.command == "visual":
asyncio.run(visual(n_queries=args.n, top_k=args.top_k))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment