From a5f33e2d514e5bc3502e06cc2dc0523d3df6d017 Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 14:51:37 +0800 Subject: [PATCH] update ollama --- examples/llm_vision.py | 4 +- metagpt/configs/llm_config.py | 4 +- metagpt/provider/general_api_base.py | 4 +- metagpt/provider/ollama_api.py | 76 ++++++++++++++++++++-------- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/examples/llm_vision.py b/examples/llm_vision.py index 276decd59..eea6550f6 100644 --- a/examples/llm_vision.py +++ b/examples/llm_vision.py @@ -15,8 +15,8 @@ async def main(): # check if the configured llm supports llm-vision capacity. If not, it will throw a error invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") img_base64 = encode_image(invoice_path) - res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) - assert "true" in res.lower() + res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) + assert ("true" in res.lower()) or ("invoice" in res.lower()) if __name__ == "__main__": diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 3a13789d9..34b73d2d5 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -106,8 +106,8 @@ class LLMConfig(YamlModel): root_config_path = CONFIG_ROOT / "config2.yaml" if root_config_path.exists(): raise ValueError( - f"Please set your API key in {root_config_path}. If you also set your config in { - repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" + f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n" + f"the former will overwrite the latter. This may cause unexpected result.\n" ) elif repo_config_path.exists(): raise ValueError(f"Please set your API key in {repo_config_path}") diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index a4b50af4b..34a39fe6c 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -396,8 +396,8 @@ class APIRequestor: "X-LLM-Client-User-Agent": json.dumps(ua), "User-Agent": user_agent, } - - headers.update(api_key_to_header(self.api_type, self.api_key)) + if self.api_key: + headers.update(api_key_to_header(self.api_type, self.api_key)) if self.organization: headers["LLM-Organization"] = self.organization diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 8522f08a1..ab067340c 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -27,6 +27,7 @@ class OllamaMessageBase: def __init__(self, model: str, **additional_kwargs) -> None: self.model, self.additional_kwargs = model, additional_kwargs + self._image_b64_rms = len("data:image/jpeg;base64,") @property def api_suffix(self) -> str: @@ -38,15 +39,16 @@ class OllamaMessageBase: def decode(self, response: OpenAIResponse) -> dict: return json.loads(response.data.decode("utf-8")) + def get_choice(self, to_choice_dict: dict) -> str: + raise NotImplementedError + def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]: - if "role" in msg: - return msg["content"], None - elif "type" in msg: + if "type" in msg: tpe = msg["type"] if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: @@ -84,18 +86,35 @@ class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta): return "/chat" def apply(self, messages: list[dict]) -> dict: + content = messages[0]["content"] prompts = [] images = [] - for msg in messages: - prompt, image = self._parse_input_msg(msg) - if prompt: - prompts.append(prompt) - if image: - images.append(image) - sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + if isinstance(content, list): + for msg in content: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + else: + prompts.append(content) + messes = [] + for prompt in prompts: + if len(images) > 0: + messes.append({"role": "user", "content": "\n".join(prompts), "images": images}) + else: + messes.append({"role": "user", "content": "\n".join(prompts)}) + sends = {"model": self.model, "messages": messes} sends.update(self.additional_kwargs) return sends + def get_choice(self, to_choice_dict: dict) -> str: + message = to_choice_dict["message"] + if message["role"] == "assistant": + return message["content"] + else: + raise ValueError + class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): api_type = OllamaMessageAPI.GENERATE @@ -104,6 +123,29 @@ class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): def api_suffix(self) -> str: return "/generate" + def apply(self, messages: list[dict]) -> dict: + content = messages[0]["content"] + prompts = [] + images = [] + if isinstance(content, list): + for msg in content: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + else: + prompts.append(content) + if len(images) > 0: + sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + else: + sends = {"model": self.model, "prompt": "\n".join(prompts)} + sends.update(self.additional_kwargs) + return sends + + def get_choice(self, to_choice_dict: dict) -> str: + return to_choice_dict["response"] + class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta): api_type = OllamaMessageAPI.EMBED @@ -137,13 +179,6 @@ class OllamaLLM(BaseLLM): self.cost_manager = TokenCostManager() self.__init_ollama(config) - def _get_headers(self): - return ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else {"Authorization": f"Bearer {self.config.api_key}"} - ) - @property def _llama_api_inuse(self) -> OllamaMessageAPI: return OllamaMessageAPI.CHAT @@ -158,7 +193,6 @@ class OllamaLLM(BaseLLM): self.pricing_plan = self.model ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse) self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs) - self.headers = self._get_headers() def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} @@ -167,7 +201,6 @@ class OllamaLLM(BaseLLM): resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, - headers=self.headers, params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) @@ -185,7 +218,6 @@ class OllamaLLM(BaseLLM): resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, - headers=self.headers, params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), stream=True, @@ -210,7 +242,7 @@ class OllamaLLM(BaseLLM): chunk = self.ollama_message.decode(raw_chunk) if not chunk.get("done", False): - content = self.get_choice_text(chunk) + content = self.ollama_message.get_choice(chunk) collected_content.append(content) log_llm_stream(content) else: