diff --git a/src/blackbox/vlms.py b/src/blackbox/vlms.py index 3e07bfa..baff1da 100644 --- a/src/blackbox/vlms.py +++ b/src/blackbox/vlms.py @@ -1,11 +1,11 @@ from fastapi import Request, Response, status from fastapi.responses import JSONResponse from injector import singleton,inject -from typing import Optional +from typing import Optional, List from .blackbox import Blackbox from ..log.logging_time import logging_time -from .chroma_query import ChromaQuery +# from .chroma_query import ChromaQuery from ..configuration import VLMConf import requests @@ -39,46 +39,53 @@ class VLMS(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, prompt, images, model_name: Optional[str] = None) -> str: + def processing(self, prompt:str, images:str | bytes, model_name: Optional[str] = None, user_context: List[dict] = None) -> str: + """ + Args: + prompt: a string query to the model. + images: a base64 string of image data; + user_context: a list of history conversation, should be a list of openai format. + + Return: + response: a string + history: a list + """ if model_name == "Qwen-VL-Chat": model_name = "infer-qwen-vl" elif model_name == "llava-llama-3-8b-v1_1-transformers": model_name = "infer-lav-lam-v1-1" else: model_name = "infer-qwen-vl" + - + # Transform the images into base64 format where openai format need. + if is_base64(images): # image as base64 str + images_data = images + elif isinstance(images,bytes): # image as bytes + images_data = str(base64.b64encode(images),'utf-8') + else: # image as pathLike str + # with open(images, "rb") as img_file: + # images_data = str(base64.b64encode(img_file.read()), 'utf-8') + res = requests.get(images) + images_data = str(base64.b64encode(res.content),'utf-8') ## AutoLoad Model # url = 'http://10.6.80.87:8000/' + model_name + '/' - - if is_base64(images): - images_data = images - else: - # print("{}Type of image data in form {}".format('#'*20,type(images))) - # print("{}Type of image data in form {}".format('#'*20,type(images.file))) - # byte_stream = io.BytesIO(images.read()) - # print("{}Type of image data in form {}".format('#'*20,type(byte_stream))) - # # roiImg = Image.open(byte_stream) - # # print("{}Successful {}".format('#'*20,type(roiImg))) - # return str(type(byte_stream)) - # images_data = base64.b64encode(byte_stream) - with open(images, "rb") as img_file: - # images_data = str(base64.b64encode(img_file.read()), 'utf-8') - images_data = base64.b64encode(img_file.read()) - # data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data} - # data = requests.post(url, json=data_input) # print(data.text) + # return data.text + # 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' ## Lmdeploy + if not user_context: + user_context = [] + # user_context = [{'role':'user','content':'你好'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}] api_client = APIClient(self.url) - # api_client = APIClient(f'http://10.6.80.87:23333') model_name = api_client.available_models[0] - messages = [{ - 'role': - 'user', + + messages = user_context + [{ + 'role': 'user', 'content': [{ 'type': 'text', 'text': prompt, @@ -93,25 +100,37 @@ class VLMS(Blackbox): ] responses = '' + total_token_usage = 0 # which can be used to count the cost of a query for i,item in enumerate(api_client.chat_completions_v1(model=model_name, messages=messages#,stream = True )): - print(item["choices"][0]["message"]['content']) + # print(item["choices"][0]["message"]['content']) responses += item["choices"][0]["message"]['content'] - - return responses + total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *} + user_context = messages + [{'role': 'assistant', 'content': responses}] + return responses, user_context - # return data.text + async def fast_api_handler(self, request: Request) -> Response: + json_request = True try: - data = await request.form() - except: + content_type = request.headers['content-type'] + if content_type == 'application/json': + data = await request.json() + else: + data = await request.form() + json_request = False + except Exception as e: return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) - + model_name = data.get("model_name") prompt = data.get("prompt") - img_data = data.get("img_data") + + if json_request: + img_data = data.get("img_data") + else: + img_data = await data.get("img_data").read() if prompt is None: return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -119,6 +138,7 @@ class VLMS(Blackbox): if model_name is None or model_name.isspace(): model_name = "Qwen-VL-Chat" + response, history = self.processing(prompt, img_data, model_name) # jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8") - return JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}, status_code=status.HTTP_200_OK) \ No newline at end of file + return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK) \ No newline at end of file