Skip to content

Instantly share code, notes, and snippets.

@VictorZhang2014
Last active February 22, 2026 14:59
Show Gist options
  • Select an option

  • Save VictorZhang2014/219e97082bd0308ed453ec6b007bb501 to your computer and use it in GitHub Desktop.

Select an option

Save VictorZhang2014/219e97082bd0308ed453ec6b007bb501 to your computer and use it in GitHub Desktop.
GraphRAG_LangChain.py
import hashlib
import json
import logging
import os
import re
import sys
import tempfile
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)
# ─────────────────────────── 配置 ───────────────────────────
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "") # 必填
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "sk-proj-Axxxxxxxxxxx") # 必填
EMBED_MODEL = os.getenv("EMBED_MODEL", "text-embedding-3-small")
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "800"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "150"))
TOP_K = int(os.getenv("TOP_K", "5"))
# ─────────────────────────── Neo4j ───────────────────────────
from neo4j import GraphDatabase
def get_driver():
return GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
def init_schema(driver):
"""创建约束和向量索引"""
stmts = [
"CREATE CONSTRAINT doc_id IF NOT EXISTS FOR (d:Document) REQUIRE d.id IS UNIQUE",
"CREATE CONSTRAINT chunk_id IF NOT EXISTS FOR (c:Chunk) REQUIRE c.id IS UNIQUE",
"CREATE CONSTRAINT entity_uniq IF NOT EXISTS FOR (e:Entity) REQUIRE (e.name, e.type) IS UNIQUE",
# Neo4j AuraDB 5.x 向量索引
"""CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
FOR (c:Chunk) ON c.embedding
OPTIONS {indexConfig: {`vector.dimensions`:1536, `vector.similarity_function`:'cosine'}}""",
]
with driver.session() as s:
for q in stmts:
try:
s.run(q)
except Exception as e:
log.debug("Schema: %s", e)
# 等待索引上线
import time
for _ in range(20):
with driver.session() as s:
result = s.run(
"SHOW VECTOR INDEXES WHERE name = 'chunk_embedding_index'"
).data()
if result and result[0].get("state") == "ONLINE":
log.info("✅ 向量索引 chunk_embedding_index 已就绪")
return
log.info("⏳ 等待向量索引上线...")
time.sleep(2)
log.warning("⚠️ 向量索引可能尚未就绪,继续执行")
def run(driver, query, **params):
with driver.session() as s:
return list(s.run(query, **params))
# ─────────────────────────── PDF 解析 ───────────────────────────
def parse_pdf(path: str) -> str:
from markitdown import MarkItDown
result = MarkItDown().convert(path)
text = result.text_content
if not text or not text.strip():
raise ValueError("PDF 解析结果为空")
log.info("📄 解析完成,%d 字符", len(text))
return text
# ─────────────────────────── 分块 ───────────────────────────
def split_text(text: str) -> list[str]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", "。", ".", " ", ""],
)
chunks = [c.strip() for c in splitter.split_text(text) if len(c.strip()) >= 30]
log.info("🧩 分块完成:%d 块", len(chunks))
return chunks
# ─────────────────────────── 嵌入 ───────────────────────────
def embed_texts(texts: list[str]) -> list[list[float]]:
from langchain_openai import OpenAIEmbeddings
embedder = OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY)
return embedder.embed_documents(texts)
def embed_query(text: str) -> list[float]:
from langchain_openai import OpenAIEmbeddings
embedder = OpenAIEmbeddings(model=EMBED_MODEL, openai_api_key=OPENAI_API_KEY)
return embedder.embed_query(text)
# ─────────────────────────── 实体抽取 ───────────────────────────
EXTRACT_PROMPT = """从以下文本中提取实体和关系,只返回 JSON,不要其他内容。
格式:
{"entities":[{"name":"...","type":"PERSON/ORG/CONCEPT/LOCATION/OTHER","desc":"..."}],
"relations":[{"src":"...","src_type":"...","rel":"...","tgt":"...","tgt_type":"..."}]}
文本:
{text}"""
def extract_kg(text: str) -> dict:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, openai_api_key=OPENAI_API_KEY)
try:
raw = llm.invoke(EXTRACT_PROMPT.format(text=text[:1500])).content
raw = re.sub(r"```json|```", "", raw).strip()
return json.loads(raw)
except Exception as e:
log.debug("实体抽取失败: %s", e)
return {"entities": [], "relations": []}
# ─────────────────────────── 摄入管道 ───────────────────────────
def ingest(file_path: str, driver):
path = Path(file_path)
doc_id = hashlib.md5(path.read_bytes()).hexdigest()
filename = path.name
# 1. 解析 PDF
text = parse_pdf(str(path))
# 2. 写入 Document 节点
run(driver,
"MERGE (d:Document {id:$id}) SET d.filename=$fn, d.createdAt=datetime()",
id=doc_id, fn=filename)
# 3. 分块 + 嵌入
chunks = split_text(text)
log.info("🔢 生成嵌入向量...")
embeddings = embed_texts(chunks)
# 4. 写入 Chunk 节点 + 实体抽取
for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
cid = f"{doc_id}_{i}"
run(driver,
"""MERGE (c:Chunk {id:$cid})
SET c.text=$text, c.embedding=$emb, c.idx=$i
WITH c
MATCH (d:Document {id:$did})
MERGE (d)-[:HAS_CHUNK]->(c)""",
cid=cid, text=chunk, emb=emb, i=i, did=doc_id)
# 仅对前 40 块做实体抽取(省钱)
if i < 40:
kg = extract_kg(chunk)
for e in kg.get("entities", []):
if e.get("name"):
run(driver,
"""MERGE (e:Entity {name:$name, type:$type})
SET e.desc=$desc
WITH e MATCH (c:Chunk {id:$cid})
MERGE (c)-[:MENTIONS]->(e)""",
name=e["name"], type=e.get("type","OTHER"),
desc=e.get("desc",""), cid=cid)
for r in kg.get("relations", []):
if r.get("src") and r.get("tgt"):
run(driver,
"""MERGE (s:Entity {name:$src, type:$st})
MERGE (t:Entity {name:$tgt, type:$tt})
MERGE (s)-[:RELATES_TO {rel:$rel}]->(t)""",
src=r["src"], st=r.get("src_type","OTHER"),
tgt=r["tgt"], tt=r.get("tgt_type","OTHER"),
rel=r.get("rel","related"))
if (i + 1) % 10 == 0:
log.info(" %d/%d 块处理完毕", i+1, len(chunks))
log.info("✅ 摄入完成: %s doc_id=%s", filename, doc_id)
return doc_id
# ─────────────────────────── 检索 ───────────────────────────
def retrieve(question: str, driver) -> str:
qvec = embed_query(question)
# 向量检索
chunks = run(driver,
"""CALL db.index.vector.queryNodes('chunk_embedding_index', $k, $vec)
YIELD node, score
MATCH (d:Document)-[:HAS_CHUNK]->(node)
RETURN node.text AS text, d.filename AS filename, score
ORDER BY score DESC""",
k=TOP_K, vec=qvec)
if not chunks:
return ""
chunk_context = "\n\n".join(
f"[{r['filename']} | 相关度:{r['score']:.3f}]\n{r['text']}"
for r in chunks
)
# 图谱实体关系检索
chunk_texts = [r["text"][:100] for r in chunks]
graph_facts = run(driver,
"""UNWIND $texts AS t
MATCH (c:Chunk) WHERE c.text STARTS WITH t
MATCH (c)-[:MENTIONS]->(e:Entity)
OPTIONAL MATCH (e)-[r:RELATES_TO]->(e2:Entity)
RETURN DISTINCT e.name AS src, e.type AS src_type,
r.rel AS rel, e2.name AS tgt, e2.type AS tgt_type
LIMIT 30""",
texts=chunk_texts)
graph_context = ""
if graph_facts:
lines = []
seen = set()
for r in graph_facts:
if r["rel"] and r["tgt"]:
key = (r["src"], r["rel"], r["tgt"])
if key not in seen:
seen.add(key)
lines.append(f" [{r['src_type']}]{r['src']} --{r['rel']}--> [{r['tgt_type']}]{r['tgt']}")
elif r["src"] and (r["src"],) not in seen:
seen.add((r["src"],))
lines.append(f" [{r['src_type']}]{r['src']}")
if lines:
graph_context = "\n知识图谱关联事实:\n" + "\n".join(lines)
return chunk_context + graph_context
# ─────────────────────────── LLM 回答 ───────────────────────────
SYSTEM_PROMPT = """你是专业文档问答助手。严格基于提供的上下文回答问题,不要编造内容。
如上下文不足,请如实说明。回答要清晰、准确、有条理。"""
def ask(question: str, context: str, history: list) -> str:
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
llm = ChatOpenAI(model=LLM_MODEL, temperature=0.1, openai_api_key=OPENAI_API_KEY)
messages = [SystemMessage(content=SYSTEM_PROMPT)]
for h, a in history[-4:]: # 最近4轮历史
messages += [HumanMessage(content=h), AIMessage(content=a)]
user_content = f"上下文:\n{context}\n\n问题: {question}" if context else question
messages.append(HumanMessage(content=user_content))
return llm.invoke(messages).content
# ─────────────────────────── CLI ───────────────────────────
def cmd_upload(args):
path = Path(args.file)
if not path.exists():
print(f"❌ 文件不存在: {path}"); sys.exit(1)
driver = get_driver()
init_schema(driver)
try:
ingest(str(path), driver)
finally:
driver.close()
def cmd_list(args):
driver = get_driver()
try:
docs = run(driver,
"""MATCH (d:Document)
OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c)
RETURN d.id AS id, d.filename AS name,
d.createdAt AS created, count(c) AS chunks
ORDER BY d.createdAt DESC""")
if not docs:
print("📭 暂无文档"); return
print(f"\n{'ID':<34} {'文件名':<28} {'块数':>5}")
print("-"*70)
for d in docs:
print(f"{d['id']:<34} {d['name'][:26]:<28} {d['chunks']:>5}")
finally:
driver.close()
def cmd_query(args):
driver = get_driver()
print("\n🤖 GraphRAG PDF 问答(输入 exit 退出,clear 清空历史)\n")
history = []
try:
while True:
try:
q = input("你: ").strip()
except (EOFError, KeyboardInterrupt):
print("\n👋"); break
if not q: continue
if q.lower() in ("exit", "quit"): print("👋"); break
if q.lower() == "clear":
history.clear(); print("🗑️ 历史已清空\n"); continue
context = retrieve(q, driver)
if not context:
print("AI: 未找到相关文档内容,请先上传 PDF。\n")
continue
answer = ask(q, context, history)
print(f"\nAI: {answer}\n")
history.append((q, answer))
finally:
driver.close()
def cmd_delete(args):
driver = get_driver()
# 支持按文件名或 doc_id 删除
keyword = args.keyword
docs = run(driver,
"""MATCH (d:Document)
WHERE d.id = $kw OR d.filename = $kw
RETURN d.id AS id, d.filename AS name""",
kw=keyword)
if not docs:
print(f"❌ 未找到文档: {keyword}")
driver.close(); return
for d in docs:
run(driver,
"""MATCH (d:Document {id:$id})
OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
OPTIONAL MATCH (c)-[:MENTIONS]->(e:Entity)
DETACH DELETE d, c""",
id=d["id"])
print(f"🗑️ 已删除: {d['name']} (id={d['id']})")
driver.close()
def cmd_update(args):
"""删除旧文档后重新上传"""
path = Path(args.file)
if not path.exists():
print(f"❌ 文件不存在: {path}"); sys.exit(1)
driver = get_driver()
# 按文件名找旧文档
docs = run(driver,
"MATCH (d:Document) WHERE d.filename = $fn RETURN d.id AS id, d.filename AS name",
fn=path.name)
if docs:
for d in docs:
run(driver,
"""MATCH (d:Document {id:$id})
OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
DETACH DELETE d, c""",
id=d["id"])
print(f"🗑️ 已删除旧版本: {d['name']}")
else:
print(f"ℹ️ 未找到旧版本,直接上传")
driver.close()
# 重新上传
driver = get_driver()
try:
ingest(str(path), driver)
finally:
driver.close()
def cmd_init(args):
driver = get_driver()
print("🔧 正在初始化 Schema 和向量索引...")
init_schema(driver)
driver.close()
def main():
import argparse
p = argparse.ArgumentParser(description="GraphRAG PDF 问答")
sub = p.add_subparsers(dest="cmd", required=True)
i = sub.add_parser("init", help="初始化数据库索引(首次使用先运行)")
i.set_defaults(func=cmd_init)
up = sub.add_parser("upload", help="上传 PDF")
up.add_argument("file")
up.set_defaults(func=cmd_upload)
ls = sub.add_parser("list", help="列出文档")
ls.set_defaults(func=cmd_list)
q = sub.add_parser("query", help="交互式问答")
q.set_defaults(func=cmd_query)
d = sub.add_parser("delete", help="删除文档(按文件名或 doc_id)")
d.add_argument("keyword", help="文件名或 doc_id")
d.set_defaults(func=cmd_delete)
u = sub.add_parser("update", help="更新文档(删除旧版本后重新上传)")
u.add_argument("file", help="新的 PDF 文件路径")
u.set_defaults(func=cmd_update)
args = p.parse_args()
args.func(args)
if __name__ == "__main__":
main()
@VictorZhang2014
Copy link
Author

Launching a Neo4j instance locally on Docker environment.

version: "3.9"

services:
  # Neo4j 图数据库
  neo4j:
    image: neo4j:5.20.0
    container_name: graphrag-neo4j
    ports:
      - "7474:7474"   # Neo4j Browser UI
      - "7687:7687"   # Bolt 协议
    environment:
      NEO4J_AUTH: neo4j/${NEO4J_PASSWORD:-graphrag2024}
      NEO4J_PLUGINS: '["apoc"]'
      NEO4J_dbms_security_procedures_unrestricted: "apoc.*"
      NEO4J_server_memory_heap_initial__size: 512m
      NEO4J_server_memory_heap_max__size: 2G
    volumes:
      - neo4j_data:/data
      - neo4j_logs:/logs
    healthcheck:
      test: ["CMD", "cypher-shell", "-u", "neo4j", "-p", "${NEO4J_PASSWORD:-graphrag2024}", "RETURN 1"]
      interval: 10s
      timeout: 10s
      retries: 10

  # GraphRAG API 服务
  api:
    build: .
    container_name: graphrag-api
    ports:
      - "8000:8000"
    environment:
      OPENAI_API_KEY: ${OPENAI_API_KEY}
      OPENAI_MODEL: ${OPENAI_MODEL:-gpt-4o}
      OPENAI_EMBEDDING_MODEL: ${OPENAI_EMBEDDING_MODEL:-text-embedding-3-small}
      NEO4J_URI: bolt://neo4j:7687
      NEO4J_USERNAME: neo4j
      NEO4J_PASSWORD: ${NEO4J_PASSWORD:-graphrag2024}
      NEO4J_DATABASE: neo4j
      MAX_UPLOAD_SIZE_MB: ${MAX_UPLOAD_SIZE_MB:-50}
      CHUNK_SIZE: ${CHUNK_SIZE:-800}
      CHUNK_OVERLAP: ${CHUNK_OVERLAP:-150}
    depends_on:
      neo4j:
        condition: service_healthy
    restart: unless-stopped

volumes:
  neo4j_data:
  neo4j_logs:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment