RAG小结

由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:1 个月前

优化策略: rerank, 混合搜索,对文档进行summary生成,summary添加元数据

数据加载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import html2text
from bs4 import BeautifulSoup
from tqdm.notebook import tqdm
from llama_index.core import (
Settings, SimpleDirectoryReader, Document
)
from llama_index.core.readers.base import BaseReader

import nest_asyncio
nest_asyncio.apply()

# Setting the global LLM model to None
Settings.llm = None

# 创建 HTML 到文本转换器
converter = html2text.HTML2Text()

def clean_html(html):
soup = BeautifulSoup(html, 'html.parser')
col_content_div = soup.find(id="col-content")

for a_tag in col_content_div.find_all('a'):
del a_tag['href']

first_nav_tag = col_content_div.find('nav')
nav_content = first_nav_tag.get_text(strip=True, separator=' ') if first_nav_tag else ''

text = converter.handle(str(col_content_div))
return nav_content, text

# 自定义的 HTML 文件读取器
class HtmlFileReader(BaseReader):
def load_data(self, file, extra_info=None):
with open(file, "r", encoding="utf-8") as f:
html = f.read()

try:
nav_content, clean_data = clean_html(html)
except Exception as e:
print(f"Error processing {file}: {e}")
clean_data = converter.handle(html)
nav_content = ''

# 添加额外的元数据信息
metadata = {
"file_name": file.name,
"file_path": str(file.resolve()),
"nav_content": nav_content,
}
# 返回 Document 对象列表
return [Document(text=clean_data, metadata=metadata)]

# 扩展 SimpleDirectoryReader 以显示进度条
class ProgressSimpleDirectoryReader(SimpleDirectoryReader):
def load_data(self):
documents = []
file_list = list(self.input_dir.glob("*"))
for file in tqdm(file_list, desc="Loading files"):
file_extension = file.suffix.lower()
reader = self.file_extractor.get(file_extension)
if reader:
documents.extend(reader.load_data(file))
return documents

# 示例用法
# 指定输入目录和文件读取器
# reader = ProgressSimpleDirectoryReader(
# input_dir="/home/wangjiabin/data/html/",
# file_extractor={".html": HtmlFileReader()}
# )

# # 加载数据
# documents = reader.load_data()

构建向量索引并导入Milvus数据库

对文档生成摘要并添加元数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from copy import deepcopy

documents_cp = deepcopy(documents)
# num_predict限制输出的token
llm_qwen = Ollama(base_url='http://localhost:11434', model="qwen2:latest", temperature=0.1, request_timeout=300.0, additional_kwargs={"num_predict": 500})

for i, doc in tqdm(enumerate(documents_cp)):
text = doc.text
# Summarize the following text in a concise manner. The summary should be less than 500 words
res = llm_qwen.complete(f"Summarize the following text in a concise manner, ensuring that the summary is less than 500 words.
Please make sure to clearly distinguish between undergraduate and graduate information in the summary.:\n\n{text}")

documents_cp[i].text = res.text

import json
doc_summary_dict = {d.metadata['file_name']:d.text for d in documents_cp}
with open('html_summary.json', 'w') as json_file:
json.dump(doc_summary_dict, json_file)

with open('summary/html_summary.json') as json_file:
summary = json.load(json_file)

for i, doc in enumerate(documents_cp):
nav_content = doc.metadata['nav_content']
file_name = doc.metadata['file_name']
text = summary.get(file_name)
newText = f'''
Title: {nav_content}\n\n
Summary: {text}
'''
documents_cp[i].text = newText
生成索引
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import nest_asyncio
from typing import List
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.vector_stores.milvus.utils import BaseSparseEmbeddingFunction
from llama_index.vector_stores.milvus import MilvusVectorStore
from FlagEmbedding import BGEM3FlagModel

# Apply nest_asyncio to enable nested event loops
nest_asyncio.apply()

# Set environment variables for OpenAI API
os.environ['OPENAI_API_KEY'] = ''
os.environ['OPENAI_API_BASE'] = ''

# Initialize HuggingFace embedding model
embed_model = HuggingFaceEmbedding(
model_name="/home/wangjiabin/model/embed/bge-m3", max_length=8192
)

# Initialize SentenceTransformer reranker
rerank = SentenceTransformerRerank(
top_n=5,
model="/home/wangjiabin/model/rerank/bge-reranker-v2-m3/"
)

# Define a custom sparse embedding function using BGEM3FlagModel
class ExampleEmbeddingFunction(BaseSparseEmbeddingFunction):
def __init__(self):
self.model = BGEM3FlagModel("/home/wangjiabin/model/embed/bge-m3", use_fp16=False)

def encode_queries(self, queries: List[str]):
outputs = self.model.encode(
queries,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False,
)["lexical_weights"]
return [self._to_standard_dict(output) for output in outputs]

def encode_documents(self, documents: List[str]):
outputs = self.model.encode(
documents,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False,
)["lexical_weights"]
return [self._to_standard_dict(output) for output in outputs]

def _to_standard_dict(self, raw_output):
result = {}
for k in raw_output:
result[int(k)] = raw_output[k]
return result

# Initialize Milvus vector store with sparse embedding function and hybrid ranker
vector_store = MilvusVectorStore(
dim=1024,
uri="http://localhost:19530",
collection_name='summary_html_index',
overwrite=True,
enable_sparse=True,
sparse_embedding_function=ExampleEmbeddingFunction(),
hybrid_ranker="RRFRanker",
hybrid_ranker_params={"k": 60}
)

# Create a storage context using the vector store
storage_context = StorageContext.from_defaults(vector_store=vector_store)

# Initialize the vector store index with documents, storage context, and embedding model
# Note: `documents_cp` should be defined elsewhere in your code
index = VectorStoreIndex(
documents=documents_cp, # Ensure this variable is defined
storage_context=storage_context,
embed_model=embed_model,
use_async=True,
show_progress=True
)

加载索引并提问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from llama_index.vector_stores.milvus import MilvusVectorStore
from pymilvus import MilvusClient

vector_store = MilvusVectorStore(dim=1024, uri="http://localhost:19530",collection_name='summary_html_index', overwrite=False,enable_sparse=True,
hybrid_ranker="RRFRanker",
hybrid_ranker_params={"k": 60})
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embed_model)


PROMPT = '''
You are an expert Q&A system that is trusted around the world for your factual accuracy.
Always answer the query using the provided context information, and not prior knowledge. Ensure your answers are fact-based and accurately reflect the context provided.
Some rules to follow:
1. Never directly reference the given context in your answer.
2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines.
3. Focus on succinct answers that provide only the facts necessary, do not be verbose.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:
'''

query_engine_rerank = index.as_query_engine(
similarity_top_k=5,
# text_qa_template=PromptTemplate(PROMPT),
node_postprocessors=[rerank],
embed_model=embed_model,
# llm=llm_qwen,
vector_store_query_mode="hybrid"
)
response = query_engine_rerank.query('Undergraduate Robert F. Wagner Graduate School of Public Service Programs')
response.source_nodes[0]

参考:https://mp.weixin.qq.com/s/ciw_vSpwe7ryx3ktv1Yhdg