From 96ce036bd4b6fd01402a11e1e9f5abb274e93e77 Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 18 Nov 2023 21:34:14 +0800 Subject: [PATCH 1/8] add zhipuai api with extra async invoke methods --- metagpt/config.py | 9 +- metagpt/llm.py | 26 ++-- metagpt/provider/zhipuai/__init__.py | 3 + metagpt/provider/zhipuai/async_sse_client.py | 77 ++++++++++ metagpt/provider/zhipuai/zhipu_model_api.py | 76 ++++++++++ metagpt/provider/zhipuai_api.py | 139 +++++++++++++++++++ metagpt/utils/token_counter.py | 6 +- 7 files changed, 323 insertions(+), 13 deletions(-) create mode 100644 metagpt/provider/zhipuai/__init__.py create mode 100644 metagpt/provider/zhipuai/async_sse_client.py create mode 100644 metagpt/provider/zhipuai/zhipu_model_api.py create mode 100644 metagpt/provider/zhipuai_api.py diff --git a/metagpt/config.py b/metagpt/config.py index 27455d38d..3f9e742bd 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -45,10 +45,11 @@ class Config(metaclass=Singleton): self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") - if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and ( - not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key - ): - raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first") + self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ + (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ + (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key): + raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: diff --git a/metagpt/llm.py b/metagpt/llm.py index e6f815950..1f6a6bb1a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,14 +6,24 @@ @File : llm.py """ +from metagpt.logs import logger +from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 as Claude -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM +from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.spark_api import SparkAPI -DEFAULT_LLM = LLM() -CLAUDE_LLM = Claude() -async def ai_func(prompt): - """使用LLM进行QA - QA with LLMs - """ - return await DEFAULT_LLM.aask(prompt) +def LLM(): + """ initialize different LLM instance according to the key field existence""" + # TODO a little trick, can use registry to initialize LLM instance further + if CONFIG.openai_api_key and CONFIG.openai_api_key.starswith("sk-"): + llm = OpenAIGPTAPI() + elif CONFIG.claude_api_key: + llm = Claude() + elif CONFIG.spark_api_key: + llm = SparkAPI() + elif CONFIG.zhipuai_api_key: + llm = ZhiPuAIGPTAPI() + + return llm diff --git a/metagpt/provider/zhipuai/__init__.py b/metagpt/provider/zhipuai/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/provider/zhipuai/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/provider/zhipuai/async_sse_client.py b/metagpt/provider/zhipuai/async_sse_client.py new file mode 100644 index 000000000..7a4275982 --- /dev/null +++ b/metagpt/provider/zhipuai/async_sse_client.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : async_sse_client to make keep the use of Event to access response + +from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR + + +class AsyncSSEClient(SSEClient): + + async def _aread(self): + data = b"" + async for chunk in self._event_source: + for line in chunk.splitlines(True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + async def async_events(self): + async for chunk in self._aread(): + event = Event() + # Split before decoding so splitlines() only uses \r and \n + for line in chunk.splitlines(): + # Decode the line. + line = line.decode(self._char_enc) + + # Lines starting with a separator are comments and are to be + # ignored. + if not line.strip() or line.startswith(_FIELD_SEPARATOR): + continue + + data = line.split(_FIELD_SEPARATOR, 1) + field = data[0] + + # Ignore unknown fields. + if field not in event.__dict__: + self._logger.debug( + "Saw invalid field %s while parsing " "Server Side Event", field + ) + continue + + if len(data) > 1: + # From the spec: + # "If value starts with a single U+0020 SPACE character, + # remove it from value." + if data[1].startswith(" "): + value = data[1][1:] + else: + value = data[1] + else: + # If no value is present after the separator, + # assume an empty value. + value = "" + + # The data field may come over multiple lines and their values + # are concatenated with each other. + if field == "data": + event.__dict__[field] += value + "\n" + else: + event.__dict__[field] = value + + # Events with no data are not dispatched. + if not event.data: + continue + + # If the data field ends with a newline, remove it. + if event.data.endswith("\n"): + event.data = event.data[0:-1] + + # Empty event names default to 'message' + event.event = event.event or "message" + + # Dispatch the event + self._logger.debug("Dispatching %s...", event) + yield event diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py new file mode 100644 index 000000000..f1fd6f2e2 --- /dev/null +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : zhipu model api to support sync & async for invoke & sse_invoke + +import zhipuai +from zhipuai.model_api.api import ModelAPI, InvokeType +from zhipuai.utils.http_client import headers as zhipuai_default_headers +from zhipuai.utils.sse_client import SSEClient + +from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient +from metagpt.provider.general_api_requestor import GeneralAPIRequestor + + +class ZhiPuModelAPI(ModelAPI): + + @classmethod + def get_header(cls) -> dict: + token = cls._generate_token() + zhipuai_default_headers.update({"Authorization": token}) + return zhipuai_default_headers + + @classmethod + def get_sse_header(cls) -> dict: + token = cls._generate_token() + headers = { + "Authorization": token + } + return headers + + @classmethod + def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs): + # use this method to prevent zhipu api upgrading to different version. + zhipu_api_url = cls._build_api_url(kwargs, invoke_type) + # example: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} + arr = zhipu_api_url.split("/api/") + # ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke") + return f"{arr[0]}/api", f"/{arr[1]}" + + @classmethod + async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs): + # TODO to make the async request to be more generic for models in http mode. + assert method in ["post", "get"] + + api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs) + requester = GeneralAPIRequestor(api_base=api_base) + result, _, api_key = await requester.arequest( + method=method, + url=url, + headers=headers, + stream=stream, + params=kwargs, + request_timeout=zhipuai.api_timeout_seconds + ) + + return result + + @classmethod + async def ainvoke(cls, **kwargs) -> dict: + """ async invoke different from raw method `async_invoke` which get the final result by task_id""" + headers = cls.get_header() + resp = await cls.arequest(invoke_type=InvokeType.SYNC, + stream=False, + method="post", + headers=headers, + kwargs=kwargs) + return resp + + @classmethod + async def asse_invoke(cls, **kwargs) -> AsyncSSEClient: + """ async sse_invoke """ + headers = cls.get_sse_header() + return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE, + stream=True, + method="post", + headers=headers, + kwargs=kwargs)) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py new file mode 100644 index 000000000..4e8e6b760 --- /dev/null +++ b/metagpt/provider/zhipuai_api.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk + +from enum import Enum +import json +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +from requests import ConnectionError + +import zhipuai + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import CostManager, log_and_reraise +from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI +from metagpt.utils.ahttp_client import astream + + +class ZhiPuEvent(Enum): + ADD = "add" + ERROR = "error" + INTERRUPTED = "interrupted" + FINISH = "finish" + + +class ZhiPuAIGPTAPI(BaseGPTAPI): + """ + Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` + From now, there is only one model named `chatglm_turbo` + """ + + use_system_prompt: bool = False # zhipuai has no system prompt when use api + + def __init__(self): + self.__init_zhipuai(CONFIG) + self.llm = ZhiPuModelAPI + self.model = "chatglm_turbo" # so far only one model, just use it + self._cost_manager = CostManager() + + def __init_zhipuai(self, config: CONFIG): + assert config.zhipuai_api_key + zhipuai.api_key = config.zhipuai_api_key + + def _const_kwargs(self, messages: list[dict]) -> dict: + kwargs = { + "model": self.model, + "prompt": messages, + "temperature": 0.3 + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("zhipuai updats costs failed!", e) + + def get_choice_text(self, resp: dict) -> str: + """ get the first text of choice from llm response """ + assist_msg = resp.get("data").get("choices")[-1] + assert assist_msg["role"] == "assistant" + return assist_msg.get("content") + + def completion(self, messages: list[dict]) -> dict: + resp = self.llm.invoke(**self._const_kwargs(messages)) + usage = resp.get("data").get("usage") + self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> dict: + resp = await self.llm.ainvoke(**self._const_kwargs(messages)) + usage = resp.get("data").get("usage") + self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + response = await self.llm.asse_invoke(**self._const_kwargs(messages)) + collected_content = [] + usage = {} + async for event in response.async_events(): + if event.event == ZhiPuEvent.ADD.value: + content = event.data + collected_content.append(content) + print(content, end="") + elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value: + content = event.data + logger.error(f"event error: {content}", end="") + collected_content.append([content]) + elif event.event == ZhiPuEvent.FINISH.value: + """ + event.meta + { + "task_status":"SUCCESS", + "usage":{ + "completion_tokens":351, + "prompt_tokens":595, + "total_tokens":946 + }, + "task_id":"xx", + "request_id":"xxx" + } + """ + meta = json.loads(event.meta) + usage = meta.get("usage") + else: + print(f"zhipuapi else event: {event.data}", end="") + + self._update_costs(usage) + full_content = "".join(collected_content) + logger.info(f"full_content: {full_content} !!") + return full_content + + # @retry( + # stop=stop_after_attempt(3), + # wait=wait_fixed(1), + # after=after_log(logger, logger.level("WARNING").name), + # retry=retry_if_exception_type(ConnectionError), + # retry_error_callback=log_and_reraise + # ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a5a65803a..1af96f272 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -22,6 +22,7 @@ TOKEN_COSTS = { "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens } @@ -37,6 +38,7 @@ TOKEN_MAX = { "gpt-4-32k-0314": 32768, "gpt-4-0613": 8192, "text-embedding-ada-002": 8192, + "chatglm_turbo": 32768 } @@ -68,7 +70,9 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): return count_message_tokens(messages, model="gpt-4-0613") else: raise NotImplementedError( - f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + f"num_tokens_from_messages() is not implemented for model {model}. " + f"See https://github.com/openai/openai-python/blob/main/chatml.md " + f"for information on how messages are converted to tokens." ) num_tokens = 0 for message in messages: From 66f27ca2d599c7e6420ac11d4f44c2d4668d5d28 Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 18 Nov 2023 21:35:41 +0800 Subject: [PATCH 2/8] add General Async API for http-based LLM model --- metagpt/provider/general_api_requestor.py | 64 +++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 metagpt/provider/general_api_requestor.py diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py new file mode 100644 index 000000000..e4e5f0f96 --- /dev/null +++ b/metagpt/provider/general_api_requestor.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : General Async API for http-based LLM model + +from typing import AsyncGenerator, Tuple, Union, Optional, Literal +import aiohttp +import asyncio + +from openai.api_requestor import APIRequestor, aiohttp_session + +from metagpt.logs import logger + + +class GeneralAPIRequestor(APIRequestor): + """ + usage + # full_url = "{api_base}{url}" + requester = GeneralAPIRequestor(api_base=api_base) + result, _, api_key = await requester.arequest( + method=method, + url=url, + headers=headers, + stream=stream, + params=kwargs, + request_timeout=120 + ) + """ + + def _interpret_response_line( + self, rbody: str, rcode: int, rheaders, stream: bool + ) -> str: + # just do nothing to meet the APIRequestor process and return the raw data + # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases. + + return rbody + + async def _interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]: + if stream and "text/event-stream" in result.headers.get("Content-Type", ""): + logger.warning("stream") + return ( + self._interpret_response_line( + line, result.status, result.headers, stream=True + ) + async for line in result.content + ), True + else: + logger.warning("non stream") + try: + await result.read() + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: + raise TimeoutError("Request timed out") from e + except aiohttp.ClientError as exp: + logger.warning(f"response: {result.content}, exp: {exp}") + return ( + self._interpret_response_line( + await result.read(), # let the caller to decode the msg + result.status, + result.headers, + stream=False, + ), + False, + ) From 8e201384bf67020e6cf8efcf248dc2fd71977e57 Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 18 Nov 2023 21:28:49 +0800 Subject: [PATCH 3/8] add use_system_prompt to judge if need to add system_prompt part --- config/config.yaml | 3 +++ metagpt/provider/base_chatbot.py | 1 + metagpt/provider/base_gpt_api.py | 8 +++++--- requirements.txt | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index b2c50991d..fc6961f9e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -31,6 +31,9 @@ RPM: 10 #DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME" #DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID" +#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" +ZHIPUAI_API_KEY: "YOUR_API_KEY" + #### for Search ## Supported values: serpapi/google/serper/ddg diff --git a/metagpt/provider/base_chatbot.py b/metagpt/provider/base_chatbot.py index abdf423f4..72e6c94f9 100644 --- a/metagpt/provider/base_chatbot.py +++ b/metagpt/provider/base_chatbot.py @@ -13,6 +13,7 @@ from dataclasses import dataclass class BaseChatbot(ABC): """Abstract GPT class""" mode: str = "API" + use_system_prompt: bool = True @abstractmethod def ask(self, msg: str) -> str: diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index de61167b9..3a157b63e 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -32,15 +32,17 @@ class BaseGPTAPI(BaseChatbot): return self._system_msg(self.system_prompt) def ask(self, msg: str) -> str: - message = [self._default_system_msg(), self._user_msg(msg)] + message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)] rsp = self.completion(message) return self.get_choice_text(rsp) async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: if system_msgs: - message = self._system_msgs(system_msgs) + [self._user_msg(msg)] + message = self._system_msgs(system_msgs) + [self._user_msg(msg)] if self.use_system_prompt \ + else [self._user_msg(msg)] else: - message = [self._default_system_msg(), self._user_msg(msg)] + message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt \ + else [self._user_msg(msg)] rsp = await self.acompletion_text(message, stream=True) logger.debug(message) # logger.debug(rsp) diff --git a/requirements.txt b/requirements.txt index 093298775..66e8dce02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,3 +44,4 @@ ta==0.10.2 semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 +zhipuai==1.0.7 From 2c81cc3e0f976d9b3774761d72036c60aa824866 Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 18 Nov 2023 22:00:52 +0800 Subject: [PATCH 4/8] add zhipuai_api unittest and remove useless log --- metagpt/provider/general_api_requestor.py | 2 - metagpt/provider/zhipuai/zhipu_model_api.py | 1 - tests/metagpt/provider/test_zhipuai_api.py | 47 +++++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 tests/metagpt/provider/test_zhipuai_api.py diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index e4e5f0f96..169b7c146 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -38,7 +38,6 @@ class GeneralAPIRequestor(APIRequestor): self, result: aiohttp.ClientResponse, stream: bool ) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]: if stream and "text/event-stream" in result.headers.get("Content-Type", ""): - logger.warning("stream") return ( self._interpret_response_line( line, result.status, result.headers, stream=True @@ -46,7 +45,6 @@ class GeneralAPIRequestor(APIRequestor): async for line in result.content ), True else: - logger.warning("non stream") try: await result.read() except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index f1fd6f2e2..e1d52061d 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -5,7 +5,6 @@ import zhipuai from zhipuai.model_api.api import ModelAPI, InvokeType from zhipuai.utils.http_client import headers as zhipuai_default_headers -from zhipuai.utils.sse_client import SSEClient from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient from metagpt.provider.general_api_requestor import GeneralAPIRequestor diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py new file mode 100644 index 000000000..6a0c70de5 --- /dev/null +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of ZhiPuAIGPTAPI + +import pytest + +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI + + +default_resp = { + "code": 200, + "data": { + "choices": [ + {"role": "assistant", "content": "I'm chatglm-turbo"} + ] + } +} + +messages = [ + {"role": "user", "content": "who are you"} +] + + +def mock_llm_ask(self, messages: list[dict]) -> dict: + return default_resp + + +def test_zhipuai_completion(mocker): + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask) + + resp = ZhiPuAIGPTAPI().completion(messages) + assert resp["code"] == 200 + assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] + + +async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict: + return default_resp + + +@pytest.mark.asyncio +async def test_zhipuai_acompletion(mocker): + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask) + + resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False) + + assert resp["code"] == 200 + assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] From 6ef3b213c3504a3d2e5f4f5ffc04b47e70dc78fd Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 18 Nov 2023 22:17:40 +0800 Subject: [PATCH 5/8] fix small problem --- metagpt/llm.py | 2 +- metagpt/provider/zhipuai/async_sse_client.py | 1 + metagpt/provider/zhipuai/zhipu_model_api.py | 6 +++++- metagpt/provider/zhipuai_api.py | 17 ++++++++--------- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index 1f6a6bb1a..e9b80d7a8 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -14,7 +14,7 @@ from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI from metagpt.provider.spark_api import SparkAPI -def LLM(): +def LLM() -> "BaseGPTAPI": """ initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further if CONFIG.openai_api_key and CONFIG.openai_api_key.starswith("sk-"): diff --git a/metagpt/provider/zhipuai/async_sse_client.py b/metagpt/provider/zhipuai/async_sse_client.py index 7a4275982..b819fdc63 100644 --- a/metagpt/provider/zhipuai/async_sse_client.py +++ b/metagpt/provider/zhipuai/async_sse_client.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : async_sse_client to make keep the use of Event to access response +# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py` from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index e1d52061d..618b2e865 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -29,8 +29,12 @@ class ZhiPuModelAPI(ModelAPI): @classmethod def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs): # use this method to prevent zhipu api upgrading to different version. + # and follow the GeneralAPIRequestor implemented based on openai sdk zhipu_api_url = cls._build_api_url(kwargs, invoke_type) - # example: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} + """ + example: + zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} + """ arr = zhipu_api_url.split("/api/") # ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke") return f"{arr[0]}/api", f"/{arr[1]}" diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 4e8e6b760..2ad1944c2 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -68,7 +68,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): def get_choice_text(self, resp: dict) -> str: """ get the first text of choice from llm response """ - assist_msg = resp.get("data").get("choices")[-1] + assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] assert assist_msg["role"] == "assistant" return assist_msg.get("content") @@ -121,16 +121,15 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): self._update_costs(usage) full_content = "".join(collected_content) - logger.info(f"full_content: {full_content} !!") return full_content - # @retry( - # stop=stop_after_attempt(3), - # wait=wait_fixed(1), - # after=after_log(logger, logger.level("WARNING").name), - # retry=retry_if_exception_type(ConnectionError), - # retry_error_callback=log_and_reraise - # ) + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: """ response in async with stream or non-stream mode """ if stream: From f8f938f333ff781ccc3755c798200df2e1272f87 Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 20 Nov 2023 14:46:31 +0800 Subject: [PATCH 6/8] fix config when open llm model hosts as openai interface --- config/config.yaml | 4 ++-- metagpt/llm.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index fc6961f9e..bed67083c 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -7,7 +7,7 @@ ## Or, you can configure OPENAI_PROXY to access official OPENAI_API_BASE. OPENAI_API_BASE: "https://api.openai.com/v1" #OPENAI_PROXY: "http://127.0.0.1:8118" -#OPENAI_API_KEY: "YOUR_API_KEY" +#OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model OPENAI_API_MODEL: "gpt-4" MAX_TOKENS: 1500 RPM: 10 @@ -32,7 +32,7 @@ RPM: 10 #DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID" #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" -ZHIPUAI_API_KEY: "YOUR_API_KEY" +# ZHIPUAI_API_KEY: "YOUR_API_KEY" #### for Search diff --git a/metagpt/llm.py b/metagpt/llm.py index e9b80d7a8..13e5a56e0 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -17,13 +17,15 @@ from metagpt.provider.spark_api import SparkAPI def LLM() -> "BaseGPTAPI": """ initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further - if CONFIG.openai_api_key and CONFIG.openai_api_key.starswith("sk-"): + if CONFIG.openai_api_key and CONFIG.openai_api_key.startswith("sk-"): llm = OpenAIGPTAPI() elif CONFIG.claude_api_key: llm = Claude() elif CONFIG.spark_api_key: llm = SparkAPI() - elif CONFIG.zhipuai_api_key: + elif CONFIG.zhipuai_api_key and CONFIG.zhipuai_api_key != "YOUR_API_KEY": llm = ZhiPuAIGPTAPI() + else: + raise RuntimeError("You should config a LLM configuration first") return llm From 9d1d8a9fe4c7946c3c36995d6584491dfd17f97c Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 21 Nov 2023 14:10:04 +0800 Subject: [PATCH 7/8] fix --- metagpt/provider/zhipuai_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 2ad1944c2..064ec35ba 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -20,7 +20,6 @@ from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.openai_api import CostManager, log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI -from metagpt.utils.ahttp_client import astream class ZhiPuEvent(Enum): From ded2044be7eb244ad0549873b5041945a8da674d Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 21 Nov 2023 14:50:31 +0800 Subject: [PATCH 8/8] rm useless func --- metagpt/provider/general_api_requestor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 169b7c146..150f2f1e0 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -6,7 +6,7 @@ from typing import AsyncGenerator, Tuple, Union, Optional, Literal import aiohttp import asyncio -from openai.api_requestor import APIRequestor, aiohttp_session +from openai.api_requestor import APIRequestor from metagpt.logs import logger