Files
jarvis-models/src/blackbox/vlms.py
2025-09-18 10:55:27 +08:00

358 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
from sse_starlette.sse import EventSourceResponse
from injector import singleton,inject
from typing import Optional, List
from .blackbox import Blackbox
from ..log.logging_time import logging_time
# from .chroma_query import ChromaQuery
from ..configuration import VLMConf
import requests
import base64
import copy
import ast
import json
import random
from time import time
import io
from PIL import Image
# from lmdeploy.serve.openai.api_client import APIClient
from openai import OpenAI
def is_base64(value) -> bool:
try:
base64.b64decode(base64.b64decode(value)) == value.encode()
return True
except Exception:
return False
@singleton
class VLMS(Blackbox):
@inject
def __init__(self, vlm_config: VLMConf):
"""
Initialization for endpoint url and generation config.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True."""
self.model_dict = vlm_config.urls
self.available_models = {}
self.temperature: float = 0.7
self.top_p:float = 1
self.max_tokens: (int |None) = 512
self.repetition_penalty: float = 1
self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>']
self.top_k: (int) = 40
self.ignore_eos: (bool) = False
self.skip_special_tokens: (bool) = True
self.settings: dict = {
"temperature": self.temperature,
"top_p":self.top_p,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
"stop": self.stop,
"top_k": self.top_k,
"ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens,
}
for model, url in self.model_dict.items():
try:
response = requests.get(url+'/health',timeout=3)
if response.status_code == 200:
self.available_models[model] = url
except Exception as e:
pass
def __call__(self, *args, **kwargs):
return self.processing(*args, **kwargs)
def valid(self, *args, **kwargs) -> bool:
data = args[0]
return isinstance(data, list)
def processing(self, prompt:str | None, images:str | bytes | None, settings: dict, user_context: List[dict] = None) -> str:
"""
Args:
prompt: a string query to the model.
images: a base64 string of image data;
user_context: a list of history conversation, should be a list of openai format.
settings: a dictionary set by user with fields stated in __init__
Return:
response: a string
history: a list
"""
config: dict = {
"lmdeploy_infer":True,
"system_prompt":"",
"vlm_model_name":"",
}
if settings:
for k in list(settings.keys()):
if k not in self.settings:
print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys()))
config[k] = settings.pop(k)
tmp = copy.deepcopy(self.settings)
tmp.update(settings)
settings = tmp
else:
settings = {}
config['lmdeploy_infer'] = str(config['lmdeploy_infer']).strip().lower() == 'true'
if not prompt:
prompt = '你是一个辅助机器人请就此图做一个简短的概括性描述包括图中的主体物品及状态不超过50字。' if images else '你好'
# Transform the images into base64 format where openai url)
if images:
if is_base64(images): # image as base64 str
images_data = images
elif isinstance(images,bytes): # image as bytes
images_data = str(base64.b64encode(images),'utf-8')
else: # image as pathLike str
# with open(images, "rb") as img_file:
# images_data = str(base64.b64encode(img_file.read()), 'utf-8')
res = requests.get(images)
images_data = str(base64.b64encode(res.content),'utf-8')
else:
images_data = None
## Predefine user_context only for testing
# user_context = [{'role':'user','content':'你好,我叫康康,你是谁?'}, {'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}]
if not user_context and config['system_prompt']: user_context = [{'role':'system','content': config['system_prompt']}]
user_context = self.keep_last_k_images(user_context,k = 2)
# Reformat input into openai format to request.
if images_data:
messages = user_context + [{
'role': 'user',
'content': [{
'type': 'text',
'text': prompt,
},{
'type': 'image_url',
'image_url': { # Image two
'url':
f"data:image/jpeg;base64,{images_data}",
},
# },{ # Image one
# 'type': 'image_url',
# 'image_url': {
# 'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'
# },
}]
}
]
else:
messages = user_context + [{
'role': 'user',
'content': [{
'type': 'text',
'text': prompt,
}]
}
]
responses = ''
total_token_usage = 0 # which can be used to count the cost of a query
model_url = self._get_model_url(config['vlm_model_name'])
# if config['lmdeploy_infer']:
# # api_client = APIClient(model_url)
# # model_name = api_client.available_models[0]
# for i,item in enumerate(api_client.chat_completions_v1(model=model_name,
# messages=messages,stream = True,
# **settings,
# # session_id=,
# )):
# # Stream output
# yield item["choices"][0]["delta"]['content']
# responses += item["choices"][0]["delta"]['content']
# # print(item["choices"][0]["message"]['content'])
# # responses += item["choices"][0]["message"]['content']
# # total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *}
# else:
api_key = "EMPTY_API_KEY"
api_client = OpenAI(api_key=api_key, base_url=model_url+'/v1')
model_name = api_client.models.list().data[0].id
for item in api_client.chat.completions.create(
model=model_name,
messages=messages,
**settings,
stream=True):
yield(item.choices[0].delta.content)
responses += item.choices[0].delta.content
# print(response.choices[0].message.content)
# return response.choices[0].message.content
user_context = messages + [{'role': 'assistant', 'content': responses}]
self.custom_print(user_context)
# return responses
def _get_model_url(self,model_name:str | None):
if not self.available_models: print("There are no available running models and please check your endpoint urls.")
if model_name and model_name in self.available_models:
return self.available_models[model_name]
else:
model = random.choice(list(self.available_models.keys()))
print(f"No such model {model_name}, using {model} instead.") if model_name else print(f"Using random model {model}.")
return self.available_models[model]
def _into_openai_format(self, context:List[list]) -> List[dict]:
"""
Convert the data into openai format.
context: a list of list, each element have the form [user_input, response],
and the first one of list 'user_input' is also tuple with [,text]; [image,text] or [[imgs],text]
#TODO: add support for multiple images
"""
user_context = []
for i,item in enumerate(context):
user_content = item[0]
if isinstance(user_content, list):
if len(user_content) == 1:
user_content = [{
'type': 'text',
'text': user_content[0]
}]
elif is_base64(user_content[0]):
user_content = [{
'type': 'image_url',
'image_url': {
'url': f"data:image/jpeg;base64,{user_content[0]}"
},
},{
'type': 'text',
'text': user_content[1]
}]
else:
user_content = [{
'type': 'image_url',
'image_url': {
'url': user_content[0]
},
},{
'type': 'text',
'text': user_content[1]
}]
else:
user_content = [{
'type': 'text',
'text': user_content
}]
user_context.append({
'role': 'user',
'content': user_content
})
user_context.append({
'role': 'assistant',
'content': item[1]
})
return user_context
def keep_last_k_images(self, user_context: list, k:int=2):
count = 0
result =[]
for item in user_context[::-1]:
if item['role'] == 'user' and len(item['content']) > 1:
for idx, info in enumerate(item['content']):
if info['type'] in ('image_url','image') and count >= k:
item['content'].pop(idx)
# item['content'].insert(idx, {'type': 'text', 'text': '<IMAGE>'})
elif info['type'] in ('image_url','image') and count < k:
count += 1
else:
continue
result.append(item)
return result[::-1]
def custom_print(self, user_context: list):
result = []
for item in user_context:
if item['role'] == 'user':
for idx, info in enumerate(item['content']):
if info['type'] in ('image_url','image'):
item['content'].pop(idx)
item['content'].insert(idx, {'type': 'image', 'image': '##<IMAGE>##'})
else:
continue
result.append(item)
print(result)
async def fast_api_handler(self, request: Request) -> Response:
## TODO: add support for multiple images and support image in form-data format
json_request = True
try:
content_type = request.headers.get('content-type', '')
if content_type == 'application/json':
data = await request.json()
elif 'multipart/form-data' in content_type:
data = await request.form()
json_request = False
else:
body = await request.body()
data = json.loads(body.decode("utf-8"))
except Exception as e:
return JSONResponse(content={"error": "json parse error"}, status_code=status.HTTP_400_BAD_REQUEST)
prompt = data.get("prompt")
settings: dict = data.get('settings')
context = data.get("context")
if not context:
user_context = []
elif isinstance(context[0], list):
user_context = self._into_openai_format(context)
elif isinstance(context[0], dict):
user_context = context
else:
return JSONResponse(content={"error": "context format error, should be in format of list or Openai_format"}, status_code=status.HTTP_400_BAD_REQUEST)
if json_request or 'multipart/form-data' not in content_type:
img_data = data.get("img_data")
else:
img_data = await data.get("img_data").read()
if settings: settings = ast.literal_eval(settings)
if prompt is None:
return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST)
streaming_output = str(settings.get('stream',False)).strip().lower() == 'true' if settings else False
if streaming_output:
# return StreamingResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
return EventSourceResponse(self.processing(prompt, img_data,settings, user_context=user_context), status_code=status.HTTP_200_OK)
else:
# HTTP JsonResponse
output = self.processing(prompt, img_data,settings, user_context=user_context)
response = ''.join([res for res in output])
return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK)