Merge pull request #14 from BoardWare-Genius/ivan

VLM implementation
This commit is contained in:
IvanWu
2024-08-20 18:09:41 +08:00
committed by GitHub
3 changed files with 377 additions and 290 deletions

View File

@ -89,4 +89,7 @@ Model:
batch_size: 3 batch_size: 3
blackbox: blackbox:
lazyloading: true lazyloading: true
vlms:
url: http://10.6.80.87:23333
``` ```

View File

@ -1,11 +1,19 @@
from fastapi import Request, Response, status from fastapi import Request, Response, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from injector import singleton,inject
from typing import Optional, List
from .blackbox import Blackbox from .blackbox import Blackbox
from typing import Optional from ..log.logging_time import logging_time
# from .chroma_query import ChromaQuery
from ..configuration import VLMConf
import requests import requests
import base64 import base64
import io
from PIL import Image
from lmdeploy.serve.openai.api_client import APIClient
def is_base64(value) -> bool: def is_base64(value) -> bool:
try: try:
@ -14,9 +22,16 @@ def is_base64(value) -> bool:
except Exception: except Exception:
return False return False
@singleton
class VLMS(Blackbox): class VLMS(Blackbox):
@inject
def __init__(self, vlm_config: VLMConf):
# Chroma database initially set up for RAG for vision model.
# It could be expended to history store.
# self.chroma_query = chroma_query
self.url = vlm_config.url
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs) return self.processing(*args, **kwargs)
@ -24,8 +39,18 @@ class VLMS(Blackbox):
data = args[0] data = args[0]
return isinstance(data, list) return isinstance(data, list)
def processing(self, prompt, images, model_name: Optional[str] = None) -> str: def processing(self, prompt:str, images:str | bytes, model_name: Optional[str] = None, user_context: List[dict] = None) -> str:
"""
Args:
prompt: a string query to the model.
images: a base64 string of image data;
user_context: a list of history conversation, should be a list of openai format.
Return:
response: a string
history: a list
"""
if model_name == "Qwen-VL-Chat": if model_name == "Qwen-VL-Chat":
model_name = "infer-qwen-vl" model_name = "infer-qwen-vl"
elif model_name == "llava-llama-3-8b-v1_1-transformers": elif model_name == "llava-llama-3-8b-v1_1-transformers":
@ -33,29 +58,79 @@ class VLMS(Blackbox):
else: else:
model_name = "infer-qwen-vl" model_name = "infer-qwen-vl"
url = 'http://120.196.116.194:48894/' + model_name + '/'
if is_base64(images): # Transform the images into base64 format where openai format need.
if is_base64(images): # image as base64 str
images_data = images images_data = images
else: elif isinstance(images,bytes): # image as bytes
with open(images, "rb") as img_file: images_data = str(base64.b64encode(images),'utf-8')
images_data = str(base64.b64encode(img_file.read()), 'utf-8') else: # image as pathLike str
# with open(images, "rb") as img_file:
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
res = requests.get(images)
images_data = str(base64.b64encode(res.content),'utf-8')
## AutoLoad Model
# url = 'http://10.6.80.87:8000/' + model_name + '/'
# data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data}
# data = requests.post(url, json=data_input)
# print(data.text)
# return data.text
data_input = {'model': model_name, 'prompt': prompt, 'img_data': images_data} # 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
## Lmdeploy
if not user_context:
user_context = []
# user_context = [{'role':'user','content':'你好'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
api_client = APIClient(self.url)
model_name = api_client.available_models[0]
messages = user_context + [{
'role': 'user',
'content': [{
'type': 'text',
'text': prompt,
}, {
'type': 'image_url',
'image_url': {
'url': f"data:image/jpeg;base64,{images_data}",
# './val_data/image_5.jpg',
},
}]
}
]
responses = ''
total_token_usage = 0 # which can be used to count the cost of a query
for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
messages=messages#,stream = True
)):
# print(item["choices"][0]["message"]['content'])
responses += item["choices"][0]["message"]['content']
total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
user_context = messages + [{'role': 'assistant', 'content': responses}]
return responses, user_context
data = requests.post(url, json=data_input)
return data.text
async def fast_api_handler(self, request: Request) -> Response: async def fast_api_handler(self, request: Request) -> Response:
json_request = True
try: try:
data = await request.json() content_type = request.headers['content-type']
except: if content_type == 'application/json':
data = await request.json()
else:
data = await request.form()
json_request = False
except Exception as e:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
model_name = data.get("model_name") model_name = data.get("model_name")
prompt = data.get("prompt") prompt = data.get("prompt")
img_data = data.get("img_data")
if json_request:
img_data = data.get("img_data")
else:
img_data = await data.get("img_data").read()
if prompt is None: if prompt is None:
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
@ -63,5 +138,7 @@ class VLMS(Blackbox):
if model_name is None or model_name.isspace(): if model_name is None or model_name.isspace():
model_name = "Qwen-VL-Chat" model_name = "Qwen-VL-Chat"
jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8") response, history = self.processing(prompt, img_data, model_name)
return JSONResponse(content={"response": jsonresp}, status_code=status.HTTP_200_OK) # jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8")
return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK)

View File

@ -129,3 +129,10 @@ class BlackboxConf():
@inject @inject
def __init__(self, config: Configuration) -> None: def __init__(self, config: Configuration) -> None:
self.lazyloading = bool(config.get("blackbox.lazyloading", default=False)) self.lazyloading = bool(config.get("blackbox.lazyloading", default=False))
@singleton
class VLMConf():
@inject
def __init__(self, config: Configuration) -> None:
self.url = config.get("vlms.url")