mirror of
https://github.com/BoardWare-Genius/jarvis-models.git
synced 2025-12-13 16:53:24 +00:00
feat: chat2tts stream
This commit is contained in:
@ -10,64 +10,64 @@ import time
|
||||
|
||||
# chroma run --path chroma_db/ --port 8000 --host 0.0.0.0
|
||||
# loader = TextLoader("/Workspace/chroma_data/粤语语料.txt",encoding="utf-8")
|
||||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_boss.txt")
|
||||
loader = TextLoader("/Workspace/jarvis-models/sample/RAG_zh_kiki.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=0, length_function=len, is_separator_regex=True,separators=['\n', '\n\n'])
|
||||
docs = text_splitter.split_documents(documents)
|
||||
print("len(docs)", len(docs))
|
||||
ids = ["粤语语料"+str(i) for i in range(len(docs))]
|
||||
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:1"})
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
embedding_model = SentenceTransformerEmbeddings(model_name='/Workspace/Models/BAAI/bge-m3', model_kwargs={"device": "cuda:0"})
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
id = "boss"
|
||||
client.delete_collection(id)
|
||||
id = "kiki"
|
||||
# client.delete_collection(id)
|
||||
# 插入向量(如果ids已存在,则会更新向量)
|
||||
db = Chroma.from_documents(documents=docs, embedding=embedding_model, ids=ids, collection_name=id, client=client)
|
||||
|
||||
|
||||
|
||||
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:1")
|
||||
embedding_model = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="/Workspace/Models/BAAI/bge-m3", device = "cuda:0")
|
||||
|
||||
client = chromadb.HttpClient(host='192.168.0.200', port=7000)
|
||||
client = chromadb.HttpClient(host='10.6.44.141', port=7000)
|
||||
|
||||
collection = client.get_collection(id, embedding_function=embedding_model)
|
||||
|
||||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:1")
|
||||
reranker_model = CrossEncoder("/Workspace/Models/BAAI/bge-reranker-v2-m3", max_length=512, device = "cuda:0")
|
||||
|
||||
while True:
|
||||
usr_question = input("\n 请输入问题: ")
|
||||
# query it
|
||||
time1 = time.time()
|
||||
results = collection.query(
|
||||
query_texts=[usr_question],
|
||||
n_results=10,
|
||||
)
|
||||
time2 = time.time()
|
||||
print("query time: ", time2 - time1)
|
||||
# while True:
|
||||
# usr_question = input("\n 请输入问题: ")
|
||||
# # query it
|
||||
# time1 = time.time()
|
||||
# results = collection.query(
|
||||
# query_texts=[usr_question],
|
||||
# n_results=10,
|
||||
# )
|
||||
# time2 = time.time()
|
||||
# print("query time: ", time2 - time1)
|
||||
|
||||
# print("query: ",usr_question)
|
||||
# print("results: ",print(results["documents"][0]))
|
||||
# # print("query: ",usr_question)
|
||||
# # print("results: ",print(results["documents"][0]))
|
||||
|
||||
|
||||
pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||||
# print('\n',pairs)
|
||||
scores = reranker_model.predict(pairs)
|
||||
# pairs = [[usr_question, doc] for doc in results["documents"][0]]
|
||||
# # print('\n',pairs)
|
||||
# scores = reranker_model.predict(pairs)
|
||||
|
||||
#重新排列文件顺序:
|
||||
print("New Ordering:")
|
||||
i = 0
|
||||
final_result = ''
|
||||
for o in np.argsort(scores)[::-1]:
|
||||
if i == 3 or scores[o] < 0.5:
|
||||
break
|
||||
i += 1
|
||||
print(o+1)
|
||||
print("Scores:", scores[o])
|
||||
print(results["documents"][0][o],'\n')
|
||||
final_result += results["documents"][0][o] + '\n'
|
||||
# #重新排列文件顺序:
|
||||
# print("New Ordering:")
|
||||
# i = 0
|
||||
# final_result = ''
|
||||
# for o in np.argsort(scores)[::-1]:
|
||||
# if i == 3 or scores[o] < 0.5:
|
||||
# break
|
||||
# i += 1
|
||||
# print(o+1)
|
||||
# print("Scores:", scores[o])
|
||||
# print(results["documents"][0][o],'\n')
|
||||
# final_result += results["documents"][0][o] + '\n'
|
||||
|
||||
print("\n final_result: ", final_result)
|
||||
time3 = time.time()
|
||||
print("rerank time: ", time3 - time2)
|
||||
# print("\n final_result: ", final_result)
|
||||
# time3 = time.time()
|
||||
# print("rerank time: ", time3 - time2)
|
||||
Reference in New Issue
Block a user