fixbug: OpenAIGPTAPI:_achat_completion_stream

This commit is contained in:
莘权 马 2023-12-22 17:43:59 +08:00
parent b445c3f4b6
commit 5d97a20e08
4 changed files with 358 additions and 357 deletions

View file

@ -7,341 +7,340 @@
@Desc : Support memory for multiple tasks and multiple mainlines. Obsoleted by `utils/*_repository.py`.
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
"""
import json
import re
from enum import Enum
from typing import Dict, List, Optional
import openai
import pydantic
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
from metagpt.llm import LLMType
from metagpt.logs import logger
from metagpt.schema import Message, RawMessage
from metagpt.utils.redis import Redis
class MessageType(Enum):
Talk = "TALK"
Solution = "SOLUTION"
Problem = "PROBLEM"
Skill = "SKILL"
Answer = "ANSWER"
class BrainMemory(pydantic.BaseModel):
history: List[Dict] = []
stack: List[Dict] = []
solution: List[Dict] = []
knowledge: List[Dict] = []
historical_summary: str = ""
last_history_id: str = ""
is_dirty: bool = False
last_talk: str = None
llm_type: Optional[str] = None
cacheable: bool = True
def add_talk(self, msg: Message):
msg.role = "user"
self.add_history(msg)
self.is_dirty = True
def add_answer(self, msg: Message):
msg.role = "assistant"
self.add_history(msg)
self.is_dirty = True
def get_knowledge(self) -> str:
texts = [Message(**m).content for m in self.knowledge]
return "\n".join(texts)
@staticmethod
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return BrainMemory(llm_type=CONFIG.LLM_TYPE)
v = await redis.get(key=redis_key)
logger.debug(f"REDIS GET {redis_key} {v}")
if v:
data = json.loads(v)
bm = BrainMemory(**data)
bm.is_dirty = False
return bm
return BrainMemory(llm_type=CONFIG.LLM_TYPE)
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
if not self.is_dirty:
return
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return False
v = self.json()
if self.cacheable:
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
logger.debug(f"REDIS SET {redis_key} {v}")
self.is_dirty = False
@staticmethod
def to_redis_key(prefix: str, user_id: str, chat_id: str):
return f"{prefix}:{user_id}:{chat_id}"
async def set_history_summary(self, history_summary, redis_key, redis_conf):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
return
self.historical_summary = history_summary
self.history = []
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
def add_history(self, msg: Message):
if msg.id:
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
return
self.history.append(msg.dict())
self.last_history_id = str(msg.id)
self.is_dirty = True
def exists(self, text) -> bool:
for m in reversed(self.history):
if m.get("content") == text:
return True
return False
@staticmethod
def to_int(v, default_value):
try:
return int(v)
except:
return default_value
def pop_last_talk(self):
v = self.last_talk
self.last_talk = None
return v
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
if self.llm_type == LLMType.METAGPT.value:
return await self._metagpt_summarize(llm=llm, max_words=max_words, keep_language=keep_language, **kwargs)
return await self._openai_summarize(
llm=llm, max_words=max_words, keep_language=keep_language, limit=limit, **kwargs
)
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
texts = [self.historical_summary]
for i in self.history:
m = Message(**i)
texts.append(m.content)
text = "\n".join(texts)
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(
text=ws, llm=llm, max_words=part_max_words, keep_language=keep_language
)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
if summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
return summary
raise openai.InvalidRequestError(message="text too long", param=None)
async def _metagpt_summarize(self, max_words=200, **kwargs):
if not self.history:
return ""
total_length = 0
msgs = []
for i in reversed(self.history):
m = Message(**i)
delta = len(m.content)
if total_length + delta > max_words:
left = max_words - total_length
if left == 0:
break
m.content = m.content[0:left]
msgs.append(m.dict())
break
msgs.append(i)
total_length += delta
msgs.reverse()
self.history = msgs
self.is_dirty = True
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
self.is_dirty = False
return BrainMemory.to_metagpt_history_format(self.history)
@staticmethod
def to_metagpt_history_format(history) -> str:
mmsg = []
for m in history:
msg = Message(**m)
r = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content)
mmsg.append(r)
return json.dumps(mmsg)
@staticmethod
async def _get_summary(text: str, llm, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
async def get_title(self, llm, max_words=5, **kwargs) -> str:
"""Generate text title"""
if self.llm_type == LLMType.METAGPT.value:
return Message(**self.history[0]).content if self.history else "New"
summary = await self.summarize(llm=llm, max_words=500)
language = CONFIG.language or DEFAULT_LANGUAGE
command = f"Translate the above summary into a {language} title of less than {max_words} words."
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
logger.debug(f"title rsp: {response}")
return response
async def is_related(self, text1, text2, llm):
if self.llm_type == LLMType.METAGPT.value:
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
@staticmethod
async def _metagpt_is_related(**kwargs):
return False
@staticmethod
async def _openai_is_related(text1, text2, llm, **kwargs):
# command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
rsp = await llm.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
return result
async def rewrite(self, sentence: str, context: str, llm):
if self.llm_type == LLMType.METAGPT.value:
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
async def _metagpt_rewrite(self, sentence: str, **kwargs):
return sentence
async def _openai_rewrite(self, sentence: str, context: str, llm, **kwargs):
# command = (
# f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
# )
command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
rsp = await llm.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
return rsp
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = BrainMemory.DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows
@staticmethod
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
match = re.match(pattern, input_string)
if match:
return match.group(1), match.group(2)
else:
return None, input_string
def set_llm_type(self, v):
if v and v != self.llm_type:
self.llm_type = v
self.is_dirty = True
@property
def is_history_available(self):
return bool(self.history or self.historical_summary)
@property
def history_text(self):
if self.llm_type == LLMType.METAGPT.value:
return self._get_metagpt_history_text()
return self._get_openai_history_text()
def _get_metagpt_history_text(self):
return BrainMemory.to_metagpt_history_format(self.history)
def _get_openai_history_text(self):
if len(self.history) == 0 and not self.historical_summary:
return ""
texts = [self.historical_summary] if self.historical_summary else []
for m in self.history[:-1]:
if isinstance(m, Dict):
t = Message(**m).content
elif isinstance(m, Message):
t = m.content
else:
continue
texts.append(t)
return "\n".join(texts)
DEFAULT_TOKEN_SIZE = 500
# import json
# import re
# from enum import Enum
# from typing import Dict, List, Optional
#
# import openai
# import pydantic
#
# from metagpt.config import CONFIG
# from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
# from metagpt.logs import logger
# from metagpt.schema import Message, RawMessage
# from metagpt.utils.redis import Redis
#
#
# class MessageType(Enum):
# Talk = "TALK"
# Solution = "SOLUTION"
# Problem = "PROBLEM"
# Skill = "SKILL"
# Answer = "ANSWER"
#
#
# class BrainMemory(pydantic.BaseModel):
# history: List[Dict] = []
# stack: List[Dict] = []
# solution: List[Dict] = []
# knowledge: List[Dict] = []
# historical_summary: str = ""
# last_history_id: str = ""
# is_dirty: bool = False
# last_talk: str = None
# llm_type: Optional[str] = None
# cacheable: bool = True
#
# def add_talk(self, msg: Message):
# msg.role = "user"
# self.add_history(msg)
# self.is_dirty = True
#
# def add_answer(self, msg: Message):
# msg.role = "assistant"
# self.add_history(msg)
# self.is_dirty = True
#
# def get_knowledge(self) -> str:
# texts = [Message(**m).content for m in self.knowledge]
# return "\n".join(texts)
#
# @staticmethod
# async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
# redis = Redis(conf=redis_conf)
# if not redis.is_valid() or not redis_key:
# return BrainMemory(llm_type=CONFIG.LLM_TYPE)
# v = await redis.get(key=redis_key)
# logger.debug(f"REDIS GET {redis_key} {v}")
# if v:
# data = json.loads(v)
# bm = BrainMemory(**data)
# bm.is_dirty = False
# return bm
# return BrainMemory(llm_type=CONFIG.LLM_TYPE)
#
# async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
# if not self.is_dirty:
# return
# redis = Redis(conf=redis_conf)
# if not redis.is_valid() or not redis_key:
# return False
# v = self.json()
# if self.cacheable:
# await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
# logger.debug(f"REDIS SET {redis_key} {v}")
# self.is_dirty = False
#
# @staticmethod
# def to_redis_key(prefix: str, user_id: str, chat_id: str):
# return f"{prefix}:{user_id}:{chat_id}"
#
# async def set_history_summary(self, history_summary, redis_key, redis_conf):
# if self.historical_summary == history_summary:
# if self.is_dirty:
# await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
# self.is_dirty = False
# return
#
# self.historical_summary = history_summary
# self.history = []
# await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
# self.is_dirty = False
#
# def add_history(self, msg: Message):
# if msg.id:
# if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
# return
# self.history.append(msg.dict())
# self.last_history_id = str(msg.id)
# self.is_dirty = True
#
# def exists(self, text) -> bool:
# for m in reversed(self.history):
# if m.get("content") == text:
# return True
# return False
#
# @staticmethod
# def to_int(v, default_value):
# try:
# return int(v)
# except:
# return default_value
#
# def pop_last_talk(self):
# v = self.last_talk
# self.last_talk = None
# return v
#
# async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
# if self.llm_type == LLMType.METAGPT.value:
# return await self._metagpt_summarize(llm=llm, max_words=max_words, keep_language=keep_language, **kwargs)
#
# return await self._openai_summarize(
# llm=llm, max_words=max_words, keep_language=keep_language, limit=limit, **kwargs
# )
#
# async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
# max_token_count = DEFAULT_MAX_TOKENS
# max_count = 100
# texts = [self.historical_summary]
# for i in self.history:
# m = Message(**i)
# texts.append(m.content)
# text = "\n".join(texts)
# text_length = len(text)
# if limit > 0 and text_length < limit:
# return text
# summary = ""
# while max_count > 0:
# if text_length < max_token_count:
# summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
# break
#
# padding_size = 20 if max_token_count > 20 else 0
# text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
# part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
# summaries = []
# for ws in text_windows:
# response = await self._get_summary(
# text=ws, llm=llm, max_words=part_max_words, keep_language=keep_language
# )
# summaries.append(response)
# if len(summaries) == 1:
# summary = summaries[0]
# break
#
# # Merged and retry
# text = "\n".join(summaries)
# text_length = len(text)
#
# max_count -= 1 # safeguard
# if summary:
# await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
# return summary
# raise openai.InvalidRequestError(message="text too long", param=None)
#
# async def _metagpt_summarize(self, max_words=200, **kwargs):
# if not self.history:
# return ""
#
# total_length = 0
# msgs = []
# for i in reversed(self.history):
# m = Message(**i)
# delta = len(m.content)
# if total_length + delta > max_words:
# left = max_words - total_length
# if left == 0:
# break
# m.content = m.content[0:left]
# msgs.append(m.dict())
# break
# msgs.append(i)
# total_length += delta
# msgs.reverse()
# self.history = msgs
# self.is_dirty = True
# await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
# self.is_dirty = False
#
# return BrainMemory.to_metagpt_history_format(self.history)
#
# @staticmethod
# def to_metagpt_history_format(history) -> str:
# mmsg = []
# for m in history:
# msg = Message(**m)
# r = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content)
# mmsg.append(r)
# return json.dumps(mmsg)
#
# @staticmethod
# async def _get_summary(text: str, llm, max_words=20, keep_language: bool = False):
# """Generate text summary"""
# if len(text) < max_words:
# return text
# if keep_language:
# command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
# else:
# command = f"Translate the above content into a summary of less than {max_words} words."
# msg = text + "\n\n" + command
# logger.debug(f"summary ask:{msg}")
# response = await llm.aask(msg=msg, system_msgs=[])
# logger.debug(f"summary rsp: {response}")
# return response
#
# async def get_title(self, llm, max_words=5, **kwargs) -> str:
# """Generate text title"""
# if self.llm_type == LLMType.METAGPT.value:
# return Message(**self.history[0]).content if self.history else "New"
#
# summary = await self.summarize(llm=llm, max_words=500)
#
# language = CONFIG.language or DEFAULT_LANGUAGE
# command = f"Translate the above summary into a {language} title of less than {max_words} words."
# summaries = [summary, command]
# msg = "\n".join(summaries)
# logger.debug(f"title ask:{msg}")
# response = await llm.aask(msg=msg, system_msgs=[])
# logger.debug(f"title rsp: {response}")
# return response
#
# async def is_related(self, text1, text2, llm):
# if self.llm_type == LLMType.METAGPT.value:
# return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
# return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
#
# @staticmethod
# async def _metagpt_is_related(**kwargs):
# return False
#
# @staticmethod
# async def _openai_is_related(text1, text2, llm, **kwargs):
# # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
# command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
# rsp = await llm.aask(msg=command, system_msgs=[])
# result = True if "TRUE" in rsp else False
# p2 = text2.replace("\n", "")
# p1 = text1.replace("\n", "")
# logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
# return result
#
# async def rewrite(self, sentence: str, context: str, llm):
# if self.llm_type == LLMType.METAGPT.value:
# return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
# return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
#
# async def _metagpt_rewrite(self, sentence: str, **kwargs):
# return sentence
#
# async def _openai_rewrite(self, sentence: str, context: str, llm, **kwargs):
# # command = (
# # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
# # )
# command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
# rsp = await llm.aask(msg=command, system_msgs=[])
# logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
# return rsp
#
# @staticmethod
# def split_texts(text: str, window_size) -> List[str]:
# """Splitting long text into sliding windows text"""
# if window_size <= 0:
# window_size = BrainMemory.DEFAULT_TOKEN_SIZE
# total_len = len(text)
# if total_len <= window_size:
# return [text]
#
# padding_size = 20 if window_size > 20 else 0
# windows = []
# idx = 0
# data_len = window_size - padding_size
# while idx < total_len:
# if window_size + idx > total_len: # 不足一个滑窗
# windows.append(text[idx:])
# break
# # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# # window_size=3, padding_size=1
# # [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# # idx=2, | idx=5 | idx=8 | ...
# w = text[idx : idx + window_size]
# windows.append(w)
# idx += data_len
#
# return windows
#
# @staticmethod
# def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
# match = re.match(pattern, input_string)
# if match:
# return match.group(1), match.group(2)
# else:
# return None, input_string
#
# def set_llm_type(self, v):
# if v and v != self.llm_type:
# self.llm_type = v
# self.is_dirty = True
#
# @property
# def is_history_available(self):
# return bool(self.history or self.historical_summary)
#
# @property
# def history_text(self):
# if self.llm_type == LLMType.METAGPT.value:
# return self._get_metagpt_history_text()
# return self._get_openai_history_text()
#
# def _get_metagpt_history_text(self):
# return BrainMemory.to_metagpt_history_format(self.history)
#
# def _get_openai_history_text(self):
# if len(self.history) == 0 and not self.historical_summary:
# return ""
# texts = [self.historical_summary] if self.historical_summary else []
# for m in self.history[:-1]:
# if isinstance(m, Dict):
# t = Message(**m).content
# elif isinstance(m, Message):
# t = m.content
# else:
# continue
# texts.append(t)
#
# return "\n".join(texts)
#
# DEFAULT_TOKEN_SIZE = 500

View file

@ -93,13 +93,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base)
RateLimiter.__init__(self, rpm=self.rpm)
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
kwargs = self._cons_kwargs(messages, timeout=timeout)
response = await self._client.chat.completions.create(**kwargs, stream=True)
# iterate through the stream of events
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" # extract the message
yield chunk_message
# async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
# kwargs = self._cons_kwargs(messages, timeout=timeout)
# response = await self._client.chat.completions.create(**kwargs, stream=True)
# # iterate through the stream of events
# async for chunk in response:
# chunk_message = chunk.choices[0].delta.content or "" # extract the message
# yield chunk_message
def __init_openai(self):
self.rpm = int(self.config.get("RPM", 10))
@ -131,9 +131,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return params
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
# create variables to collect the stream of chunks

View file

@ -70,22 +70,22 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict]) -> dict:
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
collected_content = []
usage = {}
@ -128,9 +128,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
return await self._achat_completion_stream(messages, timeout=timeout)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -19,7 +19,8 @@ def llm():
@pytest.mark.asyncio
async def test_llm_aask(llm):
assert len(await llm.aask("hello world")) > 0
rsp = await llm.aask("hello world", stream=False)
assert len(rsp) > 0
@pytest.mark.asyncio
@ -30,7 +31,8 @@ async def test_llm_aask_batch(llm):
@pytest.mark.asyncio
async def test_llm_acompletion(llm):
hello_msg = [{"role": "user", "content": "hello"}]
assert len(await llm.acompletion(hello_msg)) > 0
rsp = await llm.acompletion(hello_msg)
assert len(rsp.choices[0].message.content) > 0
assert len(await llm.acompletion_batch([hello_msg])) > 0
assert len(await llm.acompletion_batch_text([hello_msg])) > 0