mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 01:06:27 +02:00
update zhipu api due to new model and api; repair extra invalid generate output; update its unittest
This commit is contained in:
parent
75cbf9f087
commit
4e13eaca6e
17 changed files with 156 additions and 214 deletions
|
|
@ -36,6 +36,7 @@ TIMEOUT: 60 # Timeout for llm invocation
|
|||
|
||||
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
|
||||
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
|
||||
# ZHIPUAI_API_MODEL: "glm-4"
|
||||
|
||||
#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
|
||||
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ async def main():
|
|||
# streaming mode, much slower
|
||||
await llm.acompletion_text(hello_msg, stream=True)
|
||||
|
||||
# check completion if exist to test llm complete functions
|
||||
if hasattr(llm, "completion"):
|
||||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -157,9 +157,11 @@ class WriteCodeReview(Action):
|
|||
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
|
||||
format_example=format_example,
|
||||
)
|
||||
len1 = len(iterative_code) if iterative_code else 0
|
||||
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
|
||||
logger.info(
|
||||
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
|
||||
f"{len(self.context.code_doc.content)=}"
|
||||
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
|
||||
f"len(self.context.code_doc.content)={len2}"
|
||||
)
|
||||
result, rewrited_code = await self.write_code_review_and_rewrite(
|
||||
context_prompt, cr_prompt, self.context.code_doc.filename
|
||||
|
|
|
|||
|
|
@ -144,6 +144,7 @@ class Config(metaclass=Singleton):
|
|||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
|
||||
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
|
||||
self.zhipuai_api_model = self._get("ZHIPUAI_API_MODEL")
|
||||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ class BaseLLM(ABC):
|
|||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
|
||||
def get_choice_delta_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of stream choice"""
|
||||
return rsp.get("choices")[0]["delta"]["content"]
|
||||
|
||||
def get_choice_function(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function of choice
|
||||
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
|
||||
|
|
|
|||
|
|
@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
if stream and (
|
||||
"text/event-stream" in result.headers.get("Content-Type", "")
|
||||
or "application/x-ndjson" in result.headers.get("Content-Type", "")
|
||||
):
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
|
|
|
|||
|
|
@ -1,75 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : async_sse_client to make keep the use of Event to access response
|
||||
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
|
||||
# refs to `zhipuai/core/_sse_client.py`
|
||||
|
||||
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
|
||||
import json
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
class AsyncSSEClient(SSEClient):
|
||||
async def _aread(self):
|
||||
data = b""
|
||||
class AsyncSSEClient(object):
|
||||
def __init__(self, event_source: Iterator[Any]):
|
||||
self._event_source = event_source
|
||||
|
||||
async def stream(self) -> dict:
|
||||
if isinstance(self._event_source, bytes):
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
async for chunk in self._event_source:
|
||||
for line in chunk.splitlines(True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
line = chunk.decode("utf-8")
|
||||
if line.startswith(":") or not line:
|
||||
return
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
event = Event()
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for line in chunk.splitlines():
|
||||
# Decode the line.
|
||||
line = line.decode(self._char_enc)
|
||||
|
||||
# Lines starting with a separator are comments and are to be
|
||||
# ignored.
|
||||
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
||||
continue
|
||||
|
||||
data = line.split(_FIELD_SEPARATOR, 1)
|
||||
field = data[0]
|
||||
|
||||
# Ignore unknown fields.
|
||||
if field not in event.__dict__:
|
||||
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
|
||||
continue
|
||||
|
||||
if len(data) > 1:
|
||||
# From the spec:
|
||||
# "If value starts with a single U+0020 SPACE character,
|
||||
# remove it from value."
|
||||
if data[1].startswith(" "):
|
||||
value = data[1][1:]
|
||||
else:
|
||||
value = data[1]
|
||||
else:
|
||||
# If no value is present after the separator,
|
||||
# assume an empty value.
|
||||
value = ""
|
||||
|
||||
# The data field may come over multiple lines and their values
|
||||
# are concatenated with each other.
|
||||
if field == "data":
|
||||
event.__dict__[field] += value + "\n"
|
||||
else:
|
||||
event.__dict__[field] = value
|
||||
|
||||
# Events with no data are not dispatched.
|
||||
if not event.data:
|
||||
continue
|
||||
|
||||
# If the data field ends with a newline, remove it.
|
||||
if event.data.endswith("\n"):
|
||||
event.data = event.data[0:-1]
|
||||
|
||||
# Empty event names default to 'message'
|
||||
event.event = event.event or "message"
|
||||
|
||||
# Dispatch the event
|
||||
self._logger.debug("Dispatching %s...", event)
|
||||
yield event
|
||||
field, _p, value = line.partition(":")
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
if value.startswith("[DONE]"):
|
||||
break
|
||||
data = json.loads(value)
|
||||
yield data
|
||||
|
|
|
|||
|
|
@ -4,46 +4,27 @@
|
|||
|
||||
import json
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType, ModelAPI
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
class ZhiPuModelAPI(ModelAPI):
|
||||
@classmethod
|
||||
def get_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
zhipuai_default_headers.update({"Authorization": token})
|
||||
return zhipuai_default_headers
|
||||
|
||||
@classmethod
|
||||
def get_sse_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
headers = {"Authorization": token}
|
||||
return headers
|
||||
|
||||
@classmethod
|
||||
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
|
||||
class ZhiPuModelAPI(ZhipuAI):
|
||||
def split_zhipu_api_url(self):
|
||||
# use this method to prevent zhipu api upgrading to different version.
|
||||
# and follow the GeneralAPIRequestor implemented based on openai sdk
|
||||
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
|
||||
"""
|
||||
example:
|
||||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
|
||||
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
|
||||
# TODO to make the async request to be more generic for models in http mode.
|
||||
assert method in ["post", "get"]
|
||||
|
||||
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
base_url, url = self.split_zhipu_api_url()
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
|
|
@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds,
|
||||
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def ainvoke(cls, **kwargs) -> dict:
|
||||
async def acreate(self, **kwargs) -> dict:
|
||||
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
headers = cls.get_header()
|
||||
resp = await cls.arequest(
|
||||
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
|
||||
)
|
||||
headers = self._default_headers
|
||||
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
|
||||
resp = resp.decode("utf-8")
|
||||
resp = json.loads(resp)
|
||||
if "error" in resp:
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
|
||||
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
|
||||
"""async sse_invoke"""
|
||||
headers = cls.get_sse_header()
|
||||
return AsyncSSEClient(
|
||||
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
|
||||
)
|
||||
headers = self._default_headers
|
||||
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
import zhipuai
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
|
|
@ -15,6 +13,7 @@ from tenacity import (
|
|||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
|
|
@ -35,26 +34,25 @@ class ZhiPuEvent(Enum):
|
|||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
From now, support glm-3-turbo、glm-4, and also system_prompt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_zhipuai(CONFIG)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
self.llm = ZhiPuModelAPI(api_key=self.api_key)
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
self.api_key = config.zhipuai_api_key
|
||||
self.model = config.zhipuai_api_model # so far, it support glm-3-turbo、glm-4
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
if config.openai_proxy:
|
||||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.openai_proxy
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
|
|
@ -67,21 +65,15 @@ class ZhiPuAILLM(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
return resp.model_dump()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp = await self.llm.acreate(**self._const_kwargs(messages))
|
||||
usage = resp.get("usage", {})
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
|
|
@ -89,35 +81,18 @@ class ZhiPuAILLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for event in response.async_events():
|
||||
if event.event == ZhiPuEvent.ADD.value:
|
||||
content = event.data
|
||||
async for chunk in response.stream():
|
||||
finish_reason = chunk.get("choices")[0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
usage = chunk.get("usage", {})
|
||||
else:
|
||||
content = self.get_choice_delta_text(chunk)
|
||||
collected_content.append(content)
|
||||
log_llm_stream(content)
|
||||
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
|
||||
content = event.data
|
||||
logger.error(f"event error: {content}", end="")
|
||||
elif event.event == ZhiPuEvent.FINISH.value:
|
||||
"""
|
||||
event.meta
|
||||
{
|
||||
"task_status":"SUCCESS",
|
||||
"usage":{
|
||||
"completion_tokens":351,
|
||||
"prompt_tokens":595,
|
||||
"total_tokens":946
|
||||
},
|
||||
"task_id":"xx",
|
||||
"request_id":"xxx"
|
||||
}
|
||||
"""
|
||||
meta = json.loads(event.meta)
|
||||
usage = meta.get("usage")
|
||||
else:
|
||||
print(f"zhipuapi else event: {event.data}", end="")
|
||||
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class FileRepository:
|
|||
"""
|
||||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = content if content else "" # avoid `argument must be str, not None` to make it continue
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
|
|
|||
|
|
@ -120,6 +120,15 @@ def repair_json_format(output: str) -> str:
|
|||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove `#` in output json str, usually appeared in `glm-4`
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
for line in arr:
|
||||
idx = line.find("#")
|
||||
if idx >= 0:
|
||||
line = line[:idx]
|
||||
new_arr.append(line)
|
||||
output = "\n".join(new_arr)
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -168,15 +177,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+)"
|
||||
pattern = r"line ([0-9]+) column ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0]) - 1
|
||||
line_no = int(matches[0][0]) - 1
|
||||
col_no = int(matches[0][1]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
rline = arr[line_no] # raw line
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
|
|
@ -187,9 +198,12 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
elif '",' not in line and "," not in line:
|
||||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif "," not in line:
|
||||
elif not line.endswith(","):
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ TOKEN_COSTS = {
|
|||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ aioredis~=2.0.1 # Used by metagpt/utils/redis.py
|
|||
websocket-client==1.6.2
|
||||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==1.0.7
|
||||
zhipuai==2.0.1
|
||||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
|
|
@ -15,35 +14,16 @@ messages = [{"role": "user", "content": prompt_msg}]
|
|||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
|
||||
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -52,23 +32,26 @@ async def mock_zhipuai_asse_invoke(**kwargs):
|
|||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async def stream(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM()
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
|||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
yield b'data: {"test_key": "test_value"}'
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert "test_value" in chunk.values()
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert not chunk
|
||||
|
|
|
|||
|
|
@ -6,15 +6,13 @@ from typing import Any, Tuple
|
|||
|
||||
import pytest
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = b'{"result": "test response"}'
|
||||
default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -23,22 +21,15 @@ async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipu_model_api(mocker):
|
||||
header = ZhiPuModelAPI.get_header()
|
||||
zhipuai_default_headers.update({"Authorization": api_key})
|
||||
assert header == zhipuai_default_headers
|
||||
|
||||
sse_header = ZhiPuModelAPI.get_sse_header()
|
||||
assert len(sse_header["Authorization"]) == 191
|
||||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
url_prefix, url_suffix = ZhiPuModelAPI(api_key=api_key).split_zhipu_api_url()
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
assert url_suffix == "/paas/v4/chat/completions"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
result = await ZhiPuModelAPI(api_key=api_key).arequest(
|
||||
stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
result = await ZhiPuModelAPI(api_key=api_key).acreate()
|
||||
assert result["choices"][0]["message"]["content"] == "test response"
|
||||
|
|
|
|||
|
|
@ -128,6 +128,19 @@ def test_repair_json_format():
|
|||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """
|
||||
{
|
||||
"Language": "en_us", # define language
|
||||
"Programming Language": "Python"
|
||||
}
|
||||
"""
|
||||
target_output = """{
|
||||
"Language": "en_us",
|
||||
"Programming Language": "Python"
|
||||
}"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_repair_invalid_json():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_invalid_json
|
||||
|
|
@ -204,6 +217,25 @@ def test_retry_parse_json_text():
|
|||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
invalid_json_text = '''{
|
||||
"Data structures and interfaces": """
|
||||
class UI:
|
||||
- game_engine: GameEngine
|
||||
+ __init__(engine: GameEngine) -> None
|
||||
+ display_board() -> None
|
||||
+ display_score() -> None
|
||||
+ prompt_move() -> str
|
||||
+ reset_game() -> None
|
||||
"""
|
||||
"Anything UNCLEAR": "no"
|
||||
}'''
|
||||
target_json = {
|
||||
"Data structures and interfaces": "\n class UI:\n - game_engine: GameEngine\n + __init__(engine: GameEngine) -> None\n + display_board() -> None\n + display_score() -> None\n + prompt_move() -> str\n + reset_game() -> None\n ",
|
||||
"Anything UNCLEAR": "no",
|
||||
}
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
|
||||
def test_extract_content_from_output():
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue