Created
January 8, 2026 06:46
-
-
Save hotchpotch/a3f3c729822d2bd869be01aca0d41045 to your computer and use it in GitHub Desktop.
nano_beir_ja_eval_cli standalone
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Run NanoBEIR-ja evaluation for a SentenceTransformer model (CLI). | |
| Usage example (ndcg@10 only by default): | |
| uv run scripts/nano_beir_ja_eval_cli.py \ | |
| --model-path cl-nagoya/ruri-v3-30m \ | |
| --batch-size 512 --autocast-dtype bf16 \ | |
| --output output/nano_beir_ja_eval_ruri-v3-30m.json | |
| Use --all-metrics to emit the full metric set. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import time | |
| from collections.abc import Callable, Sequence | |
| from typing import TYPE_CHECKING, Any, cast | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.evaluation import InformationRetrievalEvaluator | |
| from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator | |
| from sentence_transformers.similarity_functions import SimilarityFunction | |
| from sentence_transformers.util import is_datasets_available | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| if TYPE_CHECKING: # pragma: no cover - for type checkers only | |
| from sentence_transformers import SentenceTransformer | |
| DATASET_ID = "tomaarsen/NanoBEIR-ja" | |
| # Canonical task names (split names in the HF dataset) | |
| TASKS = [ | |
| "ArguAna", | |
| "ClimateFEVER", | |
| "DBPedia", | |
| "FEVER", | |
| "FiQA2018", | |
| "HotpotQA", | |
| "MSMARCO", | |
| "NFCorpus", | |
| "NQ", | |
| "QuoraRetrieval", | |
| "SCIDOCS", | |
| "SciFact", | |
| "Touche2020", | |
| ] | |
| _TASKS_BY_LOWER = {name.lower(): name for name in TASKS} | |
| logger = logging.getLogger(__name__) | |
| def _normalize_task(name: str) -> str: | |
| key = name.lower() | |
| if key.startswith("nano"): | |
| key = key[4:] | |
| if key in _TASKS_BY_LOWER: | |
| return _TASKS_BY_LOWER[key] | |
| compact = key.replace(" ", "") | |
| return _TASKS_BY_LOWER.get(compact, name) | |
| def _split_name(task: str) -> str: | |
| return f"Nano{task}" | |
| def _human_readable(task: str) -> str: | |
| return f"NanoBEIR-ja-{task}" | |
| class NanoBeirJaEvaluator(SentenceEvaluator): | |
| """Evaluate a model on the NanoBEIR-ja collection (all tasks or subset).""" | |
| information_retrieval_class = InformationRetrievalEvaluator | |
| def __init__( | |
| self, | |
| dataset_names: list[str] | None = None, | |
| mrr_at_k: list[int] | None = None, | |
| ndcg_at_k: list[int] | None = None, | |
| accuracy_at_k: list[int] | None = None, | |
| precision_recall_at_k: list[int] | None = None, | |
| map_at_k: list[int] | None = None, | |
| show_progress_bar: bool = False, | |
| batch_size: int = 32, | |
| write_csv: bool = True, | |
| truncate_dim: int | None = None, | |
| score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] | None = None, | |
| main_score_function: str | SimilarityFunction | None = None, | |
| aggregate_fn: Callable[[list[float]], float] = np.mean, | |
| aggregate_key: str = "mean", | |
| query_prompts: str | dict[str, str] | None = None, | |
| corpus_prompts: str | dict[str, str] | None = None, | |
| write_predictions: bool = False, | |
| ndcg_only: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| if dataset_names is None: | |
| dataset_names = TASKS | |
| self.dataset_names = [_normalize_task(name) for name in dataset_names] | |
| self.aggregate_fn = aggregate_fn | |
| self.aggregate_key = aggregate_key | |
| self.write_csv = write_csv | |
| self.query_prompts = self._normalize_prompts(query_prompts) | |
| self.corpus_prompts = self._normalize_prompts(corpus_prompts) | |
| self.show_progress_bar = show_progress_bar | |
| self.score_functions = score_functions or {} | |
| self.score_function_names = sorted(self.score_functions.keys()) | |
| self.main_score_function = main_score_function | |
| self.truncate_dim = truncate_dim | |
| self.name = f"NanoBEIR-ja_{aggregate_key}" | |
| if self.truncate_dim: | |
| self.name += f"_{self.truncate_dim}" | |
| self.ndcg_only = ndcg_only | |
| self.mrr_at_k = mrr_at_k or [10] | |
| self.ndcg_at_k = ndcg_at_k or [10] | |
| if ndcg_only: | |
| self.accuracy_at_k = [10] | |
| self.precision_recall_at_k = [10] | |
| self.map_at_k = [10] | |
| else: | |
| self.accuracy_at_k = accuracy_at_k or [1, 3, 5, 10] | |
| self.precision_recall_at_k = precision_recall_at_k or [1, 3, 5, 10] | |
| self.map_at_k = map_at_k or [100] | |
| self._validate_dataset_names() | |
| self._validate_prompts() | |
| ir_kwargs = { | |
| "mrr_at_k": self.mrr_at_k, | |
| "ndcg_at_k": self.ndcg_at_k, | |
| "accuracy_at_k": self.accuracy_at_k, | |
| "precision_recall_at_k": self.precision_recall_at_k, | |
| "map_at_k": self.map_at_k, | |
| "show_progress_bar": show_progress_bar, | |
| "batch_size": batch_size, | |
| "write_csv": write_csv, | |
| "truncate_dim": truncate_dim, | |
| "score_functions": score_functions, | |
| "main_score_function": main_score_function, | |
| "write_predictions": write_predictions, | |
| } | |
| self.evaluators = [ | |
| self._load_dataset(name, **ir_kwargs) | |
| for name in tqdm(self.dataset_names, desc="Loading NanoBEIR-ja", leave=False) | |
| ] | |
| self.csv_file = f"NanoBEIR-ja_evaluation_{aggregate_key}_results.csv" | |
| self.csv_headers = ["epoch", "steps"] | |
| self._append_csv_headers(self.score_function_names) | |
| def _normalize_prompts(self, prompts: str | dict[str, str] | None) -> dict[str, str] | None: | |
| if prompts is None: | |
| return None | |
| if isinstance(prompts, str): | |
| return {name: prompts for name in self.dataset_names} | |
| normalized: dict[str, str] = {} | |
| for key, value in prompts.items(): | |
| normalized[_normalize_task(key)] = value | |
| return normalized | |
| def _append_csv_headers(self, score_function_names): | |
| for score_name in score_function_names: | |
| for k in self.accuracy_at_k: | |
| self.csv_headers.append(f"{score_name}-Accuracy@{k}") | |
| for k in self.precision_recall_at_k: | |
| self.csv_headers.append(f"{score_name}-Precision@{k}") | |
| self.csv_headers.append(f"{score_name}-Recall@{k}") | |
| for k in self.mrr_at_k: | |
| self.csv_headers.append(f"{score_name}-MRR@{k}") | |
| for k in self.ndcg_at_k: | |
| self.csv_headers.append(f"{score_name}-NDCG@{k}") | |
| for k in self.map_at_k: | |
| self.csv_headers.append(f"{score_name}-MAP@{k}") | |
| def _load_dataset(self, task: str, **ir_kwargs) -> InformationRetrievalEvaluator: | |
| if not is_datasets_available(): | |
| raise ValueError("datasets is required; install via `pip install datasets`.") | |
| from datasets import load_dataset | |
| split_name = _split_name(task) | |
| t0 = time.perf_counter() | |
| corpus_ds = load_dataset(DATASET_ID, "corpus", split=split_name) | |
| queries_ds = load_dataset(DATASET_ID, "queries", split=split_name) | |
| qrels_ds = load_dataset(DATASET_ID, "qrels", split=split_name) | |
| logger.info("[NanoBEIR-ja] loaded datasets for %s in %.2fs", task, time.perf_counter() - t0) | |
| corpus_dict = {} | |
| t1 = time.perf_counter() | |
| for sample in corpus_ds: | |
| row = cast(dict[str, Any], sample) | |
| text = row.get("text") | |
| if text: | |
| corpus_dict[row["_id"]] = text | |
| queries_dict = {} | |
| for sample in queries_ds: | |
| row = cast(dict[str, Any], sample) | |
| text = row.get("text") | |
| if text: | |
| queries_dict[row["_id"]] = text | |
| qrels_dict: dict[str, set[str]] = {} | |
| for sample in qrels_ds: | |
| row = cast(dict[str, Any], sample) | |
| qid = row["query-id"] | |
| cids = row["corpus-id"] | |
| if isinstance(cids, list): | |
| qrels_dict.setdefault(qid, set()).update(cids) | |
| else: | |
| qrels_dict.setdefault(qid, set()).add(cids) | |
| logger.info( | |
| "[NanoBEIR-ja] materialized dicts for %s in %.2fs (corpus=%d, queries=%d, qrels=%d)", | |
| task, | |
| time.perf_counter() - t1, | |
| len(corpus_dict), | |
| len(queries_dict), | |
| len(qrels_dict), | |
| ) | |
| if self.query_prompts is not None: | |
| ir_kwargs["query_prompt"] = self.query_prompts.get(task, None) | |
| if self.corpus_prompts is not None: | |
| ir_kwargs["corpus_prompt"] = self.corpus_prompts.get(task, None) | |
| name = _human_readable(task) | |
| return self.information_retrieval_class( | |
| queries=queries_dict, | |
| corpus=corpus_dict, | |
| relevant_docs=qrels_dict, | |
| name=name, | |
| **ir_kwargs, | |
| ) | |
| def _validate_dataset_names(self): | |
| if len(self.dataset_names) == 0: | |
| raise ValueError("dataset_names cannot be empty. Use None to evaluate all tasks.") | |
| invalid = [task for task in self.dataset_names if _normalize_task(task) not in TASKS] | |
| if invalid: | |
| raise ValueError(f"Unknown tasks: {invalid}. Supported: {TASKS}") | |
| def _validate_prompts(self): | |
| error_msg = "" | |
| if self.query_prompts is not None: | |
| missing = [task for task in self.dataset_names if task not in self.query_prompts] | |
| if missing: | |
| error_msg += f"Missing query prompts for: {missing}\n" | |
| if self.corpus_prompts is not None: | |
| missing = [task for task in self.dataset_names if task not in self.corpus_prompts] | |
| if missing: | |
| error_msg += f"Missing corpus prompts for: {missing}\n" | |
| if error_msg: | |
| raise ValueError(error_msg.strip()) | |
| def store_metrics_in_model_card_data(self, *args, **kwargs): # pragma: no cover - mirrors NanoBEIR | |
| if len(self.dataset_names) > 1: | |
| super().store_metrics_in_model_card_data(*args, **kwargs) | |
| def get_config_dict(self) -> dict[str, Any]: | |
| cfg: dict[str, Any] = { | |
| "dataset_names": self.dataset_names, | |
| "dataset_id": DATASET_ID, | |
| "ndcg_only": self.ndcg_only, | |
| } | |
| if self.truncate_dim is not None: | |
| cfg["truncate_dim"] = self.truncate_dim | |
| if self.query_prompts is not None: | |
| cfg["query_prompts"] = self.query_prompts | |
| if self.corpus_prompts is not None: | |
| cfg["corpus_prompts"] = self.corpus_prompts | |
| return cfg | |
| def __call__( | |
| self, | |
| model: SentenceTransformer, | |
| output_path: str | None = None, | |
| epoch: int = -1, | |
| steps: int = -1, | |
| *args, | |
| **kwargs, | |
| ) -> dict[str, float]: | |
| per_metric_agg: dict[str, list[float]] = {} | |
| per_dataset: dict[str, float] = {} | |
| if self.score_functions is None: | |
| self.score_functions = {model.similarity_fn_name: model.similarity} | |
| self.score_function_names = [model.similarity_fn_name] | |
| self._append_csv_headers(self.score_function_names) | |
| for evaluator in tqdm(self.evaluators, desc="Evaluating NanoBEIR-ja", disable=not self.show_progress_bar): | |
| logger.info("Evaluating %s", evaluator.name) | |
| results = evaluator(model, output_path, epoch, steps) | |
| for key, value in results.items(): | |
| per_dataset[key] = value | |
| if "_" in key: | |
| _, metric_name = key.split("_", 1) | |
| else: | |
| metric_name = key | |
| per_metric_agg.setdefault(metric_name, []).append(value) | |
| agg_results = { | |
| f"{self.name}_{metric}": self.aggregate_fn(vals) | |
| for metric, vals in per_metric_agg.items() | |
| } | |
| if not self.primary_metric: | |
| main_score_fn = self.main_score_function | |
| main = None if main_score_fn is None else str(main_score_fn) | |
| ndcg_target = f"ndcg@{max(self.ndcg_at_k)}" | |
| candidates = [k for k in agg_results if k.endswith(ndcg_target)] | |
| if main: | |
| preferred = [k for k in candidates if main in k] | |
| if preferred: | |
| self.primary_metric = preferred[0] | |
| if not self.primary_metric and candidates: | |
| self.primary_metric = candidates[0] | |
| if self.primary_metric and self.primary_metric in agg_results: | |
| logger.info("Primary %s: %.4f", self.primary_metric, agg_results[self.primary_metric]) | |
| self.store_metrics_in_model_card_data(model, agg_results, epoch, steps) | |
| per_dataset.update(agg_results) | |
| if self.ndcg_only: | |
| per_dataset = {k: v for k, v in per_dataset.items() if "ndcg@10" in k} | |
| return per_dataset | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Evaluate a model on NanoBEIR-ja") | |
| parser.add_argument("--model-path", required=True, help="Path or HF id for SentenceTransformer model") | |
| parser.add_argument("--tasks", nargs="*", default=None, help="Task names (default: all)") | |
| parser.add_argument("--batch-size", type=int, default=512, help="Eval batch size") | |
| parser.add_argument("--output", default=None, help="Optional JSON output path for metrics") | |
| parser.add_argument("--show-progress", action="store_true", help="Show per-dataset tqdm during eval") | |
| parser.add_argument( | |
| "--no-autocast", | |
| action="store_true", | |
| help="Disable torch.autocast (default: enabled on CUDA with bfloat16 if available)", | |
| ) | |
| parser.add_argument( | |
| "--autocast-dtype", | |
| choices=["bf16", "fp16"], | |
| default="bf16", | |
| help="autocast dtype (bf16 or fp16)", | |
| ) | |
| parser.add_argument("--query-prompt", default=None, help="Prefix applied to queries") | |
| parser.add_argument("--corpus-prompt", default=None, help="Prefix applied to corpus/passages") | |
| parser.add_argument( | |
| "--all-metrics", | |
| action="store_true", | |
| help="Return all metrics (default: ndcg@10 only)", | |
| ) | |
| parser.add_argument( | |
| "--trust-remote-code", | |
| action="store_true", | |
| help="Pass trust_remote_code=True to SentenceTransformer (needed for some HF models)", | |
| ) | |
| return parser.parse_args() | |
| def main(argv: Sequence[str] | None = None) -> None: | |
| args = parse_args() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| tasks = args.tasks or TASKS | |
| # prompts: only set automatically for e5 if user did NOT provide anything | |
| auto_query_prompt = None | |
| auto_corpus_prompt = None | |
| lower_name = args.model_path.lower() | |
| if "e5" in lower_name and args.query_prompt is None and args.corpus_prompt is None: | |
| auto_query_prompt = "query: " | |
| auto_corpus_prompt = "passage: " | |
| query_prompt = args.query_prompt if args.query_prompt is not None else auto_query_prompt | |
| corpus_prompt = args.corpus_prompt if args.corpus_prompt is not None else auto_corpus_prompt | |
| logging.info("Using prompts -> query: %s | passage: %s", repr(query_prompt), repr(corpus_prompt)) | |
| model = SentenceTransformer(args.model_path, prompts=None, trust_remote_code=args.trust_remote_code) | |
| model.eval() | |
| evaluator = NanoBeirJaEvaluator( | |
| dataset_names=tasks, | |
| batch_size=args.batch_size, | |
| show_progress_bar=args.show_progress, | |
| write_csv=False, | |
| query_prompts=query_prompt if query_prompt else None, | |
| corpus_prompts=corpus_prompt if corpus_prompt else None, | |
| ndcg_only=not args.all_metrics, | |
| ) | |
| use_autocast = not args.no_autocast | |
| autocast_dtype = {"bf16": "bfloat16", "fp16": "float16"}[args.autocast_dtype] | |
| autocast_ctx = None | |
| if use_autocast: | |
| import torch | |
| device_type = "cuda" if torch.cuda.is_available() else "cpu" | |
| autocast_ctx = torch.autocast(device_type=device_type, dtype=getattr(torch, autocast_dtype)) | |
| if autocast_ctx: | |
| with autocast_ctx: | |
| results = evaluator(model) | |
| else: | |
| results = evaluator(model) | |
| if args.output: | |
| path = Path(args.output) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "cli": { | |
| "model_path": args.model_path, | |
| "tasks": tasks, | |
| "batch_size": args.batch_size, | |
| "autocast": not args.no_autocast, | |
| "autocast_dtype": args.autocast_dtype, | |
| "query_prompt": query_prompt, | |
| "corpus_prompt": corpus_prompt, | |
| "ndcg_only": not args.all_metrics, | |
| }, | |
| "metrics": results, | |
| } | |
| path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| print(f"Saved metrics to {path}") | |
| else: | |
| print(json.dumps(results, ensure_ascii=False, indent=2)) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment