From 5d97a20e084b04b1f787fcb098a0c091ff0ac3e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 22 Dec 2023 17:43:59 +0800 Subject: [PATCH] fixbug: OpenAIGPTAPI:_achat_completion_stream --- metagpt/memory/brain_memory.py | 675 ++++++++++++++++---------------- metagpt/provider/openai_api.py | 20 +- metagpt/provider/zhipuai_api.py | 14 +- tests/metagpt/test_llm.py | 6 +- 4 files changed, 358 insertions(+), 357 deletions(-) diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 8aa3be2b6..9020c67c1 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -7,341 +7,340 @@ @Desc : Support memory for multiple tasks and multiple mainlines. Obsoleted by `utils/*_repository.py`. @Modified By: mashenquan, 2023/9/4. + redis memory cache. """ -import json -import re -from enum import Enum -from typing import Dict, List, Optional - -import openai -import pydantic - -from metagpt.config import CONFIG -from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS -from metagpt.llm import LLMType -from metagpt.logs import logger -from metagpt.schema import Message, RawMessage -from metagpt.utils.redis import Redis - - -class MessageType(Enum): - Talk = "TALK" - Solution = "SOLUTION" - Problem = "PROBLEM" - Skill = "SKILL" - Answer = "ANSWER" - - -class BrainMemory(pydantic.BaseModel): - history: List[Dict] = [] - stack: List[Dict] = [] - solution: List[Dict] = [] - knowledge: List[Dict] = [] - historical_summary: str = "" - last_history_id: str = "" - is_dirty: bool = False - last_talk: str = None - llm_type: Optional[str] = None - cacheable: bool = True - - def add_talk(self, msg: Message): - msg.role = "user" - self.add_history(msg) - self.is_dirty = True - - def add_answer(self, msg: Message): - msg.role = "assistant" - self.add_history(msg) - self.is_dirty = True - - def get_knowledge(self) -> str: - texts = [Message(**m).content for m in self.knowledge] - return "\n".join(texts) - - @staticmethod - async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": - redis = Redis(conf=redis_conf) - if not redis.is_valid() or not redis_key: - return BrainMemory(llm_type=CONFIG.LLM_TYPE) - v = await redis.get(key=redis_key) - logger.debug(f"REDIS GET {redis_key} {v}") - if v: - data = json.loads(v) - bm = BrainMemory(**data) - bm.is_dirty = False - return bm - return BrainMemory(llm_type=CONFIG.LLM_TYPE) - - async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None): - if not self.is_dirty: - return - redis = Redis(conf=redis_conf) - if not redis.is_valid() or not redis_key: - return False - v = self.json() - if self.cacheable: - await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec) - logger.debug(f"REDIS SET {redis_key} {v}") - self.is_dirty = False - - @staticmethod - def to_redis_key(prefix: str, user_id: str, chat_id: str): - return f"{prefix}:{user_id}:{chat_id}" - - async def set_history_summary(self, history_summary, redis_key, redis_conf): - if self.historical_summary == history_summary: - if self.is_dirty: - await self.dumps(redis_key=redis_key, redis_conf=redis_conf) - self.is_dirty = False - return - - self.historical_summary = history_summary - self.history = [] - await self.dumps(redis_key=redis_key, redis_conf=redis_conf) - self.is_dirty = False - - def add_history(self, msg: Message): - if msg.id: - if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1): - return - self.history.append(msg.dict()) - self.last_history_id = str(msg.id) - self.is_dirty = True - - def exists(self, text) -> bool: - for m in reversed(self.history): - if m.get("content") == text: - return True - return False - - @staticmethod - def to_int(v, default_value): - try: - return int(v) - except: - return default_value - - def pop_last_talk(self): - v = self.last_talk - self.last_talk = None - return v - - async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): - if self.llm_type == LLMType.METAGPT.value: - return await self._metagpt_summarize(llm=llm, max_words=max_words, keep_language=keep_language, **kwargs) - - return await self._openai_summarize( - llm=llm, max_words=max_words, keep_language=keep_language, limit=limit, **kwargs - ) - - async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): - max_token_count = DEFAULT_MAX_TOKENS - max_count = 100 - texts = [self.historical_summary] - for i in self.history: - m = Message(**i) - texts.append(m.content) - text = "\n".join(texts) - text_length = len(text) - if limit > 0 and text_length < limit: - return text - summary = "" - while max_count > 0: - if text_length < max_token_count: - summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language) - break - - padding_size = 20 if max_token_count > 20 else 0 - text_windows = self.split_texts(text, window_size=max_token_count - padding_size) - part_max_words = min(int(max_words / len(text_windows)) + 1, 100) - summaries = [] - for ws in text_windows: - response = await self._get_summary( - text=ws, llm=llm, max_words=part_max_words, keep_language=keep_language - ) - summaries.append(response) - if len(summaries) == 1: - summary = summaries[0] - break - - # Merged and retry - text = "\n".join(summaries) - text_length = len(text) - - max_count -= 1 # safeguard - if summary: - await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) - return summary - raise openai.InvalidRequestError(message="text too long", param=None) - - async def _metagpt_summarize(self, max_words=200, **kwargs): - if not self.history: - return "" - - total_length = 0 - msgs = [] - for i in reversed(self.history): - m = Message(**i) - delta = len(m.content) - if total_length + delta > max_words: - left = max_words - total_length - if left == 0: - break - m.content = m.content[0:left] - msgs.append(m.dict()) - break - msgs.append(i) - total_length += delta - msgs.reverse() - self.history = msgs - self.is_dirty = True - await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF) - self.is_dirty = False - - return BrainMemory.to_metagpt_history_format(self.history) - - @staticmethod - def to_metagpt_history_format(history) -> str: - mmsg = [] - for m in history: - msg = Message(**m) - r = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content) - mmsg.append(r) - return json.dumps(mmsg) - - @staticmethod - async def _get_summary(text: str, llm, max_words=20, keep_language: bool = False): - """Generate text summary""" - if len(text) < max_words: - return text - if keep_language: - command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." - else: - command = f"Translate the above content into a summary of less than {max_words} words." - msg = text + "\n\n" + command - logger.debug(f"summary ask:{msg}") - response = await llm.aask(msg=msg, system_msgs=[]) - logger.debug(f"summary rsp: {response}") - return response - - async def get_title(self, llm, max_words=5, **kwargs) -> str: - """Generate text title""" - if self.llm_type == LLMType.METAGPT.value: - return Message(**self.history[0]).content if self.history else "New" - - summary = await self.summarize(llm=llm, max_words=500) - - language = CONFIG.language or DEFAULT_LANGUAGE - command = f"Translate the above summary into a {language} title of less than {max_words} words." - summaries = [summary, command] - msg = "\n".join(summaries) - logger.debug(f"title ask:{msg}") - response = await llm.aask(msg=msg, system_msgs=[]) - logger.debug(f"title rsp: {response}") - return response - - async def is_related(self, text1, text2, llm): - if self.llm_type == LLMType.METAGPT.value: - return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) - return await self._openai_is_related(text1=text1, text2=text2, llm=llm) - - @staticmethod - async def _metagpt_is_related(**kwargs): - return False - - @staticmethod - async def _openai_is_related(text1, text2, llm, **kwargs): - # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." - command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear." - rsp = await llm.aask(msg=command, system_msgs=[]) - result = True if "TRUE" in rsp else False - p2 = text2.replace("\n", "") - p1 = text1.replace("\n", "") - logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n") - return result - - async def rewrite(self, sentence: str, context: str, llm): - if self.llm_type == LLMType.METAGPT.value: - return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) - return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) - - async def _metagpt_rewrite(self, sentence: str, **kwargs): - return sentence - - async def _openai_rewrite(self, sentence: str, context: str, llm, **kwargs): - # command = ( - # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" - # ) - command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}" - rsp = await llm.aask(msg=command, system_msgs=[]) - logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n") - return rsp - - @staticmethod - def split_texts(text: str, window_size) -> List[str]: - """Splitting long text into sliding windows text""" - if window_size <= 0: - window_size = BrainMemory.DEFAULT_TOKEN_SIZE - total_len = len(text) - if total_len <= window_size: - return [text] - - padding_size = 20 if window_size > 20 else 0 - windows = [] - idx = 0 - data_len = window_size - padding_size - while idx < total_len: - if window_size + idx > total_len: # 不足一个滑窗 - windows.append(text[idx:]) - break - # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] - # window_size=3, padding_size=1: - # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... - # idx=2, | idx=5 | idx=8 | ... - w = text[idx : idx + window_size] - windows.append(w) - idx += data_len - - return windows - - @staticmethod - def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"): - match = re.match(pattern, input_string) - if match: - return match.group(1), match.group(2) - else: - return None, input_string - - def set_llm_type(self, v): - if v and v != self.llm_type: - self.llm_type = v - self.is_dirty = True - - @property - def is_history_available(self): - return bool(self.history or self.historical_summary) - - @property - def history_text(self): - if self.llm_type == LLMType.METAGPT.value: - return self._get_metagpt_history_text() - return self._get_openai_history_text() - - def _get_metagpt_history_text(self): - return BrainMemory.to_metagpt_history_format(self.history) - - def _get_openai_history_text(self): - if len(self.history) == 0 and not self.historical_summary: - return "" - texts = [self.historical_summary] if self.historical_summary else [] - for m in self.history[:-1]: - if isinstance(m, Dict): - t = Message(**m).content - elif isinstance(m, Message): - t = m.content - else: - continue - texts.append(t) - - return "\n".join(texts) - - DEFAULT_TOKEN_SIZE = 500 +# import json +# import re +# from enum import Enum +# from typing import Dict, List, Optional +# +# import openai +# import pydantic +# +# from metagpt.config import CONFIG +# from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS +# from metagpt.logs import logger +# from metagpt.schema import Message, RawMessage +# from metagpt.utils.redis import Redis +# +# +# class MessageType(Enum): +# Talk = "TALK" +# Solution = "SOLUTION" +# Problem = "PROBLEM" +# Skill = "SKILL" +# Answer = "ANSWER" +# +# +# class BrainMemory(pydantic.BaseModel): +# history: List[Dict] = [] +# stack: List[Dict] = [] +# solution: List[Dict] = [] +# knowledge: List[Dict] = [] +# historical_summary: str = "" +# last_history_id: str = "" +# is_dirty: bool = False +# last_talk: str = None +# llm_type: Optional[str] = None +# cacheable: bool = True +# +# def add_talk(self, msg: Message): +# msg.role = "user" +# self.add_history(msg) +# self.is_dirty = True +# +# def add_answer(self, msg: Message): +# msg.role = "assistant" +# self.add_history(msg) +# self.is_dirty = True +# +# def get_knowledge(self) -> str: +# texts = [Message(**m).content for m in self.knowledge] +# return "\n".join(texts) +# +# @staticmethod +# async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": +# redis = Redis(conf=redis_conf) +# if not redis.is_valid() or not redis_key: +# return BrainMemory(llm_type=CONFIG.LLM_TYPE) +# v = await redis.get(key=redis_key) +# logger.debug(f"REDIS GET {redis_key} {v}") +# if v: +# data = json.loads(v) +# bm = BrainMemory(**data) +# bm.is_dirty = False +# return bm +# return BrainMemory(llm_type=CONFIG.LLM_TYPE) +# +# async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None): +# if not self.is_dirty: +# return +# redis = Redis(conf=redis_conf) +# if not redis.is_valid() or not redis_key: +# return False +# v = self.json() +# if self.cacheable: +# await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec) +# logger.debug(f"REDIS SET {redis_key} {v}") +# self.is_dirty = False +# +# @staticmethod +# def to_redis_key(prefix: str, user_id: str, chat_id: str): +# return f"{prefix}:{user_id}:{chat_id}" +# +# async def set_history_summary(self, history_summary, redis_key, redis_conf): +# if self.historical_summary == history_summary: +# if self.is_dirty: +# await self.dumps(redis_key=redis_key, redis_conf=redis_conf) +# self.is_dirty = False +# return +# +# self.historical_summary = history_summary +# self.history = [] +# await self.dumps(redis_key=redis_key, redis_conf=redis_conf) +# self.is_dirty = False +# +# def add_history(self, msg: Message): +# if msg.id: +# if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1): +# return +# self.history.append(msg.dict()) +# self.last_history_id = str(msg.id) +# self.is_dirty = True +# +# def exists(self, text) -> bool: +# for m in reversed(self.history): +# if m.get("content") == text: +# return True +# return False +# +# @staticmethod +# def to_int(v, default_value): +# try: +# return int(v) +# except: +# return default_value +# +# def pop_last_talk(self): +# v = self.last_talk +# self.last_talk = None +# return v +# +# async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): +# if self.llm_type == LLMType.METAGPT.value: +# return await self._metagpt_summarize(llm=llm, max_words=max_words, keep_language=keep_language, **kwargs) +# +# return await self._openai_summarize( +# llm=llm, max_words=max_words, keep_language=keep_language, limit=limit, **kwargs +# ) +# +# async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): +# max_token_count = DEFAULT_MAX_TOKENS +# max_count = 100 +# texts = [self.historical_summary] +# for i in self.history: +# m = Message(**i) +# texts.append(m.content) +# text = "\n".join(texts) +# text_length = len(text) +# if limit > 0 and text_length < limit: +# return text +# summary = "" +# while max_count > 0: +# if text_length < max_token_count: +# summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language) +# break +# +# padding_size = 20 if max_token_count > 20 else 0 +# text_windows = self.split_texts(text, window_size=max_token_count - padding_size) +# part_max_words = min(int(max_words / len(text_windows)) + 1, 100) +# summaries = [] +# for ws in text_windows: +# response = await self._get_summary( +# text=ws, llm=llm, max_words=part_max_words, keep_language=keep_language +# ) +# summaries.append(response) +# if len(summaries) == 1: +# summary = summaries[0] +# break +# +# # Merged and retry +# text = "\n".join(summaries) +# text_length = len(text) +# +# max_count -= 1 # safeguard +# if summary: +# await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) +# return summary +# raise openai.InvalidRequestError(message="text too long", param=None) +# +# async def _metagpt_summarize(self, max_words=200, **kwargs): +# if not self.history: +# return "" +# +# total_length = 0 +# msgs = [] +# for i in reversed(self.history): +# m = Message(**i) +# delta = len(m.content) +# if total_length + delta > max_words: +# left = max_words - total_length +# if left == 0: +# break +# m.content = m.content[0:left] +# msgs.append(m.dict()) +# break +# msgs.append(i) +# total_length += delta +# msgs.reverse() +# self.history = msgs +# self.is_dirty = True +# await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF) +# self.is_dirty = False +# +# return BrainMemory.to_metagpt_history_format(self.history) +# +# @staticmethod +# def to_metagpt_history_format(history) -> str: +# mmsg = [] +# for m in history: +# msg = Message(**m) +# r = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content) +# mmsg.append(r) +# return json.dumps(mmsg) +# +# @staticmethod +# async def _get_summary(text: str, llm, max_words=20, keep_language: bool = False): +# """Generate text summary""" +# if len(text) < max_words: +# return text +# if keep_language: +# command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." +# else: +# command = f"Translate the above content into a summary of less than {max_words} words." +# msg = text + "\n\n" + command +# logger.debug(f"summary ask:{msg}") +# response = await llm.aask(msg=msg, system_msgs=[]) +# logger.debug(f"summary rsp: {response}") +# return response +# +# async def get_title(self, llm, max_words=5, **kwargs) -> str: +# """Generate text title""" +# if self.llm_type == LLMType.METAGPT.value: +# return Message(**self.history[0]).content if self.history else "New" +# +# summary = await self.summarize(llm=llm, max_words=500) +# +# language = CONFIG.language or DEFAULT_LANGUAGE +# command = f"Translate the above summary into a {language} title of less than {max_words} words." +# summaries = [summary, command] +# msg = "\n".join(summaries) +# logger.debug(f"title ask:{msg}") +# response = await llm.aask(msg=msg, system_msgs=[]) +# logger.debug(f"title rsp: {response}") +# return response +# +# async def is_related(self, text1, text2, llm): +# if self.llm_type == LLMType.METAGPT.value: +# return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) +# return await self._openai_is_related(text1=text1, text2=text2, llm=llm) +# +# @staticmethod +# async def _metagpt_is_related(**kwargs): +# return False +# +# @staticmethod +# async def _openai_is_related(text1, text2, llm, **kwargs): +# # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." +# command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear." +# rsp = await llm.aask(msg=command, system_msgs=[]) +# result = True if "TRUE" in rsp else False +# p2 = text2.replace("\n", "") +# p1 = text1.replace("\n", "") +# logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n") +# return result +# +# async def rewrite(self, sentence: str, context: str, llm): +# if self.llm_type == LLMType.METAGPT.value: +# return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) +# return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) +# +# async def _metagpt_rewrite(self, sentence: str, **kwargs): +# return sentence +# +# async def _openai_rewrite(self, sentence: str, context: str, llm, **kwargs): +# # command = ( +# # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" +# # ) +# command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}" +# rsp = await llm.aask(msg=command, system_msgs=[]) +# logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n") +# return rsp +# +# @staticmethod +# def split_texts(text: str, window_size) -> List[str]: +# """Splitting long text into sliding windows text""" +# if window_size <= 0: +# window_size = BrainMemory.DEFAULT_TOKEN_SIZE +# total_len = len(text) +# if total_len <= window_size: +# return [text] +# +# padding_size = 20 if window_size > 20 else 0 +# windows = [] +# idx = 0 +# data_len = window_size - padding_size +# while idx < total_len: +# if window_size + idx > total_len: # 不足一个滑窗 +# windows.append(text[idx:]) +# break +# # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] +# # window_size=3, padding_size=1: +# # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... +# # idx=2, | idx=5 | idx=8 | ... +# w = text[idx : idx + window_size] +# windows.append(w) +# idx += data_len +# +# return windows +# +# @staticmethod +# def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"): +# match = re.match(pattern, input_string) +# if match: +# return match.group(1), match.group(2) +# else: +# return None, input_string +# +# def set_llm_type(self, v): +# if v and v != self.llm_type: +# self.llm_type = v +# self.is_dirty = True +# +# @property +# def is_history_available(self): +# return bool(self.history or self.historical_summary) +# +# @property +# def history_text(self): +# if self.llm_type == LLMType.METAGPT.value: +# return self._get_metagpt_history_text() +# return self._get_openai_history_text() +# +# def _get_metagpt_history_text(self): +# return BrainMemory.to_metagpt_history_format(self.history) +# +# def _get_openai_history_text(self): +# if len(self.history) == 0 and not self.historical_summary: +# return "" +# texts = [self.historical_summary] if self.historical_summary else [] +# for m in self.history[:-1]: +# if isinstance(m, Dict): +# t = Message(**m).content +# elif isinstance(m, Message): +# t = m.content +# else: +# continue +# texts.append(t) +# +# return "\n".join(texts) +# +# DEFAULT_TOKEN_SIZE = 500 diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index ca130ce15..d5d77c5ec 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -93,13 +93,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._client = AsyncOpenAI(api_key=CONFIG.openai_api_key, base_url=CONFIG.openai_api_base) RateLimiter.__init__(self, rpm=self.rpm) - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: - kwargs = self._cons_kwargs(messages, timeout=timeout) - response = await self._client.chat.completions.create(**kwargs, stream=True) - # iterate through the stream of events - async for chunk in response: - chunk_message = chunk.choices[0].delta.content or "" # extract the message - yield chunk_message + # async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: + # kwargs = self._cons_kwargs(messages, timeout=timeout) + # response = await self._client.chat.completions.create(**kwargs, stream=True) + # # iterate through the stream of events + # async for chunk in response: + # chunk_message = chunk.choices[0].delta.content or "" # extract the message + # yield chunk_message def __init_openai(self): self.rpm = int(self.config.get("RPM", 10)) @@ -131,9 +131,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return params - async def _achat_completion_stream(self, messages: list[dict]) -> str: - response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( - **self._cons_kwargs(messages), stream=True + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: + response: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create( + **self._cons_kwargs(messages, timeout=timeout), stream=True ) # create variables to collect the stream of chunks diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 54f0ddcbb..4a2cae51d 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -70,22 +70,22 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): assert assist_msg["role"] == "assistant" return assist_msg.get("content") - def completion(self, messages: list[dict]) -> dict: + def completion(self, messages: list[dict], timeout=3) -> dict: resp = self.llm.invoke(**self._const_kwargs(messages)) usage = resp.get("data").get("usage") self._update_costs(usage) return resp - async def _achat_completion(self, messages: list[dict]) -> dict: + async def _achat_completion(self, messages: list[dict], timeout=3) -> dict: resp = await self.llm.ainvoke(**self._const_kwargs(messages)) usage = resp.get("data").get("usage") self._update_costs(usage) return resp - async def acompletion(self, messages: list[dict]) -> dict: - return await self._achat_completion(messages) + async def acompletion(self, messages: list[dict], timeout=3) -> dict: + return await self._achat_completion(messages, timeout=timeout) - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response = await self.llm.asse_invoke(**self._const_kwargs(messages)) collected_content = [] usage = {} @@ -128,9 +128,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """response in async with stream or non-stream mode""" if stream: - return await self._achat_completion_stream(messages) + return await self._achat_completion_stream(messages, timeout=timeout) resp = await self._achat_completion(messages) return self.get_choice_text(resp) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index d972e55c0..31e6c2b24 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -19,7 +19,8 @@ def llm(): @pytest.mark.asyncio async def test_llm_aask(llm): - assert len(await llm.aask("hello world")) > 0 + rsp = await llm.aask("hello world", stream=False) + assert len(rsp) > 0 @pytest.mark.asyncio @@ -30,7 +31,8 @@ async def test_llm_aask_batch(llm): @pytest.mark.asyncio async def test_llm_acompletion(llm): hello_msg = [{"role": "user", "content": "hello"}] - assert len(await llm.acompletion(hello_msg)) > 0 + rsp = await llm.acompletion(hello_msg) + assert len(rsp.choices[0].message.content) > 0 assert len(await llm.acompletion_batch([hello_msg])) > 0 assert len(await llm.acompletion_batch_text([hello_msg])) > 0