diff --git a/config/config.yaml b/config/config.yaml index e724897ee..a9c764c56 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -48,6 +48,10 @@ RPM: 10 #FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1" #FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat +#### if use self-host open llm model by ollama +# OLLAMA_API_BASE: http://127.0.0.1:11434/api +# OLLAMA_API_MODEL: llama2 + #### for Search ## Supported values: serpapi/google/serper/ddg diff --git a/metagpt/config.py b/metagpt/config.py index 5176a7677..208b4fd7b 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -42,6 +42,7 @@ class LLMProviderEnum(Enum): FIREWORKS = "fireworks" OPEN_LLM = "open_llm" GEMINI = "gemini" + OLLAMA = "ollama" class Config(metaclass=Singleton): @@ -78,7 +79,8 @@ class Config(metaclass=Singleton): (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), - (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key + (self.gemini_api_key, LLMProviderEnum.GEMINI), + (self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): # logger.debug(f"Use LLMProvider: {v.value}") @@ -103,6 +105,8 @@ class Config(metaclass=Singleton): self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") self.gemini_api_key = self._get("GEMINI_API_KEY") + self.ollama_api_base = self._get("OLLAMA_API_BASE") + self.ollama_api_model = self._get("OLLAMA_API_MODEL") _ = self.get_default_llm_provider_enum() self.openai_base_url = self._get("OPENAI_BASE_URL") diff --git a/metagpt/const.py b/metagpt/const.py index 3b4f2ae4b..1819bbb49 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -102,3 +102,5 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries" CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries" YAPI_URL = "http://yapi.deepwisdomai.com/" + +LLM_API_TIMEOUT = 300 diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index a9f46eb03..42626a551 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -8,8 +8,9 @@ from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.ollama_api import OllamaGPTAPI from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"] diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index da16e942d..015e34aeb 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -1,3 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : refs to openai 0.x sdk + import asyncio import json import os @@ -43,8 +47,8 @@ MAX_CONNECTION_RETRIES = 2 # Has one attribute per thread, 'session'. _thread_context = threading.local() -OPENAI_LOG = os.environ.get("OPENAI_LOG") -OPENAI_LOG = "debug" +LLM_LOG = os.environ.get("LLM_LOG") +LLM_LOG = "debug" class ApiType(Enum): @@ -74,8 +78,8 @@ api_key_to_header = ( def _console_log_level(): - if OPENAI_LOG in ["debug", "info"]: - return OPENAI_LOG + if LLM_LOG in ["debug", "info"]: + return LLM_LOG else: return None @@ -140,7 +144,7 @@ class OpenAIResponse: @property def organization(self) -> Optional[str]: - return self._headers.get("OpenAI-Organization") + return self._headers.get("LLM-Organization") @property def response_ms(self) -> Optional[int]: @@ -478,7 +482,7 @@ class APIRequestor: error_data["message"] += "\n\n" + error_data["internal_message"] log_info( - "OpenAI API error received", + "LLM API error received", error_code=error_data.get("code"), error_type=error_data.get("type"), error_message=error_data.get("message"), @@ -516,7 +520,7 @@ class APIRequestor: ) def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]: - user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,) + user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,) uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node") ua = { @@ -530,17 +534,17 @@ class APIRequestor: } headers = { - "X-OpenAI-Client-User-Agent": json.dumps(ua), + "X-LLM-Client-User-Agent": json.dumps(ua), "User-Agent": user_agent, } headers.update(api_key_to_header(self.api_type, self.api_key)) if self.organization: - headers["OpenAI-Organization"] = self.organization + headers["LLM-Organization"] = self.organization if self.api_version is not None and self.api_type == ApiType.OPEN_AI: - headers["OpenAI-Version"] = self.api_version + headers["LLM-Version"] = self.api_version if request_id is not None: headers["X-Request-Id"] = request_id headers.update(extra) @@ -592,15 +596,14 @@ class APIRequestor: headers["Content-Type"] = "application/json" else: raise openai.APIConnectionError( - "Unrecognized HTTP method %r. This may indicate a bug in the " - "OpenAI bindings. Please contact us through our help center at help.openai.com for " - "assistance." % (method,) + message=f"Unrecognized HTTP method {method}. This may indicate a bug in the LLM bindings.", + request=None, ) headers = self.request_headers(method, headers, request_id) - log_debug("Request to OpenAI API", method=method, path=abs_url) - log_debug("Post details", data=data, api_version=self.api_version) + # log_debug("Request to LLM API", method=method, path=abs_url) + # log_debug("Post details", data=data, api_version=self.api_version) return abs_url, headers, data @@ -639,14 +642,14 @@ class APIRequestor: except requests.exceptions.Timeout as e: raise openai.APITimeoutError("Request timed out: {}".format(e)) from e except requests.exceptions.RequestException as e: - raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e - log_debug( - "OpenAI API response", - path=abs_url, - response_code=result.status_code, - processing_ms=result.headers.get("OpenAI-Processing-Ms"), - request_id=result.headers.get("X-Request-Id"), - ) + raise openai.APIConnectionError(message="Error communicating with LLM: {}".format(e), request=None) from e + # log_debug( + # "LLM API response", + # path=abs_url, + # response_code=result.status_code, + # processing_ms=result.headers.get("LLM-Processing-Ms"), + # request_id=result.headers.get("X-Request-Id"), + # ) return result async def arequest_raw( @@ -685,18 +688,18 @@ class APIRequestor: } try: result = await session.request(**request_kwargs) - log_info( - "OpenAI API response", - path=abs_url, - response_code=result.status, - processing_ms=result.headers.get("OpenAI-Processing-Ms"), - request_id=result.headers.get("X-Request-Id"), - ) + # log_info( + # "LLM API response", + # path=abs_url, + # response_code=result.status, + # processing_ms=result.headers.get("LLM-Processing-Ms"), + # request_id=result.headers.get("X-Request-Id"), + # ) return result except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: raise openai.APITimeoutError("Request timed out") from e except aiohttp.ClientError as e: - raise openai.APIConnectionError("Error communicating with OpenAI") from e + raise openai.APIConnectionError(message="Error communicating with LLM", request=None) from e def _interpret_response( self, result: requests.Response, stream: bool diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index f8321cc6b..8b06b9388 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -3,14 +3,38 @@ # @Desc : General Async API for http-based LLM model import asyncio -from typing import AsyncGenerator, Tuple, Union +from typing import AsyncGenerator, Generator, Iterator, Optional, Tuple, Union import aiohttp +import requests from metagpt.logs import logger from metagpt.provider.general_api_base import APIRequestor +def parse_stream_helper(line: bytes) -> Optional[str]: + if line and line.startswith(b"data:"): + if line.startswith(b"data: "): + # SSE event may be valid when it contain whitespace + line = line[len(b"data: ") :] + else: + line = line[len(b"data:") :] + if line.strip() == b"[DONE]": + # return here will cause GeneratorExit exception in urllib3 + # and it will close http connection with TCP Reset + return None + else: + return line.decode("utf-8") + return None + + +def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: + for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + class GeneralAPIRequestor(APIRequestor): """ usage @@ -32,10 +56,34 @@ class GeneralAPIRequestor(APIRequestor): return rbody + def _interpret_response( + self, result: requests.Response, stream: bool + ) -> Tuple[Union[str, Iterator[Generator]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + return ( + self._interpret_response_line(line, result.status_code, result.headers, stream=True) + for line in parse_stream(result.iter_lines()) + ), True + else: + return ( + self._interpret_response_line( + result.content, # let the caller to decode the msg + result.status_code, + result.headers, + stream=False, + ), + False, + ) + async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool ) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]: - if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + if stream and ( + "text/event-stream" in result.headers.get("Content-Type", "") + or "application/x-ndjson" in result.headers.get("Content-Type", "") + ): + # the `Content-Type` of ollama stream resp is "application/x-ndjson" return ( self._interpret_response_line(line, result.status, result.headers, stream=True) async for line in result.content diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py new file mode 100644 index 000000000..a15c46458 --- /dev/null +++ b/metagpt/provider/ollama_api.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : self-host open llm model with ollama which isn't openai-api-compatible + +import json + +from requests import ConnectionError +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.const import LLM_API_TIMEOUT +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class OllamaCostManager(CostManager): + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + + logger.info( + f"Max budget: ${CONFIG.max_budget:.3f} | " + f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + CONFIG.total_cost = self.total_cost + + +@register_provider(LLMProviderEnum.OLLAMA) +class OllamaGPTAPI(BaseGPTAPI): + """ + Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion` + """ + + def __init__(self): + self.__init_ollama(CONFIG) + self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base) + self.suffix_url = "/chat" + self.http_method = "post" + self.use_system_prompt = False + self._cost_manager = OllamaCostManager() + + def __init_ollama(self, config: CONFIG): + assert config.ollama_api_base + + self.model = config.ollama_api_model + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} + return kwargs + + def _update_costs(self, usage: dict): + """update each request's token cost""" + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"ollama updats costs failed! exp: {e}") + + def get_choice_text(self, resp: dict) -> str: + """get the resp content from llm response""" + assist_msg = resp.get("message", {}) + assert assist_msg.get("role", None) == "assistant" + return assist_msg.get("content") + + def get_usage(self, resp: dict) -> dict: + return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} + + def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict: + chunk = chunk.decode(encoding) + return json.loads(chunk) + + def completion(self, messages: list[dict]) -> dict: + resp, _, _ = self.client.request( + method=self.http_method, + url=self.suffix_url, + params=self._const_kwargs(messages), + request_timeout=LLM_API_TIMEOUT, + ) + resp = self._decode_and_load(resp) + usage = self.get_usage(resp) + self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> dict: + resp, _, _ = await self.client.arequest( + method=self.http_method, + url=self.suffix_url, + params=self._const_kwargs(messages), + request_timeout=LLM_API_TIMEOUT, + ) + resp = self._decode_and_load(resp) + usage = self.get_usage(resp) + self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + stream_resp, _, _ = await self.client.arequest( + method=self.http_method, + url=self.suffix_url, + stream=True, + params=self._const_kwargs(messages, stream=True), + request_timeout=LLM_API_TIMEOUT, + ) + + collected_content = [] + usage = {} + async for raw_chunk in stream_resp: + chunk = self._decode_and_load(raw_chunk) + + if not chunk.get("done", False): + content = self.get_choice_text(chunk) + collected_content.append(content) + print(content, end="") + else: + # stream finished + usage = self.get_usage(chunk) + + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """response in async with stream or non-stream mode""" + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 67ad4e963..87fd0efd0 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -196,6 +196,8 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = f'"{line}' elif '",' in line: new_line = line[:-2] + "'," + else: + new_line = line arr[line_no] = new_line output = "\n".join(arr) diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 229d9b9a7..9c8cf46c0 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -9,7 +9,7 @@ import pytest from metagpt.provider.google_gemini_api import GeminiGPTAPI -messages = [{"role": "user", "content": "who are you"}] +messages = [{"role": "user", "parts": "who are you"}] @dataclass diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py new file mode 100644 index 000000000..2798f5cc3 --- /dev/null +++ b/tests/metagpt/provider/test_ollama_api.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of ollama api + +import pytest + +from metagpt.provider.ollama_api import OllamaGPTAPI + +messages = [{"role": "user", "content": "who are you"}] + + +default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}} + + +def mock_llm_ask(self, messages: list[dict]) -> dict: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask) + resp = OllamaGPTAPI().completion(messages) + assert resp["message"]["content"] == default_resp["message"]["content"] + + +async def mock_llm_aask(self, messgaes: list[dict]) -> dict: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask) + resp = await OllamaGPTAPI().acompletion(messages) + assert resp["message"]["content"] == default_resp["message"]["content"]