Created
March 13, 2026 05:45
-
-
Save surya501/3becd5e9b8aa870ed0afb6b912188841 to your computer and use it in GitHub Desktop.
LanceDB vector store: ingest, search, benchmark & visual search (PEP 723 single-file script)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "h5py", | |
| # "lancedb", | |
| # "numpy", | |
| # "pillow", | |
| # "tqdm", | |
| # ] | |
| # /// | |
| """ | |
| LanceDB vector store: ingest, search, and benchmark. | |
| Single-file PEP 723 script — run directly with `uv run`, no setup needed. | |
| Dependencies are declared in the inline metadata above and resolved | |
| automatically by uv. | |
| Ingestion is a one-time offline step. Run it locally (or in CI) to produce | |
| the LanceDB store directory, then bake that directory into the Docker image. | |
| The server only needs the reusable API section (marked below) to query it. | |
| Subcommands: | |
| # Ingest 117K vectors from H5 + CSV into LanceDB (one-time): | |
| uv run lancedb_tool.py ingest --copies 1 | |
| # Ingest 825K vectors (7 synthetic copies, for benchmarking only): | |
| uv run lancedb_tool.py ingest --copies 7 | |
| # Quick ingest test (1000 vectors): | |
| uv run lancedb_tool.py ingest --copies 1 --limit 1000 | |
| # Search latency benchmark (sequential + async gather): | |
| uv run lancedb_tool.py bench --n 100 --top-k 10 | |
| # Visual search — saves composite PNGs to results/: | |
| uv run lancedb_tool.py visual --n 3 --top-k 10 | |
| Expects: | |
| mount/embeddings/embeddings.h5 — H5 file with "keys" and "vectors" datasets | |
| mount/stamp_database.csv — CSV with image, country, shape, aspect_ratio | |
| mount/cropped/ — PNG crops named by UUID (for visual mode) | |
| The reusable API section (marked below) can be copied directly into a | |
| FastAPI/Flask project. Everything else is CLI/benchmark tooling. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import csv | |
| import functools | |
| import gc | |
| import random | |
| import shutil | |
| import sqlite3 | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import h5py | |
| import lancedb | |
| import numpy as np | |
| from tqdm import tqdm | |
| # =================================================================== | |
| # REUSABLE API — copy this section into your FastAPI / Flask project | |
| # | |
| # Functions: | |
| # get_table() → async; call once at startup, reuse the handle | |
| # search_single() → async; one query, optional .where() pre-filter | |
| # search_batch() → async; N queries via asyncio.gather | |
| # build_filter_params() → build FilterParams from a metadata dict | |
| # build_where_clause() → SQL WHERE string for .where() pre-filter | |
| # | |
| # Data types: | |
| # FilterParams, SearchResult | |
| # | |
| # For Flask (sync), wrap calls with asyncio.run(): | |
| # table = asyncio.run(get_table()) | |
| # results = asyncio.run(search_single(table, vec)) | |
| # =================================================================== | |
| LANCE_PATH = Path("mount/embeddings/lancedb_store_800k") | |
| TABLE_NAME = "stamp_embeddings" | |
| ASPECT_RATIO_TOL = 0.34 | |
| @dataclass | |
| class FilterParams: | |
| country: str | |
| valid_shapes: set[str] | |
| ar_min: float | |
| ar_max: float | |
| @dataclass | |
| class SearchResult: | |
| uid: str | |
| distance: float | |
| metadata: dict | |
| async def get_table(path: Path = LANCE_PATH) -> lancedb.table.AsyncTable: | |
| """Open LanceDB table via the async API. | |
| Typical startup cost: 2-15ms. Store the returned table in app state | |
| (e.g. FastAPI lifespan / Flask global) and reuse across requests. | |
| """ | |
| db = await lancedb.connect_async(str(path)) | |
| return await db.open_table(TABLE_NAME) | |
| def _valid_shapes_for(shape: str) -> set[str]: | |
| """Map a query shape to the set of compatible target shapes.""" | |
| shape = shape.upper() | |
| if shape == "HR": | |
| return {"HR", "S"} | |
| if shape == "VR": | |
| return {"VR", "S"} | |
| return {"HR", "VR", "S"} | |
| def _sql_escape(s: str) -> str: | |
| """Escape single quotes for SQL string literals.""" | |
| return s.replace("'", "''") | |
| def build_where_clause(params: FilterParams) -> str: | |
| """Build a SQL WHERE clause for LanceDB .where() pre-filter.""" | |
| country = _sql_escape(params.country) | |
| shapes_sql = ", ".join(f"'{_sql_escape(s)}'" for s in params.valid_shapes) | |
| return ( | |
| f"country = '{country}' " | |
| f"AND shape IN ({shapes_sql}) " | |
| f"AND aspect_ratio >= {params.ar_min} " | |
| f"AND aspect_ratio <= {params.ar_max}" | |
| ) | |
| def build_filter_params(meta: dict) -> FilterParams: | |
| """Build FilterParams from a metadata dict (country, shape, aspect_ratio). | |
| Use in API handlers to construct filters from request data without a | |
| DB round-trip. | |
| """ | |
| ar = meta.get("aspect_ratio", 0.0) | |
| return FilterParams( | |
| country=meta.get("country", ""), | |
| valid_shapes=_valid_shapes_for(meta.get("shape", "")), | |
| ar_min=ar * (1 - ASPECT_RATIO_TOL), | |
| ar_max=ar * (1 + ASPECT_RATIO_TOL), | |
| ) | |
| async def search_single( | |
| table: lancedb.table.AsyncTable, | |
| query_vec: list[float], | |
| params: FilterParams | None = None, | |
| top_k: int = 10, | |
| ) -> list[SearchResult]: | |
| """Single vector search, optionally pre-filtered via .where(). | |
| Uses vector_search() which is synchronous-to-build, async-to-execute, | |
| avoiding an extra await on query construction. | |
| """ | |
| query = table.vector_search(query_vec).limit(top_k) | |
| if params is not None: | |
| query = query.where(build_where_clause(params)) | |
| results = (await query.to_arrow()).to_pylist() | |
| return [ | |
| SearchResult( | |
| uid=r["id"], | |
| distance=r["_distance"], | |
| metadata={ | |
| "country": r["country"], | |
| "shape": r["shape"], | |
| "aspect_ratio": r["aspect_ratio"], | |
| } if params is not None else {}, | |
| ) | |
| for r in results | |
| ] | |
| async def search_batch( | |
| table: lancedb.table.AsyncTable, | |
| query_vecs: list[list[float]], | |
| params_list: list[FilterParams | None], | |
| top_k: int = 10, | |
| ) -> list[list[SearchResult]]: | |
| """Run N queries concurrently via asyncio.gather. | |
| The Rust engine parallelises internally, so gather mainly keeps the | |
| event loop responsive rather than adding raw throughput. | |
| """ | |
| return await asyncio.gather(*[ | |
| search_single(table, vec, params, top_k) | |
| for vec, params in zip(query_vecs, params_list) | |
| ]) | |
| # =================================================================== | |
| # END REUSABLE API | |
| # =================================================================== | |
| # =================================================================== | |
| # TIMING DECORATOR — for ad-hoc profiling | |
| # =================================================================== | |
| def timed(label: str | None = None): | |
| """Decorator that prints wall-clock time for sync or async functions.""" | |
| def decorator(fn): # type: ignore[return] | |
| name = label or fn.__name__ | |
| if asyncio.iscoroutinefunction(fn): | |
| @functools.wraps(fn) | |
| async def async_wrapper(*args, **kwargs): # type: ignore[no-untyped-def] | |
| t0 = time.perf_counter() | |
| result = await fn(*args, **kwargs) | |
| print(f"[{name}] {(time.perf_counter()-t0)*1000:.0f}ms") | |
| return result | |
| return async_wrapper | |
| @functools.wraps(fn) | |
| def sync_wrapper(*args, **kwargs): # type: ignore[no-untyped-def] | |
| t0 = time.perf_counter() | |
| result = fn(*args, **kwargs) | |
| print(f"[{name}] {(time.perf_counter()-t0)*1000:.0f}ms") | |
| return result | |
| return sync_wrapper | |
| return decorator | |
| # =================================================================== | |
| # INGEST | |
| # =================================================================== | |
| @dataclass(frozen=True) | |
| class IngestConfig: | |
| h5_path: Path = Path("mount/embeddings/embeddings.h5") | |
| csv_path: Path = Path("mount/stamp_database.csv") | |
| lance_path: Path = LANCE_PATH | |
| table_name: str = TABLE_NAME | |
| batch_size: int = 500 | |
| offset_block: int = 300 | |
| CFG = IngestConfig() | |
| def _setup_sqlite(csv_path: Path) -> sqlite3.Connection: | |
| """Build an in-memory SQLite DB from the stamp CSV for fast batch lookups.""" | |
| conn = sqlite3.connect(":memory:") | |
| conn.row_factory = sqlite3.Row | |
| conn.execute( | |
| "CREATE TABLE stamps (" | |
| " image TEXT PRIMARY KEY," | |
| " country TEXT," | |
| " shape TEXT," | |
| " aspect_ratio REAL" | |
| ")" | |
| ) | |
| with open(csv_path, newline="", encoding="utf-8") as fh: | |
| reader = csv.DictReader(fh) | |
| conn.executemany( | |
| "INSERT OR IGNORE INTO stamps (image, country, shape, aspect_ratio) " | |
| "VALUES (:image, :country, :shape, :aspect_ratio)", | |
| reader, | |
| ) | |
| conn.commit() | |
| return conn | |
| def _get_metadata_batch(conn: sqlite3.Connection, uids: list[str]) -> list[dict]: | |
| """Fetch metadata for a batch of UIDs from the in-memory SQLite cache.""" | |
| placeholders = ",".join("?" for _ in uids) | |
| rows = conn.execute( | |
| f"SELECT image, country, shape, aspect_ratio FROM stamps " | |
| f"WHERE image IN ({placeholders})", | |
| uids, | |
| ).fetchall() | |
| lookup = {r["image"]: dict(r) for r in rows} | |
| return [ | |
| lookup.get(uid, {"country": "", "shape": "", "aspect_ratio": 0.0}) | |
| for uid in uids | |
| ] | |
| def _apply_offset( | |
| raw: np.ndarray, start_idx: int, offset_start: int, offset_end: int, | |
| ) -> int: | |
| """Apply mean/2 offset to vectors in the overlap window. Returns count modified.""" | |
| lo = max(offset_start - start_idx, 0) | |
| hi = min(offset_end - start_idx, len(raw)) | |
| if lo >= hi: | |
| return 0 | |
| means = raw[lo:hi].mean(axis=1, keepdims=True) | |
| raw[lo:hi] += means / 2.0 | |
| return hi - lo | |
| def ingest(num_copies: int = 7, limit: int | None = None) -> None: | |
| """Stream H5 vectors + CSV metadata into LanceDB. | |
| The --copies flag creates synthetic duplicates for benchmarking at scale. | |
| Copy 0 = original vectors (normalized). Copies 1-N apply a small offset | |
| to a sliding window of 300 vectors, then re-normalize — producing near- | |
| duplicates that stress the ANN index. IDs are suffixed with -0 through -N. | |
| For production use, set --copies 1 (no duplicates). | |
| """ | |
| if CFG.lance_path.exists(): | |
| import subprocess | |
| subprocess.run(["xattr", "-rc", str(CFG.lance_path)], | |
| capture_output=True) | |
| shutil.rmtree(CFG.lance_path) | |
| print(f"Removed existing {CFG.lance_path}") | |
| db = lancedb.connect(str(CFG.lance_path)) | |
| print("Building in-memory SQLite metadata cache...") | |
| conn = _setup_sqlite(CFG.csv_path) | |
| table: lancedb.table.Table | None = None | |
| with h5py.File(str(CFG.h5_path), "r") as f: | |
| keys_ds: h5py.Dataset = f["keys"] # type: ignore[assignment] | |
| vecs_ds: h5py.Dataset = f["vectors"] # type: ignore[assignment] | |
| n = keys_ds.shape[0] | |
| if limit is not None: | |
| n = min(n, limit) | |
| dim = vecs_ds.shape[1] | |
| total = n * num_copies | |
| print(f"Ingesting {n:,} x {num_copies} = {total:,} vectors ({dim}-dim)") | |
| t0 = time.perf_counter() | |
| pbar = tqdm(total=total, unit="vec", desc="Ingesting", dynamic_ncols=True) | |
| for copy in range(num_copies): | |
| if copy == 0: | |
| offset_start, offset_end = -1, -1 | |
| else: | |
| offset_start = (copy - 1) * CFG.offset_block | |
| offset_end = offset_start + CFG.offset_block | |
| offset_label = (f"offset [{offset_start}:{offset_end}]" | |
| if offset_start >= 0 else "original") | |
| pbar.set_postfix_str(f"copy {copy}/{num_copies-1} ({offset_label})") | |
| for start in range(0, n, CFG.batch_size): | |
| end = min(start + CFG.batch_size, n) | |
| batch_len = end - start | |
| orig_ids = [k.decode("utf-8") for k in keys_ds[start:end]] | |
| ids = [f"{uid}-{copy}" for uid in orig_ids] | |
| raw: np.ndarray = vecs_ds[start:end].astype(np.float32) | |
| if offset_start >= 0: | |
| _apply_offset(raw, start, offset_start, offset_end) | |
| norms = np.linalg.norm(raw, axis=1, keepdims=True) | |
| norms = np.where(norms > 0, norms, 1.0) | |
| vectors = raw / norms | |
| metas = _get_metadata_batch(conn, orig_ids) | |
| records = [ | |
| { | |
| "id": ids[i], | |
| "vector": vectors[i].tolist(), | |
| "country": metas[i].get("country", ""), | |
| "shape": metas[i].get("shape", ""), | |
| "aspect_ratio": metas[i].get("aspect_ratio", 0.0), | |
| } | |
| for i in range(batch_len) | |
| ] | |
| if table is None: | |
| table = db.create_table( | |
| CFG.table_name, data=records, mode="overwrite", | |
| ) | |
| else: | |
| table.add(records) | |
| pbar.update(batch_len) | |
| del ids, orig_ids, raw, vectors, metas, records | |
| gc.collect() | |
| pbar.close() | |
| conn.close() | |
| total_time = time.perf_counter() - t0 | |
| print(f"Done: {total:,} vectors in {total_time:.1f}s " | |
| f"({total / total_time:.0f} vec/s)") | |
| # Create indexes | |
| assert table is not None, "No data ingested" | |
| print("\nCreating vector index (IVF_HNSW_SQ, dot product)...") | |
| t0 = time.perf_counter() | |
| table.create_index(metric="dot", index_type="IVF_HNSW_SQ") | |
| print(f" Vector index created in {time.perf_counter()-t0:.1f}s") | |
| for col in ("country", "shape", "aspect_ratio"): | |
| t0 = time.perf_counter() | |
| table.create_scalar_index(col) | |
| print(f" Scalar index on '{col}' created in {time.perf_counter()-t0:.1f}s") | |
| print("All indexes created.") | |
| # =================================================================== | |
| # BENCHMARK | |
| # =================================================================== | |
| H5_PATH = Path("mount/embeddings/embeddings.h5") | |
| async def _lookup_filter_params( | |
| table: lancedb.table.AsyncTable, key: str, | |
| ) -> FilterParams: | |
| """Look up metadata for a key and build filter params (benchmark helper).""" | |
| rows = ( | |
| await table.query() | |
| .where(f"id = '{_sql_escape(key)}-0'") | |
| .limit(1) | |
| .to_arrow() | |
| ).to_pylist() | |
| meta = rows[0] if rows else {"country": "", "shape": "", "aspect_ratio": 0.0} | |
| return build_filter_params(meta) | |
| async def benchmark(n_queries: int = 100, top_k: int = 10) -> None: | |
| """Latency benchmark: unfiltered vs native-filtered, all async.""" | |
| t_start = time.perf_counter() | |
| t0 = time.perf_counter() | |
| table = await get_table() | |
| count = await table.count_rows() | |
| if count == 0: | |
| print("Table empty. Run ingest first.") | |
| return | |
| print(f"Table: {count:,} vectors (opened in {(time.perf_counter()-t0)*1000:.0f}ms)") | |
| print(f"Running {n_queries} queries (top_k={top_k}, async)...\n") | |
| # Load random query vectors from H5 | |
| t0 = time.perf_counter() | |
| with h5py.File(str(H5_PATH), "r") as f: | |
| vecs_ds_: h5py.Dataset = f["vectors"] # type: ignore[assignment] | |
| keys_ds_: h5py.Dataset = f["keys"] # type: ignore[assignment] | |
| total = vecs_ds_.shape[0] | |
| indices = sorted(random.sample(range(total), n_queries)) | |
| raw = vecs_ds_[indices].astype(np.float32) | |
| norms = np.linalg.norm(raw, axis=1, keepdims=True) | |
| query_vecs = raw / np.where(norms > 0, norms, 1.0) | |
| query_keys = [keys_ds_[i].decode("utf-8") for i in indices] | |
| print(f"Loaded {n_queries} query vectors in {(time.perf_counter()-t0)*1000:.0f}ms") | |
| vecs_as_lists = [query_vecs[i].tolist() for i in range(n_queries)] | |
| # Warm up | |
| n_warmup = min(10, n_queries) | |
| t0 = time.perf_counter() | |
| print(f"Warming up ({n_warmup} queries)...", end=" ", flush=True) | |
| for i in range(n_warmup): | |
| await search_single(table, vecs_as_lists[i], top_k=top_k) | |
| print(f"done ({(time.perf_counter()-t0)*1000:.0f}ms)") | |
| # Phase 1: Unfiltered (sequential) | |
| print("Phase 1/3: Unfiltered search...", end=" ", flush=True) | |
| lats_unfilt: list[float] = [] | |
| for i in range(n_queries): | |
| t0 = time.perf_counter() | |
| await search_single(table, vecs_as_lists[i], top_k=top_k) | |
| lats_unfilt.append((time.perf_counter() - t0) * 1000) | |
| if (i + 1) % max(1, n_queries // 5) == 0: | |
| print(f"{i+1}/{n_queries}", end=" ", flush=True) | |
| print("done") | |
| # Phase 2: Native-filtered (sequential) | |
| print("Phase 2/3: Native-filtered search...", end=" ", flush=True) | |
| lats_filtered: list[float] = [] | |
| params_list: list[FilterParams] = [] | |
| for i in range(n_queries): | |
| params = await _lookup_filter_params(table, query_keys[i]) | |
| params_list.append(params) | |
| t0 = time.perf_counter() | |
| await search_single(table, vecs_as_lists[i], params, top_k) | |
| lats_filtered.append((time.perf_counter() - t0) * 1000) | |
| if (i + 1) % max(1, n_queries // 5) == 0: | |
| print(f"{i+1}/{n_queries}", end=" ", flush=True) | |
| print("done") | |
| # Phase 3: Concurrent batch via asyncio.gather | |
| print("Phase 3/3: Concurrent gather (all queries at once)...", end=" ", flush=True) | |
| t0 = time.perf_counter() | |
| await search_batch(table, vecs_as_lists, [None] * n_queries, top_k) | |
| gather_unfilt = (time.perf_counter() - t0) * 1000 | |
| t0 = time.perf_counter() | |
| await search_batch(table, vecs_as_lists, params_list, top_k) | |
| gather_filt = (time.perf_counter() - t0) * 1000 | |
| print("done") | |
| def stats(lats: list[float]) -> str: | |
| s = sorted(lats) | |
| return (f"mean={np.mean(s):>8.1f} " | |
| f"median={np.median(s):>8.1f} " | |
| f"p95={s[int(0.95 * len(s))]:>8.1f}") | |
| print(f"\n{'Unfiltered top-' + str(top_k):<30} {stats(lats_unfilt)}") | |
| print(f"{'Native .where() filter':<30} {stats(lats_filtered)}") | |
| print(f"\n{'Gather unfiltered':<30} " | |
| f"total={gather_unfilt:>8.1f}ms " | |
| f"per-query={gather_unfilt / n_queries:>8.1f}ms") | |
| print(f"{'Gather filtered':<30} " | |
| f"total={gather_filt:>8.1f}ms " | |
| f"per-query={gather_filt / n_queries:>8.1f}ms") | |
| print(f"\nTotal wall time: {(time.perf_counter()-t_start)*1000:.0f}ms") | |
| # =================================================================== | |
| # VISUAL — filtered search with composite image output | |
| # =================================================================== | |
| CROPS_DIR = Path("mount/cropped") | |
| RESULTS_DIR = Path("results") | |
| THUMB_H = 200 | |
| def _load_crop(uid: str) -> "Image.Image | None": | |
| """Load a crop image by UUID, or None if not found.""" | |
| from PIL import Image | |
| path = CROPS_DIR / f"{uid}.png" | |
| return Image.open(path) if path.exists() else None | |
| def _make_composite( | |
| query_uid: str, | |
| query_meta: dict, | |
| results: list[dict], | |
| out_path: Path, | |
| ) -> None: | |
| """Build a composite PNG: query on left, top-k results on right.""" | |
| from PIL import Image, ImageDraw | |
| q_img = _load_crop(query_uid) | |
| if q_img is None: | |
| print(f" Query image not found: {query_uid}") | |
| return | |
| r_imgs = [] | |
| for r in results: | |
| rid = r["id"].rsplit("-", 1)[0] | |
| img = _load_crop(rid) | |
| if img: | |
| r_imgs.append((img, r)) | |
| if not r_imgs: | |
| print(" No result images found") | |
| return | |
| def scale(img: "Image.Image", h: int) -> "Image.Image": | |
| w = int(img.width * h / img.height) | |
| return img.resize((w, h), Image.LANCZOS) | |
| q_thumb = scale(q_img, THUMB_H * 2) | |
| r_thumbs = [(scale(img, THUMB_H), r) for img, r in r_imgs] | |
| gap = 10 | |
| label_h = 20 | |
| cols = min(5, len(r_thumbs)) | |
| rows_n = (len(r_thumbs) + cols - 1) // cols | |
| max_r_w = max(t[0].width for t in r_thumbs) | |
| grid_w = cols * (max_r_w + gap) - gap | |
| grid_h = rows_n * (THUMB_H + label_h + gap) - gap | |
| total_w = q_thumb.width + gap * 3 + grid_w + gap | |
| total_h = max(q_thumb.height + label_h + gap, grid_h + label_h + gap) + gap * 2 | |
| canvas = Image.new("RGB", (total_w, total_h), (255, 255, 255)) | |
| draw = ImageDraw.Draw(canvas) | |
| # Query image (left) | |
| qx, qy = gap, gap + label_h | |
| canvas.paste(q_thumb, (qx, qy)) | |
| draw.text((qx, gap), f"QUERY: {query_uid[:8]}...", fill=(0, 0, 200)) | |
| meta_label = (f"{query_meta.get('country', '')[:30]} | " | |
| f"{query_meta.get('shape', '')} | " | |
| f"ar={query_meta.get('aspect_ratio', 0):.2f}") | |
| draw.text((qx, qy + q_thumb.height + 2), meta_label, fill=(100, 100, 100)) | |
| # Results grid (right) | |
| rx_start = q_thumb.width + gap * 3 | |
| for idx, (thumb, r) in enumerate(r_thumbs): | |
| col = idx % cols | |
| row = idx // cols | |
| x = rx_start + col * (max_r_w + gap) | |
| y = gap + label_h + row * (THUMB_H + label_h + gap) | |
| canvas.paste(thumb, (x, y)) | |
| draw.text((x, y - label_h + 4), | |
| f"#{idx+1} d={r['_distance']:.3f}", fill=(200, 0, 0)) | |
| r_meta = (f"{r.get('country', '')[:20]} | " | |
| f"{r.get('shape', '')} | " | |
| f"ar={r.get('aspect_ratio', 0):.2f}") | |
| draw.text((x, y + THUMB_H + 2), r_meta, fill=(100, 100, 100)) | |
| canvas.save(out_path) | |
| print(f" Saved: {out_path} ({canvas.width}x{canvas.height})") | |
| async def visual(n_queries: int = 3, top_k: int = 10) -> None: | |
| """Run filtered searches and save composite result images.""" | |
| RESULTS_DIR.mkdir(exist_ok=True) | |
| table = await get_table() | |
| count = await table.count_rows() | |
| print(f"Table: {count:,} vectors") | |
| # Pick queries with distinct countries | |
| sample = (await table.query().limit(500).to_arrow()).to_pylist() | |
| random.shuffle(sample) | |
| seen: set[str] = set() | |
| queries: list[dict] = [] | |
| for r in sample: | |
| c = r["country"] | |
| if c not in seen: | |
| seen.add(c) | |
| queries.append(r) | |
| if len(queries) >= n_queries: | |
| break | |
| for qi, q in enumerate(queries): | |
| uid = q["id"].rsplit("-", 1)[0] | |
| print(f"\nQuery {qi+1}: {uid} ({q['country']})") | |
| params = build_filter_params(q) | |
| where = build_where_clause(params) | |
| results = ( | |
| await table.vector_search(q["vector"]) | |
| .where(where).limit(top_k).to_arrow() | |
| ).to_pylist() | |
| print(f" {len(results)} filtered results") | |
| out_path = RESULTS_DIR / f"search_q{qi+1}_{uid[:8]}.png" | |
| _make_composite(uid, q, results, out_path) | |
| # =================================================================== | |
| # CLI | |
| # =================================================================== | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="LanceDB: ingest vectors and run search benchmarks", | |
| ) | |
| sub = parser.add_subparsers(dest="command", required=True) | |
| p_ingest = sub.add_parser("ingest", help="Ingest vectors from H5 + CSV") | |
| p_ingest.add_argument("--copies", type=int, default=7, | |
| help="Number of copies (1 = original only)") | |
| p_ingest.add_argument("--limit", type=int, default=None, | |
| help="Max vectors per copy (for testing)") | |
| p_bench = sub.add_parser("bench", help="Run search benchmark") | |
| p_bench.add_argument("--n", type=int, default=100, | |
| help="Number of queries") | |
| p_bench.add_argument("--top-k", type=int, default=10, | |
| help="Results per query") | |
| p_visual = sub.add_parser("visual", help="Visual search with composite images") | |
| p_visual.add_argument("--n", type=int, default=3, | |
| help="Number of queries (distinct countries)") | |
| p_visual.add_argument("--top-k", type=int, default=10, | |
| help="Results per query") | |
| args = parser.parse_args() | |
| if args.command == "ingest": | |
| ingest(num_copies=args.copies, limit=args.limit) | |
| elif args.command == "bench": | |
| asyncio.run(benchmark(n_queries=args.n, top_k=args.top_k)) | |
| elif args.command == "visual": | |
| asyncio.run(visual(n_queries=args.n, top_k=args.top_k)) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment