diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 8e5da8f16..ab733db58 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -320,49 +320,6 @@ class APIRequestor: resp, got_stream = self._interpret_response(result, stream) return resp, got_stream, self.api_key - @overload - async def arequest( - self, - method, - url, - params, - headers, - files, - stream: Literal[True], - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: - pass - - @overload - async def arequest( - self, - method, - url, - params=..., - headers=..., - files=..., - *, - stream: Literal[True], - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: - pass - - @overload - async def arequest( - self, - method, - url, - params=..., - headers=..., - files=..., - stream: Literal[False] = ..., - request_id: Optional[str] = ..., - request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., - ) -> Tuple[OpenAIResponse, bool, str]: - pass - @overload async def arequest( self, diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 501a064e3..b8da1565d 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -3,25 +3,24 @@ # @Desc : General Async API for http-based LLM model import asyncio -from typing import AsyncGenerator, Generator, Iterator, Tuple, Union +from typing import AsyncGenerator, Iterator, Optional, Tuple, Union import aiohttp import requests from metagpt.logs import logger -from metagpt.provider.general_api_base import APIRequestor +from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse -def parse_stream_helper(line: bytes) -> Union[bytes, None]: +def parse_stream_helper(line: bytes) -> Optional[bytes]: if line and line.startswith(b"data:"): if line.startswith(b"data: "): - # SSE event may be valid when it contain whitespace + # SSE event may be valid when it contains whitespace line = line[len(b"data: ") :] else: line = line[len(b"data:") :] if line.strip() == b"[DONE]": - # return here will cause GeneratorExit exception in urllib3 - # and it will close http connection with TCP Reset + # Returning None to indicate end of stream return None else: return line @@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]: class GeneralAPIRequestor(APIRequestor): """ - usage + Usage example: # full_url = "{base_url}{url}" requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( @@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor): ) """ - def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes: - # just do nothing to meet the APIRequestor process and return the raw data - # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases. + def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse: + """ + Process and return the response data wrapped in OpenAIResponse. - return rbody + Args: + rbody (bytes): The response body. + rcode (int): The response status code. + rheaders (dict): The response headers. + stream (bool): Whether the response is a stream. + + Returns: + OpenAIResponse: The response data wrapped in OpenAIResponse. + """ + return OpenAIResponse(rbody, rheaders) def _interpret_response( self, result: requests.Response, stream: bool - ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: - """Returns the response(s) and a bool indicating whether it is a stream.""" + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: + """ + Interpret a synchronous response. + + Args: + result (requests.Response): The response object. + stream (bool): Whether the response is a stream. + + Returns: + Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream. + """ content_type = result.headers.get("Content-Type", "") if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type): return ( - self._interpret_response_line(line, result.status_code, result.headers, stream=True) - for line in parse_stream(result.iter_lines()) - ), True + ( + self._interpret_response_line(line, result.status_code, result.headers, stream=True) + for line in parse_stream(result.iter_lines()) + ), + True, + ) else: return ( self._interpret_response_line( - result.content, # let the caller to decode the msg + result.content, # let the caller decode the msg result.status_code, result.headers, stream=False, @@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor): async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool - ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]: + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + """ + Interpret an asynchronous response. + + Args: + result (aiohttp.ClientResponse): The response object. + stream (bool): Whether the response is a stream. + + Returns: + Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream. + """ content_type = result.headers.get("Content-Type", "") if stream and ( "text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == "" ): - # the `Content-Type` of ollama stream resp is "application/x-ndjson" return ( - self._interpret_response_line(line, result.status, result.headers, stream=True) - async for line in result.content - ), True + ( + self._interpret_response_line(line, result.status, result.headers, stream=True) + async for line in result.content + ), + True, + ) else: try: - await result.read() + response_content = await result.read() except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: raise TimeoutError("Request timed out") from e except aiohttp.ClientError as exp: - logger.warning(f"response: {result.content}, exp: {exp}") + logger.warning(f"response: {result}, exp: {exp}") + response_content = b"" return ( self._interpret_response_line( - await result.read(), # let the caller to decode the msg + response_content, # let the caller decode the msg result.status, result.headers, stream=False, diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 454f0e3ee..39cbe291e 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -3,12 +3,13 @@ # @Desc : self-host open llm model with ollama which isn't openai-api-compatible import json +from typing import AsyncGenerator, Tuple from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM -from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse from metagpt.provider.llm_provider_registry import register_provider from metagpt.utils.cost_manager import TokenCostManager @@ -34,65 +35,68 @@ class OllamaLLM(BaseLLM): self.pricing_plan = self.model def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} - return kwargs + return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) - assert assist_msg.get("role", None) == "assistant" - return assist_msg.get("content") + if assist_msg.get("role", None) == "assistant": # chat + return assist_msg.get("content") + else: # llava + return resp["response"] def get_usage(self, resp: dict) -> dict: return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)} - def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict: - chunk = chunk.decode(encoding) - return json.loads(chunk) + def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict: + return json.loads(openai_resp.data.decode(encoding)) async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict: - headers = ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else { - "Authorization": f"Bearer {self.config.api_key}", - } - ) + messages, fixed_suffix_url = self._apply_llava(messages) resp, _, _ = await self.client.arequest( method=self.http_method, - url=self.suffix_url, - headers=headers, - params=self._const_kwargs(messages), + url=fixed_suffix_url, + headers=self._get_headers(), + params=messages, request_timeout=self.get_timeout(timeout), ) - resp = self._decode_and_load(resp) - usage = self.get_usage(resp) - self._update_costs(usage) - return resp + if isinstance(resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(resp) + elif isinstance(resp, OpenAIResponse): + return self._processing_openai_response(resp) + else: + raise NotImplementedError async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - headers = ( - None - if not self.config.api_key or self.config.api_key == "sk-" - else { - "Authorization": f"Bearer {self.config.api_key}", - } - ) + messages, fixed_suffix_url = self._apply_llava(messages, stream=True) stream_resp, _, _ = await self.client.arequest( method=self.http_method, - url=self.suffix_url, - headers=headers, + url=fixed_suffix_url, + headers=self._get_headers(), stream=True, - params=self._const_kwargs(messages, stream=True), + params=messages, request_timeout=self.get_timeout(timeout), ) + if isinstance(stream_resp, AsyncGenerator): + return await self._processing_openai_response_async_generator(stream_resp) + elif isinstance(stream_resp, OpenAIResponse): + return self._processing_openai_response(stream_resp) + else: + raise NotImplementedError + def _processing_openai_response(self, openai_resp: OpenAIResponse): + resp = self._decode_and_load(openai_resp) + usage = self.get_usage(resp) + self._update_costs(usage) + return resp + + async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]): collected_content = [] usage = {} - async for raw_chunk in stream_resp: + async for raw_chunk in ag_openai_resp: chunk = self._decode_and_load(raw_chunk) if not chunk.get("done", False): @@ -107,3 +111,29 @@ class OllamaLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content + + def _get_headers(self): + return ( + None + if not self.config.api_key or self.config.api_key == "sk-" + else {"Authorization": f"Bearer {self.config.api_key}"} + ) + + def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]: + llava = False + if isinstance(messages[0]["content"], str): + return self._const_kwargs(messages, stream), self.suffix_url + + if any(len(msg["content"]) >= 2 for msg in messages): + assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type" + llava = True + if not llava: + return self._const_kwargs(messages, stream), self.suffix_url + + assert len(messages) <= 1, "not support batch massages in llava calling images" + contents = messages[0]["content"] + return { + "model": self.model, + "prompt": contents[0]["text"], + "images": [i["image_url"]["url"][23:] for i in contents[1:]], + }, "/generate"