support formdata of request

This commit is contained in:
Ivan087
2024-08-20 18:02:44 +08:00
parent 4c3756811d
commit 4d260b3361

View File

@ -1,11 +1,11 @@
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from injector import singleton,inject from injector import singleton,inject
from typing import Optional from typing import Optional, List
from .blackbox import Blackbox from .blackbox import Blackbox
from ..log.logging_time import logging_time from ..log.logging_time import logging_time
from .chroma_query import ChromaQuery # from .chroma_query import ChromaQuery
from ..configuration import VLMConf from ..configuration import VLMConf
import requests import requests
@ -39,46 +39,53 @@ class VLMS(Blackbox):
data = args[0] data = args[0]
return isinstance(data, list) return isinstance(data, list)
def processing(self, prompt, images, model_name: Optional[str] = None) -> str: def processing(self, prompt:str, images:str | bytes, model_name: Optional[str] = None, user_context: List[dict] = None) -> str:
"""
Args:
prompt: a string query to the model.
images: a base64 string of image data;
user_context: a list of history conversation, should be a list of openai format.
Return:
response: a string
history: a list
"""
if model_name == "Qwen-VL-Chat": if model_name == "Qwen-VL-Chat":
model_name = "infer-qwen-vl" model_name = "infer-qwen-vl"
elif model_name == "llava-llama-3-8b-v1_1-transformers": elif model_name == "llava-llama-3-8b-v1_1-transformers":
model_name = "infer-lav-lam-v1-1" model_name = "infer-lav-lam-v1-1"
else: else:
model_name = "infer-qwen-vl" model_name = "infer-qwen-vl"
# Transform the images into base64 format where openai format need.
if is_base64(images): # image as base64 str
images_data = images
elif isinstance(images,bytes): # image as bytes
images_data = str(base64.b64encode(images),'utf-8')
else: # image as pathLike str
# with open(images, "rb") as img_file:
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
res = requests.get(images)
images_data = str(base64.b64encode(res.content),'utf-8')
## AutoLoad Model ## AutoLoad Model
# url = 'http://10.6.80.87:8000/' + model_name + '/' # url = 'http://10.6.80.87:8000/' + model_name + '/'
if is_base64(images):
images_data = images
else:
# print("{}Type of image data in form {}".format('#'*20,type(images)))
# print("{}Type of image data in form {}".format('#'*20,type(images.file)))
# byte_stream = io.BytesIO(images.read())
# print("{}Type of image data in form {}".format('#'*20,type(byte_stream)))
# # roiImg = Image.open(byte_stream)
# # print("{}Successful {}".format('#'*20,type(roiImg)))
# return str(type(byte_stream))
# images_data = base64.b64encode(byte_stream)
with open(images, "rb") as img_file:
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
images_data = base64.b64encode(img_file.read())
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data} # data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
# data = requests.post(url, json=data_input) # data = requests.post(url, json=data_input)
# print(data.text) # print(data.text)
# return data.text
# 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
## Lmdeploy ## Lmdeploy
if not user_context:
user_context = []
# user_context = [{'role':'user','content':'你好'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
api_client = APIClient(self.url) api_client = APIClient(self.url)
# api_client = APIClient(f'http://10.6.80.87:23333')
model_name = api_client.available_models[0] model_name = api_client.available_models[0]
messages = [{
'role': messages = user_context + [{
'user', 'role': 'user',
'content': [{ 'content': [{
'type': 'text', 'type': 'text',
'text': prompt, 'text': prompt,
@ -93,25 +100,37 @@ class VLMS(Blackbox):
] ]
responses = '' responses = ''
total_token_usage = 0 # which can be used to count the cost of a query
for i,item in enumerate(api_client.chat_completions_v1(model=model_name, for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
messages=messages#,stream = True messages=messages#,stream = True
)): )):
print(item["choices"][0]["message"]['content']) # print(item["choices"][0]["message"]['content'])
responses += item["choices"][0]["message"]['content'] responses += item["choices"][0]["message"]['content']
total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
return responses user_context = messages + [{'role': 'assistant', 'content': responses}]
return responses, user_context
# return data.text
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
json_request = True
try: try:
data = await request.form() content_type = request.headers['content-type']
except: if content_type == 'application/json':
data = await request.json()
else:
data = await request.form()
json_request = False
except Exception as e:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
model_name = data.get("model_name") model_name = data.get("model_name")
prompt = data.get("prompt") prompt = data.get("prompt")
img_data = data.get("img_data")
if json_request:
img_data = data.get("img_data")
else:
img_data = await data.get("img_data").read()
if prompt is None: if prompt is None:
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -119,6 +138,7 @@ class VLMS(Blackbox):
if model_name is None or model_name.isspace(): if model_name is None or model_name.isspace():
model_name = "Qwen-VL-Chat" model_name = "Qwen-VL-Chat"
response, history = self.processing(prompt, img_data, model_name)
# jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8") # jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
return JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}, status_code=status.HTTP_200_OK) return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK)