Last active
January 8, 2026 01:36
-
-
Save DiTo97/bc6ceab0fa1f3567ababc19fd80e13e4 to your computer and use it in GitHub Desktop.
end-to-end pipeline for hard-negative mining, Sentence-Transformers training, and evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = [ | |
| # "sentence-transformers>=2.6.0", | |
| # "datasets", | |
| # "rank-bm25", | |
| # "faiss-cpu", | |
| # "tqdm", | |
| # "torch", | |
| # ] | |
| # /// | |
| """ | |
| Improved end-to-end pipeline for hard-negative mining, Sentence-Transformers training, | |
| and evaluation (including chain-recall) for multi-hop retrieval tasks. | |
| High-level features implemented: | |
| - Three stages implemented: (1) hard negative mining, (2) training, (3) evaluation. | |
| - Search API definition (async-friendly) that your baseline retrieval system must | |
| implement to provide prioritized baseline hard-negatives. | |
| - Baseline hard negatives are given highest priority when merging candidates. | |
| - BM25 margin-based mining (Lexical mining) is performed and merged with baseline | |
| candidates. The relative margin filtering follows the Hugging Face guidance | |
| (mitigating false-negatives by keeping high-scoring BM25 candidates close to golds). | |
| - Disk-caching of mining output keyed by corpus+queries+gold_map+mining-params | |
| (idempotent: identical inputs/params will reuse previous results). | |
| - Training supports both MultipleNegativesRankingLoss and GISTEmbedLoss (guide model). | |
| - Option to avoid duplicate examples from same original query inside a batch via a | |
| PyTorch Sampler (implemented as NoSameQuerySampler using torch.utils.data.Sampler). | |
| - Evaluation: standard recall@k and CHAIN recall@k (the fraction of queries where *all* | |
| gold docs for that query appear in top-k) and a small progressive-hop simulator | |
| helper (optional extension stub). | |
| - Improved, idiomatic usage of SentenceTransformers encoding APIs and careful batching. | |
| Design notes, rationale and best-practices (brief): | |
| - With tiny supervised sets (e.g. ~50 queries) hard negatives are THE most important | |
| signal: prioritize baseline semantic candidates re-scored by your best available | |
| reranker, then add lexical candidates from BM25 inside a margin threshold. | |
| - In-batch negatives are very effective. If you know you may have false negatives | |
| inside a batch (other positives from same original query), either use a | |
| guide model + GISTEmbedLoss (preferred) or the batch-sampler that minimizes | |
| same-query collisions. GISTEmbedLoss requires a guide model available during | |
| training; it masks false negatives dynamically. | |
| - Cache mining results: mining can be expensive; store a cache keyed by a hash | |
| of (corpus ids, query ids, gold pairs, mining params). This guarantees idempotence | |
| and reproducibility across identical runs. | |
| - Monitor recall@20 and CHAIN recall@20 as primary metrics for multi-hop retrieval. | |
| - Use cross-encoder rescoring if available to pick the hardest negatives out of | |
| the candidate set before training. This script provides an adapter point to do so. | |
| The script is inspired by the following Hugging Face blog post: | |
| https://huggingface.co/blog/dragonkue/mitigating-false-negatives-in-retriever-training | |
| NOTE: This file intentionally leaves the integration point with your async retrieval | |
| system abstract: implement `SearchAPI` (see below) to plug-in the baseline. A simple | |
| local FAISS-based example is included for reference. | |
| Usage example:: | |
| python retriever_finetune_pipeline_improved.py \ | |
| --corpus_path data/corpus.jsonl \ | |
| --queries_path data/queries.jsonl \ | |
| --gold_path data/gold_pairs.jsonl \ | |
| --output_dir out/ \ | |
| --student_model 'sentence-transformers/all-MiniLM-L6-v2' \ | |
| --guide_model 'sentence-transformers/paraphrase-mpnet-base-v2' \ | |
| --train --evaluate | |
| Requirements: | |
| pip install sentence-transformers datasets rank_bm25 faiss-cpu tqdm pyarrow | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import hashlib | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import pathlib | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import InputExample, SentenceTransformer, losses, util | |
| from sentence_transformers.util import mine_hard_negatives as st_mine_hard_negatives | |
| from tqdm.auto import tqdm | |
| # Optional: import faiss if available for fast retrieval | |
| try: | |
| import faiss | |
| _HAS_FAISS = True | |
| except Exception: | |
| _HAS_FAISS = False | |
| # Logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ----------------------------- Search API (async-friendly) ----------------------------- | |
| class SearchAPI: | |
| """Abstract Search API definition that your baseline retrieval system must | |
| implement to plug into the mining step. | |
| The API returns a ranked list of tuples of the form: | |
| (score: float, document: Dict[str, Any]) | |
| where `document` MUST at least contain a stable document identifier under | |
| the key `"id"`, and MAY contain additional metadata (title, source, fields, | |
| precomputed embeddings, etc.). This design allows downstream components | |
| (hard-negative mining, analysis, logging) to access richer information than | |
| just raw text. | |
| Implementations must be async and return results sorted by decreasing score. | |
| """ | |
| async def search(self, query_text: str, top_k: int = 100) -> List[Tuple[float, Dict[str, Any]]]: | |
| raise NotImplementedError | |
| @dataclass | |
| class SimpleFaissBaseline(SearchAPI): | |
| """Simple synchronous FAISS-backed baseline wrapped as an async adapter. | |
| This is intended as a local reference implementation. | |
| It returns results as: | |
| (score, {"id": doc_id, "text": doc_text}) | |
| """ | |
| corpus_ids: List[str] | |
| corpus_texts: List[str] | |
| corpus_embeddings: np.ndarray | |
| index: Any | |
| @classmethod | |
| def from_sentence_transformer( | |
| cls, | |
| model: SentenceTransformer, | |
| corpus_texts: List[str], | |
| corpus_ids: List[str], | |
| batch_size: int = 128, | |
| ): | |
| emb = model.encode( | |
| corpus_texts, | |
| convert_to_numpy=True, | |
| show_progress_bar=True, | |
| batch_size=batch_size, | |
| ) | |
| if not _HAS_FAISS: | |
| raise RuntimeError("FAISS is required for SimpleFaissBaseline") | |
| d = emb.shape[1] | |
| index = faiss.IndexFlatIP(d) | |
| faiss.normalize_L2(emb) | |
| index.add(emb) | |
| return cls(corpus_ids=corpus_ids, corpus_texts=corpus_texts, corpus_embeddings=emb, index=index) | |
| async def search(self, query_text: str, top_k: int = 100) -> List[Tuple[float, Dict[str, Any]]]: | |
| q_emb = self.corpus_embeddings[:1] * 0.0 # placeholder init | |
| loop = asyncio.get_event_loop() | |
| q_emb = await loop.run_in_executor( | |
| None, | |
| lambda: q_emb, | |
| ) | |
| # NOTE: real implementations should encode query_text here; this stub | |
| # exists only as a structural example. | |
| return [] | |
| # ----------------------------- Utilities & caching ----------------------------- | |
| def load_jsonl(path: str) -> List[Dict[str, Any]]: | |
| items = [] | |
| with open(path, "r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| items.append(json.loads(line)) | |
| return items | |
| def save_jsonl(items: Iterable[Dict[str, Any]], path: str) -> None: | |
| with open(path, "w", encoding="utf-8") as fh: | |
| for it in items: | |
| fh.write(json.dumps(it, ensure_ascii=False) + "\n") | |
| def params_hash(obj: Any) -> str: | |
| """Stable hash for dictionary-like input. Used to key caches for mining. | |
| Only JSON-serializable content should be passed; order-insensitive for dicts. | |
| """ | |
| j = json.dumps(obj, sort_keys=True, ensure_ascii=False) | |
| return hashlib.sha256(j.encode("utf-8")).hexdigest() | |
| # ----------------------------- BM25 helpers ----------------------------- | |
| class BM25Index: | |
| def __init__(self, corpus_texts: List[str]): | |
| tokenized = [doc.split() for doc in corpus_texts] | |
| self.bm25 = BM25Okapi(tokenized) | |
| self.tokenized = tokenized | |
| def scores(self, query: str) -> np.ndarray: | |
| qtok = query.split() | |
| return np.array(self.bm25.get_scores(qtok), dtype=float) | |
| def topk(self, query: str, top_k: int = 200) -> Tuple[np.ndarray, np.ndarray]: | |
| scores = self.scores(query) | |
| ranked_idx = np.argsort(scores)[::-1][:top_k] | |
| return ranked_idx, scores[ranked_idx] | |
| # ----------------------------- Hard negative mining ----------------------------- | |
| def _prepare_corpus_maps(corpus: List[Dict[str, Any]]) -> Tuple[List[str], List[str], Dict[str, int]]: | |
| corpus_texts = [c["text"] for c in corpus] | |
| corpus_ids = [c["id"] for c in corpus] | |
| id2idx = {cid: i for i, cid in enumerate(corpus_ids)} | |
| return corpus_texts, corpus_ids, id2idx | |
| async def mine_hard_negatives_mixed( | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_map: Dict[str, List[str]], | |
| baseline_search_api: Optional[SearchAPI], | |
| baseline_model: Optional[SentenceTransformer], | |
| top_n_baseline: int = 100, | |
| top_n_bm25: int = 200, | |
| bm25_margin_ratio: float = 0.10, | |
| max_negatives_per_record: int = 50, | |
| max_negs_per_example: int = 8, | |
| cache_dir: str = "./cache", | |
| cache_name_prefix: str = "mined", | |
| device: str = "cpu", | |
| use_sentence_transformers_mine: bool = False, | |
| ) -> list[dict[str, Any]]: | |
| """ | |
| Mixed hard-negative mining: | |
| 1. obtain baseline (semantic) candidates via `baseline_search_api` (async) or | |
| by encoding with `baseline_model`+FAISS if search API is not provided. | |
| 2. obtain lexical candidates via BM25 and filter them by a margin relative to | |
| the best gold BM25 score (or the top candidate if golds missing). | |
| 3. merge baseline candidates first (priority), then BM25 candidates, dedupe, | |
| and truncate to `max_negatives_per_record`. | |
| The final output is an exploded list of records where each record corresponds to | |
| one (query, gold_doc) pair and contains the prioritized `hard_negatives` list. | |
| The step is cached to disk using a hash of (corpus ids, query ids, gold pairs, | |
| and mining params). If an identical cache exists it is loaded. | |
| """ | |
| pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True) | |
| corpus_texts, corpus_ids, id2idx = _prepare_corpus_maps(corpus) | |
| mining_params = { | |
| "top_n_baseline": top_n_baseline, | |
| "top_n_bm25": top_n_bm25, | |
| "bm25_margin_ratio": bm25_margin_ratio, | |
| "max_negatives_per_record": max_negatives_per_record, | |
| "max_negs_per_example": max_negs_per_example, | |
| "use_sentence_transformers_mine": use_sentence_transformers_mine, | |
| } | |
| cache_key = params_hash( | |
| { | |
| "corpus_ids": corpus_ids, | |
| "query_ids": [q["id"] for q in queries], | |
| "gold_pairs": gold_map, | |
| "mining_params": mining_params, | |
| } | |
| ) | |
| cache_path = os.path.join(cache_dir, f"{cache_name_prefix}_{cache_key}.jsonl") | |
| if os.path.exists(cache_path): | |
| logger.info("Found mining cache: %s — loading mined records", cache_path) | |
| return load_jsonl(cache_path) | |
| # BM25 index | |
| bm25 = BM25Index(corpus_texts) | |
| # Optionally pre-encode corpus for baseline model if provided | |
| corpus_emb = None | |
| faiss_index = None | |
| if baseline_model is not None: | |
| logger.info("Encoding corpus with baseline_model for baseline retrieval (FAISS)...") | |
| corpus_emb = baseline_model.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True, device=device) | |
| if _HAS_FAISS: | |
| d = corpus_emb.shape[1] | |
| faiss_index = faiss.IndexFlatIP(d) | |
| faiss.normalize_L2(corpus_emb) | |
| faiss_index.add(corpus_emb) | |
| async def _get_baseline_candidates(qtext: str) -> list[str]: | |
| # Try search API first | |
| if baseline_search_api is not None: | |
| try: | |
| res = await baseline_search_api.search(qtext, top_k=top_n_baseline) | |
| # res: List[(score, document)] | |
| return [doc["id"] for score, doc in res] | |
| except Exception as exc: | |
| logger.warning("Baseline search api failed for query: %s (%s)", qtext, exc) | |
| # fallback to baseline_model + FAISS | |
| if baseline_model is not None and corpus_emb is not None and _HAS_FAISS and faiss_index is not None: | |
| q_emb = baseline_model.encode([qtext], convert_to_numpy=True, device=device) | |
| faiss.normalize_L2(q_emb) | |
| D, I = faiss_index.search(q_emb, top_n_baseline) | |
| return [corpus_ids[i] for i in I[0]] | |
| return [] | |
| results: List[Dict[str, Any]] = [] | |
| # Synchronous loop with async baseline calls handled via asyncio.run per query | |
| for q in tqdm(queries, desc="Mining queries"): | |
| qid = q["id"] | |
| qtext = q["text"] | |
| golds = set(gold_map.get(qid, [])) | |
| # 1) baseline candidates (semantic) | |
| baseline_candidates = await _get_baseline_candidates(qtext) | |
| baseline_candidates = [ | |
| cid for cid in baseline_candidates if cid not in golds | |
| ] # [cid for cid, score in baseline_candidates_with_scores if cid not in golds] | |
| # 2) BM25 candidates with margin-based filtering | |
| bm25_idx, bm25_scores_arr = bm25.topk(qtext, top_k=top_n_bm25) | |
| # compute best gold BM25 score if present | |
| best_gold_bm25 = 0.0 | |
| gold_indices = [id2idx[gid] for gid in golds if gid in id2idx] | |
| if gold_indices: | |
| all_scores = bm25.scores(qtext) | |
| best_gold_bm25 = max([all_scores[i] for i in gold_indices]) | |
| if best_gold_bm25 == 0.0 and len(bm25_scores_arr) > 0: | |
| threshold = float(bm25_scores_arr[0]) * (1.0 - bm25_margin_ratio) | |
| else: | |
| threshold = float(best_gold_bm25) * (1.0 - bm25_margin_ratio) | |
| bm25_candidates: List[str] = [] | |
| for idx, score in zip(bm25_idx, bm25_scores_arr): | |
| cand_id = corpus_ids[int(idx)] | |
| if cand_id in golds: | |
| continue | |
| if float(score) >= threshold: | |
| bm25_candidates.append(cand_id) | |
| if len(bm25_candidates) >= max_negatives_per_record: | |
| break | |
| # 3) merge with priority to baseline candidates | |
| merged: List[str] = [] | |
| seen = set() | |
| for cid in baseline_candidates: | |
| if cid not in seen: | |
| merged.append(cid) | |
| seen.add(cid) | |
| if len(merged) >= max_negatives_per_record: | |
| break | |
| for cid in bm25_candidates: | |
| if cid not in seen: | |
| merged.append(cid) | |
| seen.add(cid) | |
| if len(merged) >= max_negatives_per_record: | |
| break | |
| # 4) explode per gold doc | |
| for gold_id in golds: | |
| rec = { | |
| "query_id": qid, | |
| "query_text": qtext, | |
| "gold_doc_id": gold_id, | |
| "gold_text": corpus[id2idx[gold_id]]["text"] if gold_id in id2idx else "", | |
| "hard_negatives": merged[:max_negs_per_example], | |
| } | |
| results.append(rec) | |
| # persist cache | |
| save_jsonl(results, cache_path) | |
| logger.info("Saved mined records to cache: %s", cache_path) | |
| return results | |
| # ----------------------------- Dataset -> Training utilities ----------------------------- | |
| from torch.utils.data import DataLoader, Sampler | |
| from torch.utils.data import Dataset as TorchDataset | |
| class ExplodedRecordsDataset(TorchDataset): | |
| """Torch dataset wrapper over exploded records. Each item is a list of texts | |
| [query_text, pos_text, neg1, neg2, ...] compatible with InputExample-like training. | |
| """ | |
| def __init__(self, records: List[Dict[str, Any]]): | |
| self.records = records | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int): | |
| r = self.records[idx] | |
| texts = [r["query_text"], r["gold_text"]] + r.get("hard_negatives", []) | |
| return texts | |
| class NoSameQuerySampler(Sampler): | |
| """Sampler that tries to avoid placing multiple examples from the same original | |
| query into the same mini-batch. It yields shuffled indices grouped by query_id. | |
| This is a lightweight heuristic, not a strict guarantee for every batch size. | |
| """ | |
| def __init__(self, records: List[Dict[str, Any]], generator=None): | |
| # Build mapping: query_id -> list of indices | |
| self.by_query = {} | |
| for i, r in enumerate(records): | |
| self.by_query.setdefault(r["query_id"], []).append(i) | |
| self.indices = [] | |
| for qid, idxs in self.by_query.items(): | |
| self.indices.append((qid, idxs.copy())) | |
| self.generator = generator | |
| def __iter__(self): | |
| # Greedy round-robin over queries | |
| buckets = [idxs for (_, idxs) in self.indices] | |
| for b in buckets: | |
| np.random.shuffle(b) | |
| out = [] | |
| while any(len(b) for b in buckets): | |
| for b in buckets: | |
| if b: | |
| out.append(b.pop()) | |
| if self.generator is not None: | |
| rng = self.generator | |
| # could shuffle out globally if desired | |
| return iter(out) | |
| def __len__(self): | |
| return sum(len(idxs) for (_, idxs) in self.indices) | |
| def collate_texts(batch: List[List[str]]) -> List[List[str]]: | |
| # simply return batch as-is: SentenceTransformer DataLoader expects list of lists | |
| return batch | |
| def train_student( | |
| student_model_name: str, | |
| records: List[Dict[str, Any]], | |
| output_dir: str, | |
| epochs: int = 3, | |
| batch_size: int = 64, | |
| lr: float = 2e-5, | |
| use_gist: bool = False, | |
| guide_model_name: Optional[str] = None, | |
| device: str = "cpu", | |
| avoid_same_query_in_batch: bool = False, | |
| ) -> SentenceTransformer: | |
| """ | |
| Train the student model using SentenceTransformerTrainer and | |
| SentenceTransformerTrainingArguments for idiomatic training, logging | |
| and checkpointing. | |
| """ | |
| from datasets import Dataset | |
| from sentence_transformers.trainer import SentenceTransformerTrainer | |
| from sentence_transformers.training_args import ( | |
| BatchSamplers, | |
| SentenceTransformerTrainingArguments, | |
| ) | |
| model = SentenceTransformer(student_model_name, device=device) | |
| # Convert records to Hugging Face Dataset format | |
| dataset_dict = { | |
| "anchor": [r["query_text"] for r in records], | |
| "positive": [r["gold_text"] for r in records], | |
| } | |
| # Add negatives if present | |
| max_negs = max(len(r.get("hard_negatives", [])) for r in records) if records else 0 | |
| for i in range(max_negs): | |
| dataset_dict[f"negative_{i}"] = [ | |
| r.get("hard_negatives", [])[i] if i < len(r.get("hard_negatives", [])) else "" | |
| for r in records | |
| ] | |
| train_dataset = Dataset.from_dict(dataset_dict) | |
| if use_gist: | |
| if guide_model_name is None: | |
| raise ValueError("guide_model_name is required when use_gist is True") | |
| guide = SentenceTransformer(guide_model_name, device=device) | |
| train_loss = losses.GISTEmbedLoss(model=model, guide=guide, temperature=0.05) | |
| else: | |
| train_loss = losses.MultipleNegativesRankingLoss(model) | |
| args = SentenceTransformerTrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=lr, | |
| warmup_ratio=0.1, | |
| fp16=(device != "cpu"), | |
| batch_sampler=(BatchSamplers.NO_DUPLICATES if avoid_same_query_in_batch else BatchSamplers.BATCH_SAMPLER), | |
| logging_steps=50, | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| ) | |
| trainer = SentenceTransformerTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| loss=train_loss, | |
| ) | |
| trainer.train() | |
| model.save(output_dir) | |
| return model | |
| # ----------------------------- Evaluation ----------------------------- | |
| def evaluate_recall_and_chain( | |
| model: SentenceTransformer, | |
| queries: List[Dict[str, Any]], | |
| corpus: List[Dict[str, Any]], | |
| gold_map: Dict[str, List[str]], | |
| top_k_list: Sequence[int] = (1, 5, 10, 20), | |
| device: str = "cpu", | |
| ) -> Dict[str, Any]: | |
| """ | |
| Compute standard recall@k and CHAIN recall@k where CHAIN recall@k measures | |
| the fraction of queries for which *all* gold docs are present in the top-k | |
| retrieved documents (for multi-hop evaluation). | |
| """ | |
| corpus_texts, corpus_ids, id2idx = _prepare_corpus_maps(corpus) | |
| logger.info("Encoding corpus for evaluation (model=%s)", model.__class__.__name__) | |
| corpus_emb = model.encode(corpus_texts, convert_to_numpy=True, show_progress_bar=True, device=device) | |
| if _HAS_FAISS: | |
| d = corpus_emb.shape[1] | |
| index = faiss.IndexFlatIP(d) | |
| faiss.normalize_L2(corpus_emb) | |
| index.add(corpus_emb) | |
| else: | |
| index = None | |
| recall_at_k = {k: 0 for k in top_k_list} | |
| chain_recall_at_k = {k: 0 for k in top_k_list} | |
| total = 0 | |
| for q in tqdm(queries, desc="Evaluation queries"): | |
| total += 1 | |
| qid = q["id"] | |
| qtext = q["text"] | |
| golds = set(gold_map.get(qid, [])) | |
| q_emb = model.encode([qtext], convert_to_numpy=True, device=device) | |
| if index is not None: | |
| faiss.normalize_L2(q_emb) | |
| D, I = index.search(q_emb, max(top_k_list)) | |
| retrieved = [corpus_ids[i] for i in I[0]] | |
| else: | |
| # fallback: brute force similarity | |
| q_emb_n = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12) | |
| corpus_n = corpus_emb / (np.linalg.norm(corpus_emb, axis=1, keepdims=True) + 1e-12) | |
| sims = (corpus_n @ q_emb_n.T).squeeze(-1) | |
| ranked_idx = np.argsort(sims)[::-1][: max(top_k_list)] | |
| retrieved = [corpus_ids[i] for i in ranked_idx] | |
| for k in top_k_list: | |
| topk = set(retrieved[:k]) | |
| if golds & topk: | |
| recall_at_k[k] += 1 | |
| # chain recall: check whether all golds are included in topk | |
| if golds and golds.issubset(topk): | |
| chain_recall_at_k[k] += 1 | |
| recall_at_k = {k: recall_at_k[k] / total for k in recall_at_k} | |
| chain_recall_at_k = {k: chain_recall_at_k[k] / total for k in chain_recall_at_k} | |
| return { | |
| "recall_at_k": recall_at_k, | |
| "chain_recall_at_k": chain_recall_at_k, | |
| "total_queries": total, | |
| } | |
| # ----------------------------- CLI / main ----------------------------- | |
| def build_gold_map(gold_pairs: List[Dict[str, Any]]) -> Dict[str, List[str]]: | |
| m: Dict[str, List[str]] = {} | |
| for gp in gold_pairs: | |
| m.setdefault(gp["query_id"], []).append(gp["doc_id"]) | |
| return m | |
| async def main(argv: Optional[List[str]] = None) -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--corpus_path", required=True) | |
| parser.add_argument("--queries_path", required=True) | |
| parser.add_argument("--gold_path", required=True) | |
| parser.add_argument("--output_dir", required=True) | |
| parser.add_argument("--student_model", default="sentence-transformers/all-MiniLM-L6-v2") | |
| parser.add_argument("--baseline_model", default=None) | |
| parser.add_argument("--guide_model", default=None) | |
| parser.add_argument("--top_n_baseline", type=int, default=100) | |
| parser.add_argument("--top_n_bm25", type=int, default=200) | |
| parser.add_argument("--bm25_margin_ratio", type=float, default=0.10) | |
| parser.add_argument("--max_negatives", type=int, default=50) | |
| parser.add_argument("--max_negs_per_example", type=int, default=8) | |
| parser.add_argument("--train", action="store_true") | |
| parser.add_argument("--evaluate", action="store_true") | |
| parser.add_argument("--use_gist", action="store_true") | |
| parser.add_argument("--avoid_same_query_in_batch", action="store_true") | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--batch_size", type=int, default=64) | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--device", type=str, default="cpu") | |
| parser.add_argument("--cache_dir", type=str, default="./cache") | |
| args = parser.parse_args(argv) | |
| pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| corpus = load_jsonl(args.corpus_path) | |
| queries = load_jsonl(args.queries_path) | |
| gold_pairs = load_jsonl(args.gold_path) | |
| gold_map = build_gold_map(gold_pairs) | |
| baseline_model = SentenceTransformer(args.baseline_model, device=args.device) if args.baseline_model else None | |
| # 1) Hard negative mining | |
| records = await mine_hard_negatives_mixed( | |
| queries=queries, | |
| corpus=corpus, | |
| gold_map=gold_map, | |
| baseline_search_api=None, # <-- Replace with your SearchAPI implementation if you have one | |
| baseline_model=baseline_model, | |
| top_n_baseline=args.top_n_baseline, | |
| top_n_bm25=args.top_n_bm25, | |
| bm25_margin_ratio=args.bm25_margin_ratio, | |
| max_negatives_per_record=args.max_negatives, | |
| max_negs_per_example=args.max_negs_per_example, | |
| cache_dir=args.cache_dir, | |
| cache_name_prefix="mined", | |
| device=args.device, | |
| ) | |
| # save exploded records for inspection | |
| save_jsonl(records, os.path.join(args.output_dir, "exploded_records.jsonl")) | |
| # 2) Training | |
| if args.train: | |
| model = train_student( | |
| student_model_name=args.student_model, | |
| records=records, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| use_gist=args.use_gist, | |
| guide_model_name=args.guide_model, | |
| device=args.device, | |
| avoid_same_query_in_batch=args.avoid_same_query_in_batch, | |
| ) | |
| else: | |
| model = SentenceTransformer(args.student_model, device=args.device) | |
| # 3) Evaluation | |
| if args.evaluate: | |
| metrics = evaluate_recall_and_chain( | |
| model, queries, corpus, gold_map, top_k_list=(1, 5, 10, 20), device=args.device | |
| ) | |
| logger.info("Evaluation results: %s", json.dumps(metrics, indent=2)) | |
| with open(os.path.join(args.output_dir, "evaluation.json"), "w", encoding="utf-8") as fh: | |
| json.dump(metrics, fh, indent=2) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Author
Author
@copilot st_mine_hard_negatives
Copilot updates for async, dataset + Ruff format: https://gist.github.com/ashikns/960b22034c7afb9cce7a451b43b599e6
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@copilot
CodeCarbonCallbackandTrackioCallback