Last active
February 22, 2026 14:59
-
-
Save VictorZhang2014/219e97082bd0308ed453ec6b007bb501 to your computer and use it in GitHub Desktop.
GraphRAG_LangChain.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| log = logging.getLogger(__name__) | |
| # ─────────────────────────── 配置 ─────────────────────────── | |
| NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") | |
| NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") | |
| NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "") # 必填 | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "sk-proj-Axxxxxxxxxxx") # 必填 | |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small") | |
| LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o") | |
| CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "800")) | |
| CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "150")) | |
| TOP_K = int(os.getenv("TOP_K", "5")) | |
| # ─────────────────────────── Neo4j ─────────────────────────── | |
| from neo4j import GraphDatabase | |
| def get_driver(): | |
| return GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) | |
| def init_schema(driver): | |
| """创建约束和向量索引""" | |
| stmts = [ | |
| "CREATE CONSTRAINT doc_id IF NOT EXISTS FOR (d:Document) REQUIRE d.id IS UNIQUE", | |
| "CREATE CONSTRAINT chunk_id IF NOT EXISTS FOR (c:Chunk) REQUIRE c.id IS UNIQUE", | |
| "CREATE CONSTRAINT entity_uniq IF NOT EXISTS FOR (e:Entity) REQUIRE (e.name, e.type) IS UNIQUE", | |
| # Neo4j AuraDB 5.x 向量索引 | |
| """CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS | |
| FOR (c:Chunk) ON c.embedding | |
| OPTIONS {indexConfig: {`vector.dimensions`:1536, `vector.similarity_function`:'cosine'}}""", | |
| ] | |
| with driver.session() as s: | |
| for q in stmts: | |
| try: | |
| s.run(q) | |
| except Exception as e: | |
| log.debug("Schema: %s", e) | |
| # 等待索引上线 | |
| import time | |
| for _ in range(20): | |
| with driver.session() as s: | |
| result = s.run( | |
| "SHOW VECTOR INDEXES WHERE name = 'chunk_embedding_index'" | |
| ).data() | |
| if result and result[0].get("state") == "ONLINE": | |
| log.info("✅ 向量索引 chunk_embedding_index 已就绪") | |
| return | |
| log.info("⏳ 等待向量索引上线...") | |
| time.sleep(2) | |
| log.warning("⚠️ 向量索引可能尚未就绪,继续执行") | |
| def run(driver, query, **params): | |
| with driver.session() as s: | |
| return list(s.run(query, **params)) | |
| # ─────────────────────────── PDF 解析 ─────────────────────────── | |
| def parse_pdf(path: str) -> str: | |
| from markitdown import MarkItDown | |
| result = MarkItDown().convert(path) | |
| text = result.text_content | |
| if not text or not text.strip(): | |
| raise ValueError("PDF 解析结果为空") | |
| log.info("📄 解析完成,%d 字符", len(text)) | |
| return text | |
| # ─────────────────────────── 分块 ─────────────────────────── | |
| def split_text(text: str) -> list[str]: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| separators=["\n\n", "\n", "。", ".", " ", ""], | |
| ) | |
| chunks = [c.strip() for c in splitter.split_text(text) if len(c.strip()) >= 30] | |
| log.info("🧩 分块完成:%d 块", len(chunks)) | |
| return chunks | |
| # ─────────────────────────── 嵌入 ─────────────────────────── | |
| def embed_texts(texts: list[str]) -> list[list[float]]: | |
| from langchain_openai import OpenAIEmbeddings | |
| embedder = OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY) | |
| return embedder.embed_documents(texts) | |
| def embed_query(text: str) -> list[float]: | |
| from langchain_openai import OpenAIEmbeddings | |
| embedder = OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY) | |
| return embedder.embed_query(text) | |
| # ─────────────────────────── 实体抽取 ─────────────────────────── | |
| EXTRACT_PROMPT = """从以下文本中提取实体和关系,只返回 JSON,不要其他内容。 | |
| 格式: | |
| {"entities":[{"name":"...","type":"PERSON/ORG/CONCEPT/LOCATION/OTHER","desc":"..."}], | |
| "relations":[{"src":"...","src_type":"...","rel":"...","tgt":"...","tgt_type":"..."}]} | |
| 文本: | |
| {text}""" | |
| def extract_kg(text: str) -> dict: | |
| from langchain_openai import ChatOpenAI | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, openai_api_key=OPENAI_API_KEY) | |
| try: | |
| raw = llm.invoke(EXTRACT_PROMPT.format(text=text[:1500])).content | |
| raw = re.sub(r"```json|```", "", raw).strip() | |
| return json.loads(raw) | |
| except Exception as e: | |
| log.debug("实体抽取失败: %s", e) | |
| return {"entities": [], "relations": []} | |
| # ─────────────────────────── 摄入管道 ─────────────────────────── | |
| def ingest(file_path: str, driver): | |
| path = Path(file_path) | |
| doc_id = hashlib.md5(path.read_bytes()).hexdigest() | |
| filename = path.name | |
| # 1. 解析 PDF | |
| text = parse_pdf(str(path)) | |
| # 2. 写入 Document 节点 | |
| run(driver, | |
| "MERGE (d:Document {id:$id}) SET d.filename=$fn, d.createdAt=datetime()", | |
| id=doc_id, fn=filename) | |
| # 3. 分块 + 嵌入 | |
| chunks = split_text(text) | |
| log.info("🔢 生成嵌入向量...") | |
| embeddings = embed_texts(chunks) | |
| # 4. 写入 Chunk 节点 + 实体抽取 | |
| for i, (chunk, emb) in enumerate(zip(chunks, embeddings)): | |
| cid = f"{doc_id}_{i}" | |
| run(driver, | |
| """MERGE (c:Chunk {id:$cid}) | |
| SET c.text=$text, c.embedding=$emb, c.idx=$i | |
| WITH c | |
| MATCH (d:Document {id:$did}) | |
| MERGE (d)-[:HAS_CHUNK]->(c)""", | |
| cid=cid, text=chunk, emb=emb, i=i, did=doc_id) | |
| # 仅对前 40 块做实体抽取(省钱) | |
| if i < 40: | |
| kg = extract_kg(chunk) | |
| for e in kg.get("entities", []): | |
| if e.get("name"): | |
| run(driver, | |
| """MERGE (e:Entity {name:$name, type:$type}) | |
| SET e.desc=$desc | |
| WITH e MATCH (c:Chunk {id:$cid}) | |
| MERGE (c)-[:MENTIONS]->(e)""", | |
| name=e["name"], type=e.get("type","OTHER"), | |
| desc=e.get("desc",""), cid=cid) | |
| for r in kg.get("relations", []): | |
| if r.get("src") and r.get("tgt"): | |
| run(driver, | |
| """MERGE (s:Entity {name:$src, type:$st}) | |
| MERGE (t:Entity {name:$tgt, type:$tt}) | |
| MERGE (s)-[:RELATES_TO {rel:$rel}]->(t)""", | |
| src=r["src"], st=r.get("src_type","OTHER"), | |
| tgt=r["tgt"], tt=r.get("tgt_type","OTHER"), | |
| rel=r.get("rel","related")) | |
| if (i + 1) % 10 == 0: | |
| log.info(" %d/%d 块处理完毕", i+1, len(chunks)) | |
| log.info("✅ 摄入完成: %s doc_id=%s", filename, doc_id) | |
| return doc_id | |
| # ─────────────────────────── 检索 ─────────────────────────── | |
| def retrieve(question: str, driver) -> str: | |
| qvec = embed_query(question) | |
| # 向量检索 | |
| chunks = run(driver, | |
| """CALL db.index.vector.queryNodes('chunk_embedding_index', $k, $vec) | |
| YIELD node, score | |
| MATCH (d:Document)-[:HAS_CHUNK]->(node) | |
| RETURN node.text AS text, d.filename AS filename, score | |
| ORDER BY score DESC""", | |
| k=TOP_K, vec=qvec) | |
| if not chunks: | |
| return "" | |
| chunk_context = "\n\n".join( | |
| f"[{r['filename']} | 相关度:{r['score']:.3f}]\n{r['text']}" | |
| for r in chunks | |
| ) | |
| # 图谱实体关系检索 | |
| chunk_texts = [r["text"][:100] for r in chunks] | |
| graph_facts = run(driver, | |
| """UNWIND $texts AS t | |
| MATCH (c:Chunk) WHERE c.text STARTS WITH t | |
| MATCH (c)-[:MENTIONS]->(e:Entity) | |
| OPTIONAL MATCH (e)-[r:RELATES_TO]->(e2:Entity) | |
| RETURN DISTINCT e.name AS src, e.type AS src_type, | |
| r.rel AS rel, e2.name AS tgt, e2.type AS tgt_type | |
| LIMIT 30""", | |
| texts=chunk_texts) | |
| graph_context = "" | |
| if graph_facts: | |
| lines = [] | |
| seen = set() | |
| for r in graph_facts: | |
| if r["rel"] and r["tgt"]: | |
| key = (r["src"], r["rel"], r["tgt"]) | |
| if key not in seen: | |
| seen.add(key) | |
| lines.append(f" [{r['src_type']}]{r['src']} --{r['rel']}--> [{r['tgt_type']}]{r['tgt']}") | |
| elif r["src"] and (r["src"],) not in seen: | |
| seen.add((r["src"],)) | |
| lines.append(f" [{r['src_type']}]{r['src']}") | |
| if lines: | |
| graph_context = "\n知识图谱关联事实:\n" + "\n".join(lines) | |
| return chunk_context + graph_context | |
| # ─────────────────────────── LLM 回答 ─────────────────────────── | |
| SYSTEM_PROMPT = """你是专业文档问答助手。严格基于提供的上下文回答问题,不要编造内容。 | |
| 如上下文不足,请如实说明。回答要清晰、准确、有条理。""" | |
| def ask(question: str, context: str, history: list) -> str: | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| llm = ChatOpenAI(model=LLM_MODEL, temperature=0.1, openai_api_key=OPENAI_API_KEY) | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] | |
| for h, a in history[-4:]: # 最近4轮历史 | |
| messages += [HumanMessage(content=h), AIMessage(content=a)] | |
| user_content = f"上下文:\n{context}\n\n问题: {question}" if context else question | |
| messages.append(HumanMessage(content=user_content)) | |
| return llm.invoke(messages).content | |
| # ─────────────────────────── CLI ─────────────────────────── | |
| def cmd_upload(args): | |
| path = Path(args.file) | |
| if not path.exists(): | |
| print(f"❌ 文件不存在: {path}"); sys.exit(1) | |
| driver = get_driver() | |
| init_schema(driver) | |
| try: | |
| ingest(str(path), driver) | |
| finally: | |
| driver.close() | |
| def cmd_list(args): | |
| driver = get_driver() | |
| try: | |
| docs = run(driver, | |
| """MATCH (d:Document) | |
| OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c) | |
| RETURN d.id AS id, d.filename AS name, | |
| d.createdAt AS created, count(c) AS chunks | |
| ORDER BY d.createdAt DESC""") | |
| if not docs: | |
| print("📭 暂无文档"); return | |
| print(f"\n{'ID':<34} {'文件名':<28} {'块数':>5}") | |
| print("-"*70) | |
| for d in docs: | |
| print(f"{d['id']:<34} {d['name'][:26]:<28} {d['chunks']:>5}") | |
| finally: | |
| driver.close() | |
| def cmd_query(args): | |
| driver = get_driver() | |
| print("\n🤖 GraphRAG PDF 问答(输入 exit 退出,clear 清空历史)\n") | |
| history = [] | |
| try: | |
| while True: | |
| try: | |
| q = input("你: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\n👋"); break | |
| if not q: continue | |
| if q.lower() in ("exit", "quit"): print("👋"); break | |
| if q.lower() == "clear": | |
| history.clear(); print("🗑️ 历史已清空\n"); continue | |
| context = retrieve(q, driver) | |
| if not context: | |
| print("AI: 未找到相关文档内容,请先上传 PDF。\n") | |
| continue | |
| answer = ask(q, context, history) | |
| print(f"\nAI: {answer}\n") | |
| history.append((q, answer)) | |
| finally: | |
| driver.close() | |
| def cmd_delete(args): | |
| driver = get_driver() | |
| # 支持按文件名或 doc_id 删除 | |
| keyword = args.keyword | |
| docs = run(driver, | |
| """MATCH (d:Document) | |
| WHERE d.id = $kw OR d.filename = $kw | |
| RETURN d.id AS id, d.filename AS name""", | |
| kw=keyword) | |
| if not docs: | |
| print(f"❌ 未找到文档: {keyword}") | |
| driver.close(); return | |
| for d in docs: | |
| run(driver, | |
| """MATCH (d:Document {id:$id}) | |
| OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk) | |
| OPTIONAL MATCH (c)-[:MENTIONS]->(e:Entity) | |
| DETACH DELETE d, c""", | |
| id=d["id"]) | |
| print(f"🗑️ 已删除: {d['name']} (id={d['id']})") | |
| driver.close() | |
| def cmd_update(args): | |
| """删除旧文档后重新上传""" | |
| path = Path(args.file) | |
| if not path.exists(): | |
| print(f"❌ 文件不存在: {path}"); sys.exit(1) | |
| driver = get_driver() | |
| # 按文件名找旧文档 | |
| docs = run(driver, | |
| "MATCH (d:Document) WHERE d.filename = $fn RETURN d.id AS id, d.filename AS name", | |
| fn=path.name) | |
| if docs: | |
| for d in docs: | |
| run(driver, | |
| """MATCH (d:Document {id:$id}) | |
| OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk) | |
| DETACH DELETE d, c""", | |
| id=d["id"]) | |
| print(f"🗑️ 已删除旧版本: {d['name']}") | |
| else: | |
| print(f"ℹ️ 未找到旧版本,直接上传") | |
| driver.close() | |
| # 重新上传 | |
| driver = get_driver() | |
| try: | |
| ingest(str(path), driver) | |
| finally: | |
| driver.close() | |
| def cmd_init(args): | |
| driver = get_driver() | |
| print("🔧 正在初始化 Schema 和向量索引...") | |
| init_schema(driver) | |
| driver.close() | |
| def main(): | |
| import argparse | |
| p = argparse.ArgumentParser(description="GraphRAG PDF 问答") | |
| sub = p.add_subparsers(dest="cmd", required=True) | |
| i = sub.add_parser("init", help="初始化数据库索引(首次使用先运行)") | |
| i.set_defaults(func=cmd_init) | |
| up = sub.add_parser("upload", help="上传 PDF") | |
| up.add_argument("file") | |
| up.set_defaults(func=cmd_upload) | |
| ls = sub.add_parser("list", help="列出文档") | |
| ls.set_defaults(func=cmd_list) | |
| q = sub.add_parser("query", help="交互式问答") | |
| q.set_defaults(func=cmd_query) | |
| d = sub.add_parser("delete", help="删除文档(按文件名或 doc_id)") | |
| d.add_argument("keyword", help="文件名或 doc_id") | |
| d.set_defaults(func=cmd_delete) | |
| u = sub.add_parser("update", help="更新文档(删除旧版本后重新上传)") | |
| u.add_argument("file", help="新的 PDF 文件路径") | |
| u.set_defaults(func=cmd_update) | |
| args = p.parse_args() | |
| args.func(args) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Launching a Neo4j instance locally on Docker environment.