support streaming. not enable yet.

This commit is contained in:
geekan 2023-07-06 00:01:47 +08:00
parent 504de57b29
commit a69f07e62f
2 changed files with 32 additions and 2 deletions

View file

@ -17,11 +17,14 @@ async def main():
logger.info(await llm.aask('hello world'))
logger.info(await llm.aask_batch(['hi', 'write python hello world.']))
hello_msg = [{'role': 'user', 'content': 'hello'}]
hello_msg = [{'role': 'user', 'content': 'count from 1 to 10. split by newline.'}]
logger.info(await llm.acompletion(hello_msg))
logger.info(await llm.acompletion_batch([hello_msg]))
logger.info(await llm.acompletion_batch_text([hello_msg]))
logger.info(await llm.acompletion_text(hello_msg))
await llm.acompletion_text(hello_msg, stream=True)
if __name__ == '__main__':
asyncio.run(main())

View file

@ -145,6 +145,31 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
openai.api_version = config.openai_api_version
self.rpm = int(config.get("RPM", 10))
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await openai.ChatCompletion.acreate(
model=self.model,
messages=messages,
max_tokens=self.config.max_tokens_rsp,
n=1,
stop=None,
temperature=0,
stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk['choices'][0]['delta'] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
return full_reply_content
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(
model=self.model,
@ -180,7 +205,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# messages = self.messages_to_dict(messages)
return await self._achat_completion(messages)
async def acompletion_text(self, messages: list[dict]) -> str:
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)