mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
support new openai package
This commit is contained in:
parent
5f69878a08
commit
eaf531e0ac
9 changed files with 866 additions and 137 deletions
|
|
@ -5,7 +5,6 @@ Provide configuration, singleton
|
|||
"""
|
||||
import os
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
|
|
@ -52,11 +51,8 @@ class Config(metaclass=Singleton):
|
|||
and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key)
|
||||
):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
|
||||
self.openai_api_base = self._get("OPENAI_BASE_URL")
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
if openai_proxy:
|
||||
openai.proxy = openai_proxy
|
||||
openai.api_base = self.openai_api_base
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
self.openai_api_rpm = self._get("RPM", 3)
|
||||
|
|
|
|||
718
metagpt/provider/general_api_base.py
Normal file
718
metagpt/provider/general_api_base.py
Normal file
|
|
@ -0,0 +1,718 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from urllib.parse import urlencode, urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from openai import version
|
||||
|
||||
logger = logging.getLogger("openai")
|
||||
|
||||
TIMEOUT_SECS = 600
|
||||
MAX_SESSION_LIFETIME_SECS = 180
|
||||
MAX_CONNECTION_RETRIES = 2
|
||||
|
||||
# Has one attribute per thread, 'session'.
|
||||
_thread_context = threading.local()
|
||||
|
||||
OPENAI_LOG = os.environ.get("OPENAI_LOG")
|
||||
OPENAI_LOG = "debug"
|
||||
|
||||
|
||||
class ApiType(Enum):
|
||||
AZURE = 1
|
||||
OPEN_AI = 2
|
||||
AZURE_AD = 3
|
||||
|
||||
@staticmethod
|
||||
def from_str(label):
|
||||
if label.lower() == "azure":
|
||||
return ApiType.AZURE
|
||||
elif label.lower() in ("azure_ad", "azuread"):
|
||||
return ApiType.AZURE_AD
|
||||
elif label.lower() in ("open_ai", "openai"):
|
||||
return ApiType.OPEN_AI
|
||||
else:
|
||||
raise openai.OpenAIError(
|
||||
"The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'"
|
||||
)
|
||||
|
||||
|
||||
api_key_to_header = (
|
||||
lambda api, key: {"Authorization": f"Bearer {key}"}
|
||||
if api in (ApiType.OPEN_AI, ApiType.AZURE_AD)
|
||||
else {"api-key": f"{key}"}
|
||||
)
|
||||
|
||||
|
||||
def _console_log_level():
|
||||
if OPENAI_LOG in ["debug", "info"]:
|
||||
return OPENAI_LOG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def log_debug(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
if _console_log_level() == "debug":
|
||||
print(msg, file=sys.stderr)
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
def log_info(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
if _console_log_level() in ["debug", "info"]:
|
||||
print(msg, file=sys.stderr)
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def log_warn(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
print(msg, file=sys.stderr)
|
||||
logger.warn(msg)
|
||||
|
||||
|
||||
def logfmt(props):
|
||||
def fmt(key, val):
|
||||
# Handle case where val is a bytes or bytesarray
|
||||
if hasattr(val, "decode"):
|
||||
val = val.decode("utf-8")
|
||||
# Check if val is already a string to avoid re-encoding into ascii.
|
||||
if not isinstance(val, str):
|
||||
val = str(val)
|
||||
if re.search(r"\s", val):
|
||||
val = repr(val)
|
||||
# key should already be a string
|
||||
if re.search(r"\s", key):
|
||||
key = repr(key)
|
||||
return "{key}={val}".format(key=key, val=val)
|
||||
|
||||
return " ".join([fmt(key, val) for key, val in sorted(props.items())])
|
||||
|
||||
|
||||
class OpenAIResponse:
|
||||
def __init__(self, data, headers):
|
||||
self._headers = headers
|
||||
self.data = data
|
||||
|
||||
@property
|
||||
def request_id(self) -> Optional[str]:
|
||||
return self._headers.get("request-id")
|
||||
|
||||
@property
|
||||
def retry_after(self) -> Optional[int]:
|
||||
try:
|
||||
return int(self._headers.get("retry-after"))
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def operation_location(self) -> Optional[str]:
|
||||
return self._headers.get("operation-location")
|
||||
|
||||
@property
|
||||
def organization(self) -> Optional[str]:
|
||||
return self._headers.get("OpenAI-Organization")
|
||||
|
||||
@property
|
||||
def response_ms(self) -> Optional[int]:
|
||||
h = self._headers.get("Openai-Processing-Ms")
|
||||
return None if h is None else round(float(h))
|
||||
|
||||
|
||||
def _build_api_url(url, query):
|
||||
scheme, netloc, path, base_query, fragment = urlsplit(url)
|
||||
|
||||
if base_query:
|
||||
query = "%s&%s" % (base_query, query)
|
||||
|
||||
return urlunsplit((scheme, netloc, path, query, fragment))
|
||||
|
||||
|
||||
def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
|
||||
"""Returns a value suitable for the 'proxies' argument to 'requests.request."""
|
||||
if proxy is None:
|
||||
return None
|
||||
elif isinstance(proxy, str):
|
||||
return {"http": proxy, "https": proxy}
|
||||
elif isinstance(proxy, dict):
|
||||
return proxy.copy()
|
||||
else:
|
||||
raise ValueError(
|
||||
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
|
||||
)
|
||||
|
||||
|
||||
def _aiohttp_proxies_arg(proxy) -> Optional[str]:
|
||||
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
|
||||
if proxy is None:
|
||||
return None
|
||||
elif isinstance(proxy, str):
|
||||
return proxy
|
||||
elif isinstance(proxy, dict):
|
||||
return proxy["https"] if "https" in proxy else proxy["http"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
|
||||
)
|
||||
|
||||
|
||||
def _make_session() -> requests.Session:
|
||||
s = requests.Session()
|
||||
s.mount(
|
||||
"https://",
|
||||
requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES),
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
def parse_stream_helper(line: bytes) -> Optional[str]:
|
||||
if line:
|
||||
if line.strip() == b"data: [DONE]":
|
||||
# return here will cause GeneratorExit exception in urllib3
|
||||
# and it will close http connection with TCP Reset
|
||||
return None
|
||||
if line.startswith(b"data: "):
|
||||
line = line[len(b"data: ") :]
|
||||
return line.decode("utf-8")
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||
for line in rbody:
|
||||
_line = parse_stream_helper(line)
|
||||
if _line is not None:
|
||||
yield _line
|
||||
|
||||
|
||||
async def parse_stream_async(rbody: aiohttp.StreamReader):
|
||||
async for line in rbody:
|
||||
_line = parse_stream_helper(line)
|
||||
if _line is not None:
|
||||
yield _line
|
||||
|
||||
|
||||
class APIRequestor:
|
||||
def __init__(
|
||||
self,
|
||||
key=None,
|
||||
base_url=None,
|
||||
api_type=None,
|
||||
api_version=None,
|
||||
organization=None,
|
||||
):
|
||||
self.base_url = base_url or openai.base_url
|
||||
self.api_key = key or openai.api_key
|
||||
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str("openai")
|
||||
self.api_version = api_version or openai.api_version
|
||||
self.organization = organization or openai.organization
|
||||
|
||||
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
|
||||
if not predicate(response):
|
||||
return
|
||||
error_data = response.data["error"]
|
||||
message = error_data.get("message", "Operation failed")
|
||||
code = error_data.get("code")
|
||||
raise openai.APIError(message=message, body=dict(code=code))
|
||||
|
||||
def _poll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
time.sleep(delay)
|
||||
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
time.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
async def _apoll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
await asyncio.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
headers,
|
||||
files,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: Literal[False] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[OpenAIResponse, bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: bool = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
|
||||
pass
|
||||
|
||||
def request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=None,
|
||||
headers=None,
|
||||
files=None,
|
||||
stream: bool = False,
|
||||
request_id: Optional[str] = None,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
|
||||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
|
||||
result = self.request_raw(
|
||||
method.lower(),
|
||||
url,
|
||||
params=params,
|
||||
supplied_headers=headers,
|
||||
files=files,
|
||||
stream=stream,
|
||||
request_id=request_id,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
resp, got_stream = self._interpret_response(result, stream)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
headers,
|
||||
files,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: Literal[False] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[OpenAIResponse, bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: bool = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
|
||||
pass
|
||||
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=None,
|
||||
headers=None,
|
||||
files=None,
|
||||
stream: bool = False,
|
||||
request_id: Optional[str] = None,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
|
||||
ctx = aiohttp_session()
|
||||
session = await ctx.__aenter__()
|
||||
try:
|
||||
result = await self.arequest_raw(
|
||||
method.lower(),
|
||||
url,
|
||||
session,
|
||||
params=params,
|
||||
supplied_headers=headers,
|
||||
files=files,
|
||||
request_id=request_id,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
resp, got_stream = await self._interpret_async_response(result, stream)
|
||||
except Exception:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
raise
|
||||
if got_stream:
|
||||
|
||||
async def wrap_resp():
|
||||
assert isinstance(resp, AsyncGenerator)
|
||||
try:
|
||||
async for r in resp:
|
||||
yield r
|
||||
finally:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
|
||||
return wrap_resp(), got_stream, self.api_key
|
||||
else:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
|
||||
try:
|
||||
error_data = resp["error"]
|
||||
except (KeyError, TypeError):
|
||||
raise openai.APIError(
|
||||
"Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode)
|
||||
)
|
||||
|
||||
if "internal_message" in error_data:
|
||||
error_data["message"] += "\n\n" + error_data["internal_message"]
|
||||
|
||||
log_info(
|
||||
"OpenAI API error received",
|
||||
error_code=error_data.get("code"),
|
||||
error_type=error_data.get("type"),
|
||||
error_message=error_data.get("message"),
|
||||
error_param=error_data.get("param"),
|
||||
stream_error=stream_error,
|
||||
)
|
||||
|
||||
# Rate limits were previously coded as 400's with code 'rate_limit'
|
||||
if rcode == 429:
|
||||
return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif rcode in [400, 404, 415]:
|
||||
return openai.BadRequestError(
|
||||
message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}',
|
||||
body=rbody,
|
||||
)
|
||||
elif rcode == 401:
|
||||
return openai.AuthenticationError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 403:
|
||||
return openai.PermissionDeniedError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 409:
|
||||
return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif stream_error:
|
||||
# TODO: we will soon attach status codes to stream errors
|
||||
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
|
||||
message = " ".join([p for p in parts if p is not None])
|
||||
return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
else:
|
||||
return openai.APIError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
|
||||
body=rbody,
|
||||
)
|
||||
|
||||
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
|
||||
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
|
||||
|
||||
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
|
||||
ua = {
|
||||
"bindings_version": version.VERSION,
|
||||
"httplib": "requests",
|
||||
"lang": "python",
|
||||
"lang_version": platform.python_version(),
|
||||
"platform": platform.platform(),
|
||||
"publisher": "openai",
|
||||
"uname": uname_without_node,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"X-OpenAI-Client-User-Agent": json.dumps(ua),
|
||||
"User-Agent": user_agent,
|
||||
}
|
||||
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
|
||||
if self.organization:
|
||||
headers["OpenAI-Organization"] = self.organization
|
||||
|
||||
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
|
||||
headers["OpenAI-Version"] = self.api_version
|
||||
if request_id is not None:
|
||||
headers["X-Request-Id"] = request_id
|
||||
headers.update(extra)
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_headers(self, supplied_headers: Optional[Dict[str, str]]) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
if supplied_headers is None:
|
||||
return headers
|
||||
|
||||
if not isinstance(supplied_headers, dict):
|
||||
raise TypeError("Headers must be a dictionary")
|
||||
|
||||
for k, v in supplied_headers.items():
|
||||
if not isinstance(k, str):
|
||||
raise TypeError("Header keys must be strings")
|
||||
if not isinstance(v, str):
|
||||
raise TypeError("Header values must be strings")
|
||||
headers[k] = v
|
||||
|
||||
# NOTE: It is possible to do more validation of the headers, but a request could always
|
||||
# be made to the API manually with invalid headers, so we need to handle them server side.
|
||||
|
||||
return headers
|
||||
|
||||
def _prepare_request_raw(
|
||||
self,
|
||||
url,
|
||||
supplied_headers,
|
||||
method,
|
||||
params,
|
||||
files,
|
||||
request_id: Optional[str],
|
||||
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
|
||||
abs_url = "%s%s" % (self.base_url, url)
|
||||
headers = self._validate_headers(supplied_headers)
|
||||
|
||||
data = None
|
||||
if method == "get" or method == "delete":
|
||||
if params:
|
||||
encoded_params = urlencode([(k, v) for k, v in params.items() if v is not None])
|
||||
abs_url = _build_api_url(abs_url, encoded_params)
|
||||
elif method in {"post", "put"}:
|
||||
if params and files:
|
||||
data = params
|
||||
if params and not files:
|
||||
data = json.dumps(params).encode()
|
||||
headers["Content-Type"] = "application/json"
|
||||
else:
|
||||
raise openai.APIConnectionError(
|
||||
"Unrecognized HTTP method %r. This may indicate a bug in the "
|
||||
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
|
||||
"assistance." % (method,)
|
||||
)
|
||||
|
||||
headers = self.request_headers(method, headers, request_id)
|
||||
|
||||
log_debug("Request to OpenAI API", method=method, path=abs_url)
|
||||
log_debug("Post details", data=data, api_version=self.api_version)
|
||||
|
||||
return abs_url, headers, data
|
||||
|
||||
def request_raw(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
*,
|
||||
params=None,
|
||||
supplied_headers: Optional[Dict[str, str]] = None,
|
||||
files=None,
|
||||
stream: bool = False,
|
||||
request_id: Optional[str] = None,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
|
||||
) -> requests.Response:
|
||||
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
|
||||
|
||||
if not hasattr(_thread_context, "session"):
|
||||
_thread_context.session = _make_session()
|
||||
_thread_context.session_create_time = time.time()
|
||||
elif time.time() - getattr(_thread_context, "session_create_time", 0) >= MAX_SESSION_LIFETIME_SECS:
|
||||
_thread_context.session.close()
|
||||
_thread_context.session = _make_session()
|
||||
_thread_context.session_create_time = time.time()
|
||||
try:
|
||||
result = _thread_context.session.request(
|
||||
method,
|
||||
abs_url,
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files,
|
||||
stream=stream,
|
||||
timeout=request_timeout if request_timeout else TIMEOUT_SECS,
|
||||
proxies=_thread_context.session.proxies,
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
|
||||
log_debug(
|
||||
"OpenAI API response",
|
||||
path=abs_url,
|
||||
response_code=result.status_code,
|
||||
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
|
||||
request_id=result.headers.get("X-Request-Id"),
|
||||
)
|
||||
return result
|
||||
|
||||
async def arequest_raw(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
session,
|
||||
*,
|
||||
params=None,
|
||||
supplied_headers: Optional[Dict[str, str]] = None,
|
||||
files=None,
|
||||
request_id: Optional[str] = None,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id)
|
||||
|
||||
if isinstance(request_timeout, tuple):
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
connect=request_timeout[0],
|
||||
total=request_timeout[1],
|
||||
)
|
||||
else:
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
|
||||
|
||||
if files:
|
||||
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
|
||||
# For now we use the private `requests` method that is known to have worked so far.
|
||||
data, content_type = requests.models.RequestEncodingMixin._encode_files(files, data) # type: ignore
|
||||
headers["Content-Type"] = content_type
|
||||
request_kwargs = {
|
||||
"method": method,
|
||||
"url": abs_url,
|
||||
"headers": headers,
|
||||
"data": data,
|
||||
"timeout": timeout,
|
||||
}
|
||||
try:
|
||||
result = await session.request(**request_kwargs)
|
||||
log_info(
|
||||
"OpenAI API response",
|
||||
path=abs_url,
|
||||
response_code=result.status,
|
||||
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
|
||||
request_id=result.headers.get("X-Request-Id"),
|
||||
)
|
||||
return result
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise openai.APITimeoutError("Request timed out") from e
|
||||
except aiohttp.ClientError as e:
|
||||
raise openai.APIConnectionError("Error communicating with OpenAI") from e
|
||||
|
||||
def _interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
|
||||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
|
||||
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> OpenAIResponse:
|
||||
...
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
yield session
|
||||
|
|
@ -2,20 +2,20 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : General Async API for http-based LLM model
|
||||
|
||||
from typing import AsyncGenerator, Tuple, Union, Optional, Literal
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Tuple, Union
|
||||
|
||||
from openai.api_requestor import APIRequestor
|
||||
import aiohttp
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.general_api_base import APIRequestor
|
||||
|
||||
|
||||
class GeneralAPIRequestor(APIRequestor):
|
||||
"""
|
||||
usage
|
||||
# full_url = "{api_base}{url}"
|
||||
requester = GeneralAPIRequestor(api_base=api_base)
|
||||
# full_url = "{base_url}{url}"
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
url=url,
|
||||
|
|
@ -26,9 +26,7 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
)
|
||||
"""
|
||||
|
||||
def _interpret_response_line(
|
||||
self, rbody: str, rcode: int, rheaders, stream: bool
|
||||
) -> str:
|
||||
def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str:
|
||||
# just do nothing to meet the APIRequestor process and return the raw data
|
||||
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
|
||||
|
||||
|
|
@ -39,11 +37,9 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
|
||||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
line, result.status, result.headers, stream=True
|
||||
)
|
||||
async for line in result.content
|
||||
), True
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
), True
|
||||
else:
|
||||
try:
|
||||
await result.read()
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@
|
|||
@File : openai.py
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
|
|
@ -18,7 +21,7 @@ from tenacity import (
|
|||
wait_fixed,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
|
|
@ -144,23 +147,40 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config):
|
||||
openai.api_key = config.openai_api_key
|
||||
if config.openai_api_base:
|
||||
openai.api_base = config.openai_api_base
|
||||
if config.openai_api_type:
|
||||
openai.api_type = config.openai_api_type
|
||||
openai.api_version = config.openai_api_version
|
||||
def __init_openai(self, config: Config):
|
||||
client_kwargs, async_client_kwargs = self.__make_client_args(config)
|
||||
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_client_kwargs)
|
||||
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
def __make_client_args(self, config: Config):
|
||||
mapping = {
|
||||
"api_key": "openai_api_key",
|
||||
"base_url": "openai_base_url",
|
||||
}
|
||||
|
||||
kwargs = {key: getattr(config, mapping[key]) for key in mapping if getattr(config, mapping[key], None)}
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# need http_client to support proxy
|
||||
if config.openai_proxy:
|
||||
httpx_args = dict(base_url=kwargs["base_url"], proxies=config.openai_proxy)
|
||||
kwargs["http_client"] = httpx.Client(**httpx_args)
|
||||
async_kwargs["http_client"] = httpx.AsyncClient(**httpx_args)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
|
|
@ -168,15 +188,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
choices = chunk["choices"]
|
||||
if len(choices) > 0:
|
||||
chunk_message = chunk["choices"][0].get("delta", {}) # extract the message
|
||||
if chunk.choices:
|
||||
chunk_message = chunk.choices[0].delta # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
if chunk_message.content:
|
||||
print(chunk_message.content, end="")
|
||||
print()
|
||||
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
full_reply_content = "".join([m.content for m in collected_messages if m.content])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
|
@ -208,24 +227,20 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
kwargs.update(kwargs_mode)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp)
|
||||
def _chat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
# if isinstance(messages[0], Message):
|
||||
# messages = self.messages_to_dict(messages)
|
||||
def completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return self._chat_completion(messages)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
# if isinstance(messages[0], Message):
|
||||
# messages = self.messages_to_dict(messages)
|
||||
async def acompletion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
@retry(
|
||||
|
|
@ -255,14 +270,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return self._cons_kwargs(messages, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict:
|
||||
rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**self._func_configs(messages, **chat_configs)
|
||||
)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
|
|
@ -317,21 +334,34 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
|
||||
usage = {}
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
try:
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
return rsp.choices[0].message.content if rsp.choices else ""
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = count_message_tokens(messages, self.model)
|
||||
completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage["prompt_tokens"] = prompt_tokens
|
||||
usage["completion_tokens"] = completion_tokens
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
return usage
|
||||
except Exception as e:
|
||||
logger.error("usage calculation failed!", e)
|
||||
else:
|
||||
return usage
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
|
||||
async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
|
@ -357,12 +387,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage["prompt_tokens"])
|
||||
completion_tokens = int(usage["completion_tokens"])
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
|
||||
|
|
@ -385,7 +413,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.error(f"moderating failed:{e}")
|
||||
|
||||
def _moderation(self, content: Union[str, list[str]]):
|
||||
rsp = self.llm.Moderation.create(input=content)
|
||||
rsp = self.client.moderations.create(input=content)
|
||||
return rsp
|
||||
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
|
|
@ -399,5 +427,5 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.error(f"moderating failed:{e}")
|
||||
|
||||
async def _amoderation(self, content: Union[str, list[str]]):
|
||||
rsp = await self.llm.Moderation.acreate(input=content)
|
||||
rsp = await self.async_client.moderations.create(input=content)
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -3,15 +3,14 @@
|
|||
# @Desc : zhipu model api to support sync & async for invoke & sse_invoke
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import ModelAPI, InvokeType
|
||||
from zhipuai.model_api.api import InvokeType, ModelAPI
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
class ZhiPuModelAPI(ModelAPI):
|
||||
|
||||
@classmethod
|
||||
def get_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
|
|
@ -21,9 +20,7 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
@classmethod
|
||||
def get_sse_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
headers = {
|
||||
"Authorization": token
|
||||
}
|
||||
headers = {"Authorization": token}
|
||||
return headers
|
||||
|
||||
@classmethod
|
||||
|
|
@ -44,36 +41,32 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
# TODO to make the async request to be more generic for models in http mode.
|
||||
assert method in ["post", "get"]
|
||||
|
||||
api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
requester = GeneralAPIRequestor(api_base=api_base)
|
||||
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds
|
||||
request_timeout=zhipuai.api_timeout_seconds,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def ainvoke(cls, **kwargs) -> dict:
|
||||
""" async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
headers = cls.get_header()
|
||||
resp = await cls.arequest(invoke_type=InvokeType.SYNC,
|
||||
stream=False,
|
||||
method="post",
|
||||
headers=headers,
|
||||
kwargs=kwargs)
|
||||
resp = await cls.arequest(
|
||||
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
|
||||
""" async sse_invoke """
|
||||
"""async sse_invoke"""
|
||||
headers = cls.get_sse_header()
|
||||
return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE,
|
||||
stream=True,
|
||||
method="post",
|
||||
headers=headers,
|
||||
kwargs=kwargs))
|
||||
return AsyncSSEClient(
|
||||
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,26 @@
|
|||
import inspect
|
||||
import re
|
||||
from typing import List, Callable, Dict
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import wrapt
|
||||
import textwrap
|
||||
import inspect
|
||||
from interpreter.core.core import Interpreter
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.highlight import highlight
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script
|
||||
|
||||
|
||||
def extract_python_code(code: str):
|
||||
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
|
||||
# Use regular expressions to match comment blocks and related code.
|
||||
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
|
||||
pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
|
||||
matches = re.findall(pattern, code, re.DOTALL)
|
||||
|
||||
# Extract the last code block when encountering the same comment.
|
||||
|
|
@ -25,8 +29,8 @@ def extract_python_code(code: str):
|
|||
unique_comments[comment] = code_block
|
||||
|
||||
# concatenate into functional form
|
||||
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
header_code = code[:code.find("#")]
|
||||
result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
header_code = code[: code.find("#")]
|
||||
code = header_code + result_code
|
||||
|
||||
logger.info(f"Extract python code: \n {highlight(code)}")
|
||||
|
|
@ -36,12 +40,12 @@ def extract_python_code(code: str):
|
|||
|
||||
class OpenCodeInterpreter(object):
|
||||
"""https://github.com/KillianLucas/open-interpreter"""
|
||||
|
||||
def __init__(self, auto_run: bool = True) -> None:
|
||||
interpreter = Interpreter()
|
||||
interpreter.auto_run = auto_run
|
||||
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
|
||||
interpreter.api_key = CONFIG.openai_api_key
|
||||
# interpreter.api_base = CONFIG.openai_api_base
|
||||
self.interpreter = interpreter
|
||||
|
||||
def chat(self, query: str, reset: bool = True):
|
||||
|
|
@ -50,15 +54,16 @@ class OpenCodeInterpreter(object):
|
|||
return self.interpreter.chat(query)
|
||||
|
||||
@staticmethod
|
||||
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
|
||||
function_format: str = None) -> str:
|
||||
def extract_function(
|
||||
query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
|
||||
) -> str:
|
||||
"""create a function from query_respond."""
|
||||
if language not in ('python'):
|
||||
if language not in ("python"):
|
||||
raise NotImplementedError(f"Not support to parse language {language}!")
|
||||
|
||||
# set function form
|
||||
if function_format is None:
|
||||
assert language == 'python', f"Expect python language for default function_format, but got {language}."
|
||||
assert language == "python", f"Expect python language for default function_format, but got {language}."
|
||||
function_format = """def {function_name}():\n{code}"""
|
||||
# Extract the code module in the open-interpreter respond message.
|
||||
# The query_respond of open-interpreter before v0.1.4 is:
|
||||
|
|
@ -68,25 +73,29 @@ class OpenCodeInterpreter(object):
|
|||
# "parsed_arguments": {"language": "python", "code": code of first plan}
|
||||
# ...]
|
||||
if "function_call" in query_respond[1]:
|
||||
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
|
||||
if "function_call" in item
|
||||
and "parsed_arguments" in item["function_call"]
|
||||
and 'language' in item["function_call"]['parsed_arguments']
|
||||
and item["function_call"]['parsed_arguments']['language'] == language]
|
||||
code = [
|
||||
item["function_call"]["parsed_arguments"]["code"]
|
||||
for item in query_respond
|
||||
if "function_call" in item
|
||||
and "parsed_arguments" in item["function_call"]
|
||||
and "language" in item["function_call"]["parsed_arguments"]
|
||||
and item["function_call"]["parsed_arguments"]["language"] == language
|
||||
]
|
||||
# The query_respond of open-interpreter v0.1.7 is:
|
||||
# [{'role': 'user', 'message': your query string},
|
||||
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
|
||||
# 'code': code of first plan, 'output': output of first plan code},
|
||||
# ...]
|
||||
elif "code" in query_respond[1]:
|
||||
code = [item['code'] for item in query_respond
|
||||
if "code" in item
|
||||
and 'language' in item
|
||||
and item['language'] == language]
|
||||
code = [
|
||||
item["code"]
|
||||
for item in query_respond
|
||||
if "code" in item and "language" in item and item["language"] == language
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
|
||||
# add indent.
|
||||
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
|
||||
indented_code_str = textwrap.indent("\n".join(code), " " * 4)
|
||||
# Return the code after deduplication.
|
||||
if language == "python":
|
||||
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
|
||||
|
|
@ -115,13 +124,13 @@ class OpenInterpreterDecorator(object):
|
|||
|
||||
def _have_code(self, rsp: List[Dict]):
|
||||
# Is there any code generated?
|
||||
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
|
||||
return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
|
||||
|
||||
def _is_faild_plan(self, rsp: List[Dict]):
|
||||
# is faild plan?
|
||||
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
|
||||
func_code = OpenCodeInterpreter.extract_function(rsp, "function")
|
||||
# If there is no more than 1 '\n', the plan execution fails.
|
||||
if isinstance(func_code, str) and func_code.count('\n') <= 1:
|
||||
if isinstance(func_code, str) and func_code.count("\n") <= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -184,4 +193,5 @@ class OpenInterpreterDecorator(object):
|
|||
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
|
||||
raise Exception("Could not evaluate Python code", e)
|
||||
return res
|
||||
|
||||
return wrapper(wrapped)
|
||||
|
|
|
|||
|
|
@ -21,14 +21,12 @@ def make_sk_kernel():
|
|||
if CONFIG.openai_api_type == "azure":
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key),
|
||||
AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_base_url, CONFIG.openai_api_key),
|
||||
)
|
||||
else:
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(
|
||||
CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base
|
||||
),
|
||||
OpenAIChatCompletion(CONFIG.openai_api_model, CONFIG.openai_api_key),
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue