From 8351c8ec3511dcfba6667aa5413bc895f42593ed Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 26 Dec 2023 14:31:26 +0800 Subject: [PATCH] remove generator para in acompletion_text --- metagpt/provider/base_gpt_api.py | 3 +-- metagpt/provider/openai_api.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index a5541324f..c7417af90 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -43,7 +43,6 @@ class BaseGPTAPI(BaseChatbot): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, - generator: bool = False, timeout=3, stream=True, ) -> str: @@ -54,7 +53,7 @@ class BaseGPTAPI(BaseChatbot): if format_msgs: message.extend(format_msgs) message.append(self._user_msg(msg)) - rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout) + rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) # logger.debug(rsp) return rsp diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 195d2ea16..405d523e5 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -12,7 +12,7 @@ import asyncio import json import time -from typing import List, Union +from typing import AsyncIterator, List, Union import openai from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI @@ -123,7 +123,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return params - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]: response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( **self._cons_kwargs(messages, timeout=timeout), stream=True ) @@ -171,12 +171,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: resp = self._achat_completion_stream(messages, timeout=timeout) - if generator: - return resp collected_messages = [] async for i in resp: