From 3fefc48efa52b04e231e055463de7efb252f7d91 Mon Sep 17 00:00:00 2001 From: EvensXia Date: Mon, 28 Oct 2024 17:55:50 +0800 Subject: [PATCH 01/13] fix ollama to add llava support --- metagpt/provider/general_api_base.py | 43 ---------- metagpt/provider/general_api_requestor.py | 83 ++++++++++++++------ metagpt/provider/ollama_api.py | 96 +++++++++++++++-------- 3 files changed, 121 insertions(+), 101 deletions(-) diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 8e5da8f16..ab733db58 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -320,49 +320,6 @@ class APIRequestor: resp, got_stream = self._interpret_response(result, stream) return resp, got_stream, self.api_key - @overload - async def arequest( - self, - method, - url, - params, - headers, - files, - stream: Literal[True], - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: - pass - - @overload - async def arequest( - self, - method, - url, - params=..., - headers=..., - files=..., - *, - stream: Literal[True], - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: - pass - - @overload - async def arequest( - self, - method, - url, - params=..., - headers=..., - files=..., - stream: Literal[False] = ..., - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[OpenAIResponse, bool, str]: - pass - @overload async def arequest( self, diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 501a064e3..b8da1565d 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -3,25 +3,24 @@ # @Desc : General Async API for http-based LLM model import asyncio -from typing import AsyncGenerator, Generator, Iterator, Tuple, Union +from typing import AsyncGenerator, Iterator, Optional, Tuple, Union import aiohttp import requests from metagpt.logs import logger -from metagpt.provider.general_api_base import APIRequestor +from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse -def parse_stream_helper(line: bytes) -> Union[bytes, None]: +def parse_stream_helper(line: bytes) -> Optional[bytes]: if line and line.startswith(b"data:"): if line.startswith(b"data: "): - # SSE event may be valid when it contain whitespace + # SSE event may be valid when it contains whitespace line = line[len(b"data: ") :] else: line = line[len(b"data:") :] if line.strip() == b"[DONE]": - # return here will cause GeneratorExit exception in urllib3 - # and it will close http connection with TCP Reset + # Returning None to indicate end of stream return None else: return line @@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]: class GeneralAPIRequestor(APIRequestor): """ - usage + Usage example: # full_url = "{base_url}{url}" requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( @@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor): ) """ - def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes: - # just do nothing to meet the APIRequestor process and return the raw data - # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases. + def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse: + """ + Process and return the response data wrapped in OpenAIResponse. - return rbody + Args: + rbody (bytes): The response body. + rcode (int): The response status code. + rheaders (dict): The response headers. + stream (bool): Whether the response is a stream. + + Returns: + OpenAIResponse: The response data wrapped in OpenAIResponse. + """ + return OpenAIResponse(rbody, rheaders) def _interpret_response( self, result: requests.Response, stream: bool - ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: - """Returns the response(s) and a bool indicating whether it is a stream.""" + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: + """ + Interpret a synchronous response. + + Args: + result (requests.Response): The response object. + stream (bool): Whether the response is a stream. + + Returns: + Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream. + """ content_type = result.headers.get("Content-Type", "") if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type): return ( - self._interpret_response_line(line, result.status_code, result.headers, stream=True) - for line in parse_stream(result.iter_lines()) - ), True + ( + self._interpret_response_line(line, result.status_code, result.headers, stream=True) + for line in parse_stream(result.iter_lines()) + ), + True, + ) else: return ( self._interpret_response_line( - result.content, # let the caller to decode the msg + result.content, # let the caller decode the msg result.status_code, result.headers, stream=False, @@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor): async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool - ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]: + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + """ + Interpret an asynchronous response. + + Args: + result (aiohttp.ClientResponse): The response object. + stream (bool): Whether the response is a stream. + + Returns: + Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream. + """ content_type = result.headers.get("Content-Type", "") if stream and ( "text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == "" ): - # the `Content-Type` of ollama stream resp is "application/x-ndjson" return ( - self._interpret_response_line(line, result.status, result.headers, stream=True) - async for line in result.content - ), True + ( + self._interpret_response_line(line, result.status, result.headers, stream=True) + async for line in result.content + ), + True, + ) else: try: - await result.read() + response_content = await result.read() except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: raise TimeoutError("Request timed out") from e except aiohttp.ClientError as exp: - logger.warning(f"response: {result.content}, exp: {exp}") + logger.warning(f"response: {result}, exp: {exp}") + response_content = b"" return ( self._interpret_response_line( - await result.read(), # let the caller to decode the msg + response_content, # let the caller decode the msg result.status, result.headers, stream=False, diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 454f0e3ee..39cbe291e 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -3,12 +3,13 @@ # @Desc : self-host open llm model with ollama which isn't openai-api-compatible import json +from typing import AsyncGenerator, Tuple from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM -from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse from metagpt.provider.llm_provider_registry import register_provider from metagpt.utils.cost_manager import TokenCostManager @@ -34,65 +35,68 @@ class OllamaLLM(BaseLLM): self.pricing_plan = self.model def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} - return kwargs + return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) - assert assist_msg.get("role", None) == "assistant" - return assist_msg.get("content") + if assist_msg.get("role", None) == "assistant": # chat + return assist_msg.get("content") + else: # llava + return resp["response"] def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} - def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict: - chunk = chunk.decode(encoding) - return json.loads(chunk) + def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict: + return json.loads(openai_resp.data.decode(encoding)) async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: - headers = ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else { - "Authorization": f"Bearer {self.config.api_key}", - } - ) + messages, fixed_suffix_url = self._apply_llava(messages) resp, _, _ = await self.client.arequest( method=self.http_method, - url=self.suffix_url, - headers=headers, - params=self._const_kwargs(messages), + url=fixed_suffix_url, + headers=self._get_headers(), + params=messages, request_timeout=self.get_timeout(timeout), ) - resp = self._decode_and_load(resp) - usage = self.get_usage(resp) - self._update_costs(usage) - return resp + if isinstance(resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(resp) + elif isinstance(resp, OpenAIResponse): + return self._processing_openai_response(resp) + else: + raise NotImplementedError async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - headers = ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else { - "Authorization": f"Bearer {self.config.api_key}", - } - ) + messages, fixed_suffix_url = self._apply_llava(messages, stream=True) stream_resp, _, _ = await self.client.arequest( method=self.http_method, - url=self.suffix_url, - headers=headers, + url=fixed_suffix_url, + headers=self._get_headers(), stream=True, - params=self._const_kwargs(messages, stream=True), + params=messages, request_timeout=self.get_timeout(timeout), ) + if isinstance(stream_resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(stream_resp) + elif isinstance(stream_resp, OpenAIResponse): + return self._processing_openai_response(stream_resp) + else: + raise NotImplementedError + def _processing_openai_response(self, openai_resp: OpenAIResponse): + resp = self._decode_and_load(openai_resp) + usage = self.get_usage(resp) + self._update_costs(usage) + return resp + + async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]): collected_content = [] usage = {} - async for raw_chunk in stream_resp: + async for raw_chunk in ag_openai_resp: chunk = self._decode_and_load(raw_chunk) if not chunk.get("done", False): @@ -107,3 +111,29 @@ class OllamaLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content + + def _get_headers(self): + return ( + None + if not self.config.api_key or self.config.api_key == "sk-" + else {"Authorization": f"Bearer {self.config.api_key}"} + ) + + def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]: + llava = False + if isinstance(messages[0]["content"], str): + return self._const_kwargs(messages, stream), self.suffix_url + + if any(len(msg["content"]) >= 2 for msg in messages): + assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type" + llava = True + if not llava: + return self._const_kwargs(messages, stream), self.suffix_url + + assert len(messages) <= 1, "not support batch massages in llava calling images" + contents = messages[0]["content"] + return { + "model": self.model, + "prompt": contents[0]["text"], + "images": [i["image_url"]["url"][23:] for i in contents[1:]], + }, "/generate" From fdb834674dd3df2217c116156049c20582e7b2e2 Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 11:56:11 +0800 Subject: [PATCH 02/13] update ollama --- metagpt/configs/llm_config.py | 7 +- metagpt/provider/general_api_base.py | 3 +- metagpt/provider/ollama_api.py | 235 ++++++++++++++++++++------- 3 files changed, 184 insertions(+), 61 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 60118303c..3a13789d9 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -26,7 +26,9 @@ class LLMType(Enum): GEMINI = "gemini" METAGPT = "metagpt" AZURE = "azure" - OLLAMA = "ollama" + OLLAMA = "ollama" # /chat at ollama api + OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api + OLLAMA_EMBEDDING = "ollama.embeddings" # /embeddings at ollama api QIANFAN = "qianfan" # Baidu BCE DASHSCOPE = "dashscope" # Aliyun LingJi DashScope MOONSHOT = "moonshot" @@ -104,7 +106,8 @@ class LLMConfig(YamlModel): root_config_path = CONFIG_ROOT / "config2.yaml" if root_config_path.exists(): raise ValueError( - f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" + f"Please set your API key in {root_config_path}. If you also set your config in { + repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" ) elif repo_config_path.exists(): raise ValueError(f"Please set your API key in {repo_config_path}") diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index ab733db58..a4b50af4b 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -13,6 +13,7 @@ import time from contextlib import asynccontextmanager from enum import Enum from typing import ( + Any, AsyncGenerator, AsyncIterator, Dict, @@ -121,7 +122,7 @@ def logfmt(props): class OpenAIResponse: - def __init__(self, data, headers): + def __init__(self, data: Union[bytes, Any], headers: dict): self._headers = headers self.data = data diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 39cbe291e..8522f08a1 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -3,7 +3,8 @@ # @Desc : self-host open llm model with ollama which isn't openai-api-compatible import json -from typing import AsyncGenerator, Tuple +from enum import Enum, auto +from typing import AsyncGenerator, Optional, Tuple from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT @@ -14,6 +15,114 @@ from metagpt.provider.llm_provider_registry import register_provider from metagpt.utils.cost_manager import TokenCostManager +class OllamaMessageAPI(Enum): + # default + CHAT = auto() + GENERATE = auto() + EMBED = auto() + + +class OllamaMessageBase: + api_type = OllamaMessageAPI.CHAT + + def __init__(self, model: str, **additional_kwargs) -> None: + self.model, self.additional_kwargs = model, additional_kwargs + + @property + def api_suffix(self) -> str: + raise NotImplementedError + + def apply(self, messages: list[dict]) -> dict: + raise NotImplementedError + + def decode(self, response: OpenAIResponse) -> dict: + return json.loads(response.data.decode("utf-8")) + + def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]: + if "role" in msg: + return msg["content"], None + elif "type" in msg: + tpe = msg["type"] + if tpe == "text": + return msg["text"], None + elif tpe == "image_url": + return None, msg["image_url"]["url"] + else: + raise ValueError + else: + raise ValueError + + +class OllamaMessageMeta(type): + registed_message = {} + + def __init__(cls, name, bases, attrs): + super().__init__(name, bases, attrs) + for base in bases: + if issubclass(base, OllamaMessageBase): + api_type = attrs["api_type"] + if isinstance(api_type, list): + for tpe in api_type: + assert tpe not in OllamaMessageMeta.registed_message, "api_type already exist" + assert isinstance(tpe, OllamaMessageAPI), "api_type not support" + OllamaMessageMeta.registed_message[tpe] = cls + else: + assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist" + assert isinstance(api_type, OllamaMessageAPI), "api_type not support" + OllamaMessageMeta.registed_message[api_type] = cls + + @classmethod + def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]: + return cls.registed_message[input_type] + + +class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.CHAT + + @property + def api_suffix(self) -> str: + return "/chat" + + def apply(self, messages: list[dict]) -> dict: + prompts = [] + images = [] + for msg in messages: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + sends.update(self.additional_kwargs) + return sends + + +class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.GENERATE + + @property + def api_suffix(self) -> str: + return "/generate" + + +class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.EMBED + + @property + def api_suffix(self) -> str: + return "/embeddings" + + def apply(self, messages: list[dict]) -> dict: + prompts = [] + for msg in messages: + prompt, _ = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + sends = {"model": self.model, "prompt": "\n".join(prompts)} + sends.update(self.additional_kwargs) + return sends + + @register_provider(LLMType.OLLAMA) class OllamaLLM(BaseLLM): """ @@ -21,43 +130,45 @@ class OllamaLLM(BaseLLM): """ def __init__(self, config: LLMConfig): - self.__init_ollama(config) self.client = GeneralAPIRequestor(base_url=config.base_url) self.config = config - self.suffix_url = "/chat" self.http_method = "post" self.use_system_prompt = False self.cost_manager = TokenCostManager() + self.__init_ollama(config) + + def _get_headers(self): + return ( + None + if not self.config.api_key or self.config.api_key == "sk-" + else {"Authorization": f"Bearer {self.config.api_key}"} + ) + + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.CHAT + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}, "stream": self.config.stream} def __init_ollama(self, config: LLMConfig): assert config.base_url, "ollama base url is required!" self.model = config.model self.pricing_plan = self.model - - def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} - - def get_choice_text(self, resp: dict) -> str: - """get the resp content from llm response""" - assist_msg = resp.get("message", {}) - if assist_msg.get("role", None) == "assistant": # chat - return assist_msg.get("content") - else: # llava - return resp["response"] + ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse) + self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs) + self.headers = self._get_headers() def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} - def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict: - return json.loads(openai_resp.data.decode(encoding)) - async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: - messages, fixed_suffix_url = self._apply_llava(messages) resp, _, _ = await self.client.arequest( method=self.http_method, - url=fixed_suffix_url, - headers=self._get_headers(), - params=messages, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) if isinstance(resp, AsyncGenerator): @@ -65,30 +176,29 @@ class OllamaLLM(BaseLLM): elif isinstance(resp, OpenAIResponse): return self._processing_openai_response(resp) else: - raise NotImplementedError + raise ValueError async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - messages, fixed_suffix_url = self._apply_llava(messages, stream=True) - stream_resp, _, _ = await self.client.arequest( + resp, _, _ = await self.client.arequest( method=self.http_method, - url=fixed_suffix_url, - headers=self._get_headers(), - stream=True, - params=messages, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), + stream=True, ) - if isinstance(stream_resp, AsyncGenerator): - return await self._processing_openai_response_async_generator(stream_resp) - elif isinstance(stream_resp, OpenAIResponse): - return self._processing_openai_response(stream_resp) + if isinstance(resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(resp) + elif isinstance(resp, OpenAIResponse): + return self._processing_openai_response(resp) else: - raise NotImplementedError + raise ValueError def _processing_openai_response(self, openai_resp: OpenAIResponse): - resp = self._decode_and_load(openai_resp) + resp = self.ollama_message.decode(openai_resp) usage = self.get_usage(resp) self._update_costs(usage) return resp @@ -97,7 +207,7 @@ class OllamaLLM(BaseLLM): collected_content = [] usage = {} async for raw_chunk in ag_openai_resp: - chunk = self._decode_and_load(raw_chunk) + chunk = self.ollama_message.decode(raw_chunk) if not chunk.get("done", False): content = self.get_choice_text(chunk) @@ -112,28 +222,37 @@ class OllamaLLM(BaseLLM): full_content = "".join(collected_content) return full_content - def _get_headers(self): - return ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else {"Authorization": f"Bearer {self.config.api_key}"} + +@register_provider(LLMType.OLLAMA_GENERATE) +class OllamaGenerate(OllamaLLM): + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.GENERATE + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}, "stream": self.config.stream} + + +@register_provider(LLMType.OLLAMA_EMBEDDING) +class OllamaEmbeddings(OllamaLLM): + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.EMBED + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}} + + async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: + resp, _, _ = await self.client.arequest( + method=self.http_method, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), + request_timeout=self.get_timeout(timeout), ) + return self.ollama_message.decode(resp)["embedding"] - def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]: - llava = False - if isinstance(messages[0]["content"], str): - return self._const_kwargs(messages, stream), self.suffix_url - - if any(len(msg["content"]) >= 2 for msg in messages): - assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type" - llava = True - if not llava: - return self._const_kwargs(messages, stream), self.suffix_url - - assert len(messages) <= 1, "not support batch massages in llava calling images" - contents = messages[0]["content"] - return { - "model": self.model, - "prompt": contents[0]["text"], - "images": [i["image_url"]["url"][23:] for i in contents[1:]], - }, "/generate" + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) From a5f33e2d514e5bc3502e06cc2dc0523d3df6d017 Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 14:51:37 +0800 Subject: [PATCH 03/13] update ollama --- examples/llm_vision.py | 4 +- metagpt/configs/llm_config.py | 4 +- metagpt/provider/general_api_base.py | 4 +- metagpt/provider/ollama_api.py | 76 ++++++++++++++++++++-------- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/examples/llm_vision.py b/examples/llm_vision.py index 276decd59..eea6550f6 100644 --- a/examples/llm_vision.py +++ b/examples/llm_vision.py @@ -15,8 +15,8 @@ async def main(): # check if the configured llm supports llm-vision capacity. If not, it will throw a error invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") img_base64 = encode_image(invoice_path) - res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) - assert "true" in res.lower() + res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) + assert ("true" in res.lower()) or ("invoice" in res.lower()) if __name__ == "__main__": diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 3a13789d9..34b73d2d5 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -106,8 +106,8 @@ class LLMConfig(YamlModel): root_config_path = CONFIG_ROOT / "config2.yaml" if root_config_path.exists(): raise ValueError( - f"Please set your API key in {root_config_path}. If you also set your config in { - repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" + f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n" + f"the former will overwrite the latter. This may cause unexpected result.\n" ) elif repo_config_path.exists(): raise ValueError(f"Please set your API key in {repo_config_path}") diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index a4b50af4b..34a39fe6c 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -396,8 +396,8 @@ class APIRequestor: "X-LLM-Client-User-Agent": json.dumps(ua), "User-Agent": user_agent, } - - headers.update(api_key_to_header(self.api_type, self.api_key)) + if self.api_key: + headers.update(api_key_to_header(self.api_type, self.api_key)) if self.organization: headers["LLM-Organization"] = self.organization diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 8522f08a1..ab067340c 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -27,6 +27,7 @@ class OllamaMessageBase: def __init__(self, model: str, **additional_kwargs) -> None: self.model, self.additional_kwargs = model, additional_kwargs + self._image_b64_rms = len("data:image/jpeg;base64,") @property def api_suffix(self) -> str: @@ -38,15 +39,16 @@ class OllamaMessageBase: def decode(self, response: OpenAIResponse) -> dict: return json.loads(response.data.decode("utf-8")) + def get_choice(self, to_choice_dict: dict) -> str: + raise NotImplementedError + def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]: - if "role" in msg: - return msg["content"], None - elif "type" in msg: + if "type" in msg: tpe = msg["type"] if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: @@ -84,18 +86,35 @@ class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta): return "/chat" def apply(self, messages: list[dict]) -> dict: + content = messages[0]["content"] prompts = [] images = [] - for msg in messages: - prompt, image = self._parse_input_msg(msg) - if prompt: - prompts.append(prompt) - if image: - images.append(image) - sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + if isinstance(content, list): + for msg in content: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + else: + prompts.append(content) + messes = [] + for prompt in prompts: + if len(images) > 0: + messes.append({"role": "user", "content": "\n".join(prompts), "images": images}) + else: + messes.append({"role": "user", "content": "\n".join(prompts)}) + sends = {"model": self.model, "messages": messes} sends.update(self.additional_kwargs) return sends + def get_choice(self, to_choice_dict: dict) -> str: + message = to_choice_dict["message"] + if message["role"] == "assistant": + return message["content"] + else: + raise ValueError + class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): api_type = OllamaMessageAPI.GENERATE @@ -104,6 +123,29 @@ class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): def api_suffix(self) -> str: return "/generate" + def apply(self, messages: list[dict]) -> dict: + content = messages[0]["content"] + prompts = [] + images = [] + if isinstance(content, list): + for msg in content: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + else: + prompts.append(content) + if len(images) > 0: + sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + else: + sends = {"model": self.model, "prompt": "\n".join(prompts)} + sends.update(self.additional_kwargs) + return sends + + def get_choice(self, to_choice_dict: dict) -> str: + return to_choice_dict["response"] + class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta): api_type = OllamaMessageAPI.EMBED @@ -137,13 +179,6 @@ class OllamaLLM(BaseLLM): self.cost_manager = TokenCostManager() self.__init_ollama(config) - def _get_headers(self): - return ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else {"Authorization": f"Bearer {self.config.api_key}"} - ) - @property def _llama_api_inuse(self) -> OllamaMessageAPI: return OllamaMessageAPI.CHAT @@ -158,7 +193,6 @@ class OllamaLLM(BaseLLM): self.pricing_plan = self.model ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse) self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs) - self.headers = self._get_headers() def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} @@ -167,7 +201,6 @@ class OllamaLLM(BaseLLM): resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, - headers=self.headers, params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) @@ -185,7 +218,6 @@ class OllamaLLM(BaseLLM): resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, - headers=self.headers, params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), stream=True, @@ -210,7 +242,7 @@ class OllamaLLM(BaseLLM): chunk = self.ollama_message.decode(raw_chunk) if not chunk.get("done", False): - content = self.get_choice_text(chunk) + content = self.ollama_message.get_choice(chunk) collected_content.append(content) log_llm_stream(content) else: From 322974a917a73325d02ec9593e740e83d8cdf1cd Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 16:10:56 +0800 Subject: [PATCH 04/13] fix for test --- metagpt/provider/ollama_api.py | 4 +++- tests/metagpt/provider/test_ollama_api.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index ab067340c..3f8467843 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -48,7 +48,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms :] + return None, msg["image_url"]["url"][self._image_b64_rms:] else: raise ValueError else: @@ -211,6 +211,8 @@ class OllamaLLM(BaseLLM): else: raise ValueError + def get_choice_text(self, rsp): return self.ollama_message.get_choice(rsp) + async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index af2e929e9..75cfa86d5 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -3,11 +3,11 @@ # @Desc : the unittest of ollama api import json -from typing import Any, Tuple +from typing import Any, AsyncGenerator, Tuple import pytest -from metagpt.provider.ollama_api import OllamaLLM +from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( llm_general_chat_funcs_test, @@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}} async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: if stream: - class Iterator(object): + async def async_event_generator() -> AsyncGenerator[Any, None]: events = [ b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}', b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}', ] + for event in events: + yield OpenAIResponse(event, {}) - async def __aiter__(self): - for event in self.events: - yield event - - return Iterator(), None, None + return async_event_generator(), None, None else: raw_default_resp = default_resp.copy() raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20}) - return json.dumps(raw_default_resp).encode(), None, None + return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None @pytest.mark.asyncio From 7d3d15a324905ff0688bad8ec5384f41df54cb6e Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 16:19:41 +0800 Subject: [PATCH 05/13] format --- metagpt/provider/ollama_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 3f8467843..fdc6b41fb 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -48,7 +48,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms:] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: @@ -211,7 +211,8 @@ class OllamaLLM(BaseLLM): else: raise ValueError - def get_choice_text(self, rsp): return self.ollama_message.get_choice(rsp) + def get_choice_text(self, rsp): + return self.ollama_message.get_choice(rsp) async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) From c48d16d47504a3a100aa1850a8878843c77043af Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 16:37:23 +0800 Subject: [PATCH 06/13] fix embedding --- metagpt/provider/ollama_api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index fdc6b41fb..178003742 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -48,7 +48,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms :] + return None, msg["image_url"]["url"][self._image_b64_rms:] else: raise ValueError else: @@ -291,3 +291,6 @@ class OllamaEmbeddings(OllamaLLM): async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) + + def get_choice_text(self, rsp): + return rsp From 685d301a66db047cb0a1d01115369aae2f3455db Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 16:38:20 +0800 Subject: [PATCH 07/13] f --- metagpt/provider/ollama_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 178003742..79494058c 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -48,7 +48,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms:] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: From 187e5125472aab0f58bfc7a1feccda79f61a60ab Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 23:09:38 +0800 Subject: [PATCH 08/13] fixup ollama_api chat prompt --- metagpt/provider/ollama_api.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 79494058c..4fe6be0c2 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -63,15 +63,9 @@ class OllamaMessageMeta(type): for base in bases: if issubclass(base, OllamaMessageBase): api_type = attrs["api_type"] - if isinstance(api_type, list): - for tpe in api_type: - assert tpe not in OllamaMessageMeta.registed_message, "api_type already exist" - assert isinstance(tpe, OllamaMessageAPI), "api_type not support" - OllamaMessageMeta.registed_message[tpe] = cls - else: - assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist" - assert isinstance(api_type, OllamaMessageAPI), "api_type not support" - OllamaMessageMeta.registed_message[api_type] = cls + assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist" + assert isinstance(api_type, OllamaMessageAPI), "api_type not support" + OllamaMessageMeta.registed_message[api_type] = cls @classmethod def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]: @@ -101,9 +95,9 @@ class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta): messes = [] for prompt in prompts: if len(images) > 0: - messes.append({"role": "user", "content": "\n".join(prompts), "images": images}) + messes.append({"role": "user", "content": prompt, "images": images}) else: - messes.append({"role": "user", "content": "\n".join(prompts)}) + messes.append({"role": "user", "content": prompt}) sends = {"model": self.model, "messages": messes} sends.update(self.additional_kwargs) return sends From f2aa4e3f9dc333e83bec574836e5f64a2662b9c3 Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 09:47:10 +0800 Subject: [PATCH 09/13] tested for embeddings/embed --- examples/llm_vision.py | 7 +++-- metagpt/configs/llm_config.py | 3 +- metagpt/provider/ollama_api.py | 53 +++++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/examples/llm_vision.py b/examples/llm_vision.py index eea6550f6..eff5c4d52 100644 --- a/examples/llm_vision.py +++ b/examples/llm_vision.py @@ -14,9 +14,10 @@ async def main(): # check if the configured llm supports llm-vision capacity. If not, it will throw a error invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") - img_base64 = encode_image(invoice_path) - res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) - assert ("true" in res.lower()) or ("invoice" in res.lower()) + encode_image(invoice_path) + # res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) + await llm.aask(msg="hello") + # assert ("true" in res.lower()) or ("invoice" in res.lower()) if __name__ == "__main__": diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 34b73d2d5..dbbf5f5d9 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -28,7 +28,8 @@ class LLMType(Enum): AZURE = "azure" OLLAMA = "ollama" # /chat at ollama api OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api - OLLAMA_EMBEDDING = "ollama.embeddings" # /embeddings at ollama api + OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api + OLLAMA_EMBED = "ollama.embed" # /embeddings at ollama api QIANFAN = "qianfan" # Baidu BCE DASHSCOPE = "dashscope" # Aliyun LingJi DashScope MOONSHOT = "moonshot" diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 4fe6be0c2..6a2635b95 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -20,6 +20,7 @@ class OllamaMessageAPI(Enum): CHAT = auto() GENERATE = auto() EMBED = auto() + EMBEDDINGS = auto() class OllamaMessageBase: @@ -141,24 +142,50 @@ class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): return to_choice_dict["response"] -class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta): - api_type = OllamaMessageAPI.EMBED +class OllamaMessageEmbeddings(OllamaMessageBase, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.EMBEDDINGS @property def api_suffix(self) -> str: return "/embeddings" def apply(self, messages: list[dict]) -> dict: - prompts = [] - for msg in messages: - prompt, _ = self._parse_input_msg(msg) - if prompt: - prompts.append(prompt) + content = messages[0]["content"] + prompts = [] # NOTE: not support image to embedding + if isinstance(content, list): + for msg in content: + prompt, _ = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + else: + prompts.append(content) sends = {"model": self.model, "prompt": "\n".join(prompts)} sends.update(self.additional_kwargs) return sends +class OllamaMessageEmbed(OllamaMessageEmbeddings, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.EMBED + + @property + def api_suffix(self) -> str: + return "/embed" + + def apply(self, messages: list[dict]) -> dict: + content = messages[0]["content"] + prompts = [] # NOTE: not support image to embedding + if isinstance(content, list): + for msg in content: + prompt, _ = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + else: + prompts.append(content) + sends = {"model": self.model, "input": prompts} + sends.update(self.additional_kwargs) + return sends + + @register_provider(LLMType.OLLAMA) class OllamaLLM(BaseLLM): """ @@ -263,11 +290,11 @@ class OllamaGenerate(OllamaLLM): return {"options": {"temperature": 0.3}, "stream": self.config.stream} -@register_provider(LLMType.OLLAMA_EMBEDDING) +@register_provider(LLMType.OLLAMA_EMBEDDINGS) class OllamaEmbeddings(OllamaLLM): @property def _llama_api_inuse(self) -> OllamaMessageAPI: - return OllamaMessageAPI.EMBED + return OllamaMessageAPI.EMBEDDINGS @property def _llama_api_kwargs(self) -> dict: @@ -277,7 +304,6 @@ class OllamaEmbeddings(OllamaLLM): resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, - headers=self.headers, params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) @@ -288,3 +314,10 @@ class OllamaEmbeddings(OllamaLLM): def get_choice_text(self, rsp): return rsp + + +@register_provider(LLMType.OLLAMA_EMBED) +class OllamaEmbed(OllamaLLM): + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.EMBED From 062b13fc6cb18ef2681b81c8548e9ae6cbd6344d Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 09:50:21 +0800 Subject: [PATCH 10/13] fix embed --- examples/llm_vision.py | 3 ++- metagpt/provider/ollama_api.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/llm_vision.py b/examples/llm_vision.py index eff5c4d52..1bfbfb230 100644 --- a/examples/llm_vision.py +++ b/examples/llm_vision.py @@ -16,7 +16,8 @@ async def main(): invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") encode_image(invoice_path) # res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) - await llm.aask(msg="hello") + res = await llm.aask(msg="hello") + print(res) # assert ("true" in res.lower()) or ("invoice" in res.lower()) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 6a2635b95..4537a8a2c 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -49,7 +49,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms :] + return None, msg["image_url"]["url"][self._image_b64_rms:] else: raise ValueError else: @@ -317,7 +317,7 @@ class OllamaEmbeddings(OllamaLLM): @register_provider(LLMType.OLLAMA_EMBED) -class OllamaEmbed(OllamaLLM): +class OllamaEmbed(OllamaEmbeddings): @property def _llama_api_inuse(self) -> OllamaMessageAPI: return OllamaMessageAPI.EMBED From 2bca5c9d064c4b45bcfa23d1d6961788cc38bdcf Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 09:55:13 +0800 Subject: [PATCH 11/13] fix embedding output --- metagpt/provider/ollama_api.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 4537a8a2c..3f7d20d0a 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -49,7 +49,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms:] + return None, msg["image_url"]["url"][self._image_b64_rms :] else: raise ValueError else: @@ -300,6 +300,10 @@ class OllamaEmbeddings(OllamaLLM): def _llama_api_kwargs(self) -> dict: return {"options": {"temperature": 0.3}} + @property + def _llama_embedding_key(self) -> str: + return "embedding" + async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: resp, _, _ = await self.client.arequest( method=self.http_method, @@ -307,7 +311,7 @@ class OllamaEmbeddings(OllamaLLM): params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) - return self.ollama_message.decode(resp)["embedding"] + return self.ollama_message.decode(resp)[self._llama_embedding_key] async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) @@ -321,3 +325,7 @@ class OllamaEmbed(OllamaEmbeddings): @property def _llama_api_inuse(self) -> OllamaMessageAPI: return OllamaMessageAPI.EMBED + + @property + def _llama_embedding_key(self) -> str: + return "embeddings" From 82dba4b9bc91ebd4c321e6a76ef4c5d2ffdede8a Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 09:56:35 +0800 Subject: [PATCH 12/13] revert llm_vision --- examples/llm_vision.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/llm_vision.py b/examples/llm_vision.py index 1bfbfb230..eea6550f6 100644 --- a/examples/llm_vision.py +++ b/examples/llm_vision.py @@ -14,11 +14,9 @@ async def main(): # check if the configured llm supports llm-vision capacity. If not, it will throw a error invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") - encode_image(invoice_path) - # res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) - res = await llm.aask(msg="hello") - print(res) - # assert ("true" in res.lower()) or ("invoice" in res.lower()) + img_base64 = encode_image(invoice_path) + res = await llm.aask(msg="return `True` if this image might be a invoice, or return `False`", images=[img_base64]) + assert ("true" in res.lower()) or ("invoice" in res.lower()) if __name__ == "__main__": From e209e0ec8d25818e1c2e62dedb21e7b974a9587f Mon Sep 17 00:00:00 2001 From: EvensXia Date: Wed, 30 Oct 2024 10:01:09 +0800 Subject: [PATCH 13/13] fx --- metagpt/configs/llm_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index dbbf5f5d9..ddee34f97 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -29,7 +29,7 @@ class LLMType(Enum): OLLAMA = "ollama" # /chat at ollama api OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api - OLLAMA_EMBED = "ollama.embed" # /embeddings at ollama api + OLLAMA_EMBED = "ollama.embed" # /embed at ollama api QIANFAN = "qianfan" # Baidu BCE DASHSCOPE = "dashscope" # Aliyun LingJi DashScope MOONSHOT = "moonshot"