From a69f07e62f6f60da7701658877aa7cd199c34bcc Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 6 Jul 2023 00:01:47 +0800 Subject: [PATCH] support streaming. not enable yet. --- examples/llm_hello_world.py | 5 ++++- metagpt/provider/openai_api.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index eb4679b03..e62f8dc3c 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -17,11 +17,14 @@ async def main(): logger.info(await llm.aask('hello world')) logger.info(await llm.aask_batch(['hi', 'write python hello world.'])) - hello_msg = [{'role': 'user', 'content': 'hello'}] + hello_msg = [{'role': 'user', 'content': 'count from 1 to 10. split by newline.'}] logger.info(await llm.acompletion(hello_msg)) logger.info(await llm.acompletion_batch([hello_msg])) logger.info(await llm.acompletion_batch_text([hello_msg])) + logger.info(await llm.acompletion_text(hello_msg)) + await llm.acompletion_text(hello_msg, stream=True) + if __name__ == '__main__': asyncio.run(main()) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index d1401af7e..a19d0cac8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -145,6 +145,31 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): openai.api_version = config.openai_api_version self.rpm = int(config.get("RPM", 10)) + async def _achat_completion_stream(self, messages: list[dict]) -> str: + response = await openai.ChatCompletion.acreate( + model=self.model, + messages=messages, + max_tokens=self.config.max_tokens_rsp, + n=1, + stop=None, + temperature=0, + stream=True + ) + + # create variables to collect the stream of chunks + collected_chunks = [] + collected_messages = [] + # iterate through the stream of events + async for chunk in response: + collected_chunks.append(chunk) # save the event response + chunk_message = chunk['choices'][0]['delta'] # extract the message + collected_messages.append(chunk_message) # save the message + if "content" in chunk_message: + print(chunk_message["content"], end="") + + full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) + return full_reply_content + async def _achat_completion(self, messages: list[dict]) -> dict: rsp = await self.llm.ChatCompletion.acreate( model=self.model, @@ -180,7 +205,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # messages = self.messages_to_dict(messages) return await self._achat_completion(messages) - async def acompletion_text(self, messages: list[dict]) -> str: + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + if stream: + return await self._achat_completion_stream(messages) rsp = await self._achat_completion(messages) return self.get_choice_text(rsp)