fixbug: OpenAIGPTAPI:_achat_completion_stream

This commit is contained in:
莘权 马 2023-12-22 17:43:59 +08:00
parent b445c3f4b6
commit 5d97a20e08
4 changed files with 358 additions and 357 deletions

View file

@ -93,13 +93,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base)
RateLimiter.__init__(self, rpm=self.rpm)
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
kwargs = self._cons_kwargs(messages, timeout=timeout)
response = await self._client.chat.completions.create(**kwargs, stream=True)
# iterate through the stream of events
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" # extract the message
yield chunk_message
# async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
# kwargs = self._cons_kwargs(messages, timeout=timeout)
# response = await self._client.chat.completions.create(**kwargs, stream=True)
# # iterate through the stream of events
# async for chunk in response:
# chunk_message = chunk.choices[0].delta.content or "" # extract the message
# yield chunk_message
def __init_openai(self):
self.rpm = int(self.config.get("RPM", 10))
@ -131,9 +131,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return params
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
# create variables to collect the stream of chunks

View file

@ -70,22 +70,22 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict]) -> dict:
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
collected_content = []
usage = {}
@ -128,9 +128,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
return await self._achat_completion_stream(messages, timeout=timeout)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)