remove generator para in acompletion_text

This commit is contained in:
geekan 2023-12-26 14:31:26 +08:00
parent 38f1c4f63b
commit 8351c8ec35
2 changed files with 4 additions and 7 deletions

View file

@ -43,7 +43,6 @@ class BaseGPTAPI(BaseChatbot):
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
timeout=3,
stream=True,
) -> str:
@ -54,7 +53,7 @@ class BaseGPTAPI(BaseChatbot):
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
# logger.debug(rsp)
return rsp

View file

@ -12,7 +12,7 @@
import asyncio
import json
import time
from typing import List, Union
from typing import AsyncIterator, List, Union
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
@ -123,7 +123,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
@ -171,12 +171,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)
if generator:
return resp
collected_messages = []
async for i in resp: