diff --git a/README.md b/README.md index 9c88c92a1..61d03f692 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,19 @@ # MetaGPT: The Multi-Agent Framework

Software Company Multi-Role Schematic (Gradually Implementing)

## News -🚀 Jan 03: Here comes [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)! In this version, we added serialization and deserialization of important objects and enabled breakpoint recovery. We upgraded OpenAI package to v1.6.0 and supported Gemini, ZhipuAI, Ollama, OpenLLM, etc. Moreover, we provided extremely simple examples where you need only 7 lines to implement a general election [debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py). Check out more details [here](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0)! +🚀 Jan. 16, 2024: [MetaGPT paper](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category. +🚀 Jan. 03, 2024: [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) released, new features include serialization, upgraded OpenAI package and supported multiple LLM etc. -🚀 Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduced **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or existing codebase. We also launched a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! +🚀 Dec. 15, 2023: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) released, introducing **incremental development**, **multilingual**, **multiple programming languages**, etc. + +🔥 Nov. 08, 2023: MetaGPT is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html). + +🔥 Sep. 01, 2023: MetaGPT tops GitHub Trending Monthly for the **17th time** in August 2023. + +🌟 Jun. 30, 2023: MetaGPT is now open source. + +🌟 Apr. 24, 2023: First line of MetaGPT code committed. ## Install diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index 9bc62f849..4bb530bf2 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -76,9 +76,8 @@ ### Tasks 2. ~~Support Azure asynchronous API~~ 3. Support streaming version of all APIs 4. ~~Make gpt-3.5-turbo available (HARD)~~ - 5. Support 10. Other 1. ~~Clean up existing unused code~~ - 2. Unify all code styles and establish contribution standards + 2. ~~Unify all code styles and establish contribution standards~~ 3. ~~Multi-language support~~ 4. ~~Multi-programming-language support~~ diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 76be1cc90..219a303c8 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -23,6 +23,10 @@ async def main(): # streaming mode, much slower await llm.acompletion_text(hello_msg, stream=True) + # check completion if exist to test llm complete functions + if hasattr(llm, "completion"): + logger.info(llm.completion(hello_msg)) + if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index fb086d5c2..7988dd4e8 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -84,7 +84,7 @@ class WriteTasks(Action): async def _update_requirements(self, doc): m = json.loads(doc.content) - packages = set(m.get("Required Python third-party packages", set())) + packages = set(m.get("Required Python packages", set())) requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME) if not requirement_doc: requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="") diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 8b85608ee..ec56afc61 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -160,9 +160,11 @@ class WriteCodeReview(Action): cr_prompt = EXAMPLE_AND_INSTRUCTION.format( format_example=format_example, ) + len1 = len(iterative_code) if iterative_code else 0 + len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0 logger.info( - f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, " - f"{len(self.i_context.code_doc.content)=}" + f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, " + f"len(self.i_context.code_doc.content)={len2}" ) result, rewrited_code = await self.write_code_review_and_rewrite( context_prompt, cr_prompt, self.i_context.code_doc.filename diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 93931f14e..5fe9d1c3a 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -106,6 +106,10 @@ class BaseLLM(ABC): """Required to provide the first text of choice""" return rsp.get("choices")[0]["message"]["content"] + def get_choice_delta_text(self, rsp: dict) -> str: + """Required to provide the first text of stream choice""" + return rsp.get("choices")[0]["delta"]["content"] + def get_choice_function(self, rsp: dict) -> dict: """Required to provide the first function of choice :param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls", diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index cf31fd629..500cd1426 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor): async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]: - if stream and ( - "text/event-stream" in result.headers.get("Content-Type", "") - or "application/x-ndjson" in result.headers.get("Content-Type", "") - ): + content_type = result.headers.get("Content-Type", "") + if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type): # the `Content-Type` of ollama stream resp is "application/x-ndjson" return ( self._interpret_response_line(line, result.status, result.headers, stream=True) diff --git a/metagpt/provider/zhipuai/async_sse_client.py b/metagpt/provider/zhipuai/async_sse_client.py index d7168202a..054865652 100644 --- a/metagpt/provider/zhipuai/async_sse_client.py +++ b/metagpt/provider/zhipuai/async_sse_client.py @@ -1,75 +1,31 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : async_sse_client to make keep the use of Event to access response -# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py` +# refs to `zhipuai/core/_sse_client.py` -from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient +import json +from typing import Any, Iterator -class AsyncSSEClient(SSEClient): - async def _aread(self): - data = b"" +class AsyncSSEClient(object): + def __init__(self, event_source: Iterator[Any]): + self._event_source = event_source + + async def stream(self) -> dict: + if isinstance(self._event_source, bytes): + raise RuntimeError( + f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`" + ) async for chunk in self._event_source: - for line in chunk.splitlines(True): - data += line - if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): - yield data - data = b"" - if data: - yield data + line = chunk.decode("utf-8") + if line.startswith(":") or not line: + return - async def async_events(self): - async for chunk in self._aread(): - event = Event() - # Split before decoding so splitlines() only uses \r and \n - for line in chunk.splitlines(): - # Decode the line. - line = line.decode(self._char_enc) - - # Lines starting with a separator are comments and are to be - # ignored. - if not line.strip() or line.startswith(_FIELD_SEPARATOR): - continue - - data = line.split(_FIELD_SEPARATOR, 1) - field = data[0] - - # Ignore unknown fields. - if field not in event.__dict__: - self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field) - continue - - if len(data) > 1: - # From the spec: - # "If value starts with a single U+0020 SPACE character, - # remove it from value." - if data[1].startswith(" "): - value = data[1][1:] - else: - value = data[1] - else: - # If no value is present after the separator, - # assume an empty value. - value = "" - - # The data field may come over multiple lines and their values - # are concatenated with each other. - if field == "data": - event.__dict__[field] += value + "\n" - else: - event.__dict__[field] = value - - # Events with no data are not dispatched. - if not event.data: - continue - - # If the data field ends with a newline, remove it. - if event.data.endswith("\n"): - event.data = event.data[0:-1] - - # Empty event names default to 'message' - event.event = event.event or "message" - - # Dispatch the event - self._logger.debug("Dispatching %s...", event) - yield event + field, _p, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "data": + if value.startswith("[DONE]"): + break + data = json.loads(value) + yield data diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 16d4102d4..a7d49623a 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -4,46 +4,27 @@ import json -import zhipuai -from zhipuai.model_api.api import InvokeType, ModelAPI -from zhipuai.utils.http_client import headers as zhipuai_default_headers +from zhipuai import ZhipuAI +from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient -class ZhiPuModelAPI(ModelAPI): - @classmethod - def get_header(cls) -> dict: - token = cls._generate_token() - zhipuai_default_headers.update({"Authorization": token}) - return zhipuai_default_headers - - @classmethod - def get_sse_header(cls) -> dict: - token = cls._generate_token() - headers = {"Authorization": token} - return headers - - @classmethod - def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs): +class ZhiPuModelAPI(ZhipuAI): + def split_zhipu_api_url(self): # use this method to prevent zhipu api upgrading to different version. # and follow the GeneralAPIRequestor implemented based on openai sdk - zhipu_api_url = cls._build_api_url(kwargs, invoke_type) - """ - example: - zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} - """ + zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" arr = zhipu_api_url.split("/api/") - # ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke") + # ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions") return f"{arr[0]}/api", f"/{arr[1]}" - @classmethod - async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs): + async def arequest(self, stream: bool, method: str, headers: dict, kwargs): # TODO to make the async request to be more generic for models in http mode. assert method in ["post", "get"] - base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs) + base_url, url = self.split_zhipu_api_url() requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, @@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI): headers=headers, stream=stream, params=kwargs, - request_timeout=zhipuai.api_timeout_seconds, + request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read, ) return result - @classmethod - async def ainvoke(cls, **kwargs) -> dict: + async def acreate(self, **kwargs) -> dict: """async invoke different from raw method `async_invoke` which get the final result by task_id""" - headers = cls.get_header() - resp = await cls.arequest( - invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs - ) + headers = self._default_headers + resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs) resp = resp.decode("utf-8") resp = json.loads(resp) + if "error" in resp: + raise RuntimeError( + f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`" + ) return resp - @classmethod - async def asse_invoke(cls, **kwargs) -> AsyncSSEClient: + async def acreate_stream(self, **kwargs) -> AsyncSSEClient: """async sse_invoke""" - headers = cls.get_sse_header() - return AsyncSSEClient( - await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs) - ) + headers = self._default_headers + return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs)) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 67ec6fb8d..9108a1fba 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk -import json from enum import Enum import openai @@ -35,7 +34,7 @@ class ZhiPuEvent(Enum): class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` - From now, there is only one model named `chatglm_turbo` + From now, support glm-3-turbo、glm-4, and also system_prompt. """ def __init__(self, config: LLMConfig): @@ -54,8 +53,8 @@ class ZhiPuAILLM(BaseLLM): # FIXME: openai v1.x sdk has no proxy support openai.proxy = config.proxy - def _const_kwargs(self, messages: list[dict]) -> dict: - kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs def _update_costs(self, usage: dict): @@ -68,21 +67,15 @@ class ZhiPuAILLM(BaseLLM): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") - def get_choice_text(self, resp: dict) -> str: - """get the first text of choice from llm response""" - assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] - assert assist_msg["role"] == "assistant" - return assist_msg.get("content") - def completion(self, messages: list[dict], timeout=3) -> dict: - resp = self.llm.invoke(**self._const_kwargs(messages)) - usage = resp.get("data").get("usage") + resp = self.llm.chat.completions.create(**self._const_kwargs(messages)) + usage = resp.usage.model_dump() self._update_costs(usage) - return resp + return resp.model_dump() async def _achat_completion(self, messages: list[dict], timeout=3) -> dict: - resp = await self.llm.ainvoke(**self._const_kwargs(messages)) - usage = resp.get("data").get("usage") + resp = await self.llm.acreate(**self._const_kwargs(messages)) + usage = resp.get("usage", {}) self._update_costs(usage) return resp @@ -90,35 +83,18 @@ class ZhiPuAILLM(BaseLLM): return await self._achat_completion(messages, timeout=timeout) async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: - response = await self.llm.asse_invoke(**self._const_kwargs(messages)) + response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True)) collected_content = [] usage = {} - async for event in response.async_events(): - if event.event == ZhiPuEvent.ADD.value: - content = event.data + async for chunk in response.stream(): + finish_reason = chunk.get("choices")[0].get("finish_reason") + if finish_reason == "stop": + usage = chunk.get("usage", {}) + else: + content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) - elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value: - content = event.data - logger.error(f"event error: {content}", end="") - elif event.event == ZhiPuEvent.FINISH.value: - """ - event.meta - { - "task_status":"SUCCESS", - "usage":{ - "completion_tokens":351, - "prompt_tokens":595, - "total_tokens":946 - }, - "task_id":"xx", - "request_id":"xxx" - } - """ - meta = json.loads(event.meta) - usage = meta.get("usage") - else: - print(f"zhipuapi else event: {event.data}", end="") + log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 0666a63db..42faa0cd4 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -29,7 +29,8 @@ class QaEngineer(Role): profile: str = "QaEngineer" goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" constraints: str = ( - "The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain" + "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain." + "Use same language as user requirement" ) test_round_allowed: int = 5 test_round: int = 0 @@ -54,6 +55,8 @@ class QaEngineer(Role): if not filename or "test" in filename: continue code_doc = await src_file_repo.get(filename) + if not code_doc: + continue test_doc = await self.project_repo.tests.get("test_" + code_doc.filename) if not test_doc: test_doc = Document( diff --git a/metagpt/team.py b/metagpt/team.py index 96a27d482..aec72970b 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -83,7 +83,7 @@ class Team(BaseModel): logger.info(f"Investment: ${investment}.") def _check_balance(self): - if self.cost_manager.total_cost > self.cost_manager.max_budget: + if self.cost_manager.total_cost >= self.cost_manager.max_budget: raise NoMoneyException(self.cost_manager.total_cost, f"Insufficient funds: {self.cost_manager.max_budget}") def run_project(self, idea, send_to: str = ""): diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index 846e811cc..94d6fe76d 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -54,6 +54,7 @@ class FileRepository: """ pathname = self.workdir / filename pathname.parent.mkdir(parents=True, exist_ok=True) + content = content if content else "" # avoid `argument must be str, not None` to make it continue async with aiofiles.open(str(pathname), mode="w") as writer: await writer.write(content) logger.info(f"save to: {str(pathname)}") diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index ec2da53f8..82b2dd5b1 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -120,6 +120,15 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" + # remove `#` in output json str, usually appeared in `glm-4` + arr = output.split("\n") + new_arr = [] + for line in arr: + idx = line.find("#") + if idx >= 0: + line = line[:idx] + new_arr.append(line) + output = "\n".join(new_arr) return output @@ -168,15 +177,17 @@ def repair_invalid_json(output: str, error: str) -> str: example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) """ - pattern = r"line ([0-9]+)" + pattern = r"line ([0-9]+) column ([0-9]+)" matches = re.findall(pattern, error, re.DOTALL) if len(matches) > 0: - line_no = int(matches[0]) - 1 + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` output = output.replace('"""', '"').replace("'''", '"') arr = output.split("\n") + rline = arr[line_no] # raw line line = arr[line_no].strip() # different general problems if line.endswith("],"): @@ -187,9 +198,12 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] - elif '",' not in line and "," not in line: + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif '",' not in line and "," not in line and '"' not in line: new_line = f'{line}",' - elif "," not in line: + elif not line.endswith(","): # problem, miss char `,` at the end. new_line = f"{line}," elif "," in line and len(line) == 1: diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a1b74a074..885eb37d7 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -27,7 +27,8 @@ TOKEN_COSTS = { "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens + "glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } diff --git a/requirements.txt b/requirements.txt index 0a54236f0..93ad653dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,7 @@ aioredis~=2.0.1 # Used by metagpt/utils/redis.py websocket-client==1.6.2 aiofiles==23.2.1 gitpython==3.1.40 -zhipuai==1.0.7 +zhipuai==2.0.1 socksio~=1.0.0 gitignore-parser==0.1.9 # connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index c4a40c23d..ad2ececa2 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -3,7 +3,6 @@ # @Desc : the unittest of ZhiPuAILLM import pytest -from zhipuai.utils.sse_client import Event from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu @@ -13,35 +12,16 @@ messages = [{"role": "user", "content": prompt_msg}] resp_content = "I'm chatglm-turbo" default_resp = { - "code": 200, - "data": { - "choices": [{"role": "assistant", "content": resp_content}], - "usage": {"prompt_tokens": 20, "completion_tokens": 20}, - }, + "choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}], + "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, } -def mock_zhipuai_invoke(**kwargs) -> dict: - return default_resp - - -async def mock_zhipuai_ainvoke(**kwargs) -> dict: - return default_resp - - -async def mock_zhipuai_asse_invoke(**kwargs): +async def mock_zhipuai_acreate_stream(self, **kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): - events = [ - Event(id="xxx", event="add", data=resp_content, retry=0), - Event( - id="xxx", - event="finish", - data="", - meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}', - ), - ] + events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}] async def __aiter__(self): for event in self.events: @@ -50,23 +30,26 @@ async def mock_zhipuai_asse_invoke(**kwargs): async for chunk in Iterator(): yield chunk - async def async_events(self): + async def stream(self): async for chunk in self._aread(): yield chunk return MockResponse() +async def mock_zhipuai_acreate(self, **kwargs) -> dict: + return default_resp + + @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke) - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) resp = await zhipu_gpt.acompletion(messages) - assert resp["data"]["choices"][0]["content"] == resp_content + assert resp["choices"][0]["message"]["content"] == resp_content resp = await zhipu_gpt.aask(prompt_msg, stream=False) assert resp == resp_content diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py index 2649f595b..31b2d3d64 100644 --- a/tests/metagpt/provider/zhipuai/test_async_sse_client.py +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -11,16 +11,16 @@ from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient async def test_async_sse_client(): class Iterator(object): async def __aiter__(self): - yield b"data: test_value" + yield b'data: {"test_key": "test_value"}' async_sse_client = AsyncSSEClient(event_source=Iterator()) - async for event in async_sse_client.async_events(): - assert event.data, "test_value" + async for chunk in async_sse_client.stream(): + assert "test_value" in chunk.values() class InvalidIterator(object): async def __aiter__(self): yield b"invalid: test_value" async_sse_client = AsyncSSEClient(event_source=InvalidIterator()) - async for event in async_sse_client.async_events(): - assert not event + async for chunk in async_sse_client.stream(): + assert not chunk diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py index daae65ab7..abaafb402 100644 --- a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI api_key = "xxx.xxx" zhipuai.api_key = api_key -default_resp = b'{"result": "test response"}' +default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}' async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: @@ -32,13 +32,13 @@ async def test_zhipu_model_api(mocker): url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"}) assert url_prefix == "https://open.bigmodel.cn/api" - assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" + assert url_suffix == "/paas/v4/chat/completions" mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) - result = await ZhiPuModelAPI.arequest( - InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} + result = await ZhiPuModelAPI(api_key=api_key).arequest( + stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"} ) assert result == default_resp - result = await ZhiPuModelAPI.ainvoke() - assert result["result"] == "test response" + result = await ZhiPuModelAPI(api_key=api_key).acreate() + assert result["choices"][0]["message"]["content"] == "test response" diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index bd6169d71..3ccca3e06 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -128,6 +128,19 @@ def test_repair_json_format(): output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) assert output == target_output + raw_output = """ +{ + "Language": "en_us", # define language + "Programming Language": "Python" +} +""" + target_output = """{ + "Language": "en_us", + "Programming Language": "Python" +}""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + def test_repair_invalid_json(): from metagpt.utils.repair_llm_raw_output import repair_invalid_json @@ -204,6 +217,25 @@ def test_retry_parse_json_text(): output = retry_parse_json_text(output=invalid_json_text) assert output == target_json + invalid_json_text = '''{ + "Data structures and interfaces": """ + class UI: + - game_engine: GameEngine + + __init__(engine: GameEngine) -> None + + display_board() -> None + + display_score() -> None + + prompt_move() -> str + + reset_game() -> None + """ + "Anything UNCLEAR": "no" +}''' + target_json = { + "Data structures and interfaces": "\n class UI:\n - game_engine: GameEngine\n + __init__(engine: GameEngine) -> None\n + display_board() -> None\n + display_score() -> None\n + prompt_move() -> str\n + reset_game() -> None\n ", + "Anything UNCLEAR": "no", + } + output = retry_parse_json_text(output=invalid_json_text) + assert output == target_json + def test_extract_content_from_output(): """