feat: chat2tts stream

This commit is contained in:
verachen
2025-01-09 11:29:34 +08:00
parent ec3b4b143a
commit 37174413fe
12 changed files with 643 additions and 67 deletions

View File

@ -10,64 +10,64 @@ import time
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt")
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
docs = text_splitter.split_documents(documents)
print("len(docs)", len(docs))
ids = ["粤语语料"+str(i) for i in range(len(docs))]
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
id = "boss"
client.delete_collection(id)
id = "kiki"
# client.delete_collection(id)
# 插入向量(如果ids已存在则会更新向量)
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
collection = client.get_collection(id, embedding_function=embedding_model)
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
while True:
usr_question = input("\n 请输入问题: ")
# query it
time1 = time.time()
results = collection.query(
query_texts=[usr_question],
n_results=10,
)
time2 = time.time()
print("query time: ", time2 - time1)
# while True:
# usr_question = input("\n 请输入问题: ")
# # query it
# time1 = time.time()
# results = collection.query(
# query_texts=[usr_question],
# n_results=10,
# )
# time2 = time.time()
# print("query time: ", time2 - time1)
# print("query: ",usr_question)
# print("results: ",print(results["documents"][0]))
# # print("query: ",usr_question)
# # print("results: ",print(results["documents"][0]))
pairs = [[usr_question, doc] for doc in results["documents"][0]]
# print('\n',pairs)
scores = reranker_model.predict(pairs)
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
# # print('\n',pairs)
# scores = reranker_model.predict(pairs)
#重新排列文件顺序:
print("New Ordering:")
i = 0
final_result = ''
for o in np.argsort(scores)[::-1]:
if i == 3 or scores[o] < 0.5:
break
i += 1
print(o+1)
print("Scores:", scores[o])
print(results["documents"][0][o],'\n')
final_result += results["documents"][0][o] + '\n'
# #重新排列文件顺序:
# print("New Ordering:")
# i = 0
# final_result = ''
# for o in np.argsort(scores)[::-1]:
# if i == 3 or scores[o] < 0.5:
# break
# i += 1
# print(o+1)
# print("Scores:", scores[o])
# print(results["documents"][0][o],'\n')
# final_result += results["documents"][0][o] + '\n'
print("\n final_result: ", final_result)
time3 = time.time()
print("rerank time: ", time3 - time2)
# print("\n final_result: ", final_result)
# time3 = time.time()
# print("rerank time: ", time3 - time2)