Created
May 12, 2025 22:34
-
-
Save g-fukurowl/81ccd861ce3d51913b6a2e3bdab597c0 to your computer and use it in GitHub Desktop.
とても簡単なRAG。一問一答形式でクエリを渡すとLLMがベクターストアから関連情報を検索した上で答える。サブコマンドrunで実行。ベクターストアを更新したい時はサブコマンドupdate-vectorで実行。スクリプトと同階層にmodelsディレクトリを作り、ここにgemma-3-1B-it-QAT-Q4_0.ggufを配置しておく必要があるので注意。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from llama_cpp import Llama | |
| from search import search_faiss | |
| from colorama import Fore, Style, init | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| import argparse | |
| import sys | |
| # ダウンロードした GGUF ファイルへのパスを指定。ggufファイルのダウンロードは手動で行う。 | |
| # https://huggingface.co/google/gemma-3-1b-it-qat-q4_0-gguf/resolve/main/gemma-3-1b-it-q4_0.gguf?download=true | |
| MODEL_PATH = "models/gemma-3-1B-it-QAT-Q4_0.gguf" | |
| # Llama インスタンスの生成 | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=12000, # コンテキスト長(トークン) | |
| n_threads=4, # 並列スレッド数 | |
| n_gpu_layers=-1, | |
| verbose=False | |
| ) | |
| # HuggingFace上にアップロードされているembeddingモデル名 | |
| EMBEDDING_MODEL_PATH = "intfloat/multilingual-e5-large-instruct" | |
| # ベクトル化したいドキュメントをロードする | |
| def load_documents(file_path: str): | |
| if file_path.lower().endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| elif file_path.lower().endswith(".txt"): | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| elif file_path.lower().endswith(".csv"): | |
| loader = CSVLoader(file_path, encoding="utf-8") | |
| else: | |
| raise ValueError(f"Unsupported file type: {file_path}") | |
| return loader.load() | |
| # ロードしたドキュメントをチャンク単位に分割 | |
| def split_documents(documents, chunk_size=1000, chunk_overlap=0.1): | |
| splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| return splitter.split_documents(documents) | |
| # ドキュメントをベクトル化してストアする | |
| def embed_and_store( | |
| raw_docs, | |
| model_name=EMBEDDING_MODEL_PATH, | |
| device="cpu", | |
| persist_path="faiss_index" | |
| ): | |
| # ステップ1: 生ドキュメント数を検証 | |
| if not raw_docs: | |
| raise ValueError("No documents to embed. Check loader output.") | |
| # ステップ2: チャンク生成 | |
| docs = split_documents(raw_docs, 100) | |
| print(raw_docs) | |
| # デバッグ出力 | |
| print(f"Generated {len(docs)} chunks") | |
| # ステップ3: 空チャンクの除外 | |
| docs = [doc for doc in docs if doc.page_content.strip()] | |
| if not docs: | |
| raise ValueError("All chunks are empty after filtering.") | |
| # ステップ4: 埋め込みモデルの準備 | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs={"device": device}, | |
| ) | |
| # ステップ5: FAISS に格納 | |
| vectorstore = FAISS.from_documents(docs, embeddings) | |
| vectorstore.save_local(persist_path) | |
| return vectorstore | |
| def update_vector(): | |
| import glob | |
| file_paths = glob.glob("data/*.*") | |
| raw_docs = [] | |
| for path in file_paths: | |
| raw_docs.extend(load_documents(path)) | |
| # 修正版関数の呼び出し | |
| faiss_store = embed_and_store( | |
| raw_docs, | |
| model_name=EMBEDDING_MODEL_PATH, | |
| device="cpu", | |
| persist_path="faiss_index" | |
| ) | |
| print(f"Indexed into FAISS at 'faiss_index'") | |
| # あらかじめ作っておいたベクターストアをロード | |
| def load_vectorstore(persist_path="faiss_index", | |
| model_name="intfloat/multilingual-e5-large-instruct", | |
| device="cpu"): | |
| """保存済みFAISSインデックスと埋め込みモデルをロード""" | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs={"device": device}, | |
| ) | |
| vectorstore = FAISS.load_local( | |
| persist_path, | |
| embeddings, | |
| allow_dangerous_deserialization=True # ← ここを追加 | |
| ) | |
| return vectorstore | |
| # ベクターストアを検索 | |
| def search_faiss(query: str, k: int = 5): | |
| """FAISSでクエリ検索""" | |
| persist_path = "faiss_index" | |
| vectorstore = load_vectorstore(persist_path) | |
| print(f"🔍 Searching for: {query}") | |
| results = vectorstore.similarity_search(query, k=k) | |
| return results | |
| # LLMへプロンプトを渡してリクエストを得る | |
| def chat(prompt: str, | |
| max_tokens: int = 2048, | |
| temperature: float = 0.8, | |
| top_p: float = 0.95,): | |
| """ | |
| prompt(文字列)を渡して LLM で応答を生成し、 | |
| 生成テキストを返す。 | |
| """ | |
| out = llm( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| min_p=0.05, | |
| top_k=40, | |
| repeat_penalty=1.1, | |
| stop=["\n\n\n\n"] # 必要に応じてストップシーケンスを変更 | |
| ) | |
| return out["choices"][0]["text"].strip() | |
| def run(): | |
| # colorを有効化 | |
| init() | |
| query = input("💬 Query: ") | |
| search_result = search_faiss(query, k=2) | |
| search_result_str = "" | |
| for i, doc in enumerate(search_result, 1): | |
| search_result_str = search_result_str + f"\nResult #{i}" | |
| search_result_str = search_result_str + doc.page_content | |
| search_result_str = search_result_str + f"[Metadata] {doc.metadata}" | |
| prompt = f"### 指示 \nあなたは優秀なアシスタントAIです。常に日本語で応答します。質問「{query}」に簡潔に答えてください。その際、以下の情報を参照してください。\n\n### 情報 \n{search_result_str}\n\n\n\n" | |
| response = Fore.GREEN + chat(prompt) + Style.RESET_ALL | |
| print("🤖Gemma:", response) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="シンプルなRAGツール") | |
| subparsers = parser.add_subparsers(dest='command', required=True) | |
| # run コマンド | |
| subparsers.add_parser('run', help='1問1答式の対話を実行します。Queryに入力した内容を元にセマンティック検索を行い、これについてLLMが要約して回答します。modelsディレクトリに目的の.ggufファイルを配置してください') | |
| # update-vector コマンド | |
| subparsers.add_parser('update-vector', help='ベクトルの更新処理を実行します。dataディレクトリにcsv, pdf, txtなどを配置してください') | |
| # 引数がない場合はヘルプを表示して終了 | |
| if len(sys.argv) == 1: | |
| parser.print_help() | |
| sys.exit(1) | |
| args = parser.parse_args() | |
| print(args) | |
| if args.command == 'run': | |
| run() | |
| elif args.command == 'update-vector': | |
| update_vector() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment