mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
remove sync api in openai
This commit is contained in:
parent
e15de55368
commit
bb1b9823d0
14 changed files with 52 additions and 207 deletions
|
|
@ -10,12 +10,12 @@
|
|||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
|
|
@ -24,46 +24,22 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI):
|
|||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
def _init_client(self):
|
||||
kwargs = self._make_client_kwargs()
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
api_key=self.config.OPENAI_API_KEY,
|
||||
api_version=self.config.OPENAI_API_VERSION,
|
||||
azure_endpoint=self.config.OPENAI_BASE_URL,
|
||||
)
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self._init_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
|
|
@ -103,7 +103,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
return self._cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
|
|
@ -133,9 +133,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -147,9 +147,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class OpenLLMGPTAPI(OpenAIGPTAPI):
|
|||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self._init_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
|
|
|
|||
|
|
@ -14,9 +14,8 @@ import json
|
|||
import time
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import openai
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from tenacity import (
|
||||
|
|
@ -80,9 +79,7 @@ See FAQ 5.8
|
|||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
|
|
@ -91,27 +88,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def _init_openai(self):
|
||||
self.rpm = int(self.config.RPM or 10)
|
||||
self._make_client()
|
||||
self.rpm = int(self.config.openai_api_rpm)
|
||||
self._init_client()
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
def _init_client(self):
|
||||
kwargs = self._make_client_kwargs()
|
||||
# https://github.com/openai/openai-python#async-usage
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
|
||||
async_kwargs = kwargs.copy()
|
||||
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
if proxy_params := self._get_proxy_params():
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
return kwargs
|
||||
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
|
|
@ -123,7 +116,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
|
|
@ -148,18 +141,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return self._chat_completion(messages, timeout=timeout)
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
|
|
@ -199,14 +184,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
|
|
@ -226,56 +206,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
)
|
||||
return messages
|
||||
|
||||
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> llm.ask_code("Write a python hello world code.")
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> llm.ask_code(msg)
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = self._chat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> rsp = await llm.ask_code("Write a python hello world code.")
|
||||
>>> rsp
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
try:
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
except openai.BadRequestError as e:
|
||||
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
|
||||
raise e
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
@handle_exception
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
try:
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
@ -320,12 +272,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
try:
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
|
@ -335,18 +285,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
return self.client.moderations.create(input=content)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
return await self.async_client.moderations.create(input=content)
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
if self.async_client:
|
||||
await self.async_client.close()
|
||||
self.async_client = None
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
|
|
|||
|
|
@ -50,9 +50,7 @@ class SparkGPTAPI(BaseGPTAPI):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ class OpenAIText2Image:
|
|||
self._llm = LLM()
|
||||
self._client = self._llm.async_client
|
||||
|
||||
def __del__(self):
|
||||
if self._llm:
|
||||
self._llm.close()
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
|
|
|
|||
|
|
@ -278,11 +278,11 @@ class UTGenerator:
|
|||
question += self.build_api_doc(node, path, method)
|
||||
self.ask_gpt_and_save(question, tag, summary)
|
||||
|
||||
def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
async def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = GPTAPI().ask_code(msgs=messages)
|
||||
result = await GPTAPI().aask_code(msgs=messages)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class MockBaseGPTAPI(BaseGPTAPI):
|
|||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
|
|
@ -87,14 +87,14 @@ def test_base_gpt_api():
|
|||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
resp = base_gpt_api.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return default_resp.choices[0].message.content
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = fireworks_gpt.completion(messages)
|
||||
assert resp.choices[0].message.content == resp_content
|
||||
|
||||
resp = fireworks_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -36,52 +36,6 @@ async def test_aask_code_Message():
|
|||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue