diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 60118303c..3a13789d9 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -26,7 +26,9 @@ class LLMType(Enum): GEMINI = "gemini" METAGPT = "metagpt" AZURE = "azure" - OLLAMA = "ollama" + OLLAMA = "ollama" # /chat at ollama api + OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api + OLLAMA_EMBEDDING = "ollama.embeddings" # /embeddings at ollama api QIANFAN = "qianfan" # Baidu BCE DASHSCOPE = "dashscope" # Aliyun LingJi DashScope MOONSHOT = "moonshot" @@ -104,7 +106,8 @@ class LLMConfig(YamlModel): root_config_path = CONFIG_ROOT / "config2.yaml" if root_config_path.exists(): raise ValueError( - f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" + f"Please set your API key in {root_config_path}. If you also set your config in { + repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n" ) elif repo_config_path.exists(): raise ValueError(f"Please set your API key in {repo_config_path}") diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index ab733db58..a4b50af4b 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -13,6 +13,7 @@ import time from contextlib import asynccontextmanager from enum import Enum from typing import ( + Any, AsyncGenerator, AsyncIterator, Dict, @@ -121,7 +122,7 @@ def logfmt(props): class OpenAIResponse: - def __init__(self, data, headers): + def __init__(self, data: Union[bytes, Any], headers: dict): self._headers = headers self.data = data diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 39cbe291e..8522f08a1 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -3,7 +3,8 @@ # @Desc : self-host open llm model with ollama which isn't openai-api-compatible import json -from typing import AsyncGenerator, Tuple +from enum import Enum, auto +from typing import AsyncGenerator, Optional, Tuple from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT @@ -14,6 +15,114 @@ from metagpt.provider.llm_provider_registry import register_provider from metagpt.utils.cost_manager import TokenCostManager +class OllamaMessageAPI(Enum): + # default + CHAT = auto() + GENERATE = auto() + EMBED = auto() + + +class OllamaMessageBase: + api_type = OllamaMessageAPI.CHAT + + def __init__(self, model: str, **additional_kwargs) -> None: + self.model, self.additional_kwargs = model, additional_kwargs + + @property + def api_suffix(self) -> str: + raise NotImplementedError + + def apply(self, messages: list[dict]) -> dict: + raise NotImplementedError + + def decode(self, response: OpenAIResponse) -> dict: + return json.loads(response.data.decode("utf-8")) + + def _parse_input_msg(self, msg: dict) -> Tuple[Optional[str], Optional[str]]: + if "role" in msg: + return msg["content"], None + elif "type" in msg: + tpe = msg["type"] + if tpe == "text": + return msg["text"], None + elif tpe == "image_url": + return None, msg["image_url"]["url"] + else: + raise ValueError + else: + raise ValueError + + +class OllamaMessageMeta(type): + registed_message = {} + + def __init__(cls, name, bases, attrs): + super().__init__(name, bases, attrs) + for base in bases: + if issubclass(base, OllamaMessageBase): + api_type = attrs["api_type"] + if isinstance(api_type, list): + for tpe in api_type: + assert tpe not in OllamaMessageMeta.registed_message, "api_type already exist" + assert isinstance(tpe, OllamaMessageAPI), "api_type not support" + OllamaMessageMeta.registed_message[tpe] = cls + else: + assert api_type not in OllamaMessageMeta.registed_message, "api_type already exist" + assert isinstance(api_type, OllamaMessageAPI), "api_type not support" + OllamaMessageMeta.registed_message[api_type] = cls + + @classmethod + def get_message(cls, input_type: OllamaMessageAPI) -> type[OllamaMessageBase]: + return cls.registed_message[input_type] + + +class OllamaMessageChat(OllamaMessageBase, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.CHAT + + @property + def api_suffix(self) -> str: + return "/chat" + + def apply(self, messages: list[dict]) -> dict: + prompts = [] + images = [] + for msg in messages: + prompt, image = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + if image: + images.append(image) + sends = {"model": self.model, "prompt": "\n".join(prompts), "images": images} + sends.update(self.additional_kwargs) + return sends + + +class OllamaMessageGenerate(OllamaMessageChat, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.GENERATE + + @property + def api_suffix(self) -> str: + return "/generate" + + +class OllamaMessageEmbed(OllamaMessageBase, metaclass=OllamaMessageMeta): + api_type = OllamaMessageAPI.EMBED + + @property + def api_suffix(self) -> str: + return "/embeddings" + + def apply(self, messages: list[dict]) -> dict: + prompts = [] + for msg in messages: + prompt, _ = self._parse_input_msg(msg) + if prompt: + prompts.append(prompt) + sends = {"model": self.model, "prompt": "\n".join(prompts)} + sends.update(self.additional_kwargs) + return sends + + @register_provider(LLMType.OLLAMA) class OllamaLLM(BaseLLM): """ @@ -21,43 +130,45 @@ class OllamaLLM(BaseLLM): """ def __init__(self, config: LLMConfig): - self.__init_ollama(config) self.client = GeneralAPIRequestor(base_url=config.base_url) self.config = config - self.suffix_url = "/chat" self.http_method = "post" self.use_system_prompt = False self.cost_manager = TokenCostManager() + self.__init_ollama(config) + + def _get_headers(self): + return ( + None + if not self.config.api_key or self.config.api_key == "sk-" + else {"Authorization": f"Bearer {self.config.api_key}"} + ) + + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.CHAT + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}, "stream": self.config.stream} def __init_ollama(self, config: LLMConfig): assert config.base_url, "ollama base url is required!" self.model = config.model self.pricing_plan = self.model - - def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} - - def get_choice_text(self, resp: dict) -> str: - """get the resp content from llm response""" - assist_msg = resp.get("message", {}) - if assist_msg.get("role", None) == "assistant": # chat - return assist_msg.get("content") - else: # llava - return resp["response"] + ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse) + self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs) + self.headers = self._get_headers() def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} - def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict: - return json.loads(openai_resp.data.decode(encoding)) - async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: - messages, fixed_suffix_url = self._apply_llava(messages) resp, _, _ = await self.client.arequest( method=self.http_method, - url=fixed_suffix_url, - headers=self._get_headers(), - params=messages, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), ) if isinstance(resp, AsyncGenerator): @@ -65,30 +176,29 @@ class OllamaLLM(BaseLLM): elif isinstance(resp, OpenAIResponse): return self._processing_openai_response(resp) else: - raise NotImplementedError + raise ValueError async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - messages, fixed_suffix_url = self._apply_llava(messages, stream=True) - stream_resp, _, _ = await self.client.arequest( + resp, _, _ = await self.client.arequest( method=self.http_method, - url=fixed_suffix_url, - headers=self._get_headers(), - stream=True, - params=messages, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), request_timeout=self.get_timeout(timeout), + stream=True, ) - if isinstance(stream_resp, AsyncGenerator): - return await self._processing_openai_response_async_generator(stream_resp) - elif isinstance(stream_resp, OpenAIResponse): - return self._processing_openai_response(stream_resp) + if isinstance(resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(resp) + elif isinstance(resp, OpenAIResponse): + return self._processing_openai_response(resp) else: - raise NotImplementedError + raise ValueError def _processing_openai_response(self, openai_resp: OpenAIResponse): - resp = self._decode_and_load(openai_resp) + resp = self.ollama_message.decode(openai_resp) usage = self.get_usage(resp) self._update_costs(usage) return resp @@ -97,7 +207,7 @@ class OllamaLLM(BaseLLM): collected_content = [] usage = {} async for raw_chunk in ag_openai_resp: - chunk = self._decode_and_load(raw_chunk) + chunk = self.ollama_message.decode(raw_chunk) if not chunk.get("done", False): content = self.get_choice_text(chunk) @@ -112,28 +222,37 @@ class OllamaLLM(BaseLLM): full_content = "".join(collected_content) return full_content - def _get_headers(self): - return ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else {"Authorization": f"Bearer {self.config.api_key}"} + +@register_provider(LLMType.OLLAMA_GENERATE) +class OllamaGenerate(OllamaLLM): + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.GENERATE + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}, "stream": self.config.stream} + + +@register_provider(LLMType.OLLAMA_EMBEDDING) +class OllamaEmbeddings(OllamaLLM): + @property + def _llama_api_inuse(self) -> OllamaMessageAPI: + return OllamaMessageAPI.EMBED + + @property + def _llama_api_kwargs(self) -> dict: + return {"options": {"temperature": 0.3}} + + async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: + resp, _, _ = await self.client.arequest( + method=self.http_method, + url=self.ollama_message.api_suffix, + headers=self.headers, + params=self.ollama_message.apply(messages=messages), + request_timeout=self.get_timeout(timeout), ) + return self.ollama_message.decode(resp)["embedding"] - def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]: - llava = False - if isinstance(messages[0]["content"], str): - return self._const_kwargs(messages, stream), self.suffix_url - - if any(len(msg["content"]) >= 2 for msg in messages): - assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type" - llava = True - if not llava: - return self._const_kwargs(messages, stream), self.suffix_url - - assert len(messages) <= 1, "not support batch massages in llava calling images" - contents = messages[0]["content"] - return { - "model": self.model, - "prompt": contents[0]["text"], - "images": [i["image_url"]["url"][23:] for i in contents[1:]], - }, "/generate" + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + return await self._achat_completion(messages, timeout=self.get_timeout(timeout))