support new openai package

This commit is contained in:
seehi 2023-12-05 11:26:54 +08:00
parent 5f69878a08
commit eaf531e0ac
9 changed files with 866 additions and 137 deletions

View file

@ -5,7 +5,6 @@ Provide configuration, singleton
"""
import os
import openai
import yaml
from metagpt.const import PROJECT_ROOT
@ -52,11 +51,8 @@ class Config(metaclass=Singleton):
and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key)
):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
self.openai_api_base = self._get("OPENAI_BASE_URL")
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
if openai_proxy:
openai.proxy = openai_proxy
openai.api_base = self.openai_api_base
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")
self.openai_api_rpm = self._get("RPM", 3)

View file

@ -0,0 +1,718 @@
import asyncio
import json
import os
import platform
import re
import sys
import threading
import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import (
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Iterator,
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import urlencode, urlsplit, urlunsplit
import aiohttp
import requests
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
import logging
import openai
from openai import version
logger = logging.getLogger("openai")
TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
_thread_context = threading.local()
OPENAI_LOG = os.environ.get("OPENAI_LOG")
OPENAI_LOG = "debug"
class ApiType(Enum):
AZURE = 1
OPEN_AI = 2
AZURE_AD = 3
@staticmethod
def from_str(label):
if label.lower() == "azure":
return ApiType.AZURE
elif label.lower() in ("azure_ad", "azuread"):
return ApiType.AZURE_AD
elif label.lower() in ("open_ai", "openai"):
return ApiType.OPEN_AI
else:
raise openai.OpenAIError(
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
)
api_key_to_header = (
lambda api, key: {"Authorization": f"Bearer {key}"}
if api in (ApiType.OPEN_AI, ApiType.AZURE_AD)
else {"api-key": f"{key}"}
)
def _console_log_level():
if OPENAI_LOG in ["debug", "info"]:
return OPENAI_LOG
else:
return None
def log_debug(message, **params):
msg = logfmt(dict(message=message, **params))
if _console_log_level() == "debug":
print(msg, file=sys.stderr)
logger.debug(msg)
def log_info(message, **params):
msg = logfmt(dict(message=message, **params))
if _console_log_level() in ["debug", "info"]:
print(msg, file=sys.stderr)
logger.info(msg)
def log_warn(message, **params):
msg = logfmt(dict(message=message, **params))
print(msg, file=sys.stderr)
logger.warn(msg)
def logfmt(props):
def fmt(key, val):
# Handle case where val is a bytes or bytesarray
if hasattr(val, "decode"):
val = val.decode("utf-8")
# Check if val is already a string to avoid re-encoding into ascii.
if not isinstance(val, str):
val = str(val)
if re.search(r"\s", val):
val = repr(val)
# key should already be a string
if re.search(r"\s", key):
key = repr(key)
return "{key}={val}".format(key=key, val=val)
return " ".join([fmt(key, val) for key, val in sorted(props.items())])
class OpenAIResponse:
def __init__(self, data, headers):
self._headers = headers
self.data = data
@property
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")
@property
def retry_after(self) -> Optional[int]:
try:
return int(self._headers.get("retry-after"))
except TypeError:
return None
@property
def operation_location(self) -> Optional[str]:
return self._headers.get("operation-location")
@property
def organization(self) -> Optional[str]:
return self._headers.get("OpenAI-Organization")
@property
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else round(float(h))
def _build_api_url(url, query):
scheme, netloc, path, base_query, fragment = urlsplit(url)
if base_query:
query = "%s&%s" % (base_query, query)
return urlunsplit((scheme, netloc, path, query, fragment))
def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
"""Returns a value suitable for the 'proxies' argument to 'requests.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return {"http": proxy, "https": proxy}
elif isinstance(proxy, dict):
return proxy.copy()
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _aiohttp_proxies_arg(proxy) -> Optional[str]:
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return proxy
elif isinstance(proxy, dict):
return proxy["https"] if "https" in proxy else proxy["http"]
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)
def _make_session() -> requests.Session:
s = requests.Session()
s.mount(
"https://",
requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES),
)
return s
def parse_stream_helper(line: bytes) -> Optional[str]:
if line:
if line.strip() == b"data: [DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
if line.startswith(b"data: "):
line = line[len(b"data: ") :]
return line.decode("utf-8")
else:
return None
return None
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
class APIRequestor:
def __init__(
self,
key=None,
base_url=None,
api_type=None,
api_version=None,
organization=None,
):
self.base_url = base_url or openai.base_url
self.api_key = key or openai.api_key
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str("openai")
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
if not predicate(response):
return
error_data = response.data["error"]
message = error_data.get("message", "Operation failed")
code = error_data.get("code")
raise openai.APIError(message=message, body=dict(code=code))
def _poll(
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
time.sleep(delay)
response, b, api_key = self.request(method, url, params, headers)
self._check_polling_response(response, failed)
start_time = time.time()
while not until(response):
if time.time() - start_time > TIMEOUT_SECS:
raise openai.APITimeoutError("Operation polling timed out.")
time.sleep(interval or response.retry_after or 10)
response, b, api_key = self.request(method, url, params, headers)
self._check_polling_response(response, failed)
response.data = response.data["result"]
return response, b, api_key
async def _apoll(
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
await asyncio.sleep(delay)
response, b, api_key = await self.arequest(method, url, params, headers)
self._check_polling_response(response, failed)
start_time = time.time()
while not until(response):
if time.time() - start_time > TIMEOUT_SECS:
raise openai.APITimeoutError("Operation polling timed out.")
await asyncio.sleep(interval or response.retry_after or 10)
response, b, api_key = await self.arequest(method, url, params, headers)
self._check_polling_response(response, failed)
response.data = response.data["result"]
return response, b, api_key
@overload
def request(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
def request(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
pass
def request(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
result = self.request_raw(
method.lower(),
url,
params=params,
supplied_headers=headers,
files=files,
stream=stream,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key
@overload
async def arequest(
self,
method,
url,
params,
headers,
files,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
*,
stream: Literal[True],
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: Literal[False] = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[OpenAIResponse, bool, str]:
pass
@overload
async def arequest(
self,
method,
url,
params=...,
headers=...,
files=...,
stream: bool = ...,
request_id: Optional[str] = ...,
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
pass
async def arequest(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
ctx = aiohttp_session()
session = await ctx.__aenter__()
try:
result = await self.arequest_raw(
method.lower(),
url,
session,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
except Exception:
await ctx.__aexit__(None, None, None)
raise
if got_stream:
async def wrap_resp():
assert isinstance(resp, AsyncGenerator)
try:
async for r in resp:
yield r
finally:
await ctx.__aexit__(None, None, None)
return wrap_resp(), got_stream, self.api_key
else:
await ctx.__aexit__(None, None, None)
return resp, got_stream, self.api_key
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
except (KeyError, TypeError):
raise openai.APIError(
"Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode)
)
if "internal_message" in error_data:
error_data["message"] += "\n\n" + error_data["internal_message"]
log_info(
"OpenAI API error received",
error_code=error_data.get("code"),
error_type=error_data.get("type"),
error_message=error_data.get("message"),
error_param=error_data.get("param"),
stream_error=stream_error,
)
# Rate limits were previously coded as 400's with code 'rate_limit'
if rcode == 429:
return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
elif rcode in [400, 404, 415]:
return openai.BadRequestError(
message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}',
body=rbody,
)
elif rcode == 401:
return openai.AuthenticationError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
)
elif rcode == 403:
return openai.PermissionDeniedError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
)
elif rcode == 409:
return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
elif stream_error:
# TODO: we will soon attach status codes to stream errors
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
message = " ".join([p for p in parts if p is not None])
return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody)
else:
return openai.APIError(
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
body=rbody,
)
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
ua = {
"bindings_version": version.VERSION,
"httplib": "requests",
"lang": "python",
"lang_version": platform.python_version(),
"platform": platform.platform(),
"publisher": "openai",
"uname": uname_without_node,
}
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
headers.update(api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["OpenAI-Organization"] = self.organization
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
headers.update(extra)
return headers
def _validate_headers(self, supplied_headers: Optional[Dict[str, str]]) -> Dict[str, str]:
headers: Dict[str, str] = {}
if supplied_headers is None:
return headers
if not isinstance(supplied_headers, dict):
raise TypeError("Headers must be a dictionary")
for k, v in supplied_headers.items():
if not isinstance(k, str):
raise TypeError("Header keys must be strings")
if not isinstance(v, str):
raise TypeError("Header values must be strings")
headers[k] = v
# NOTE: It is possible to do more validation of the headers, but a request could always
# be made to the API manually with invalid headers, so we need to handle them server side.
return headers
def _prepare_request_raw(
self,
url,
supplied_headers,
method,
params,
files,
request_id: Optional[str],
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.base_url, url)
headers = self._validate_headers(supplied_headers)
data = None
if method == "get" or method == "delete":
if params:
encoded_params = urlencode([(k, v) for k, v in params.items() if v is not None])
abs_url = _build_api_url(abs_url, encoded_params)
elif method in {"post", "put"}:
if params and files:
data = params
if params and not files:
data = json.dumps(params).encode()
headers["Content-Type"] = "application/json"
else:
raise openai.APIConnectionError(
"Unrecognized HTTP method %r. This may indicate a bug in the "
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
"assistance." % (method,)
)
headers = self.request_headers(method, headers, request_id)
log_debug("Request to OpenAI API", method=method, path=abs_url)
log_debug("Post details", data=data, api_version=self.api_version)
return abs_url, headers, data
def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
elif time.time() - getattr(_thread_context, "session_create_time", 0) >= MAX_SESSION_LIFETIME_SECS:
_thread_context.session.close()
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
try:
result = _thread_context.session.request(
method,
abs_url,
headers=headers,
data=data,
files=files,
stream=stream,
timeout=request_timeout if request_timeout else TIMEOUT_SECS,
proxies=_thread_context.session.proxies,
)
except requests.exceptions.Timeout as e:
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
log_debug(
"OpenAI API response",
path=abs_url,
response_code=result.status_code,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
return result
async def arequest_raw(
self,
method,
url,
session,
*,
params=None,
supplied_headers: Optional[Dict[str, str]] = None,
files=None,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> aiohttp.ClientResponse:
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
if isinstance(request_timeout, tuple):
timeout = aiohttp.ClientTimeout(
connect=request_timeout[0],
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
# For now we use the private `requests` method that is known to have worked so far.
data, content_type = requests.models.RequestEncodingMixin._encode_files(files, data) # type: ignore
headers["Content-Type"] = content_type
request_kwargs = {
"method": method,
"url": abs_url,
"headers": headers,
"data": data,
"timeout": timeout,
}
try:
result = await session.request(**request_kwargs)
log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise openai.APITimeoutError("Request timed out") from e
except aiohttp.ClientError as e:
raise openai.APIConnectionError("Error communicating with OpenAI") from e
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> OpenAIResponse:
...
@asynccontextmanager
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
async with aiohttp.ClientSession() as session:
yield session

View file

@ -2,20 +2,20 @@
# -*- coding: utf-8 -*-
# @Desc : General Async API for http-based LLM model
from typing import AsyncGenerator, Tuple, Union, Optional, Literal
import aiohttp
import asyncio
from typing import AsyncGenerator, Tuple, Union
from openai.api_requestor import APIRequestor
import aiohttp
from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
class GeneralAPIRequestor(APIRequestor):
"""
usage
# full_url = "{api_base}{url}"
requester = GeneralAPIRequestor(api_base=api_base)
# full_url = "{base_url}{url}"
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,
@ -26,9 +26,7 @@ class GeneralAPIRequestor(APIRequestor):
)
"""
def _interpret_response_line(
self, rbody: str, rcode: int, rheaders, stream: bool
) -> str:
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str:
# just do nothing to meet the APIRequestor process and return the raw data
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
@ -39,11 +37,9 @@ class GeneralAPIRequestor(APIRequestor):
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(
line, result.status, result.headers, stream=True
)
async for line in result.content
), True
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content
), True
else:
try:
await result.read()

View file

@ -5,11 +5,14 @@
@File : openai.py
"""
import asyncio
import json
import time
from typing import NamedTuple, Union
import openai
from openai.error import APIConnectionError
import httpx
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from tenacity import (
after_log,
retry,
@ -18,7 +21,7 @@ from tenacity import (
wait_fixed,
)
from metagpt.config import CONFIG
from metagpt.config import CONFIG, Config
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
@ -144,23 +147,40 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config):
openai.api_key = config.openai_api_key
if config.openai_api_base:
openai.api_base = config.openai_api_base
if config.openai_api_type:
openai.api_type = config.openai_api_type
openai.api_version = config.openai_api_version
def __init_openai(self, config: Config):
client_kwargs, async_client_kwargs = self.__make_client_args(config)
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**async_client_kwargs)
self.rpm = int(config.get("RPM", 10))
def __make_client_args(self, config: Config):
mapping = {
"api_key": "openai_api_key",
"base_url": "openai_base_url",
}
kwargs = {key: getattr(config, mapping[key]) for key in mapping if getattr(config, mapping[key], None)}
async_kwargs = kwargs.copy()
# need http_client to support proxy
if config.openai_proxy:
httpx_args = dict(base_url=kwargs["base_url"], proxies=config.openai_proxy)
kwargs["http_client"] = httpx.Client(**httpx_args)
async_kwargs["http_client"] = httpx.AsyncClient(**httpx_args)
return kwargs, async_kwargs
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
@ -168,15 +188,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
choices = chunk["choices"]
if len(choices) > 0:
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
if chunk.choices:
chunk_message = chunk.choices[0].delta # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
if chunk_message.content:
print(chunk_message.content, end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
full_reply_content = "".join([m.content for m in collected_messages if m.content])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
@ -208,24 +227,20 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
kwargs.update(kwargs_mode)
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
self._update_costs(rsp.get("usage"))
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages))
self._update_costs(rsp.usage)
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages))
self._update_costs(rsp)
def _chat_completion(self, messages: list[dict]) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages))
self._update_costs(rsp.usage)
return rsp
def completion(self, messages: list[dict]) -> dict:
# if isinstance(messages[0], Message):
# messages = self.messages_to_dict(messages)
def completion(self, messages: list[dict]) -> ChatCompletion:
return self._chat_completion(messages)
async def acompletion(self, messages: list[dict]) -> dict:
# if isinstance(messages[0], Message):
# messages = self.messages_to_dict(messages)
async def acompletion(self, messages: list[dict]) -> ChatCompletion:
return await self._achat_completion(messages)
@retry(
@ -255,14 +270,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self._cons_kwargs(messages, **kwargs)
def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict:
rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.get("usage"))
def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.usage)
return rsp
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs))
self._update_costs(rsp.get("usage"))
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion:
rsp: ChatCompletion = await self.async_client.chat.completions.create(
**self._func_configs(messages, **chat_configs)
)
self._update_costs(rsp.usage)
return rsp
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
@ -317,21 +334,34 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
:return dict: return the first function arguments of choice, for example,
{'language': 'python', 'code': "print('Hello, World!')"}
"""
try:
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
except json.JSONDecodeError:
return {}
def get_choice_text(self, rsp: ChatCompletion) -> str:
"""Required to provide the first text of choice"""
return rsp.choices[0].message.content if rsp.choices else ""
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if CONFIG.calc_usage:
try:
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage["prompt_tokens"] = prompt_tokens
usage["completion_tokens"] = completion_tokens
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
return usage
except Exception as e:
logger.error("usage calculation failed!", e)
else:
return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]:
"""Return full JSON"""
split_batches = self.split_batches(batch)
all_results = []
@ -357,12 +387,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.info(f"Result of task {idx}: {result}")
return results
def _update_costs(self, usage: dict):
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage["prompt_tokens"])
completion_tokens = int(usage["completion_tokens"])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
@ -385,7 +413,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.error(f"moderating failed:{e}")
def _moderation(self, content: Union[str, list[str]]):
rsp = self.llm.Moderation.create(input=content)
rsp = self.client.moderations.create(input=content)
return rsp
async def amoderation(self, content: Union[str, list[str]]):
@ -399,5 +427,5 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.error(f"moderating failed:{e}")
async def _amoderation(self, content: Union[str, list[str]]):
rsp = await self.llm.Moderation.acreate(input=content)
rsp = await self.async_client.moderations.create(input=content)
return rsp

View file

@ -3,15 +3,14 @@
# @Desc : zhipu model api to support sync & async for invoke & sse_invoke
import zhipuai
from zhipuai.model_api.api import ModelAPI, InvokeType
from zhipuai.model_api.api import InvokeType, ModelAPI
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
@ -21,9 +20,7 @@ class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {
"Authorization": token
}
headers = {"Authorization": token}
return headers
@classmethod
@ -44,36 +41,32 @@ class ZhiPuModelAPI(ModelAPI):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]
api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs)
requester = GeneralAPIRequestor(api_base=api_base)
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds
request_timeout=zhipuai.api_timeout_seconds,
)
return result
@classmethod
async def ainvoke(cls, **kwargs) -> dict:
""" async invoke different from raw method `async_invoke` which get the final result by task_id"""
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(invoke_type=InvokeType.SYNC,
stream=False,
method="post",
headers=headers,
kwargs=kwargs)
resp = await cls.arequest(
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
)
return resp
@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
""" async sse_invoke """
"""async sse_invoke"""
headers = cls.get_sse_header()
return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE,
stream=True,
method="post",
headers=headers,
kwargs=kwargs))
return AsyncSSEClient(
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
)

View file

@ -1,22 +1,26 @@
import inspect
import re
from typing import List, Callable, Dict
import textwrap
from pathlib import Path
from typing import Callable, Dict, List
import wrapt
import textwrap
import inspect
from interpreter.core.core import Interpreter
from metagpt.logs import logger
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.highlight import highlight
from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script
def extract_python_code(code: str):
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
# Use regular expressions to match comment blocks and related code.
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
matches = re.findall(pattern, code, re.DOTALL)
# Extract the last code block when encountering the same comment.
@ -25,8 +29,8 @@ def extract_python_code(code: str):
unique_comments[comment] = code_block
# concatenate into functional form
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[:code.find("#")]
result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[: code.find("#")]
code = header_code + result_code
logger.info(f"Extract python code: \n {highlight(code)}")
@ -36,12 +40,12 @@ def extract_python_code(code: str):
class OpenCodeInterpreter(object):
"""https://github.com/KillianLucas/open-interpreter"""
def __init__(self, auto_run: bool = True) -> None:
interpreter = Interpreter()
interpreter.auto_run = auto_run
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
interpreter.api_key = CONFIG.openai_api_key
# interpreter.api_base = CONFIG.openai_api_base
self.interpreter = interpreter
def chat(self, query: str, reset: bool = True):
@ -50,15 +54,16 @@ class OpenCodeInterpreter(object):
return self.interpreter.chat(query)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
function_format: str = None) -> str:
def extract_function(
query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
) -> str:
"""create a function from query_respond."""
if language not in ('python'):
if language not in ("python"):
raise NotImplementedError(f"Not support to parse language {language}!")
# set function form
if function_format is None:
assert language == 'python', f"Expect python language for default function_format, but got {language}."
assert language == "python", f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# Extract the code module in the open-interpreter respond message.
# The query_respond of open-interpreter before v0.1.4 is:
@ -68,25 +73,29 @@ class OpenCodeInterpreter(object):
# "parsed_arguments": {"language": "python", "code": code of first plan}
# ...]
if "function_call" in query_respond[1]:
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
code = [
item["function_call"]["parsed_arguments"]["code"]
for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and "language" in item["function_call"]["parsed_arguments"]
and item["function_call"]["parsed_arguments"]["language"] == language
]
# The query_respond of open-interpreter v0.1.7 is:
# [{'role': 'user', 'message': your query string},
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
# 'code': code of first plan, 'output': output of first plan code},
# ...]
elif "code" in query_respond[1]:
code = [item['code'] for item in query_respond
if "code" in item
and 'language' in item
and item['language'] == language]
code = [
item["code"]
for item in query_respond
if "code" in item and "language" in item and item["language"] == language
]
else:
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
# add indent.
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
indented_code_str = textwrap.indent("\n".join(code), " " * 4)
# Return the code after deduplication.
if language == "python":
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
@ -115,13 +124,13 @@ class OpenInterpreterDecorator(object):
def _have_code(self, rsp: List[Dict]):
# Is there any code generated?
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
def _is_faild_plan(self, rsp: List[Dict]):
# is faild plan?
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
func_code = OpenCodeInterpreter.extract_function(rsp, "function")
# If there is no more than 1 '\n', the plan execution fails.
if isinstance(func_code, str) and func_code.count('\n') <= 1:
if isinstance(func_code, str) and func_code.count("\n") <= 1:
return True
return False
@ -184,4 +193,5 @@ class OpenInterpreterDecorator(object):
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -21,14 +21,12 @@ def make_sk_kernel():
if CONFIG.openai_api_type == "azure":
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key),
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_base_url, CONFIG.openai_api_key),
)
else:
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(
CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base
),
OpenAIChatCompletion(CONFIG.openai_api_model, CONFIG.openai_api_key),
)
return kernel