feat: upgrade openai 1.x

This commit is contained in:
莘权 马 2023-12-06 10:10:30 +08:00
parent 7022c87008
commit 526d56cf54
7 changed files with 747 additions and 26 deletions

View file

@ -4,23 +4,20 @@
@Time : 2023/5/11 14:45
@Author : alexanderwu
@File : llm.py
@Modified By: mashenquan, 2023-12-4. Upgrade openai to 1.x
"""
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2 as Claude
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
# openai v1.x removed the 'api_requestor', making interfaces built on it no longer functional.
# More: https://github.com/openai/openai-python/discussions/742
# from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.human_provider import HumanProvider
_ = HumanProvider() # Avoid pre-commit error
def LLM() -> "BaseGPTAPI":
"""initialize different LLM instance according to the key field existence"""
""" initialize different LLM instance according to the key field existence"""
# TODO a little trick, can use registry to initialize LLM instance further
if CONFIG.openai_api_key:
llm = OpenAIGPTAPI()
@ -28,8 +25,8 @@ def LLM() -> "BaseGPTAPI":
llm = Claude()
elif CONFIG.spark_api_key:
llm = SparkAPI()
# elif CONFIG.zhipuai_api_key: # openai v1.x removed the 'api_requestor'
# llm = ZhiPuAIGPTAPI()
elif CONFIG.zhipuai_api_key:
llm = ZhiPuAIGPTAPI()
else:
raise RuntimeError("You should config a LLM configuration first")

View file

@ -0,0 +1,718 @@
import asyncio
import json
import os
import platform
import re
import sys
import threading
import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Iterator,
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import urlencode, urlsplit, urlunsplit
import aiohttp
import requests
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
import logging
import openai
from openai import version
logger = logging.getLogger("openai")
TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
_thread_context = threading.local()
OPENAI_LOG = os.environ.get("OPENAI_LOG")
OPENAI_LOG = "debug"
class ApiType(Enum):
AZURE = 1
OPEN_AI = 2
AZURE_AD = 3
@staticmethod
def from_str(label):
if label.lower() == "azure":
return ApiType.AZURE
elif label.lower() in ("azure_ad", "azuread"):
return ApiType.AZURE_AD
elif label.lower() in ("open_ai", "openai"):
return ApiType.OPEN_AI
else:
raise openai.OpenAIError(
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
)
api_key_to_header = (
lambda api, key: {"Authorization": f"Bearer {key}"}
if api in (ApiType.OPEN_AI, ApiType.AZURE_AD)
else {"api-key": f"{key}"}
)
def _console_log_level():
if OPENAI_LOG in ["debug", "info"]:
return OPENAI_LOG
else:
return None
def log_debug(message, **params):
msg = logfmt(dict(message=message, **params))
if _console_log_level() == "debug":
print(msg, file=sys.stderr)
logger.debug(msg)
def log_info(message, **params):
msg = logfmt(dict(message=message, **params))
if _console_log_level() in ["debug", "info"]:
print(msg, file=sys.stderr)
logger.info(msg)
def log_warn(message, **params):
msg = logfmt(dict(message=message, **params))
print(msg, file=sys.stderr)
logger.warn(msg)
def logfmt(props):
def fmt(key, val):
# Handle case where val is a bytes or bytesarray
if hasattr(val, "decode"):
val = val.decode("utf-8")
# Check if val is already a string to avoid re-encoding into ascii.
if not isinstance(val, str):
val = str(val)
if re.search(r"\s", val):
val = repr(val)
# key should already be a string
if re.search(r"\s", key):
key = repr(key)
return "{key}={val}".format(key=key, val=val)
return " ".join([fmt(key, val) for key, val in sorted(props.items())])
class OpenAIResponse:
def __init__(self, data, headers):
self._headers = headers
self.data = data
@property
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")
@property
def retry_after(self) -> Optional[int]:
try:
return int(self._headers.get("retry-after"))
except TypeError:
return None
@property
def operation_location(self) -> Optional[str]:
return self._headers.get("operation-location")
@property
def organization(self) -> Optional[str]:
return self._headers.get("OpenAI-Organization")
@property
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else round(float(h))
def _build_api_url(url, query):
scheme, netloc, path, base_query, fragment = urlsplit(url)
if base_query:
query = "%s&%s" % (base_query, query)
return urlunsplit((scheme, netloc, path, query, fragment))
def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
"""Returns a value suitable for the 'proxies' argument to 'requests.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return {"http": proxy, "https": proxy}
elif isinstance(proxy, dict):
return proxy.copy()
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _aiohttp_proxies_arg(proxy) -> Optional[str]:
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return proxy
elif isinstance(proxy, dict):
return proxy["https"] if "https" in proxy else proxy["http"]
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _make_session() -> requests.Session:
s = requests.Session()
s.mount(
"https://",
requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES),
)
return s
def parse_stream_helper(line: bytes) -> Optional[str]:
if line:
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
if line.startswith(b"data: "):
line = line[len(b"data: ") :]
return line.decode("utf-8")
else:
return None
return None
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
class APIRequestor:
def __init__(
self,
key=None,
base_url=None,
api_type=None,
api_version=None,
organization=None,
):
self.base_url = base_url or openai.base_url
self.api_key = key or openai.api_key
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str("openai")
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
if not predicate(response):
return
error_data = response.data["error"]
message = error_data.get("message", "Operation failed")
code = error_data.get("code")
raise openai.APIError(message=message, body=dict(code=code))
def _poll(
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
time.sleep(delay)
response, b, api_key = self.request(method, url, params, headers)
self._check_polling_response(response, failed)
start_time = time.time()
while not until(response):
if time.time() - start_time > TIMEOUT_SECS:
raise openai.APITimeoutError("Operation polling timed out.")
time.sleep(interval or response.retry_after or 10)
response, b, api_key = self.request(method, url, params, headers)
self._check_polling_response(response, failed)
response.data = response.data["result"]
return response, b, api_key
async def _apoll(
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
await asyncio.sleep(delay)
response, b, api_key = await self.arequest(method, url, params, headers)
self._check_polling_response(response, failed)
start_time = time.time()
while not until(response):
if time.time() - start_time > TIMEOUT_SECS:
raise openai.APITimeoutError("Operation polling timed out.")
await asyncio.sleep(interval or response.retry_after or 10)
response, b, api_key = await self.arequest(method, url, params, headers)
self._check_polling_response(response, failed)
response.data = response.data["result"]
return response, b, api_key
@overload
def request(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
pass
def request(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
result = self.request_raw(
method.lower(),
url,
params=params,
supplied_headers=headers,
files=files,
stream=stream,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
pass
async def arequest(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
ctx = aiohttp_session()
session = await ctx.__aenter__()
try:
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
except Exception:
await ctx.__aexit__(None, None, None)
raise
if got_stream:
async def wrap_resp():
assert isinstance(resp, AsyncGenerator)
try:
async for r in resp:
yield r
finally:
await ctx.__aexit__(None, None, None)
return wrap_resp(), got_stream, self.api_key
else:
await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
except (KeyError, TypeError):
raise openai.APIError(
"Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode)
)
if "internal_message" in error_data:
error_data["message"] += "\n\n" + error_data["internal_message"]
log_info(
"OpenAI API error received",
error_code=error_data.get("code"),
error_type=error_data.get("type"),
error_message=error_data.get("message"),
error_param=error_data.get("param"),
stream_error=stream_error,
)
# Rate limits were previously coded as 400's with code 'rate_limit'
if rcode == 429:
return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
elif rcode in [400, 404, 415]:
return openai.BadRequestError(
message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}',
body=rbody,
)
elif rcode == 401:
return openai.AuthenticationError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
)
elif rcode == 403:
return openai.PermissionDeniedError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
)
elif rcode == 409:
return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
elif stream_error:
# TODO: we will soon attach status codes to stream errors
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
message = " ".join([p for p in parts if p is not None])
return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody)
else:
return openai.APIError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
body=rbody,
)
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
ua = {
"bindings_version": version.VERSION,
"httplib": "requests",
"lang": "python",
"lang_version": platform.python_version(),
"platform": platform.platform(),
"publisher": "openai",
"uname": uname_without_node,
}
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
headers.update(api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["OpenAI-Organization"] = self.organization
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
headers.update(extra)
return headers
def _validate_headers(self, supplied_headers: Optional[Dict[str, str]]) -> Dict[str, str]:
headers: Dict[str, str] = {}
if supplied_headers is None:
return headers
if not isinstance(supplied_headers, dict):
raise TypeError("Headers must be a dictionary")
for k, v in supplied_headers.items():
if not isinstance(k, str):
raise TypeError("Header keys must be strings")
if not isinstance(v, str):
raise TypeError("Header values must be strings")
headers[k] = v
# NOTE: It is possible to do more validation of the headers, but a request could always
# be made to the API manually with invalid headers, so we need to handle them server side.
return headers
def _prepare_request_raw(
self,
url,
supplied_headers,
method,
params,
files,
request_id: Optional[str],
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.base_url, url)
headers = self._validate_headers(supplied_headers)
data = None
if method == "get" or method == "delete":
if params:
encoded_params = urlencode([(k, v) for k, v in params.items() if v is not None])
abs_url = _build_api_url(abs_url, encoded_params)
elif method in {"post", "put"}:
if params and files:
data = params
if params and not files:
data = json.dumps(params).encode()
headers["Content-Type"] = "application/json"
else:
raise openai.APIConnectionError(
"Unrecognized HTTP method %r. This may indicate a bug in the "
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
"assistance." % (method,)
)
headers = self.request_headers(method, headers, request_id)
log_debug("Request to OpenAI API", method=method, path=abs_url)
log_debug("Post details", data=data, api_version=self.api_version)
return abs_url, headers, data
def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
elif time.time() - getattr(_thread_context, "session_create_time", 0) >= MAX_SESSION_LIFETIME_SECS:
_thread_context.session.close()
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
try:
result = _thread_context.session.request(
method,
abs_url,
headers=headers,
data=data,
files=files,
stream=stream,
timeout=request_timeout if request_timeout else TIMEOUT_SECS,
proxies=_thread_context.session.proxies,
)
except requests.exceptions.Timeout as e:
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
log_debug(
"OpenAI API response",
path=abs_url,
response_code=result.status_code,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
return result
async def arequest_raw(
self,
method,
url,
session,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> aiohttp.ClientResponse:
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
if isinstance(request_timeout, tuple):
timeout = aiohttp.ClientTimeout(
connect=request_timeout[0],
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
# For now we use the private `requests` method that is known to have worked so far.
data, content_type = requests.models.RequestEncodingMixin._encode_files(files, data) # type: ignore
headers["Content-Type"] = content_type
request_kwargs = {
"method": method,
"url": abs_url,
"headers": headers,
"data": data,
"timeout": timeout,
}
try:
result = await session.request(**request_kwargs)
log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise openai.APITimeoutError("Request timed out") from e
except aiohttp.ClientError as e:
raise openai.APIConnectionError("Error communicating with OpenAI") from e
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> OpenAIResponse:
...
@asynccontextmanager
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
async with aiohttp.ClientSession() as session:
yield session

View file

@ -6,16 +6,16 @@ import asyncio
from typing import AsyncGenerator, Tuple, Union
import aiohttp
from openai.api_requestor import APIRequestor
from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
class GeneralAPIRequestor(APIRequestor):
"""
usage
# full_url = "{api_base}{url}"
requester = GeneralAPIRequestor(api_base=api_base)
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,

View file

@ -179,7 +179,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3,
}
if configs:
kwargs.update(configs)

View file

@ -3,10 +3,11 @@
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR
class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
async for chunk in self._event_source:
@ -36,7 +37,9 @@ class AsyncSSEClient(SSEClient):
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
self._logger.debug(
"Saw invalid field %s while parsing " "Server Side Event", field
)
continue
if len(data) > 1:

View file

@ -41,8 +41,8 @@ class ZhiPuModelAPI(ModelAPI):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]
api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs)
requester = GeneralAPIRequestor(api_base=api_base)
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,

View file

@ -2,12 +2,8 @@
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
import json
from enum import Enum
import openai
import zhipuai
from requests import ConnectionError
import json
from tenacity import (
after_log,
retry,
@ -15,6 +11,10 @@ from tenacity import (
stop_after_attempt,
wait_fixed,
)
from requests import ConnectionError
import openai
import zhipuai
from metagpt.config import CONFIG
from metagpt.logs import logger
@ -50,11 +50,15 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
kwargs = {
"model": self.model,
"prompt": messages,
"temperature": 0.3
}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
""" update each request's token cost """
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
@ -64,7 +68,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
logger.error("zhipuai updats costs failed!", e)
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
""" get the first text of choice from llm response """
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
@ -125,10 +129,10 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
wait=wait_fixed(1),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
retry_error_callback=log_and_reraise
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""response in async with stream or non-stream mode"""
""" response in async with stream or non-stream mode """
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)