Merge pull request #29 from send18/feature-llm-aask-stream

add llm.aask generator
This commit is contained in:
send18 2023-09-04 17:48:51 +08:00 committed by GitHub
commit 230b1afb83
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 18 deletions

View file

@ -38,13 +38,13 @@ class BaseGPTAPI(BaseChatbot):
rsp = self.completion(message)
return self.get_choice_text(rsp)
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
message = [self._default_system_msg(), self._user_msg(msg)]
try:
rsp = await self.acompletion_text(message, stream=True)
rsp = await self.acompletion_text(message, stream=True, generator=generator)
except Exception as e:
logger.exception(f"{e}")
logger.info(f"ask:{msg}, error:{e}")

View file

@ -87,22 +87,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
response = await self.async_retry_call(
openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
yield chunk_message["content"]
def _cons_kwargs(self, messages: list[dict]) -> dict:
if CONFIG.openai_api_type == "azure":
@ -157,10 +146,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
resp = self._achat_completion_stream(messages)
if generator:
return resp
collected_messages = []
async for i in resp:
print(i, end="")
collected_messages.append(i)
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)
@ -226,13 +228,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
max_count = 100
while max_count > 0:
if len(text) < max_token_count:
return await self._get_summary(text=text, max_words=max_words,keep_language=keep_language)
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=max_words,keep_language=keep_language)
response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]