From 4007fc87d6ce2d016445fe0d675284ebcbec33ca Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 26 Dec 2023 16:33:15 +0800 Subject: [PATCH] remove sync api in openai --- metagpt/provider/base_gpt_api.py | 6 -- metagpt/provider/openai_api.py | 57 ++++++------------- tests/metagpt/test_llm.py | 7 --- .../metagpt/utils/test_custom_aio_session.py | 21 ------- 4 files changed, 16 insertions(+), 75 deletions(-) delete mode 100644 tests/metagpt/utils/test_custom_aio_session.py diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index 90cf59fd4..cae55431f 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -90,7 +90,6 @@ class BaseGPTAPI(BaseChatbot): rsp_text = await self.aask_batch(msgs, timeout=timeout) return rsp_text - @abstractmethod def completion(self, messages: list[dict], timeout=3): """All GPTAPIs are required to provide the standard OpenAI completion interface [ @@ -166,8 +165,3 @@ class BaseGPTAPI(BaseChatbot): def messages_to_dict(self, messages): """objects to [{"role": "user", "content": msg}] etc.""" return [i.to_dict() for i in messages] - - @abstractmethod - async def close(self): - """Close connection""" - pass diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index ea58f690b..bfd6c7917 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -84,20 +84,21 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def __init__(self): self.config: Config = CONFIG self._init_openai() + self._init_client() self.auto_max_tokens = False RateLimiter.__init__(self, rpm=self.rpm) + super().__init__() def _init_openai(self): self.rpm = int(self.config.openai_api_rpm) - self._init_client() - - def _init_client(self): - kwargs = self._make_client_kwargs() - # https://github.com/openai/openai-python#async-usage - self.aclient = AsyncOpenAI(**kwargs) self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs - def _make_client_kwargs(self) -> (dict, dict): + def _init_client(self): + """https://github.com/openai/openai-python#async-usage""" + kwargs = self._make_client_kwargs() + self.aclient = AsyncOpenAI(**kwargs) + + def _make_client_kwargs(self) -> dict: kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL} # to use proxy, openai v1 needs http_client @@ -124,19 +125,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message yield chunk_message - def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: + def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: kwargs = { "messages": messages, - "max_tokens": self.get_max_tokens(messages), + "max_tokens": self._get_max_tokens(messages), "n": 1, "stop": None, "temperature": 0.3, "model": self.model, + "timeout": max(CONFIG.timeout, timeout), } - if configs: - kwargs.update(configs) - kwargs["timeout"] = max(CONFIG.timeout, timeout) - + if extra_kwargs: + kwargs.update(extra_kwargs) return kwargs async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: @@ -242,36 +242,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): usage.prompt_tokens = count_message_tokens(messages, self.model) usage.completion_tokens = count_string_tokens(rsp, self.model) except Exception as e: - logger.error(f"usage calculation failed!: {e}") + logger.error(f"usage calculation failed: {e}") return usage - async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]: - """Return full JSON""" - split_batches = self.split_batches(batch) - all_results = [] - - for small_batch in split_batches: - logger.info(small_batch) - await self.wait_if_needed(len(small_batch)) - - future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch] - results = await asyncio.gather(*future) - logger.info(results) - all_results.extend(results) - - return all_results - - async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]: - """Only return plain text""" - raw_results = await self.acompletion_batch(batch, timeout=timeout) - results = [] - for idx, raw_result in enumerate(raw_results, start=1): - result = self.get_choice_text(raw_result) - results.append(result) - logger.info(f"Result of task {idx}: {result}") - return results - @handle_exception def _update_costs(self, usage: CompletionUsage): if CONFIG.calc_usage and usage: @@ -280,11 +254,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def get_costs(self) -> Costs: return CONFIG.cost_manager.get_costs() - def get_max_tokens(self, messages: list[dict]): + def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) @handle_exception async def amoderation(self, content: Union[str, list[str]]): + """Moderate content.""" return await self.aclient.moderations.create(input=content) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 31e6c2b24..bc685ed8b 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -23,18 +23,11 @@ async def test_llm_aask(llm): assert len(rsp) > 0 -@pytest.mark.asyncio -async def test_llm_aask_batch(llm): - assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0 - - @pytest.mark.asyncio async def test_llm_acompletion(llm): hello_msg = [{"role": "user", "content": "hello"}] rsp = await llm.acompletion(hello_msg) assert len(rsp.choices[0].message.content) > 0 - assert len(await llm.acompletion_batch([hello_msg])) > 0 - assert len(await llm.acompletion_batch_text([hello_msg])) > 0 if __name__ == "__main__": diff --git a/tests/metagpt/utils/test_custom_aio_session.py b/tests/metagpt/utils/test_custom_aio_session.py deleted file mode 100644 index e2876e4b8..000000000 --- a/tests/metagpt/utils/test_custom_aio_session.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/7 17:23 -@Author : alexanderwu -@File : test_custom_aio_session.py -""" -from metagpt.logs import logger -from metagpt.provider.openai_api import OpenAIGPTAPI - - -async def try_hello(api): - batch = [[{"role": "user", "content": "hello"}]] - results = await api.acompletion_batch_text(batch) - return results - - -async def aask_batch(api: OpenAIGPTAPI): - results = await api.aask_batch(["hi", "write python hello world."]) - logger.info(results) - return results