From 44561de2c5d9fed66361b258682ee9baa6c92815 Mon Sep 17 00:00:00 2001 From: Ivan087 Date: Wed, 21 Aug 2024 14:38:08 +0800 Subject: [PATCH] support generation config --- src/blackbox/vlms.py | 89 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 16 deletions(-) diff --git a/src/blackbox/vlms.py b/src/blackbox/vlms.py index baff1da..5ae5123 100644 --- a/src/blackbox/vlms.py +++ b/src/blackbox/vlms.py @@ -10,6 +10,8 @@ from ..configuration import VLMConf import requests import base64 +import copy +import ast import io from PIL import Image @@ -27,11 +29,48 @@ class VLMS(Blackbox): @inject def __init__(self, vlm_config: VLMConf): - # Chroma database initially set up for RAG for vision model. - # It could be expended to history store. - # self.chroma_query = chroma_query + """ + Initialization for endpoint url and generation config. + - temperature (float): to modulate the next token probability + - top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + - max_tokens (int | None): output token nums. Default to None. + - repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + - stop (str | List[str] | None): To stop generating further + tokens. Only accept stop words that's encoded to one token idex. + + Additional arguments supported by LMDeploy: + - top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + - ignore_eos (bool): indicator for ignoring eos + - skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True.""" self.url = vlm_config.url + self.temperature: float = 0.7 + self.top_p:float = 1 + self.max_tokens: (int |None) = 512 + self.repetition_penalty: float = 1 + self.stop: (str | List[str] |None) = ['<|endoftext|>','<|im_end|>'] + + self.top_k: (int) = None + self.ignore_eos: (bool) = False + self.skip_special_tokens: (bool) = True + + self.settings: dict = { + "temperature": self.temperature, + "top_p":self.top_p, + "max_tokens": self.max_tokens, + "repetition_penalty": self.repetition_penalty, + "stop": self.stop, + "top_k": self.top_k, + "ignore_eos": self.ignore_eos, + "skip_special_tokens": self.skip_special_tokens, + } + + def __call__(self, *args, **kwargs): return self.processing(*args, **kwargs) @@ -39,25 +78,34 @@ class VLMS(Blackbox): data = args[0] return isinstance(data, list) - def processing(self, prompt:str, images:str | bytes, model_name: Optional[str] = None, user_context: List[dict] = None) -> str: + def processing(self, prompt:str, images:str | bytes, settings: dict, model_name: Optional[str] = None, user_context: List[dict] = None) -> str: """ Args: prompt: a string query to the model. images: a base64 string of image data; user_context: a list of history conversation, should be a list of openai format. - + settings: a dictionary set by user with fields stated in __init__ Return: response: a string history: a list """ - if model_name == "Qwen-VL-Chat": - model_name = "infer-qwen-vl" - elif model_name == "llava-llama-3-8b-v1_1-transformers": - model_name = "infer-lav-lam-v1-1" + # if model_name == "Qwen-VL-Chat": + # model_name = "infer-qwen-vl" + # elif model_name == "llava-llama-3-8b-v1_1-transformers": + # model_name = "infer-lav-lam-v1-1" + # else: + # model_name = "infer-qwen-vl" + if settings: + for k in settings: + if k not in self.settings: + print("Warning: '{}' is not a support argument and ignore this argment, check the arguments {}".format(k,self.settings.keys())) + settings.pop(k) + tmp = copy.deepcopy(self.settings) + tmp.update(settings) + settings = tmp else: - model_name = "infer-qwen-vl" - + settings = {} # Transform the images into base64 format where openai format need. if is_base64(images): # image as base64 str @@ -102,11 +150,18 @@ class VLMS(Blackbox): responses = '' total_token_usage = 0 # which can be used to count the cost of a query for i,item in enumerate(api_client.chat_completions_v1(model=model_name, - messages=messages#,stream = True + messages=messages,stream = True, + **settings, + # session_id=, )): + # Stream output + print(item["choices"][0]["delta"]['content'],end='') + responses += item["choices"][0]["delta"]['content'] + # print(item["choices"][0]["message"]['content']) - responses += item["choices"][0]["message"]['content'] - total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *} + # responses += item["choices"][0]["message"]['content'] + # total_token_usage += item['usage']['total_tokens'] # 'usage': {'prompt_tokens': *, 'total_tokens': *, 'completion_tokens': *} + user_context = messages + [{'role': 'assistant', 'content': responses}] return responses, user_context @@ -126,11 +181,13 @@ class VLMS(Blackbox): model_name = data.get("model_name") prompt = data.get("prompt") - + if json_request: img_data = data.get("img_data") + settings: dict = data.get('settings') else: img_data = await data.get("img_data").read() + settings: dict = ast.literal_eval(data.get('settings')) if prompt is None: return JSONResponse(content={'error': "Question is required"}, status_code=status.HTTP_400_BAD_REQUEST) @@ -138,7 +195,7 @@ class VLMS(Blackbox): if model_name is None or model_name.isspace(): model_name = "Qwen-VL-Chat" - response, history = self.processing(prompt, img_data, model_name) + response, history = self.processing(prompt, img_data,settings, model_name) # jsonresp = str(JSONResponse(content={"response": self.processing(prompt, img_data, model_name)}).body, "utf-8") return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK) \ No newline at end of file