fix small problem

This commit is contained in:
better629 2023-11-18 22:17:40 +08:00
parent 2c81cc3e0f
commit 6ef3b213c3
4 changed files with 15 additions and 11 deletions

View file

@ -14,7 +14,7 @@ from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
def LLM():
def LLM() -> "BaseGPTAPI":
""" initialize different LLM instance according to the key field existence"""
# TODO a little trick, can use registry to initialize LLM instance further
if CONFIG.openai_api_key and CONFIG.openai_api_key.starswith("sk-"):

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR

View file

@ -29,8 +29,12 @@ class ZhiPuModelAPI(ModelAPI):
@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
# use this method to prevent zhipu api upgrading to different version.
# and follow the GeneralAPIRequestor implemented based on openai sdk
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
# example: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
example:
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
return f"{arr[0]}/api", f"/{arr[1]}"

View file

@ -68,7 +68,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
def get_choice_text(self, resp: dict) -> str:
""" get the first text of choice from llm response """
assist_msg = resp.get("data").get("choices")[-1]
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
@ -121,16 +121,15 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
self._update_costs(usage)
full_content = "".join(collected_content)
logger.info(f"full_content: {full_content} !!")
return full_content
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_fixed(1),
# after=after_log(logger, logger.level("WARNING").name),
# retry=retry_if_exception_type(ConnectionError),
# retry_error_callback=log_and_reraise
# )
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
""" response in async with stream or non-stream mode """
if stream: