diff --git a/metagpt/llm.py b/metagpt/llm.py index dce33b9db..2ad40cb1c 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -4,23 +4,20 @@ @Time : 2023/5/11 14:45 @Author : alexanderwu @File : llm.py -@Modified By: mashenquan, 2023-12-4. Upgrade openai to 1.x """ from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 as Claude -from metagpt.provider.human_provider import HumanProvider from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI from metagpt.provider.spark_api import SparkAPI -# openai v1.x removed the 'api_requestor', making interfaces built on it no longer functional. -# More: https://github.com/openai/openai-python/discussions/742 -# from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.human_provider import HumanProvider _ = HumanProvider() # Avoid pre-commit error def LLM() -> "BaseGPTAPI": - """initialize different LLM instance according to the key field existence""" + """ initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further if CONFIG.openai_api_key: llm = OpenAIGPTAPI() @@ -28,8 +25,8 @@ def LLM() -> "BaseGPTAPI": llm = Claude() elif CONFIG.spark_api_key: llm = SparkAPI() - # elif CONFIG.zhipuai_api_key: # openai v1.x removed the 'api_requestor' - # llm = ZhiPuAIGPTAPI() + elif CONFIG.zhipuai_api_key: + llm = ZhiPuAIGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py new file mode 100644 index 000000000..da16e942d --- /dev/null +++ b/metagpt/provider/general_api_base.py @@ -0,0 +1,718 @@ +import asyncio +import json +import os +import platform +import re +import sys +import threading +import time +from contextlib import asynccontextmanager +from enum import Enum +from typing import ( + AsyncGenerator, + AsyncIterator, + Callable, + Dict, + Iterator, + Optional, + Tuple, + Union, + overload, +) +from urllib.parse import urlencode, urlsplit, urlunsplit + +import aiohttp +import requests + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +import logging + +import openai +from openai import version + +logger = logging.getLogger("openai") + +TIMEOUT_SECS = 600 +MAX_SESSION_LIFETIME_SECS = 180 +MAX_CONNECTION_RETRIES = 2 + +# Has one attribute per thread, 'session'. +_thread_context = threading.local() + +OPENAI_LOG = os.environ.get("OPENAI_LOG") +OPENAI_LOG = "debug" + + +class ApiType(Enum): + AZURE = 1 + OPEN_AI = 2 + AZURE_AD = 3 + + @staticmethod + def from_str(label): + if label.lower() == "azure": + return ApiType.AZURE + elif label.lower() in ("azure_ad", "azuread"): + return ApiType.AZURE_AD + elif label.lower() in ("open_ai", "openai"): + return ApiType.OPEN_AI + else: + raise openai.OpenAIError( + "The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'" + ) + + +api_key_to_header = ( + lambda api, key: {"Authorization": f"Bearer {key}"} + if api in (ApiType.OPEN_AI, ApiType.AZURE_AD) + else {"api-key": f"{key}"} +) + + +def _console_log_level(): + if OPENAI_LOG in ["debug", "info"]: + return OPENAI_LOG + else: + return None + + +def log_debug(message, **params): + msg = logfmt(dict(message=message, **params)) + if _console_log_level() == "debug": + print(msg, file=sys.stderr) + logger.debug(msg) + + +def log_info(message, **params): + msg = logfmt(dict(message=message, **params)) + if _console_log_level() in ["debug", "info"]: + print(msg, file=sys.stderr) + logger.info(msg) + + +def log_warn(message, **params): + msg = logfmt(dict(message=message, **params)) + print(msg, file=sys.stderr) + logger.warn(msg) + + +def logfmt(props): + def fmt(key, val): + # Handle case where val is a bytes or bytesarray + if hasattr(val, "decode"): + val = val.decode("utf-8") + # Check if val is already a string to avoid re-encoding into ascii. + if not isinstance(val, str): + val = str(val) + if re.search(r"\s", val): + val = repr(val) + # key should already be a string + if re.search(r"\s", key): + key = repr(key) + return "{key}={val}".format(key=key, val=val) + + return " ".join([fmt(key, val) for key, val in sorted(props.items())]) + + +class OpenAIResponse: + def __init__(self, data, headers): + self._headers = headers + self.data = data + + @property + def request_id(self) -> Optional[str]: + return self._headers.get("request-id") + + @property + def retry_after(self) -> Optional[int]: + try: + return int(self._headers.get("retry-after")) + except TypeError: + return None + + @property + def operation_location(self) -> Optional[str]: + return self._headers.get("operation-location") + + @property + def organization(self) -> Optional[str]: + return self._headers.get("OpenAI-Organization") + + @property + def response_ms(self) -> Optional[int]: + h = self._headers.get("Openai-Processing-Ms") + return None if h is None else round(float(h)) + + +def _build_api_url(url, query): + scheme, netloc, path, base_query, fragment = urlsplit(url) + + if base_query: + query = "%s&%s" % (base_query, query) + + return urlunsplit((scheme, netloc, path, query, fragment)) + + +def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]: + """Returns a value suitable for the 'proxies' argument to 'requests.request.""" + if proxy is None: + return None + elif isinstance(proxy, str): + return {"http": proxy, "https": proxy} + elif isinstance(proxy, dict): + return proxy.copy() + else: + raise ValueError( + "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." + ) + + +def _aiohttp_proxies_arg(proxy) -> Optional[str]: + """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request.""" + if proxy is None: + return None + elif isinstance(proxy, str): + return proxy + elif isinstance(proxy, dict): + return proxy["https"] if "https" in proxy else proxy["http"] + else: + raise ValueError( + "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." + ) + + +def _make_session() -> requests.Session: + s = requests.Session() + s.mount( + "https://", + requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES), + ) + return s + + +def parse_stream_helper(line: bytes) -> Optional[str]: + if line: + if line.strip() == b"data: [DONE]": + # return here will cause GeneratorExit exception in urllib3 + # and it will close http connection with TCP Reset + return None + if line.startswith(b"data: "): + line = line[len(b"data: ") :] + return line.decode("utf-8") + else: + return None + return None + + +def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: + for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + +async def parse_stream_async(rbody: aiohttp.StreamReader): + async for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + +class APIRequestor: + def __init__( + self, + key=None, + base_url=None, + api_type=None, + api_version=None, + organization=None, + ): + self.base_url = base_url or openai.base_url + self.api_key = key or openai.api_key + self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str("openai") + self.api_version = api_version or openai.api_version + self.organization = organization or openai.organization + + def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]): + if not predicate(response): + return + error_data = response.data["error"] + message = error_data.get("message", "Operation failed") + code = error_data.get("code") + raise openai.APIError(message=message, body=dict(code=code)) + + def _poll( + self, method, url, until, failed, params=None, headers=None, interval=None, delay=None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + time.sleep(delay) + + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise openai.APITimeoutError("Operation polling timed out.") + + time.sleep(interval or response.retry_after or 10) + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data["result"] + return response, b, api_key + + async def _apoll( + self, method, url, until, failed, params=None, headers=None, interval=None, delay=None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + await asyncio.sleep(delay) + + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise openai.APITimeoutError("Operation polling timed out.") + + await asyncio.sleep(interval or response.retry_after or 10) + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data["result"] + return response, b, api_key + + @overload + def request( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: + pass + + def request( + self, + method, + url, + params=None, + headers=None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: + result = self.request_raw( + method.lower(), + url, + params=params, + supplied_headers=headers, + files=files, + stream=stream, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = self._interpret_response(result, stream) + return resp, got_stream, self.api_key + + @overload + async def arequest( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + pass + + async def arequest( + self, + method, + url, + params=None, + headers=None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + ctx = aiohttp_session() + session = await ctx.__aenter__() + try: + result = await self.arequest_raw( + method.lower(), + url, + session, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + except Exception: + await ctx.__aexit__(None, None, None) + raise + if got_stream: + + async def wrap_resp(): + assert isinstance(resp, AsyncGenerator) + try: + async for r in resp: + yield r + finally: + await ctx.__aexit__(None, None, None) + + return wrap_resp(), got_stream, self.api_key + else: + await ctx.__aexit__(None, None, None) + return resp, got_stream, self.api_key + + def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): + try: + error_data = resp["error"] + except (KeyError, TypeError): + raise openai.APIError( + "Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode) + ) + + if "internal_message" in error_data: + error_data["message"] += "\n\n" + error_data["internal_message"] + + log_info( + "OpenAI API error received", + error_code=error_data.get("code"), + error_type=error_data.get("type"), + error_message=error_data.get("message"), + error_param=error_data.get("param"), + stream_error=stream_error, + ) + + # Rate limits were previously coded as 400's with code 'rate_limit' + if rcode == 429: + return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) + elif rcode in [400, 404, 415]: + return openai.BadRequestError( + message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}', + body=rbody, + ) + elif rcode == 401: + return openai.AuthenticationError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody + ) + elif rcode == 403: + return openai.PermissionDeniedError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody + ) + elif rcode == 409: + return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) + elif stream_error: + # TODO: we will soon attach status codes to stream errors + parts = [error_data.get("message"), "(Error occurred while streaming.)"] + message = " ".join([p for p in parts if p is not None]) + return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody) + else: + return openai.APIError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", + body=rbody, + ) + + def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]: + user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,) + + uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node") + ua = { + "bindings_version": version.VERSION, + "httplib": "requests", + "lang": "python", + "lang_version": platform.python_version(), + "platform": platform.platform(), + "publisher": "openai", + "uname": uname_without_node, + } + + headers = { + "X-OpenAI-Client-User-Agent": json.dumps(ua), + "User-Agent": user_agent, + } + + headers.update(api_key_to_header(self.api_type, self.api_key)) + + if self.organization: + headers["OpenAI-Organization"] = self.organization + + if self.api_version is not None and self.api_type == ApiType.OPEN_AI: + headers["OpenAI-Version"] = self.api_version + if request_id is not None: + headers["X-Request-Id"] = request_id + headers.update(extra) + + return headers + + def _validate_headers(self, supplied_headers: Optional[Dict[str, str]]) -> Dict[str, str]: + headers: Dict[str, str] = {} + if supplied_headers is None: + return headers + + if not isinstance(supplied_headers, dict): + raise TypeError("Headers must be a dictionary") + + for k, v in supplied_headers.items(): + if not isinstance(k, str): + raise TypeError("Header keys must be strings") + if not isinstance(v, str): + raise TypeError("Header values must be strings") + headers[k] = v + + # NOTE: It is possible to do more validation of the headers, but a request could always + # be made to the API manually with invalid headers, so we need to handle them server side. + + return headers + + def _prepare_request_raw( + self, + url, + supplied_headers, + method, + params, + files, + request_id: Optional[str], + ) -> Tuple[str, Dict[str, str], Optional[bytes]]: + abs_url = "%s%s" % (self.base_url, url) + headers = self._validate_headers(supplied_headers) + + data = None + if method == "get" or method == "delete": + if params: + encoded_params = urlencode([(k, v) for k, v in params.items() if v is not None]) + abs_url = _build_api_url(abs_url, encoded_params) + elif method in {"post", "put"}: + if params and files: + data = params + if params and not files: + data = json.dumps(params).encode() + headers["Content-Type"] = "application/json" + else: + raise openai.APIConnectionError( + "Unrecognized HTTP method %r. This may indicate a bug in the " + "OpenAI bindings. Please contact us through our help center at help.openai.com for " + "assistance." % (method,) + ) + + headers = self.request_headers(method, headers, request_id) + + log_debug("Request to OpenAI API", method=method, path=abs_url) + log_debug("Post details", data=data, api_version=self.api_version) + + return abs_url, headers, data + + def request_raw( + self, + method, + url, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> requests.Response: + abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id) + + if not hasattr(_thread_context, "session"): + _thread_context.session = _make_session() + _thread_context.session_create_time = time.time() + elif time.time() - getattr(_thread_context, "session_create_time", 0) >= MAX_SESSION_LIFETIME_SECS: + _thread_context.session.close() + _thread_context.session = _make_session() + _thread_context.session_create_time = time.time() + try: + result = _thread_context.session.request( + method, + abs_url, + headers=headers, + data=data, + files=files, + stream=stream, + timeout=request_timeout if request_timeout else TIMEOUT_SECS, + proxies=_thread_context.session.proxies, + ) + except requests.exceptions.Timeout as e: + raise openai.APITimeoutError("Request timed out: {}".format(e)) from e + except requests.exceptions.RequestException as e: + raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e + log_debug( + "OpenAI API response", + path=abs_url, + response_code=result.status_code, + processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), + ) + return result + + async def arequest_raw( + self, + method, + url, + session, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> aiohttp.ClientResponse: + abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id) + + if isinstance(request_timeout, tuple): + timeout = aiohttp.ClientTimeout( + connect=request_timeout[0], + total=request_timeout[1], + ) + else: + timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS) + + if files: + # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here. + # For now we use the private `requests` method that is known to have worked so far. + data, content_type = requests.models.RequestEncodingMixin._encode_files(files, data) # type: ignore + headers["Content-Type"] = content_type + request_kwargs = { + "method": method, + "url": abs_url, + "headers": headers, + "data": data, + "timeout": timeout, + } + try: + result = await session.request(**request_kwargs) + log_info( + "OpenAI API response", + path=abs_url, + response_code=result.status, + processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), + ) + return result + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: + raise openai.APITimeoutError("Request timed out") from e + except aiohttp.ClientError as e: + raise openai.APIConnectionError("Error communicating with OpenAI") from e + + def _interpret_response( + self, result: requests.Response, stream: bool + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + + async def _interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + + def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> OpenAIResponse: + ... + + +@asynccontextmanager +async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]: + async with aiohttp.ClientSession() as session: + yield session diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 875122e8b..f8321cc6b 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -6,16 +6,16 @@ import asyncio from typing import AsyncGenerator, Tuple, Union import aiohttp -from openai.api_requestor import APIRequestor from metagpt.logs import logger +from metagpt.provider.general_api_base import APIRequestor class GeneralAPIRequestor(APIRequestor): """ usage - # full_url = "{api_base}{url}" - requester = GeneralAPIRequestor(api_base=api_base) + # full_url = "{base_url}{url}" + requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, url=url, diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 45fc763be..2d4b1583a 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -179,7 +179,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "n": 1, "stop": None, "temperature": 0.3, - "timeout": 3, } if configs: kwargs.update(configs) diff --git a/metagpt/provider/zhipuai/async_sse_client.py b/metagpt/provider/zhipuai/async_sse_client.py index d7168202a..b819fdc63 100644 --- a/metagpt/provider/zhipuai/async_sse_client.py +++ b/metagpt/provider/zhipuai/async_sse_client.py @@ -3,10 +3,11 @@ # @Desc : async_sse_client to make keep the use of Event to access response # refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py` -from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient +from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR class AsyncSSEClient(SSEClient): + async def _aread(self): data = b"" async for chunk in self._event_source: @@ -36,7 +37,9 @@ class AsyncSSEClient(SSEClient): # Ignore unknown fields. if field not in event.__dict__: - self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field) + self._logger.debug( + "Saw invalid field %s while parsing " "Server Side Event", field + ) continue if len(data) > 1: diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 23dd7229d..19eb52530 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -41,8 +41,8 @@ class ZhiPuModelAPI(ModelAPI): # TODO to make the async request to be more generic for models in http mode. assert method in ["post", "get"] - api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs) - requester = GeneralAPIRequestor(api_base=api_base) + base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs) + requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, url=url, diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index edd9084e3..3161c0e88 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -2,12 +2,8 @@ # -*- coding: utf-8 -*- # @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk -import json from enum import Enum - -import openai -import zhipuai -from requests import ConnectionError +import json from tenacity import ( after_log, retry, @@ -15,6 +11,10 @@ from tenacity import ( stop_after_attempt, wait_fixed, ) +from requests import ConnectionError + +import openai +import zhipuai from metagpt.config import CONFIG from metagpt.logs import logger @@ -50,11 +50,15 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. def _const_kwargs(self, messages: list[dict]) -> dict: - kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} + kwargs = { + "model": self.model, + "prompt": messages, + "temperature": 0.3 + } return kwargs def _update_costs(self, usage: dict): - """update each request's token cost""" + """ update each request's token cost """ if CONFIG.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -64,7 +68,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): logger.error("zhipuai updats costs failed!", e) def get_choice_text(self, resp: dict) -> str: - """get the first text of choice from llm response""" + """ get the first text of choice from llm response """ assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] assert assist_msg["role"] == "assistant" return assist_msg.get("content") @@ -125,10 +129,10 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): wait=wait_fixed(1), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, + retry_error_callback=log_and_reraise ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: - """response in async with stream or non-stream mode""" + """ response in async with stream or non-stream mode """ if stream: return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages)