add zhipuai api with extra async invoke methods

This commit is contained in:
better629 2023-11-18 21:34:14 +08:00
parent f8ebfa9a74
commit 96ce036bd4
7 changed files with 323 additions and 13 deletions

View file

@ -45,10 +45,11 @@ class Config(metaclass=Singleton):
self.global_proxy = self._get("GLOBAL_PROXY")
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
self.openai_api_base = self._get("OPENAI_API_BASE")
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
if openai_proxy:

View file

@ -6,14 +6,24 @@
@File : llm.py
"""
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2 as Claude
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
DEFAULT_LLM = LLM()
CLAUDE_LLM = Claude()
async def ai_func(prompt):
"""使用LLM进行QA
QA with LLMs
"""
return await DEFAULT_LLM.aask(prompt)
def LLM():
""" initialize different LLM instance according to the key field existence"""
# TODO a little trick, can use registry to initialize LLM instance further
if CONFIG.openai_api_key and CONFIG.openai_api_key.starswith("sk-"):
llm = OpenAIGPTAPI()
elif CONFIG.claude_api_key:
llm = Claude()
elif CONFIG.spark_api_key:
llm = SparkAPI()
elif CONFIG.zhipuai_api_key:
llm = ZhiPuAIGPTAPI()
return llm

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,77 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
from zhipuai.utils.sse_client import SSEClient, Event, _FIELD_SEPARATOR
class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
async for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
async def async_events(self):
async for chunk in self._aread():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)
# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue
data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug(
"Saw invalid field %s while parsing " "Server Side Event", field
)
continue
if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(" "):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ""
# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == "data":
event.__dict__[field] += value + "\n"
else:
event.__dict__[field] = value
# Events with no data are not dispatched.
if not event.data:
continue
# If the data field ends with a newline, remove it.
if event.data.endswith("\n"):
event.data = event.data[0:-1]
# Empty event names default to 'message'
event.event = event.event or "message"
# Dispatch the event
self._logger.debug("Dispatching %s...", event)
yield event

View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : zhipu model api to support sync & async for invoke & sse_invoke
import zhipuai
from zhipuai.model_api.api import ModelAPI, InvokeType
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from zhipuai.utils.sse_client import SSEClient
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
zhipuai_default_headers.update({"Authorization": token})
return zhipuai_default_headers
@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {
"Authorization": token
}
return headers
@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
# use this method to prevent zhipu api upgrading to different version.
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
# example: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
return f"{arr[0]}/api", f"/{arr[1]}"
@classmethod
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]
api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs)
requester = GeneralAPIRequestor(api_base=api_base)
result, _, api_key = await requester.arequest(
method=method,
url=url,
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds
)
return result
@classmethod
async def ainvoke(cls, **kwargs) -> dict:
""" async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(invoke_type=InvokeType.SYNC,
stream=False,
method="post",
headers=headers,
kwargs=kwargs)
return resp
@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
""" async sse_invoke """
headers = cls.get_sse_header()
return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE,
stream=True,
method="post",
headers=headers,
kwargs=kwargs))

View file

@ -0,0 +1,139 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
from enum import Enum
import json
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_fixed,
)
from requests import ConnectionError
import zhipuai
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.openai_api import CostManager, log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
from metagpt.utils.ahttp_client import astream
class ZhiPuEvent(Enum):
ADD = "add"
ERROR = "error"
INTERRUPTED = "interrupted"
FINISH = "finish"
class ZhiPuAIGPTAPI(BaseGPTAPI):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
"""
use_system_prompt: bool = False # zhipuai has no system prompt when use api
def __init__(self):
self.__init_zhipuai(CONFIG)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self._cost_manager = CostManager()
def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {
"model": self.model,
"prompt": messages,
"temperature": 0.3
}
return kwargs
def _update_costs(self, usage: dict):
""" update each request's token cost """
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("zhipuai updats costs failed!", e)
def get_choice_text(self, resp: dict) -> str:
""" get the first text of choice from llm response """
assist_msg = resp.get("data").get("choices")[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict]) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
collected_content = []
usage = {}
async for event in response.async_events():
if event.event == ZhiPuEvent.ADD.value:
content = event.data
collected_content.append(content)
print(content, end="")
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")
collected_content.append([content])
elif event.event == ZhiPuEvent.FINISH.value:
"""
event.meta
{
"task_status":"SUCCESS",
"usage":{
"completion_tokens":351,
"prompt_tokens":595,
"total_tokens":946
},
"task_id":"xx",
"request_id":"xxx"
}
"""
meta = json.loads(event.meta)
usage = meta.get("usage")
else:
print(f"zhipuapi else event: {event.data}", end="")
self._update_costs(usage)
full_content = "".join(collected_content)
logger.info(f"full_content: {full_content} !!")
return full_content
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_fixed(1),
# after=after_log(logger, logger.level("WARNING").name),
# retry=retry_if_exception_type(ConnectionError),
# retry_error_callback=log_and_reraise
# )
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
""" response in async with stream or non-stream mode """
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -22,6 +22,7 @@ TOKEN_COSTS = {
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
}
@ -37,6 +38,7 @@ TOKEN_MAX = {
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768
}
@ -68,7 +70,9 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
return count_message_tokens(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
f"num_tokens_from_messages() is not implemented for model {model}. "
f"See https://github.com/openai/openai-python/blob/main/chatml.md "
f"for information on how messages are converted to tokens."
)
num_tokens = 0
for message in messages: