Last active
April 14, 2025 12:07
-
-
Save goabonga/043766b7dd2cb48862963d880ccf5f5c to your computer and use it in GitHub Desktop.
LLM Query System with ChromaDB and Sentence Transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import sys | |
| import re | |
| from typing import Any, Dict, List, Mapping, Union | |
| import numpy as np | |
| import httpx | |
| from chromadb.api.client import Client | |
| from chromadb.api.models.Collection import Collection | |
| from sentence_transformers import SentenceTransformer | |
| def filter_context(text: str, words_to_exclude: List[str]) -> str: | |
| for word in words_to_exclude: | |
| pattern: str = r'\b' + re.escape(word) + r'\b' | |
| text = re.sub(pattern, '', text, flags=re.IGNORECASE) | |
| return " ".join(text.split()) | |
| def create_client() -> Client: | |
| return Client() | |
| def get_or_create_collection(client: Client, collection_name: str) -> Collection: | |
| return client.get_or_create_collection(name=collection_name) | |
| def split_text(text: str, max_length: int) -> List[str]: | |
| words: List[str] = text.split() | |
| chunks: List[str] = [] | |
| current_chunk: List[str] = [] | |
| for word in words: | |
| if len(" ".join(current_chunk + [word])) > max_length: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [word] | |
| else: | |
| current_chunk.append(word) | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| def get_embeddings(model: SentenceTransformer, texts: List[str]) -> List[List[float]]: | |
| embeddings = model.encode(texts) | |
| emb_list = embeddings.tolist() | |
| return emb_list #type: ignore | |
| def add_documents( | |
| collection: Collection, | |
| documents: List[str], | |
| embeddings: np.ndarray, | |
| metadatas: List[Mapping[str, Union[str, int, float, bool]]], | |
| ids: List[str], | |
| ) -> None: | |
| collection.add(documents=documents, embeddings=embeddings, metadatas=metadatas, ids=ids) | |
| def query_documents( | |
| collection: Collection, query_text: str, model: SentenceTransformer, n_results: int | |
| ) -> Any: | |
| query_embedding = model.encode([query_text]) | |
| query_embedding_np = np.array(query_embedding.tolist(), dtype=np.float32) | |
| return collection.query( | |
| query_embeddings=query_embedding_np, | |
| n_results=n_results, | |
| include=["metadatas", "documents", "distances", "embeddings", "data"], | |
| ) | |
| def call_llm(context: str, query: str) -> Dict[str, Any]: | |
| prompt: str = ( | |
| f"You are a Python specialist assistant. Provide clear and concise answers exclusively based on the provided documents. " | |
| f"Do not answer if the question is not related to the Python language; instead reply with 'I only answer questions about Python.' " | |
| f"Do not include phrases referring to the documents in your response.\n" | |
| f"Documents:\n{context}\n\nQuestion: {query}\n" | |
| ) | |
| payload: Dict[str, Any] = { | |
| "prompt": prompt, | |
| "model": "qwen2.5:14b", | |
| "max_tokens": 500, | |
| "temperature": 0.7, | |
| "stream": False, | |
| "format": { | |
| "type": "object", | |
| "properties": {"answer": {"type": "string"}}, | |
| "required": ["answer"], | |
| }, | |
| } | |
| response = httpx.post("http://localhost:11434/api/generate", json=payload, timeout=None) | |
| result: Dict[str, Any] = response.json() | |
| return json.loads(result["response"].strip()) | |
| def main(question: str) -> None: | |
| client: Client = create_client() | |
| collection: Collection = get_or_create_collection(client, "python_docs") | |
| model: SentenceTransformer = SentenceTransformer("all-MiniLM-L6-v2") | |
| doc1: str = ( | |
| "Python is an interpreted, high-level and general-purpose programming language. " | |
| "Created by Guido van Rossum and first released in 1991, its design philosophy " | |
| "emphasizes code readability with its use of significant whitespace." | |
| ) | |
| doc2: str = ( | |
| "Python supports multiple programming paradigms, including structured, " | |
| "object-oriented, and functional programming. It features a dynamic type " | |
| "system and automatic memory management and boasts a large standard library." | |
| ) | |
| words_to_exclude: List[str] = ["Python"] | |
| documents: List[str] = [doc1, doc2] | |
| chunks: List[str] = [] | |
| ids: List[str] = [] | |
| metadatas: List[Mapping[str, Union[str, int, float, bool]]] = [] | |
| for i, doc in enumerate(documents): | |
| doc_chunks: List[str] = split_text(filter_context(doc, words_to_exclude), 50) | |
| for j, chunk in enumerate(doc_chunks): | |
| chunks.append(chunk) | |
| ids.append(f"doc{i}_chunk{j}") | |
| metadatas.append({"doc_index": str(i), "chunk_index": str(j)}) | |
| raw_embeddings: List[List[float]] = get_embeddings(model, chunks) | |
| embeddings: np.ndarray = np.array(raw_embeddings, dtype=np.float32) | |
| add_documents(collection, chunks, embeddings, metadatas, ids) | |
| query_result: Any = query_documents(collection, question, model, n_results=1) | |
| if query_result.get("metadatas") and query_result["metadatas"][0]: | |
| doc_index: str = query_result["metadatas"][0][0]["doc_index"] | |
| filtered: Any = collection.get(where={"doc_index": doc_index}, include=["documents"]) | |
| docs_list: List[str] = filtered.get("documents", []) | |
| documents_context: str = " ".join(docs_list) | |
| else: | |
| documents_context = "" | |
| llm_answer: Dict[str, Any] = call_llm(documents_context, question) | |
| print(llm_answer.get("answer", "No answer found.")) | |
| if __name__ == "__main__": | |
| question: str = "" | |
| if len(sys.argv) > 1: | |
| question = " ".join(sys.argv[1:]) | |
| else: | |
| question = "Python is oriented object ?" | |
| main(question) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment