基于chatglm-6b的本地知识问答

由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:1 年前

langchain + chatglm-6b

导包

1
2
3
4
5
6
7
8
9
10
import sys
import os
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from llama_index import StorageContext, load_index_from_storage, SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTVectorStoreIndex, PromptHelper, LLMPredictor, ServiceContext
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from transformers import AutoTokenizer, AutoModel
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

加载chatglm-6b

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

class ChatGLM(LLM):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response, history = model.chat(tokenizer, prompt, history=[])
# only return newly generated tokens
return response

@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"name_of_model": "chatglm-6b"}

@property
def _llm_type(self) -> str:
return "ChatGLM"

加载本地知识库,支持txt、pdf等格式,这里以txt为例,即把txt的文档放到docs目录下

1
2
3
4
5
6
7
8
directory_path = "./docs"

max_input_size = 4096
num_outputs = 2000
max_chunk_overlap = 20
chunk_size_limit = 600
prompt_helper = PromptHelper(max_input_size, num_outputs, max_chunk_overlap, chunk_size_limit=chunk_size_limit)
documents = SimpleDirectoryReader(directory_path).load_data()

加载embedding模型, 这里加载默认的huggingface embedding(sentence-transformers/all-mpnet-base-v2)

1
2
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
llm_predictor = LLMPredictor(llm=ChatGLM())

本地知识向量化存储, 默认存储在 ./storage

1
2
3
service_context = ServiceContext.from_defaults(embed_model=embed_model, llm_predictor=llm_predictor)
index = GPTVectorStoreIndex.from_documents(documents,service_context=service_context)
index.storage_context.persist()

加载本地知识向量并提问

1
2
3
4
5
storage_context = StorageContext.from_defaults(persist_dir='./storage')
index = load_index_from_storage(storage_context, service_context=service_context)
query_engine = index.as_query_engine()
response = query_engine.query("<What are the diseases associated with macrophage>?")
print(response.response)

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!