remove sync api in openai

This commit is contained in:
geekan 2023-12-26 15:59:11 +08:00
parent e15de55368
commit bb1b9823d0
14 changed files with 52 additions and 207 deletions

View file

@ -10,12 +10,12 @@
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.config import LLMProviderEnum
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAIGPTAPI
@register_provider(LLMProviderEnum.AZURE_OPENAI)
@ -24,46 +24,22 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI):
Check https://platform.openai.com/examples for examples
"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
self.auto_max_tokens = False
RateLimiter.__init__(self, rpm=self.rpm)
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
self.async_client = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
def _make_client_kwargs(self) -> dict:
kwargs = dict(
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs, async_kwargs
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.model,
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs

View file

@ -112,7 +112,7 @@ class BaseGPTAPI(BaseChatbot):
"""
@abstractmethod
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
def get_choice_text(self, rsp: dict) -> str:

View file

@ -83,7 +83,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self._init_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
@ -103,7 +103,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
return self._cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
@ -133,9 +133,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -147,9 +147,7 @@ class OllamaGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -46,7 +46,7 @@ class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self._init_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):

View file

@ -14,9 +14,8 @@ import json
import time
from typing import AsyncIterator, Union
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from tenacity import (
@ -80,9 +79,7 @@ See FAQ 5.8
@register_provider(LLMProviderEnum.OPENAI)
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples
"""
"""Check https://platform.openai.com/examples for examples"""
def __init__(self):
self.config: Config = CONFIG
@ -91,27 +88,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
RateLimiter.__init__(self, rpm=self.rpm)
def _init_openai(self):
self.rpm = int(self.config.RPM or 10)
self._make_client()
self.rpm = int(self.config.openai_api_rpm)
self._init_client()
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://github.com/openai/openai-python#async-usage
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
self.aclient = AsyncOpenAI(**kwargs)
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
async_kwargs = kwargs.copy()
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs, async_kwargs
return kwargs
def _get_proxy_params(self) -> dict:
params = {}
@ -123,7 +116,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
@ -148,18 +141,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=timeout)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
self._update_costs(rsp.usage)
return rsp
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return self._chat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)
@ -199,14 +184,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.usage)
return rsp
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
@ -226,56 +206,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
)
return messages
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> llm = OpenAIGPTAPI()
>>> llm.ask_code("Write a python hello world code.")
{'language': 'python', 'code': "print('Hello, World!')"}
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> llm.ask_code(msg)
{'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
rsp = self._chat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> llm = OpenAIGPTAPI()
>>> rsp = await llm.ask_code("Write a python hello world code.")
>>> rsp
{'language': 'python', 'code': "print('Hello, World!')"}
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
>>> rsp = await llm.aask_code(msg)
# -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
try:
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
except openai.BadRequestError as e:
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
raise e
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
@handle_exception
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
:return dict: return the first function arguments of choice, for example,
{'language': 'python', 'code': "print('Hello, World!')"}
"""
try:
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
except json.JSONDecodeError:
return {}
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
def get_choice_text(self, rsp: ChatCompletion) -> str:
"""Required to provide the first text of choice"""
@ -320,12 +272,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.info(f"Result of task {idx}: {result}")
return results
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
try:
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()
@ -335,18 +285,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
def moderation(self, content: Union[str, list[str]]):
return self.client.moderations.create(input=content)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
return await self.async_client.moderations.create(input=content)
async def close(self):
"""Close connection"""
if self.client:
self.client.close()
self.client = None
if self.async_client:
await self.async_client.close()
self.async_client = None
return await self.aclient.moderations.create(input=content)

View file

@ -50,9 +50,7 @@ class SparkGPTAPI(BaseGPTAPI):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)

View file

@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -25,10 +25,6 @@ class OpenAIText2Image:
self._llm = LLM()
self._client = self._llm.async_client
def __del__(self):
if self._llm:
self._llm.close()
async def text_2_image(self, text, size_type="1024x1024"):
"""Text to image

View file

@ -278,11 +278,11 @@ class UTGenerator:
question += self.build_api_doc(node, path, method)
self.ask_gpt_and_save(question, tag, summary)
def gpt_msgs_to_code(self, messages: list) -> str:
async def gpt_msgs_to_code(self, messages: list) -> str:
"""Choose based on different calling methods"""
result = ""
if self.chatgpt_method == "API":
result = GPTAPI().ask_code(msgs=messages)
result = await GPTAPI().aask_code(msgs=messages)
return result

View file

@ -34,7 +34,7 @@ class MockBaseGPTAPI(BaseGPTAPI):
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
@ -87,14 +87,14 @@ def test_base_gpt_api():
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
resp = base_gpt_api.ask(prompt_msg)
assert resp == resp_content
# resp = base_gpt_api.ask(prompt_msg)
# assert resp == resp_content
resp = base_gpt_api.ask_batch([prompt_msg])
assert resp == resp_content
# resp = base_gpt_api.ask_batch([prompt_msg])
# assert resp == resp_content
resp = base_gpt_api.ask_code([prompt_msg])
assert resp == resp_content
# resp = base_gpt_api.ask_code([prompt_msg])
# assert resp == resp_content
@pytest.mark.asyncio

View file

@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return default_resp.choices[0].message.content
def test_fireworks_completion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
fireworks_gpt = FireWorksGPTAPI()
resp = fireworks_gpt.completion(messages)
assert resp.choices[0].message.content == resp_content
resp = fireworks_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)

View file

@ -36,52 +36,6 @@ async def test_aask_code_Message():
assert len(rsp["code"]) > 0
def test_ask_code():
llm = OpenAIGPTAPI()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_str():
llm = OpenAIGPTAPI()
msg = "Write a python hello world code."
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_Message():
llm = OpenAIGPTAPI()
msg = UserMessage("Write a python hello world code.")
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_Message():
llm = OpenAIGPTAPI()
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_str():
llm = OpenAIGPTAPI()
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
print(rsp)
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
class TestOpenAI:
@pytest.fixture
def config(self):