diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index af0cf2ec0..7351e6916 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -38,13 +38,13 @@ class BaseGPTAPI(BaseChatbot): rsp = self.completion(message) return self.get_choice_text(rsp) - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: + async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str: if system_msgs: message = self._system_msgs(system_msgs) + [self._user_msg(msg)] else: message = [self._default_system_msg(), self._user_msg(msg)] try: - rsp = await self.acompletion_text(message, stream=True) + rsp = await self.acompletion_text(message, stream=True, generator=generator) except Exception as e: logger.exception(f"{e}") logger.info(f"ask:{msg}, error:{e}") diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 5c11ed7a6..d0dd5b9d8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -87,22 +87,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): response = await self.async_retry_call( openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True ) - # create variables to collect the stream of chunks - collected_chunks = [] - collected_messages = [] # iterate through the stream of events async for chunk in response: - collected_chunks.append(chunk) # save the event response chunk_message = chunk["choices"][0]["delta"] # extract the message - collected_messages.append(chunk_message) # save the message if "content" in chunk_message: - print(chunk_message["content"], end="") - print() - - full_reply_content = "".join([m.get("content", "") for m in collected_messages]) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content + yield chunk_message["content"] def _cons_kwargs(self, messages: list[dict]) -> dict: if CONFIG.openai_api_type == "azure": @@ -157,10 +146,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str: """when streaming, print each token in place.""" if stream: - return await self._achat_completion_stream(messages) + resp = self._achat_completion_stream(messages) + if generator: + return resp + + collected_messages = [] + async for i in resp: + print(i, end="") + collected_messages.append(i) + + full_reply_content = "".join(collected_messages) + usage = self._calc_usage(messages, full_reply_content) + self._update_costs(usage) + return full_reply_content + rsp = await self._achat_completion(messages) return self.get_choice_text(rsp) @@ -226,13 +228,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): max_count = 100 while max_count > 0: if len(text) < max_token_count: - return await self._get_summary(text=text, max_words=max_words,keep_language=keep_language) + return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) padding_size = 20 if max_token_count > 20 else 0 text_windows = self.split_texts(text, window_size=max_token_count - padding_size) summaries = [] for ws in text_windows: - response = await self._get_summary(text=ws, max_words=max_words,keep_language=keep_language) + response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language) summaries.append(response) if len(summaries) == 1: return summaries[0]