Skip to content

Instantly share code, notes, and snippets.

@hotchpotch
Created January 8, 2026 06:46
Show Gist options
  • Select an option

  • Save hotchpotch/a3f3c729822d2bd869be01aca0d41045 to your computer and use it in GitHub Desktop.

Select an option

Save hotchpotch/a3f3c729822d2bd869be01aca0d41045 to your computer and use it in GitHub Desktop.
nano_beir_ja_eval_cli standalone
#!/usr/bin/env python3
"""Run NanoBEIR-ja evaluation for a SentenceTransformer model (CLI).
Usage example (ndcg@10 only by default):
uv run scripts/nano_beir_ja_eval_cli.py \
--model-path cl-nagoya/ruri-v3-30m \
--batch-size 512 --autocast-dtype bf16 \
--output output/nano_beir_ja_eval_ruri-v3-30m.json
Use --all-metrics to emit the full metric set.
"""
from __future__ import annotations
import argparse
import json
import logging
import time
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.util import is_datasets_available
from torch import Tensor
from tqdm import tqdm
if TYPE_CHECKING: # pragma: no cover - for type checkers only
from sentence_transformers import SentenceTransformer
DATASET_ID = "tomaarsen/NanoBEIR-ja"
# Canonical task names (split names in the HF dataset)
TASKS = [
"ArguAna",
"ClimateFEVER",
"DBPedia",
"FEVER",
"FiQA2018",
"HotpotQA",
"MSMARCO",
"NFCorpus",
"NQ",
"QuoraRetrieval",
"SCIDOCS",
"SciFact",
"Touche2020",
]
_TASKS_BY_LOWER = {name.lower(): name for name in TASKS}
logger = logging.getLogger(__name__)
def _normalize_task(name: str) -> str:
key = name.lower()
if key.startswith("nano"):
key = key[4:]
if key in _TASKS_BY_LOWER:
return _TASKS_BY_LOWER[key]
compact = key.replace(" ", "")
return _TASKS_BY_LOWER.get(compact, name)
def _split_name(task: str) -> str:
return f"Nano{task}"
def _human_readable(task: str) -> str:
return f"NanoBEIR-ja-{task}"
class NanoBeirJaEvaluator(SentenceEvaluator):
"""Evaluate a model on the NanoBEIR-ja collection (all tasks or subset)."""
information_retrieval_class = InformationRetrievalEvaluator
def __init__(
self,
dataset_names: list[str] | None = None,
mrr_at_k: list[int] | None = None,
ndcg_at_k: list[int] | None = None,
accuracy_at_k: list[int] | None = None,
precision_recall_at_k: list[int] | None = None,
map_at_k: list[int] | None = None,
show_progress_bar: bool = False,
batch_size: int = 32,
write_csv: bool = True,
truncate_dim: int | None = None,
score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] | None = None,
main_score_function: str | SimilarityFunction | None = None,
aggregate_fn: Callable[[list[float]], float] = np.mean,
aggregate_key: str = "mean",
query_prompts: str | dict[str, str] | None = None,
corpus_prompts: str | dict[str, str] | None = None,
write_predictions: bool = False,
ndcg_only: bool = False,
) -> None:
super().__init__()
if dataset_names is None:
dataset_names = TASKS
self.dataset_names = [_normalize_task(name) for name in dataset_names]
self.aggregate_fn = aggregate_fn
self.aggregate_key = aggregate_key
self.write_csv = write_csv
self.query_prompts = self._normalize_prompts(query_prompts)
self.corpus_prompts = self._normalize_prompts(corpus_prompts)
self.show_progress_bar = show_progress_bar
self.score_functions = score_functions or {}
self.score_function_names = sorted(self.score_functions.keys())
self.main_score_function = main_score_function
self.truncate_dim = truncate_dim
self.name = f"NanoBEIR-ja_{aggregate_key}"
if self.truncate_dim:
self.name += f"_{self.truncate_dim}"
self.ndcg_only = ndcg_only
self.mrr_at_k = mrr_at_k or [10]
self.ndcg_at_k = ndcg_at_k or [10]
if ndcg_only:
self.accuracy_at_k = [10]
self.precision_recall_at_k = [10]
self.map_at_k = [10]
else:
self.accuracy_at_k = accuracy_at_k or [1, 3, 5, 10]
self.precision_recall_at_k = precision_recall_at_k or [1, 3, 5, 10]
self.map_at_k = map_at_k or [100]
self._validate_dataset_names()
self._validate_prompts()
ir_kwargs = {
"mrr_at_k": self.mrr_at_k,
"ndcg_at_k": self.ndcg_at_k,
"accuracy_at_k": self.accuracy_at_k,
"precision_recall_at_k": self.precision_recall_at_k,
"map_at_k": self.map_at_k,
"show_progress_bar": show_progress_bar,
"batch_size": batch_size,
"write_csv": write_csv,
"truncate_dim": truncate_dim,
"score_functions": score_functions,
"main_score_function": main_score_function,
"write_predictions": write_predictions,
}
self.evaluators = [
self._load_dataset(name, **ir_kwargs)
for name in tqdm(self.dataset_names, desc="Loading NanoBEIR-ja", leave=False)
]
self.csv_file = f"NanoBEIR-ja_evaluation_{aggregate_key}_results.csv"
self.csv_headers = ["epoch", "steps"]
self._append_csv_headers(self.score_function_names)
def _normalize_prompts(self, prompts: str | dict[str, str] | None) -> dict[str, str] | None:
if prompts is None:
return None
if isinstance(prompts, str):
return {name: prompts for name in self.dataset_names}
normalized: dict[str, str] = {}
for key, value in prompts.items():
normalized[_normalize_task(key)] = value
return normalized
def _append_csv_headers(self, score_function_names):
for score_name in score_function_names:
for k in self.accuracy_at_k:
self.csv_headers.append(f"{score_name}-Accuracy@{k}")
for k in self.precision_recall_at_k:
self.csv_headers.append(f"{score_name}-Precision@{k}")
self.csv_headers.append(f"{score_name}-Recall@{k}")
for k in self.mrr_at_k:
self.csv_headers.append(f"{score_name}-MRR@{k}")
for k in self.ndcg_at_k:
self.csv_headers.append(f"{score_name}-NDCG@{k}")
for k in self.map_at_k:
self.csv_headers.append(f"{score_name}-MAP@{k}")
def _load_dataset(self, task: str, **ir_kwargs) -> InformationRetrievalEvaluator:
if not is_datasets_available():
raise ValueError("datasets is required; install via `pip install datasets`.")
from datasets import load_dataset
split_name = _split_name(task)
t0 = time.perf_counter()
corpus_ds = load_dataset(DATASET_ID, "corpus", split=split_name)
queries_ds = load_dataset(DATASET_ID, "queries", split=split_name)
qrels_ds = load_dataset(DATASET_ID, "qrels", split=split_name)
logger.info("[NanoBEIR-ja] loaded datasets for %s in %.2fs", task, time.perf_counter() - t0)
corpus_dict = {}
t1 = time.perf_counter()
for sample in corpus_ds:
row = cast(dict[str, Any], sample)
text = row.get("text")
if text:
corpus_dict[row["_id"]] = text
queries_dict = {}
for sample in queries_ds:
row = cast(dict[str, Any], sample)
text = row.get("text")
if text:
queries_dict[row["_id"]] = text
qrels_dict: dict[str, set[str]] = {}
for sample in qrels_ds:
row = cast(dict[str, Any], sample)
qid = row["query-id"]
cids = row["corpus-id"]
if isinstance(cids, list):
qrels_dict.setdefault(qid, set()).update(cids)
else:
qrels_dict.setdefault(qid, set()).add(cids)
logger.info(
"[NanoBEIR-ja] materialized dicts for %s in %.2fs (corpus=%d, queries=%d, qrels=%d)",
task,
time.perf_counter() - t1,
len(corpus_dict),
len(queries_dict),
len(qrels_dict),
)
if self.query_prompts is not None:
ir_kwargs["query_prompt"] = self.query_prompts.get(task, None)
if self.corpus_prompts is not None:
ir_kwargs["corpus_prompt"] = self.corpus_prompts.get(task, None)
name = _human_readable(task)
return self.information_retrieval_class(
queries=queries_dict,
corpus=corpus_dict,
relevant_docs=qrels_dict,
name=name,
**ir_kwargs,
)
def _validate_dataset_names(self):
if len(self.dataset_names) == 0:
raise ValueError("dataset_names cannot be empty. Use None to evaluate all tasks.")
invalid = [task for task in self.dataset_names if _normalize_task(task) not in TASKS]
if invalid:
raise ValueError(f"Unknown tasks: {invalid}. Supported: {TASKS}")
def _validate_prompts(self):
error_msg = ""
if self.query_prompts is not None:
missing = [task for task in self.dataset_names if task not in self.query_prompts]
if missing:
error_msg += f"Missing query prompts for: {missing}\n"
if self.corpus_prompts is not None:
missing = [task for task in self.dataset_names if task not in self.corpus_prompts]
if missing:
error_msg += f"Missing corpus prompts for: {missing}\n"
if error_msg:
raise ValueError(error_msg.strip())
def store_metrics_in_model_card_data(self, *args, **kwargs): # pragma: no cover - mirrors NanoBEIR
if len(self.dataset_names) > 1:
super().store_metrics_in_model_card_data(*args, **kwargs)
def get_config_dict(self) -> dict[str, Any]:
cfg: dict[str, Any] = {
"dataset_names": self.dataset_names,
"dataset_id": DATASET_ID,
"ndcg_only": self.ndcg_only,
}
if self.truncate_dim is not None:
cfg["truncate_dim"] = self.truncate_dim
if self.query_prompts is not None:
cfg["query_prompts"] = self.query_prompts
if self.corpus_prompts is not None:
cfg["corpus_prompts"] = self.corpus_prompts
return cfg
def __call__(
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
*args,
**kwargs,
) -> dict[str, float]:
per_metric_agg: dict[str, list[float]] = {}
per_dataset: dict[str, float] = {}
if self.score_functions is None:
self.score_functions = {model.similarity_fn_name: model.similarity}
self.score_function_names = [model.similarity_fn_name]
self._append_csv_headers(self.score_function_names)
for evaluator in tqdm(self.evaluators, desc="Evaluating NanoBEIR-ja", disable=not self.show_progress_bar):
logger.info("Evaluating %s", evaluator.name)
results = evaluator(model, output_path, epoch, steps)
for key, value in results.items():
per_dataset[key] = value
if "_" in key:
_, metric_name = key.split("_", 1)
else:
metric_name = key
per_metric_agg.setdefault(metric_name, []).append(value)
agg_results = {
f"{self.name}_{metric}": self.aggregate_fn(vals)
for metric, vals in per_metric_agg.items()
}
if not self.primary_metric:
main_score_fn = self.main_score_function
main = None if main_score_fn is None else str(main_score_fn)
ndcg_target = f"ndcg@{max(self.ndcg_at_k)}"
candidates = [k for k in agg_results if k.endswith(ndcg_target)]
if main:
preferred = [k for k in candidates if main in k]
if preferred:
self.primary_metric = preferred[0]
if not self.primary_metric and candidates:
self.primary_metric = candidates[0]
if self.primary_metric and self.primary_metric in agg_results:
logger.info("Primary %s: %.4f", self.primary_metric, agg_results[self.primary_metric])
self.store_metrics_in_model_card_data(model, agg_results, epoch, steps)
per_dataset.update(agg_results)
if self.ndcg_only:
per_dataset = {k: v for k, v in per_dataset.items() if "ndcg@10" in k}
return per_dataset
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate a model on NanoBEIR-ja")
parser.add_argument("--model-path", required=True, help="Path or HF id for SentenceTransformer model")
parser.add_argument("--tasks", nargs="*", default=None, help="Task names (default: all)")
parser.add_argument("--batch-size", type=int, default=512, help="Eval batch size")
parser.add_argument("--output", default=None, help="Optional JSON output path for metrics")
parser.add_argument("--show-progress", action="store_true", help="Show per-dataset tqdm during eval")
parser.add_argument(
"--no-autocast",
action="store_true",
help="Disable torch.autocast (default: enabled on CUDA with bfloat16 if available)",
)
parser.add_argument(
"--autocast-dtype",
choices=["bf16", "fp16"],
default="bf16",
help="autocast dtype (bf16 or fp16)",
)
parser.add_argument("--query-prompt", default=None, help="Prefix applied to queries")
parser.add_argument("--corpus-prompt", default=None, help="Prefix applied to corpus/passages")
parser.add_argument(
"--all-metrics",
action="store_true",
help="Return all metrics (default: ndcg@10 only)",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Pass trust_remote_code=True to SentenceTransformer (needed for some HF models)",
)
return parser.parse_args()
def main(argv: Sequence[str] | None = None) -> None:
args = parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
tasks = args.tasks or TASKS
# prompts: only set automatically for e5 if user did NOT provide anything
auto_query_prompt = None
auto_corpus_prompt = None
lower_name = args.model_path.lower()
if "e5" in lower_name and args.query_prompt is None and args.corpus_prompt is None:
auto_query_prompt = "query: "
auto_corpus_prompt = "passage: "
query_prompt = args.query_prompt if args.query_prompt is not None else auto_query_prompt
corpus_prompt = args.corpus_prompt if args.corpus_prompt is not None else auto_corpus_prompt
logging.info("Using prompts -> query: %s | passage: %s", repr(query_prompt), repr(corpus_prompt))
model = SentenceTransformer(args.model_path, prompts=None, trust_remote_code=args.trust_remote_code)
model.eval()
evaluator = NanoBeirJaEvaluator(
dataset_names=tasks,
batch_size=args.batch_size,
show_progress_bar=args.show_progress,
write_csv=False,
query_prompts=query_prompt if query_prompt else None,
corpus_prompts=corpus_prompt if corpus_prompt else None,
ndcg_only=not args.all_metrics,
)
use_autocast = not args.no_autocast
autocast_dtype = {"bf16": "bfloat16", "fp16": "float16"}[args.autocast_dtype]
autocast_ctx = None
if use_autocast:
import torch
device_type = "cuda" if torch.cuda.is_available() else "cpu"
autocast_ctx = torch.autocast(device_type=device_type, dtype=getattr(torch, autocast_dtype))
if autocast_ctx:
with autocast_ctx:
results = evaluator(model)
else:
results = evaluator(model)
if args.output:
path = Path(args.output)
path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"cli": {
"model_path": args.model_path,
"tasks": tasks,
"batch_size": args.batch_size,
"autocast": not args.no_autocast,
"autocast_dtype": args.autocast_dtype,
"query_prompt": query_prompt,
"corpus_prompt": corpus_prompt,
"ndcg_only": not args.all_metrics,
},
"metrics": results,
}
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"Saved metrics to {path}")
else:
print(json.dumps(results, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment