mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
merge main
This commit is contained in:
commit
b89746b9c7
20 changed files with 167 additions and 206 deletions
13
README.md
13
README.md
|
|
@ -34,10 +34,19 @@ # MetaGPT: The Multi-Agent Framework
|
|||
<p align="center">Software Company Multi-Role Schematic (Gradually Implementing)</p>
|
||||
|
||||
## News
|
||||
🚀 Jan 03: Here comes [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)! In this version, we added serialization and deserialization of important objects and enabled breakpoint recovery. We upgraded OpenAI package to v1.6.0 and supported Gemini, ZhipuAI, Ollama, OpenLLM, etc. Moreover, we provided extremely simple examples where you need only 7 lines to implement a general election [debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py). Check out more details [here](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)!
|
||||
🚀 Jan. 16, 2024: [MetaGPT paper](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
|
||||
|
||||
🚀 Jan. 03, 2024: [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) released, new features include serialization, upgraded OpenAI package and supported multiple LLM etc.
|
||||
|
||||
🚀 Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduced **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or existing codebase. We also launched a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism!
|
||||
🚀 Dec. 15, 2023: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) released, introducing **incremental development**, **multilingual**, **multiple programming languages**, etc.
|
||||
|
||||
🔥 Nov. 08, 2023: MetaGPT is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html).
|
||||
|
||||
🔥 Sep. 01, 2023: MetaGPT tops GitHub Trending Monthly for the **17th time** in August 2023.
|
||||
|
||||
🌟 Jun. 30, 2023: MetaGPT is now open source.
|
||||
|
||||
🌟 Apr. 24, 2023: First line of MetaGPT code committed.
|
||||
|
||||
## Install
|
||||
|
||||
|
|
|
|||
|
|
@ -76,9 +76,8 @@ ### Tasks
|
|||
2. ~~Support Azure asynchronous API~~
|
||||
3. Support streaming version of all APIs
|
||||
4. ~~Make gpt-3.5-turbo available (HARD)~~
|
||||
5. Support
|
||||
10. Other
|
||||
1. ~~Clean up existing unused code~~
|
||||
2. Unify all code styles and establish contribution standards
|
||||
2. ~~Unify all code styles and establish contribution standards~~
|
||||
3. ~~Multi-language support~~
|
||||
4. ~~Multi-programming-language support~~
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ async def main():
|
|||
# streaming mode, much slower
|
||||
await llm.acompletion_text(hello_msg, stream=True)
|
||||
|
||||
# check completion if exist to test llm complete functions
|
||||
if hasattr(llm, "completion"):
|
||||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class WriteTasks(Action):
|
|||
|
||||
async def _update_requirements(self, doc):
|
||||
m = json.loads(doc.content)
|
||||
packages = set(m.get("Required Python third-party packages", set()))
|
||||
packages = set(m.get("Required Python packages", set()))
|
||||
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
|
||||
if not requirement_doc:
|
||||
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
|
||||
|
|
|
|||
|
|
@ -160,9 +160,11 @@ class WriteCodeReview(Action):
|
|||
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
|
||||
format_example=format_example,
|
||||
)
|
||||
len1 = len(iterative_code) if iterative_code else 0
|
||||
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
|
||||
logger.info(
|
||||
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
|
||||
f"{len(self.i_context.code_doc.content)=}"
|
||||
f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
|
||||
f"len(self.i_context.code_doc.content)={len2}"
|
||||
)
|
||||
result, rewrited_code = await self.write_code_review_and_rewrite(
|
||||
context_prompt, cr_prompt, self.i_context.code_doc.filename
|
||||
|
|
|
|||
|
|
@ -106,6 +106,10 @@ class BaseLLM(ABC):
|
|||
"""Required to provide the first text of choice"""
|
||||
return rsp.get("choices")[0]["message"]["content"]
|
||||
|
||||
def get_choice_delta_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of stream choice"""
|
||||
return rsp.get("choices")[0]["delta"]["content"]
|
||||
|
||||
def get_choice_function(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function of choice
|
||||
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
|
||||
|
|
|
|||
|
|
@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
if stream and (
|
||||
"text/event-stream" in result.headers.get("Content-Type", "")
|
||||
or "application/x-ndjson" in result.headers.get("Content-Type", "")
|
||||
):
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
|
|
|
|||
|
|
@ -1,75 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : async_sse_client to make keep the use of Event to access response
|
||||
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
|
||||
# refs to `zhipuai/core/_sse_client.py`
|
||||
|
||||
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
|
||||
import json
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
class AsyncSSEClient(SSEClient):
|
||||
async def _aread(self):
|
||||
data = b""
|
||||
class AsyncSSEClient(object):
|
||||
def __init__(self, event_source: Iterator[Any]):
|
||||
self._event_source = event_source
|
||||
|
||||
async def stream(self) -> dict:
|
||||
if isinstance(self._event_source, bytes):
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
async for chunk in self._event_source:
|
||||
for line in chunk.splitlines(True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
line = chunk.decode("utf-8")
|
||||
if line.startswith(":") or not line:
|
||||
return
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
event = Event()
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for line in chunk.splitlines():
|
||||
# Decode the line.
|
||||
line = line.decode(self._char_enc)
|
||||
|
||||
# Lines starting with a separator are comments and are to be
|
||||
# ignored.
|
||||
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
|
||||
continue
|
||||
|
||||
data = line.split(_FIELD_SEPARATOR, 1)
|
||||
field = data[0]
|
||||
|
||||
# Ignore unknown fields.
|
||||
if field not in event.__dict__:
|
||||
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
|
||||
continue
|
||||
|
||||
if len(data) > 1:
|
||||
# From the spec:
|
||||
# "If value starts with a single U+0020 SPACE character,
|
||||
# remove it from value."
|
||||
if data[1].startswith(" "):
|
||||
value = data[1][1:]
|
||||
else:
|
||||
value = data[1]
|
||||
else:
|
||||
# If no value is present after the separator,
|
||||
# assume an empty value.
|
||||
value = ""
|
||||
|
||||
# The data field may come over multiple lines and their values
|
||||
# are concatenated with each other.
|
||||
if field == "data":
|
||||
event.__dict__[field] += value + "\n"
|
||||
else:
|
||||
event.__dict__[field] = value
|
||||
|
||||
# Events with no data are not dispatched.
|
||||
if not event.data:
|
||||
continue
|
||||
|
||||
# If the data field ends with a newline, remove it.
|
||||
if event.data.endswith("\n"):
|
||||
event.data = event.data[0:-1]
|
||||
|
||||
# Empty event names default to 'message'
|
||||
event.event = event.event or "message"
|
||||
|
||||
# Dispatch the event
|
||||
self._logger.debug("Dispatching %s...", event)
|
||||
yield event
|
||||
field, _p, value = line.partition(":")
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
if value.startswith("[DONE]"):
|
||||
break
|
||||
data = json.loads(value)
|
||||
yield data
|
||||
|
|
|
|||
|
|
@ -4,46 +4,27 @@
|
|||
|
||||
import json
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType, ModelAPI
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
class ZhiPuModelAPI(ModelAPI):
|
||||
@classmethod
|
||||
def get_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
zhipuai_default_headers.update({"Authorization": token})
|
||||
return zhipuai_default_headers
|
||||
|
||||
@classmethod
|
||||
def get_sse_header(cls) -> dict:
|
||||
token = cls._generate_token()
|
||||
headers = {"Authorization": token}
|
||||
return headers
|
||||
|
||||
@classmethod
|
||||
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
|
||||
class ZhiPuModelAPI(ZhipuAI):
|
||||
def split_zhipu_api_url(self):
|
||||
# use this method to prevent zhipu api upgrading to different version.
|
||||
# and follow the GeneralAPIRequestor implemented based on openai sdk
|
||||
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
|
||||
"""
|
||||
example:
|
||||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
|
||||
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
|
||||
# TODO to make the async request to be more generic for models in http mode.
|
||||
assert method in ["post", "get"]
|
||||
|
||||
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
|
||||
base_url, url = self.split_zhipu_api_url()
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
method=method,
|
||||
|
|
@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
headers=headers,
|
||||
stream=stream,
|
||||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds,
|
||||
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def ainvoke(cls, **kwargs) -> dict:
|
||||
async def acreate(self, **kwargs) -> dict:
|
||||
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
|
||||
headers = cls.get_header()
|
||||
resp = await cls.arequest(
|
||||
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
|
||||
)
|
||||
headers = self._default_headers
|
||||
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
|
||||
resp = resp.decode("utf-8")
|
||||
resp = json.loads(resp)
|
||||
if "error" in resp:
|
||||
raise RuntimeError(
|
||||
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
|
||||
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
|
||||
"""async sse_invoke"""
|
||||
headers = cls.get_sse_header()
|
||||
return AsyncSSEClient(
|
||||
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
|
||||
)
|
||||
headers = self._default_headers
|
||||
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
|
|
@ -35,7 +34,7 @@ class ZhiPuEvent(Enum):
|
|||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
From now, support glm-3-turbo、glm-4, and also system_prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
|
|
@ -54,8 +53,8 @@ class ZhiPuAILLM(BaseLLM):
|
|||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.proxy
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
|
|
@ -68,21 +67,15 @@ class ZhiPuAILLM(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
return resp.model_dump()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
resp = await self.llm.acreate(**self._const_kwargs(messages))
|
||||
usage = resp.get("usage", {})
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
|
|
@ -90,35 +83,18 @@ class ZhiPuAILLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for event in response.async_events():
|
||||
if event.event == ZhiPuEvent.ADD.value:
|
||||
content = event.data
|
||||
async for chunk in response.stream():
|
||||
finish_reason = chunk.get("choices")[0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
usage = chunk.get("usage", {})
|
||||
else:
|
||||
content = self.get_choice_delta_text(chunk)
|
||||
collected_content.append(content)
|
||||
log_llm_stream(content)
|
||||
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
|
||||
content = event.data
|
||||
logger.error(f"event error: {content}", end="")
|
||||
elif event.event == ZhiPuEvent.FINISH.value:
|
||||
"""
|
||||
event.meta
|
||||
{
|
||||
"task_status":"SUCCESS",
|
||||
"usage":{
|
||||
"completion_tokens":351,
|
||||
"prompt_tokens":595,
|
||||
"total_tokens":946
|
||||
},
|
||||
"task_id":"xx",
|
||||
"request_id":"xxx"
|
||||
}
|
||||
"""
|
||||
meta = json.loads(event.meta)
|
||||
usage = meta.get("usage")
|
||||
else:
|
||||
print(f"zhipuapi else event: {event.data}", end="")
|
||||
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ class QaEngineer(Role):
|
|||
profile: str = "QaEngineer"
|
||||
goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs"
|
||||
constraints: str = (
|
||||
"The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain"
|
||||
"The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain."
|
||||
"Use same language as user requirement"
|
||||
)
|
||||
test_round_allowed: int = 5
|
||||
test_round: int = 0
|
||||
|
|
@ -54,6 +55,8 @@ class QaEngineer(Role):
|
|||
if not filename or "test" in filename:
|
||||
continue
|
||||
code_doc = await src_file_repo.get(filename)
|
||||
if not code_doc:
|
||||
continue
|
||||
test_doc = await self.project_repo.tests.get("test_" + code_doc.filename)
|
||||
if not test_doc:
|
||||
test_doc = Document(
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class Team(BaseModel):
|
|||
logger.info(f"Investment: ${investment}.")
|
||||
|
||||
def _check_balance(self):
|
||||
if self.cost_manager.total_cost > self.cost_manager.max_budget:
|
||||
if self.cost_manager.total_cost >= self.cost_manager.max_budget:
|
||||
raise NoMoneyException(self.cost_manager.total_cost, f"Insufficient funds: {self.cost_manager.max_budget}")
|
||||
|
||||
def run_project(self, idea, send_to: str = ""):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class FileRepository:
|
|||
"""
|
||||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = content if content else "" # avoid `argument must be str, not None` to make it continue
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
|
|
|||
|
|
@ -120,6 +120,15 @@ def repair_json_format(output: str) -> str:
|
|||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove `#` in output json str, usually appeared in `glm-4`
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
for line in arr:
|
||||
idx = line.find("#")
|
||||
if idx >= 0:
|
||||
line = line[:idx]
|
||||
new_arr.append(line)
|
||||
output = "\n".join(new_arr)
|
||||
return output
|
||||
|
||||
|
||||
|
|
@ -168,15 +177,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
|
||||
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
|
||||
"""
|
||||
pattern = r"line ([0-9]+)"
|
||||
pattern = r"line ([0-9]+) column ([0-9]+)"
|
||||
|
||||
matches = re.findall(pattern, error, re.DOTALL)
|
||||
if len(matches) > 0:
|
||||
line_no = int(matches[0]) - 1
|
||||
line_no = int(matches[0][0]) - 1
|
||||
col_no = int(matches[0][1]) - 1
|
||||
|
||||
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
|
||||
output = output.replace('"""', '"').replace("'''", '"')
|
||||
arr = output.split("\n")
|
||||
rline = arr[line_no] # raw line
|
||||
line = arr[line_no].strip()
|
||||
# different general problems
|
||||
if line.endswith("],"):
|
||||
|
|
@ -187,9 +198,12 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
new_line = line.replace("}", "")
|
||||
elif line.endswith("},") and output.endswith("},"):
|
||||
new_line = line[:-1]
|
||||
elif '",' not in line and "," not in line:
|
||||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif "," not in line:
|
||||
elif not line.endswith(","):
|
||||
# problem, miss char `,` at the end.
|
||||
new_line = f"{line},"
|
||||
elif "," in line and len(line) == 1:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ TOKEN_COSTS = {
|
|||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ aioredis~=2.0.1 # Used by metagpt/utils/redis.py
|
|||
websocket-client==1.6.2
|
||||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==1.0.7
|
||||
zhipuai==2.0.1
|
||||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
|
||||
|
|
@ -13,35 +12,16 @@ messages = [{"role": "user", "content": prompt_msg}]
|
|||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
|
||||
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -50,23 +30,26 @@ async def mock_zhipuai_asse_invoke(**kwargs):
|
|||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async def stream(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
|||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
yield b'data: {"test_key": "test_value"}'
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert "test_value" in chunk.values()
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert not chunk
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
|||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = b'{"result": "test response"}'
|
||||
default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -32,13 +32,13 @@ async def test_zhipu_model_api(mocker):
|
|||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
assert url_suffix == "/paas/v4/chat/completions"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
result = await ZhiPuModelAPI(api_key=api_key).arequest(
|
||||
stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
result = await ZhiPuModelAPI(api_key=api_key).acreate()
|
||||
assert result["choices"][0]["message"]["content"] == "test response"
|
||||
|
|
|
|||
|
|
@ -128,6 +128,19 @@ def test_repair_json_format():
|
|||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
raw_output = """
|
||||
{
|
||||
"Language": "en_us", # define language
|
||||
"Programming Language": "Python"
|
||||
}
|
||||
"""
|
||||
target_output = """{
|
||||
"Language": "en_us",
|
||||
"Programming Language": "Python"
|
||||
}"""
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_repair_invalid_json():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_invalid_json
|
||||
|
|
@ -204,6 +217,25 @@ def test_retry_parse_json_text():
|
|||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
invalid_json_text = '''{
|
||||
"Data structures and interfaces": """
|
||||
class UI:
|
||||
- game_engine: GameEngine
|
||||
+ __init__(engine: GameEngine) -> None
|
||||
+ display_board() -> None
|
||||
+ display_score() -> None
|
||||
+ prompt_move() -> str
|
||||
+ reset_game() -> None
|
||||
"""
|
||||
"Anything UNCLEAR": "no"
|
||||
}'''
|
||||
target_json = {
|
||||
"Data structures and interfaces": "\n class UI:\n - game_engine: GameEngine\n + __init__(engine: GameEngine) -> None\n + display_board() -> None\n + display_score() -> None\n + prompt_move() -> str\n + reset_game() -> None\n ",
|
||||
"Anything UNCLEAR": "no",
|
||||
}
|
||||
output = retry_parse_json_text(output=invalid_json_text)
|
||||
assert output == target_json
|
||||
|
||||
|
||||
def test_extract_content_from_output():
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue