Merge branch 'main' into veraGDI

This commit is contained in:
headbigsile
2025-03-19 16:07:30 +08:00
committed by GitHub
6 changed files with 395 additions and 25 deletions

View File

@ -16,6 +16,7 @@ import re
from injector import singleton,inject from injector import singleton,inject
from datetime import datetime from datetime import datetime
from .websearch import WebSearch
# 定义保存文件的路径 # 定义保存文件的路径
file_path = "chat_inputs_log.json" file_path = "chat_inputs_log.json"
@ -23,8 +24,9 @@ file_path = "chat_inputs_log.json"
class Chat(Blackbox): class Chat(Blackbox):
@inject @inject
def __init__(self, chroma_query: ChromaQuery): def __init__(self, chroma_query: ChromaQuery, websearch: WebSearch):
self.chroma_query = chroma_query self.chroma_query = chroma_query
self.websearch = websearch
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -60,6 +62,7 @@ class Chat(Blackbox):
system_prompt = settings.get('system_prompt') system_prompt = settings.get('system_prompt')
user_prompt_template = settings.get('user_prompt_template') user_prompt_template = settings.get('user_prompt_template')
user_stream = settings.get('stream') user_stream = settings.get('stream')
user_websearch = settings.get('websearch')
llm_model = "vllm" llm_model = "vllm"
@ -135,6 +138,8 @@ class Chat(Blackbox):
if user_stream in [None, ""]: if user_stream in [None, ""]:
user_stream = False user_stream = False
if user_websearch in [None, ""]:
user_websearch = False
# 文心格式和openai的不一样需要单独处理 # 文心格式和openai的不一样需要单独处理
if re.search(r"ernie", user_model_name): if re.search(r"ernie", user_model_name):
@ -201,6 +206,32 @@ class Chat(Blackbox):
{"role": "system", "content": system_prompt} {"role": "system", "content": system_prompt}
] ]
if user_websearch:
search_answer_zh_template = \
'''# 以下内容是基于用户发送的消息的搜索结果:
{search_results}
在我给你的搜索结果中,每个结果都是["title"]...["position": X]格式的X代表每篇文章的数字索引。请在适当的情况下在句子末尾引用上下文。请按照引用编号[citation:X]的格式在答案中对应部分引用上下文。如果一句话源自多个上下文,请列出所有相关的引用编号,例如[citation:3][citation:5],切记不要将引用集中在最后返回引用编号,而是在答案对应部分列出。
在回答时,请注意以下几点:
- 今天是{cur_date}
- 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
- 对于列举类的问题如列举所有航班信息尽量将答案控制在10个要点以内并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项如非必要不要主动告诉用户搜索结果未提供的内容。
- 对于创作类的问题(如写论文),请务必在正文的段落中引用对应的参考编号,例如[citation:3][citation:5],不能只在文章末尾引用。你需要解读并概括用户的题目要求,选择合适的格式,充分利用搜索结果并抽取重要信息,生成符合用户要求、极具思想深度、富有创造力与专业性的答案。你的创作篇幅需要尽可能延长,对于每一个要点的论述要推测用户的意图,给出尽可能多角度的回答要点,且务必信息量大、论述详尽。
- 如果回答很长请尽量结构化、分段落总结。如果需要分点作答尽量控制在5个点以内并合并相关的内容。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关网页来回答,不能重复引用一个网页。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
# 用户消息为:
{question}'''
websearch_response = self.websearch(prompt, settings)
print("2.Websearch_response: \n", websearch_response)
today = datetime.today().strftime("%Y-%m-%d")
user_question = search_answer_zh_template.format(question=user_question, cur_date=today, search_results=websearch_response["organic"])
if llm_model != "vllm": if llm_model != "vllm":
chat_inputs={ chat_inputs={
"model": user_model_name, "model": user_model_name,
@ -222,7 +253,12 @@ class Chat(Blackbox):
else: else:
chat_inputs={ chat_inputs={
"model": user_model_name, "model": user_model_name,
"prompt": user_question, "messages": prompt_template + user_context + [
{
"role": "user",
"content": user_question
}
],
"temperature": float(user_temperature), "temperature": float(user_temperature),
"top_p": float(user_top_p), "top_p": float(user_top_p),
"n": float(user_n), "n": float(user_n),

View File

@ -0,0 +1,93 @@
import io
import time
import requests
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from injector import inject
from injector import singleton
from ..log.logging_time import logging_time
from ..configuration import CosyVoiceConf
from .blackbox import Blackbox
import soundfile
import pyloudnorm as pyln
import sys
sys.path.append('/home/gpu/Workspace/CosyVoice')
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav
import torchaudio
import os
import logging
logger = logging.getLogger(__name__)
@singleton
class CosyVoiceTTS(Blackbox):
mode: str
url: str
speed: int
device: str
language: str
speaker: str
@logging_time(logger=logger)
def model_init(self, cosyvoice_config: CosyVoiceConf) -> None:
self.speed = cosyvoice_config.speed
self.device = cosyvoice_config.device
self.language = cosyvoice_config.language
self.speaker = cosyvoice_config.speaker
self.device = cosyvoice_config.device
self.url = ''
self.mode = cosyvoice_config.mode
self.cosyvoicetts = None
self.speaker_ids = None
os.environ['CUDA_VISIBLE_DEVICES'] = str(cosyvoice_config.device)
if self.mode == 'local':
self.cosyvoicetts = CosyVoice('/home/gpu/Workspace/Models/CosyVoice/pretrained_models/CosyVoice-300M')
else:
self.url = cosyvoice_config.url
logging.info('#### Initializing CosyVoiceTTS Service in cuda:' + str(cosyvoice_config.device) + ' mode...')
@inject
def __init__(self, cosyvoice_config: CosyVoiceConf) -> None:
self.model_init(cosyvoice_config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
text = args[0]
return isinstance(text, str)
@logging_time(logger=logger)
def processing(self, *args, **kwargs) -> io.BytesIO | bytes:
text = args[0]
current_time = time.time()
if self.mode == 'local':
audio = self.cosyvoicetts.inference_sft(text, self.language)
f = io.BytesIO()
soundfile.write(f, audio['tts_speech'].cpu().numpy().squeeze(0), 22050, format='wav')
f.seek(0)
print("#### CosyVoiceTTS Service consume - local : ", (time.time() - current_time))
return f.read()
else:
message = {
"text": text
}
response = requests.post(self.url, json=message)
print("#### CosyVoiceTTS Service consume - docker : ", (time.time()-current_time))
return response.content
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
text = data.get("text")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
return Response(content=self.processing(text), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})

108
src/blackbox/melotts.py Normal file
View File

@ -0,0 +1,108 @@
import io
import time
import requests
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from injector import inject
from injector import singleton
from ..log.logging_time import logging_time
from ..configuration import MeloConf
from .blackbox import Blackbox
import soundfile
import pyloudnorm as pyln
from melo.api import TTS
import logging
logger = logging.getLogger(__name__)
@singleton
class MeloTTS(Blackbox):
mode: str
url: str
speed: int
device: str
language: str
speaker: str
@logging_time(logger=logger)
def model_init(self, melo_config: MeloConf) -> None:
self.speed = melo_config.speed
self.device = melo_config.device
self.language = melo_config.language
self.speaker = melo_config.speaker
self.device = melo_config.device
self.url = ''
self.mode = melo_config.mode
self.melotts = None
self.speaker_ids = None
if self.mode == 'local':
self.melotts = TTS(language=self.language, device=self.device)
self.speaker_ids = self.melotts.hps.data.spk2id
else:
self.url = melo_config.url
logging.info('#### Initializing MeloTTS Service in ' + self.device + ' mode...')
@inject
def __init__(self, melo_config: MeloConf) -> None:
self.model_init(melo_config)
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
text = args[0]
return isinstance(text, str)
@logging_time(logger=logger)
def processing(self, *args, **kwargs) -> io.BytesIO | bytes:
text = args[0]
current_time = time.time()
if self.mode == 'local':
audio = self.melotts.tts_to_file(text, self.speaker_ids[self.speaker], speed=self.speed)
f = io.BytesIO()
soundfile.write(f, audio, 44100, format='wav')
f.seek(0)
# print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
# return f.read()
# Read the audio data from the buffer
data, rate = soundfile.read(f, dtype='float32')
# Peak normalization
peak_normalized_audio = pyln.normalize.peak(data, -1.0)
# Integrated loudness normalization
meter = pyln.Meter(rate)
loudness = meter.integrated_loudness(peak_normalized_audio)
loudness_normalized_audio = pyln.normalize.loudness(peak_normalized_audio, loudness, -12.0)
# Write the loudness normalized audio to an in-memory buffer
normalized_audio_buffer = io.BytesIO()
soundfile.write(normalized_audio_buffer, loudness_normalized_audio, rate, format='wav')
normalized_audio_buffer.seek(0)
print("#### MeloTTS Service consume - local : ", (time.time() - current_time))
return normalized_audio_buffer.read()
else:
message = {
"text": text
}
response = requests.post(self.url, json=message)
print("#### MeloTTS Service consume - docker : ", (time.time()-current_time))
return response.content
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
text = data.get("text")
if text is None:
return JSONResponse(content={"error": "text is required"}, status_code=status.HTTP_400_BAD_REQUEST)
return Response(content=self.processing(text), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"})

View File

@ -1,5 +1,6 @@
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
from sse_starlette.sse import EventSourceResponse
from injector import singleton,inject from injector import singleton,inject
from typing import Optional, List from typing import Optional, List
@ -12,14 +13,17 @@ import requests
import base64 import base64
import copy import copy
import ast import ast
import json import json
import random
from time import time
import io import io
from PIL import Image from PIL import Image
from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.serve.openai.api_client import APIClient
import io
from PIL import Image
from lmdeploy.serve.openai.api_client import APIClient
def is_base64(value) -> bool: def is_base64(value) -> bool:
try: try:
@ -51,8 +55,8 @@ class VLMS(Blackbox):
- ignore_eos (bool): indicator for ignoring eos - ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens - skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.""" in the decoding. Default to be True."""
self.url = vlm_config.url self.model_dict = vlm_config.urls
self.model_url = None
self.temperature: float = 0.7 self.temperature: float = 0.7
self.top_p:float = 1 self.top_p:float = 1
self.max_tokens: (int |None) = 512 self.max_tokens: (int |None) = 512
@ -82,7 +86,7 @@ class VLMS(Blackbox):
data = args[0] data = args[0]
return isinstance(data, list) return isinstance(data, list)
def processing(self, prompt:str, images:str | bytes, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str: def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str:
""" """
Args: Args:
prompt: a string query to the model. prompt: a string query to the model.
@ -105,6 +109,9 @@ class VLMS(Blackbox):
else: else:
settings = {} settings = {}
if not prompt:
prompt = '你是一个辅助机器人请就此图做一个简短的概括性描述包括图中的主体物品及状态不超过50字。' if images else '你好'
# Transform the images into base64 format where openai format need. # Transform the images into base64 format where openai format need.
if images: if images:
if is_base64(images): # image as base64 str if is_base64(images): # image as base64 str
@ -148,7 +155,11 @@ class VLMS(Blackbox):
# 'content': '图片中主要展示了一只老虎,它正在绿色的草地上休息。草地上有很多可以让人坐下的地方,而且看起来相当茂盛。背景比较模糊,可能是因为老虎的影响,让整个图片的其他部分都变得不太清晰了。' # 'content': '图片中主要展示了一只老虎,它正在绿色的草地上休息。草地上有很多可以让人坐下的地方,而且看起来相当茂盛。背景比较模糊,可能是因为老虎的影响,让整个图片的其他部分都变得不太清晰了。'
# } # }
# ] # ]
api_client = APIClient(self.url)
user_context = self.keep_last_k_images(user_context,k = 1)
if self.model_url is None: self.model_url = self._get_model_url(model_name)
api_client = APIClient(self.model_url)
# api_client = APIClient("http://10.6.80.91:23333") # api_client = APIClient("http://10.6.80.91:23333")
model_name = api_client.available_models[0] model_name = api_client.available_models[0]
# Reformat input into openai format to request. # Reformat input into openai format to request.
@ -187,20 +198,39 @@ class VLMS(Blackbox):
responses = '' responses = ''
total_token_usage = 0 # which can be used to count the cost of a query total_token_usage = 0 # which can be used to count the cost of a query
for i,item in enumerate(api_client.chat_completions_v1(model=model_name, for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
messages=messages,#stream = True, messages=messages,stream = True,
**settings, **settings,
# session_id=, # session_id=,
)): )):
# Stream output # Stream output
# print(item["choices"][0]["delta"]['content'],end='') print(item["choices"][0]["delta"]['content'],end='\n')
# responses += item["choices"][0]["delta"]['content'] yield item["choices"][0]["delta"]['content']
responses += item["choices"][0]["delta"]['content']
print(item["choices"][0]["message"]['content']) # print(item["choices"][0]["message"]['content'])
responses += item["choices"][0]["message"]['content'] # responses += item["choices"][0]["message"]['content']
# total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *} # total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
user_context = messages + [{'role': 'assistant', 'content': responses}] user_context = messages + [{'role': 'assistant', 'content': responses}]
return responses, user_context self.custom_print(user_context)
# return responses, user_context
def _get_model_url(self,model_name:str | None):
available_models = {}
for model, url in self.model_dict.items():
try:
response = requests.get(url,timeout=3)
if response.status_code == 200:
available_models[model] = url
except Exception as e:
# print(e)
pass
if not available_models: print("There are no available running models and please check your endpoint urls.")
if model_name and model_name in available_models:
return available_models[model_name]
else:
model = random.choice(list(available_models.keys()))
print(f"No such model {model_name}, using {model} instead.") if model_name else print(f"Using random model {model}.")
return available_models[model]
def _into_openai_format(self, context:List[list]) -> List[dict]: def _into_openai_format(self, context:List[list]) -> List[dict]:
""" """
@ -255,7 +285,35 @@ class VLMS(Blackbox):
return user_context return user_context
def keep_last_k_images(self, user_context: list, k:int=2):
count = 0
result =[]
for item in user_context[::-1]:
if item['role'] == 'user' and len(item['content']) > 1:
for idx, info in enumerate(item['content']):
if info['type'] in ('image_url','image') and count >= k:
item['content'].pop(idx)
# item['content'].insert(idx, {'type': 'text', 'text': '<IMAGE>'})
elif info['type'] in ('image_url','image') and count < k:
count += 1
else:
continue
result.append(item)
return result[::-1]
def custom_print(self, user_context: list):
result = []
for item in user_context:
if item['role'] == 'user':
for idx, info in enumerate(item['content']):
if info['type'] in ('image_url','image'):
item['content'].pop(idx)
item['content'].insert(idx, {'type': 'image', 'image': '##<IMAGE>##'})
else:
continue
result.append(item)
print(result)
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
## TODO: add support for multiple images and support image in form-data format ## TODO: add support for multiple images and support image in form-data format
json_request = True json_request = True
@ -278,7 +336,6 @@ class VLMS(Blackbox):
prompt = data.get("prompt") prompt = data.get("prompt")
settings: dict = data.get('settings') settings: dict = data.get('settings')
context = data.get("context") context = data.get("context")
if not context: if not context:
user_context = [] user_context = []
elif isinstance(context[0], list): elif isinstance(context[0], list):
@ -297,10 +354,13 @@ class VLMS(Blackbox):
if prompt is None: if prompt is None:
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
if model_name is None or model_name.isspace(): # if model_name is None or model_name.isspace():
model_name = "Qwen-VL-Chat" # model_name = "Qwen-VL-Chat"
# response,_ = self.processing(prompt, img_data,settings, model_name,user_context=user_context)
# return StreamingResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
return EventSourceResponse(self.processing(prompt, img_data,settings, model_name,user_context=user_context), status_code=status.HTTP_200_OK)
# HTTP JsonResponse
response, history = self.processing(prompt, img_data,settings, model_name,user_context=user_context) response, history = self.processing(prompt, img_data,settings, model_name,user_context=user_context)
# jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8") # return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)
return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)

73
src/blackbox/websearch.py Normal file
View File

@ -0,0 +1,73 @@
import datetime
from typing import Any, Coroutine
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from openai import OpenAI
from .blackbox import Blackbox
import logging
from ..log.logging_time import logging_time
import requests
import json
logger = logging.getLogger
DEFAULT_COLLECTION_ID = "123"
from injector import singleton
@singleton
class WebSearch(Blackbox):
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
# @logging_time(logger=logger)
def processing(self, question: str, settings: dict) -> str:
if settings is None:
settings = {}
# from googlesearch import search
# question = "要搜索的关键词"
# for url in search(question, num_results=10):
# print(url)
url = "https://google.serper.dev/search"
payload = json.dumps({
"q": question,
"location": "China", # 限制所在位置为中国
"gl": "cn", # 限制国家为中国
"hl": "zh-cn", # 限制搜索结果为中文
"tbs": "qdr:y" # 限制搜索结果为一年内的
})
headers = {
'X-API-KEY': '00c0f5144e44721bd0cfed219e2b3256bb3dd5fc',
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
print("web search results:", response.json())
return response.json()
async def fast_api_handler(self, request: Request) -> Response:
try:
data = await request.json()
except:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
user_question = data.get("question")
setting = data.get("settings")
return JSONResponse(
content={"response": self.processing(user_question, setting)},
status_code=status.HTTP_200_OK)

View File

@ -179,4 +179,4 @@ class VLMConf():
@inject @inject
def __init__(self, config: Configuration) -> None: def __init__(self, config: Configuration) -> None:
self.url = config.get("vlms.url") self.urls = config.get("vlms.urls")