diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..ff6f19aab --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[run] +source = + ./metagpt/ +omit = + */metagpt/environment/android/* + */metagpt/ext/android_assistant/* + */metagpt/ext/werewolf/* \ No newline at end of file diff --git a/.gitattributes b/.gitattributes index 865da2ca2..e6436790e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -14,6 +14,7 @@ *.ico binary *.jpeg binary *.mp3 binary +*.mp4 binary *.zip binary *.bin binary diff --git a/.github/workflows/fulltest.yaml b/.github/workflows/fulltest.yaml index 70c800481..2ab6444fa 100644 --- a/.github/workflows/fulltest.yaml +++ b/.github/workflows/fulltest.yaml @@ -30,7 +30,10 @@ jobs: cache: 'pip' - name: Install dependencies run: | - sh tests/scripts/run_install_deps.sh + python -m pip install --upgrade pip + pip install -e .[test] + npm install -g @mermaid-js/mermaid-cli + playwright install --with-deps - name: Run reverse proxy script for ssh service if: contains(github.ref, '-debugger') continue-on-error: true diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index dc5ae605b..25f82b1e6 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -27,20 +27,57 @@ jobs: cache: 'pip' - name: Install dependencies run: | - sh tests/scripts/run_install_deps.sh + python -m pip install --upgrade pip + pip install -e .[test] + npm install -g @mermaid-js/mermaid-cli + playwright install --with-deps - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml - pytest tests/ --ignore=tests/metagpt/environment/android_env --ignore=tests/metagpt/ext/android_assistant --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt + pytest --continue-on-collection-errors tests/ \ + --ignore=tests/metagpt/environment/android_env \ + --ignore=tests/metagpt/ext/android_assistant \ + --ignore=tests/metagpt/ext/stanford_town \ + --ignore=tests/metagpt/provider/test_bedrock_api.py \ + --ignore=tests/metagpt/rag/factories/test_embedding.py \ + --ignore=tests/metagpt/ext/werewolf/actions/test_experience_operation.py \ + --ignore=tests/metagpt/provider/test_openai.py \ + --ignore=tests/metagpt/planner/test_action_planner.py \ + --ignore=tests/metagpt/planner/test_basic_planner.py \ + --ignore=tests/metagpt/actions/test_project_management.py \ + --ignore=tests/metagpt/actions/test_write_code.py \ + --ignore=tests/metagpt/actions/test_write_code_review.py \ + --ignore=tests/metagpt/actions/test_write_prd.py \ + --ignore=tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py \ + --ignore=tests/metagpt/memory/test_brain_memory.py \ + --ignore=tests/metagpt/roles/test_assistant.py \ + --ignore=tests/metagpt/roles/test_engineer.py \ + --ignore=tests/metagpt/serialize_deserialize/test_write_code_review.py \ + --ignore=tests/metagpt/test_environment.py \ + --ignore=tests/metagpt/test_llm.py \ + --ignore=tests/metagpt/tools/test_metagpt_oas3_api_svc.py \ + --ignore=tests/metagpt/tools/test_moderation.py \ + --ignore=tests/metagpt/tools/test_search_engine.py \ + --ignore=tests/metagpt/tools/test_tool_convert.py \ + --ignore=tests/metagpt/tools/test_web_browser_engine_playwright.py \ + --ignore=tests/metagpt/utils/test_mermaid.py \ + --ignore=tests/metagpt/utils/test_redis.py \ + --ignore=tests/metagpt/utils/test_tree.py \ + --ignore=tests/metagpt/serialize_deserialize/test_sk_agent.py \ + --ignore=tests/metagpt/utils/test_text.py \ + --ignore=tests/metagpt/actions/di/test_write_analysis_code.py \ + --ignore=tests/metagpt/provider/test_ark.py \ + --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov \ + --durations=20 | tee unittest.txt - name: Show coverage report run: | coverage report -m - name: Show failed tests and overall summary run: | grep -E "FAILED tests|ERROR tests|[0-9]+ passed," unittest.txt - failed_count=$(grep -E "FAILED|ERROR" unittest.txt | wc -l) - if [[ "$failed_count" -gt 0 ]]; then + failed_count=$(grep -E "FAILED tests|ERROR tests" unittest.txt | wc -l | tr -d '[:space:]') + if [[ $failed_count -gt 0 ]]; then echo "$failed_count failed lines found! Task failed." exit 1 fi diff --git a/README.md b/README.md index 3410e08fc..6881fec25 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ ## Get Started ### Installation -> Ensure that Python 3.9+ is installed on your system. You can check this by using: `python --version`. +> Ensure that Python 3.9 or later, but less than 3.12, is installed on your system. You can check this by using: `python --version`. > You can use conda like this: `conda create -n metagpt python=3.9 && conda activate metagpt` ```bash diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 0fe11df4e..b82468eed 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -60,6 +60,10 @@ iflytek_api_secret: "YOUR_API_SECRET" metagpt_tti_url: "YOUR_MODEL_URL" +omniparse: + api_key: "YOUR_API_KEY" + base_url: "YOUR_BASE_URL" + models: # "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo # api_type: "openai" # or azure / ollama / groq etc. @@ -76,4 +80,6 @@ models: # proxy: "YOUR_PROXY" # for LLM API requests # # timeout: 600 # Optional. If set to 0, default value is 300. # # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ -# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's \ No newline at end of file +# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's + +agentops_api_key: "YOUR_AGENTOPS_API_KEY" # get key from https://app.agentops.ai/settings/projects diff --git a/examples/data/omniparse/test01.docx b/examples/data/omniparse/test01.docx new file mode 100644 index 000000000..7b6251799 Binary files /dev/null and b/examples/data/omniparse/test01.docx differ diff --git a/examples/data/omniparse/test02.pdf b/examples/data/omniparse/test02.pdf new file mode 100644 index 000000000..8cd15877f Binary files /dev/null and b/examples/data/omniparse/test02.pdf differ diff --git a/examples/data/omniparse/test03.mp4 b/examples/data/omniparse/test03.mp4 new file mode 100644 index 000000000..54746f45d Binary files /dev/null and b/examples/data/omniparse/test03.mp4 differ diff --git a/examples/data/omniparse/test04.mp3 b/examples/data/omniparse/test04.mp3 new file mode 100644 index 000000000..2c8e149d8 Binary files /dev/null and b/examples/data/omniparse/test04.mp3 differ diff --git a/examples/rag/omniparse.py b/examples/rag/omniparse.py new file mode 100644 index 000000000..b9159dae5 --- /dev/null +++ b/examples/rag/omniparse.py @@ -0,0 +1,64 @@ +import asyncio + +from metagpt.config2 import config +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.logs import logger +from metagpt.rag.parsers import OmniParse +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType +from metagpt.utils.omniparse_client import OmniParseClient + +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" + + +async def omniparse_client_example(): + client = OmniParseClient(base_url=config.omniparse.base_url) + + # docx + with open(TEST_DOCX, "rb") as f: + file_input = f.read() + document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx") + logger.info(document_parse_ret) + + # pdf + pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF) + logger.info(pdf_parse_ret) + + # video + video_parse_ret = await client.parse_video(file_input=TEST_VIDEO) + logger.info(video_parse_ret) + + # audio + audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO) + logger.info(audio_parse_ret) + + +async def omniparse_example(): + parser = OmniParse( + api_key=config.omniparse.api_key, + base_url=config.omniparse.base_url, + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ParseResultType.MD, + max_timeout=120, + num_workers=3, + ), + ) + ret = parser.load_data(file_path=TEST_PDF) + logger.info(ret) + + file_paths = [TEST_DOCX, TEST_PDF] + parser.parse_type = OmniParseType.DOCUMENT + ret = await parser.aload_data(file_path=file_paths) + logger.info(ret) + + +async def main(): + await omniparse_client_example() + await omniparse_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/rag_bm.py b/examples/rag/rag_bm.py similarity index 100% rename from examples/rag_bm.py rename to examples/rag/rag_bm.py diff --git a/examples/rag_pipeline.py b/examples/rag/rag_pipeline.py similarity index 100% rename from examples/rag_pipeline.py rename to examples/rag/rag_pipeline.py diff --git a/examples/rag_search.py b/examples/rag/rag_search.py similarity index 88% rename from examples/rag_search.py rename to examples/rag/rag_search.py index 258c5ba60..3b0e047f8 100644 --- a/examples/rag_search.py +++ b/examples/rag/rag_search.py @@ -2,7 +2,7 @@ import asyncio -from examples.rag_pipeline import DOC_PATH, QUESTION +from examples.rag.rag_pipeline import DOC_PATH, QUESTION from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.roles import Sales diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 07638ce42..ad3f0a1d1 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -237,12 +237,19 @@ class ActionNode: """基于pydantic v2的模型动态生成,用来检验结果类型正确性""" def check_fields(cls, values): - required_fields = set(mapping.keys()) + all_fields = set(mapping.keys()) + required_fields = set() + for k, v in mapping.items(): + type_v, field_info = v + if ActionNode.is_optional_type(type_v): + continue + required_fields.add(k) + missing_fields = required_fields - set(values.keys()) if missing_fields: raise ValueError(f"Missing fields: {missing_fields}") - unrecognized_fields = set(values.keys()) - required_fields + unrecognized_fields = set(values.keys()) - all_fields if unrecognized_fields: logger.warning(f"Unrecognized fields: {unrecognized_fields}") return values @@ -717,3 +724,12 @@ class ActionNode: root_node.add_child(child_node) return root_node + + @staticmethod + def is_optional_type(tp) -> bool: + """Return True if `tp` is `typing.Optional[...]`""" + if typing.get_origin(tp) is Union: + args = typing.get_args(tp) + non_none_types = [arg for arg in args if arg is not type(None)] + return len(non_none_types) == 1 and len(args) == 2 + return False diff --git a/metagpt/actions/design_api_an.py b/metagpt/actions/design_api_an.py index 5977cbd95..ca7aea95a 100644 --- a/metagpt/actions/design_api_an.py +++ b/metagpt/actions/design_api_an.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : design_api_an.py """ -from typing import List +from typing import List, Optional from metagpt.actions.action_node import ActionNode from metagpt.utils.mermaid import MMC1, MMC2 @@ -45,9 +45,10 @@ REFINED_FILE_LIST = ActionNode( example=["main.py", "game.py", "new_feature.py"], ) +# optional,because low success reproduction of class diagram in non py project. DATA_STRUCTURES_AND_INTERFACES = ActionNode( key="Data structures and interfaces", - expected_type=str, + expected_type=Optional[str], instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type" " annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. " "The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.", @@ -66,7 +67,7 @@ REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode( PROGRAM_CALL_FLOW = ActionNode( key="Program call flow", - expected_type=str, + expected_type=Optional[str], instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE " "accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.", example=MMC2, diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index db27434a1..f53062433 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -5,14 +5,14 @@ @Author : alexanderwu @File : project_management_an.py """ -from typing import List +from typing import List, Optional from metagpt.actions.action_node import ActionNode REQUIRED_PACKAGES = ActionNode( key="Required packages", - expected_type=List[str], - instruction="Provide required packages in requirements.txt format.", + expected_type=Optional[List[str]], + instruction="Provide required third-party packages in requirements.txt format.", example=["flask==1.1.2", "bcrypt==3.2.0"], ) diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py index ed6c66cf6..20ed201a3 100644 --- a/metagpt/actions/write_code_an_draft.py +++ b/metagpt/actions/write_code_an_draft.py @@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} ## Tasks -{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} +{"Required packages": ["无需第三方包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} ## Code Files ----- index.html diff --git a/metagpt/config2.py b/metagpt/config2.py index 58a99c920..27b228b33 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.file_parser_config import OmniParseConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel): # RAG Embedding embedding: EmbeddingConfig = EmbeddingConfig() + # omniparse + omniparse: OmniParseConfig = OmniParseConfig() + # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" @@ -69,6 +73,7 @@ class Config(CLIParams, YamlModel): workspace: WorkspaceConfig = WorkspaceConfig() enable_longterm_memory: bool = False code_review_k_times: int = 2 + agentops_api_key: str = "" # Will be removed in the future metagpt_tti_url: str = "" diff --git a/metagpt/configs/file_parser_config.py b/metagpt/configs/file_parser_config.py new file mode 100644 index 000000000..39742c8a4 --- /dev/null +++ b/metagpt/configs/file_parser_config.py @@ -0,0 +1,6 @@ +from metagpt.utils.yaml_model import YamlModel + + +class OmniParseConfig(YamlModel): + api_key: str = "" + base_url: str = "" diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 67fb6afdb..7388063aa 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -33,7 +33,7 @@ class LLMType(Enum): YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" BEDROCK = "bedrock" - ARK = "ark" + ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk def __missing__(self, key): return self.OPENAI @@ -90,6 +90,9 @@ class LLMConfig(YamlModel): # Cost Control calc_usage: bool = True + # For Messages Control + use_system_prompt: bool = True + @field_validator("api_key") @classmethod def check_llm_key(cls, v): diff --git a/metagpt/ext/stanford_town/roles/st_role.py b/metagpt/ext/stanford_town/roles/st_role.py index 79f58b07d..e8cb3fb04 100644 --- a/metagpt/ext/stanford_town/roles/st_role.py +++ b/metagpt/ext/stanford_town/roles/st_role.py @@ -266,7 +266,7 @@ class STRole(Role): # We will order our percept based on the distance, with the closest ones # getting priorities. percept_events_list = [] - # First, we put all events that are occuring in the nearby tiles into the + # First, we put all events that are occurring in the nearby tiles into the # percept_events_list for tile in nearby_tiles: tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile)) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 580361d33..b11b780c3 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -81,7 +81,7 @@ class Memory(BaseModel): return self.storage[-k:] def find_news(self, observed: list[Message], k=0) -> list[Message]: - """find news (previously unseen messages) from the the most recent k memories, from all memories when k=0""" + """find news (previously unseen messages) from the most recent k memories, from all memories when k=0""" already_observed = self.get(k) news: list[Message] = [] for i in observed: diff --git a/metagpt/provider/ark_api.py b/metagpt/provider/ark_api.py index c24bd1ee9..0c5704b91 100644 --- a/metagpt/provider/ark_api.py +++ b/metagpt/provider/ark_api.py @@ -1,12 +1,33 @@ -from openai import AsyncStream -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion, ChatCompletionChunk +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Provider for volcengine. +See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model + +config2.yaml example: +```yaml +llm: + base_url: "https://ark.cn-beijing.volces.com/api/v3" + api_type: "ark" + endpoint: "ep-2024080514****-d****" + api_key: "d47****b-****-****-****-d6e****0fd77" + pricing_plan: "doubao-lite" +``` +""" +from typing import Optional, Union + +from pydantic import BaseModel +from volcenginesdkarkruntime import AsyncArk +from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper +from volcenginesdkarkruntime._streaming import AsyncStream +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk from metagpt.configs.llm_config import LLMType from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM +from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS @register_provider(LLMType.ARK) @@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM): 见:https://www.volcengine.com/docs/82379/1263482 """ + aclient: Optional[AsyncArk] = None + + def _init_client(self): + """SDK: https://github.com/openai/openai-python#async-usage""" + self.model = ( + self.config.endpoint or self.config.model + ) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint + self.pricing_plan = self.config.pricing_plan or self.model + kwargs = self._make_client_kwargs() + self.aclient = AsyncArk(**kwargs) + + def _make_client_kwargs(self) -> dict: + kvs = { + "ak": self.config.access_key, + "sk": self.config.secret_key, + "api_key": self.config.api_key, + "base_url": self.config.base_url, + } + kwargs = {k: v for k, v in kvs.items() if v} + + # to use proxy, openai v1 needs http_client + if proxy_params := self._get_proxy_params(): + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs + + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): + if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs: + self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS) + if model in self.cost_manager.token_costs: + self.pricing_plan = model + if self.pricing_plan in self.cost_manager.token_costs: + super()._update_costs(usage, self.pricing_plan, local_calc_usage) + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True, - extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage + extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage ) usage = None collected_messages = [] @@ -30,7 +85,7 @@ class ArkLLM(OpenAILLM): collected_messages.append(chunk_message) if chunk.usage: # 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[] - usage = CompletionUsage(**chunk.usage) + usage = chunk.usage log_llm_stream("\n") full_reply_content = "".join(collected_messages) diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index 82224e893..837377edc 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -48,13 +48,17 @@ def build_api_arequest( request_timeout, form, resources, + base_address, + _, ) = _get_protocol_params(kwargs) task_id = kwargs.pop("task_id", None) if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]: - if not dashscope.base_http_api_url.endswith("/"): - http_url = dashscope.base_http_api_url + "/" + if base_address is None: + base_address = dashscope.base_http_api_url + if not base_address.endswith("/"): + http_url = base_address + "/" else: - http_url = dashscope.base_http_api_url + http_url = base_address if is_service: http_url = http_url + SERVICE_API_PATH + "/" diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 4fd2b1978..7f8618590 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -37,7 +37,11 @@ def register_provider(keys): def create_llm_instance(config: LLMConfig) -> BaseLLM: """get the default llm provider""" - return LLM_REGISTRY.get_provider(config.api_type)(config) + llm = LLM_REGISTRY.get_provider(config.api_type)(config) + if llm.use_system_prompt and not config.use_system_prompt: + # for models like o1-series, default openai provider.use_system_prompt is True, but it should be False for o1-* + llm.use_system_prompt = config.use_system_prompt + return llm # Registry instance diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 31907d9e8..ce3a06ec8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -37,7 +37,6 @@ from metagpt.utils.token_counter import ( count_input_tokens, count_output_tokens, get_max_completion_tokens, - get_openrouter_tokens, ) @@ -92,6 +91,7 @@ class OpenAILLM(BaseLLM): ) usage = None collected_messages = [] + has_finished = False async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message finish_reason = ( @@ -99,8 +99,13 @@ class OpenAILLM(BaseLLM): ) log_llm_stream(chunk_message) collected_messages.append(chunk_message) + chunk_has_usage = hasattr(chunk, "usage") and chunk.usage + if has_finished: + # for oneapi, there has a usage chunk after finish_reason not none chunk + if chunk_has_usage: + usage = CompletionUsage(**chunk.usage) if finish_reason: - if hasattr(chunk, "usage") and chunk.usage is not None: + if chunk_has_usage: # Some services have usage as an attribute of the chunk, such as Fireworks if isinstance(chunk.usage, CompletionUsage): usage = chunk.usage @@ -109,9 +114,10 @@ class OpenAILLM(BaseLLM): elif hasattr(chunk.choices[0], "usage"): # The usage of some services is an attribute of chunk.choices[0], such as Moonshot usage = CompletionUsage(**chunk.choices[0].usage) - elif "openrouter.ai" in self.config.base_url: + elif "openrouter.ai" in self.config.base_url and chunk_has_usage: # due to it get token cost from api - usage = await get_openrouter_tokens(chunk) + usage = chunk.usage + has_finished = True log_llm_stream("\n") full_reply_content = "".join(collected_messages) @@ -132,6 +138,10 @@ class OpenAILLM(BaseLLM): "model": self.model, "timeout": self.get_timeout(timeout), } + if "o1-" in self.model: + # compatible to openai o1-series + kwargs["temperature"] = 1 + kwargs.pop("max_tokens") if extra_kwargs: kwargs.update(extra_kwargs) return kwargs diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 04334f305..3ada7908d 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -106,13 +106,13 @@ class QianFanLLM(BaseLLM): def get_choice_text(self, resp: JsonBody) -> str: return resp.get("result", "") - def completion(self, messages: list[dict]) -> JsonBody: - resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False)) + def completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody: + resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout) self._update_costs(resp.body.get("usage", {})) return resp.body async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody: - resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout) self._update_costs(resp.body.get("usage", {})) return resp.body @@ -120,7 +120,7 @@ class QianFanLLM(BaseLLM): return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: - resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout) collected_content = [] usage = {} async for chunk in resp: diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c237dcf69..a03e0149c 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -14,6 +14,7 @@ from llama_index.core.llms import LLM from llama_index.core.node_parser import SentenceSplitter from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.readers.base import BaseReader from llama_index.core.response_synthesizers import ( BaseSynthesizer, get_response_synthesizer, @@ -28,6 +29,7 @@ from llama_index.core.schema import ( TransformComponent, ) +from metagpt.config2 import config from metagpt.rag.factories import ( get_index, get_rag_embedding, @@ -36,6 +38,7 @@ from metagpt.rag.factories import ( get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( @@ -44,6 +47,9 @@ from metagpt.rag.schema import ( BaseRetrieverConfig, BM25RetrieverConfig, ObjectNode, + OmniParseOptions, + OmniParseType, + ParseResultType, ) from metagpt.utils.common import import_class @@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine): if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + file_extractor = cls._get_file_extractor() + documents = SimpleDirectoryReader( + input_dir=input_dir, input_files=input_files, file_extractor=file_extractor + ).load_data() cls._fix_document_metadata(documents) transformations = transformations or cls._default_transformations() @@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine): @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @staticmethod + def _get_file_extractor() -> dict[str:BaseReader]: + """ + Get the file extractor. + Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index. + + Returns: + dict[file_type: BaseReader] + """ + file_extractor: dict[str:BaseReader] = {} + if config.omniparse.base_url: + pdf_parser = OmniParse( + api_key=config.omniparse.api_key, + base_url=config.omniparse.base_url, + parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD), + ) + file_extractor[".pdf"] = pdf_parser + + return file_extractor diff --git a/metagpt/rag/parsers/__init__.py b/metagpt/rag/parsers/__init__.py new file mode 100644 index 000000000..03ac0de3a --- /dev/null +++ b/metagpt/rag/parsers/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.parsers.omniparse import OmniParse + +__all__ = ["OmniParse"] diff --git a/metagpt/rag/parsers/omniparse.py b/metagpt/rag/parsers/omniparse.py new file mode 100644 index 000000000..ec08e38f1 --- /dev/null +++ b/metagpt/rag/parsers/omniparse.py @@ -0,0 +1,139 @@ +import asyncio +from fileinput import FileInput +from pathlib import Path +from typing import List, Optional, Union + +from llama_index.core import Document +from llama_index.core.async_utils import run_jobs +from llama_index.core.readers.base import BaseReader + +from metagpt.logs import logger +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType +from metagpt.utils.async_helper import NestAsyncio +from metagpt.utils.omniparse_client import OmniParseClient + + +class OmniParse(BaseReader): + """OmniParse""" + + def __init__( + self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None + ): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: OmniParse Base URL for the API. + parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values. + """ + self.parse_options = parse_options or OmniParseOptions() + self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout) + + @property + def parse_type(self): + return self.parse_options.parse_type + + @property + def result_type(self): + return self.parse_options.result_type + + @parse_type.setter + def parse_type(self, parse_type: Union[str, OmniParseType]): + if isinstance(parse_type, str): + parse_type = OmniParseType(parse_type) + self.parse_options.parse_type = parse_type + + @result_type.setter + def result_type(self, result_type: Union[str, ParseResultType]): + if isinstance(result_type, str): + result_type = ParseResultType(result_type) + self.parse_options.result_type = result_type + + async def _aload_data( + self, + file_path: Union[str, bytes, Path], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Returns: + List[Document] + """ + try: + if self.parse_type == OmniParseType.PDF: + # pdf parse + parsed_result = await self.omniparse_client.parse_pdf(file_path) + else: + # other parse use omniparse_client.parse_document + # For compatible byte data, additional filename is required + extra_info = extra_info or {} + filename = extra_info.get("filename") + parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename) + + # Get the specified structured data based on result_type + content = getattr(parsed_result, self.result_type) + docs = [ + Document( + text=content, + metadata=extra_info or {}, + ) + ] + except Exception as e: + logger.error(f"OMNI Parse Error: {e}") + docs = [] + + return docs + + async def aload_data( + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls _aload_data for processing. + + Returns: + List[Document] + """ + docs = [] + if isinstance(file_path, (str, bytes, Path)): + # Processing single file + docs = await self._aload_data(file_path, extra_info) + elif isinstance(file_path, list): + # Concurrently process multiple files + parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path] + doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers) + docs = [doc for docs in doc_ret_list for doc in docs] + return docs + + def load_data( + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls aload_data for processing. + + Returns: + List[Document] + """ + NestAsyncio.apply_once() # Ensure compatibility with nested async calls + return asyncio.run(self.aload_data(file_path, extra_info)) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 618880a22..a8a10f90e 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,7 +1,7 @@ """RAG schemas.""" - +from enum import Enum from pathlib import Path -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, List, Literal, Optional, Union from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding @@ -214,3 +214,51 @@ class ObjectNode(TextNode): ) return metadata.model_dump() + + +class OmniParseType(str, Enum): + """OmniParseType""" + + PDF = "PDF" + DOCUMENT = "DOCUMENT" + + +class ParseResultType(str, Enum): + """The result type for the parser.""" + + TXT = "text" + MD = "markdown" + JSON = "json" + + +class OmniParseOptions(BaseModel): + """OmniParse Options config""" + + result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") + parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") + max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests") + num_workers: int = Field( + default=5, + gt=0, + lt=10, + description="Number of concurrent requests for multiple files", + ) + + +class OminParseImage(BaseModel): + image: str = Field(default="", description="image str bytes") + image_name: str = Field(default="", description="image name") + image_info: Optional[dict] = Field(default={}, description="image info") + + +class OmniParsedResult(BaseModel): + markdown: str = Field(default="", description="markdown text") + text: str = Field(default="", description="plain text") + images: Optional[List[OminParseImage]] = Field(default=[], description="images") + metadata: Optional[dict] = Field(default={}, description="metadata") + + @model_validator(mode="before") + def set_markdown(cls, values): + if not values.get("markdown"): + values["markdown"] = values.get("text") + return values diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 166f8cfd0..69cce5e06 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -6,6 +6,7 @@ @File : architect.py """ + from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign from metagpt.roles.role import Role diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 9db9f7d9e..9a0511e87 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,6 +7,7 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.roles.role import Role, RoleReactMode diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 422d2889b..db8ad4558 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -6,6 +6,7 @@ @File : project_manager.py """ + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign from metagpt.roles.role import Role diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index c73c10ef3..9b3c0afc7 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -15,6 +15,7 @@ of SummarizeCode. """ + from metagpt.actions import DebugError, RunCode, WriteTest from metagpt.actions.summarize_code import SummarizeCode from metagpt.const import MESSAGE_ROUTE_TO_NONE diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 071f060ea..6e2f61f32 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -170,7 +170,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): self._check_actions() self.llm.system_prompt = self._get_prefix() self.llm.cost_manager = self.context.cost_manager - self._watch(kwargs.pop("watch", [UserRequirement])) + if not self.rc.watch: + self._watch(kwargs.pop("watch", [UserRequirement])) if self.latest_observed_msg: self.recovered = True @@ -421,8 +422,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = [] - if self.recovered: - news = [self.latest_observed_msg] if self.latest_observed_msg else [] + if self.recovered and self.latest_observed_msg: + news = self.rc.memory.find_news(observed=[self.latest_observed_msg], k=10) if not news: news = self.rc.msg_buffer.pop_all() # Store the read messages in your own memory to prevent duplicate processing. diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 103ac0551..bb35aa016 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -4,6 +4,7 @@ import asyncio from pathlib import Path +import agentops import typer from metagpt.const import CONFIG_ROOT @@ -38,6 +39,9 @@ def generate_repo( ) from metagpt.team import Team + if config.agentops_api_key != "": + agentops.init(config.agentops_api_key, tags=["software_company"]) + config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) ctx = Context(config=config) @@ -68,6 +72,9 @@ def generate_repo( company.run_project(idea) asyncio.run(company.run(n_round=n_round)) + if config.agentops_api_key != "": + agentops.end_session("Success") + return ctx.repo diff --git a/metagpt/team.py b/metagpt/team.py index cf8346259..2288f9748 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -126,6 +126,9 @@ class Team(BaseModel): self.run_project(idea=idea, send_to=send_to) while n_round > 0: + if self.env.is_idle: + logger.debug("All roles are idle.") + break n_round -= 1 self._check_balance() await self.env.run() diff --git a/metagpt/utils/omniparse_client.py b/metagpt/utils/omniparse_client.py new file mode 100644 index 000000000..e7c5a3d44 --- /dev/null +++ b/metagpt/utils/omniparse_client.py @@ -0,0 +1,239 @@ +import mimetypes +import os +from pathlib import Path +from typing import Union + +import httpx + +from metagpt.rag.schema import OmniParsedResult +from metagpt.utils.common import aread_bin + + +class OmniParseClient: + """ + OmniParse Server Client + This client interacts with the OmniParse server to parse different types of media, documents. + + OmniParse API Documentation: https://docs.cognitivelab.in/api + + Attributes: + ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions. + ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions. + ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions. + """ + + ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"} + ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"} + ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"} + + def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: Base URL for the API. + max_timeout: Maximum request timeout in seconds. + """ + self.api_key = api_key + self.base_url = base_url + self.max_timeout = max_timeout + + self.parse_media_endpoint = "/parse_media" + self.parse_website_endpoint = "/parse_website" + self.parse_document_endpoint = "/parse_document" + + async def _request_parse( + self, + endpoint: str, + method: str = "POST", + files: dict = None, + params: dict = None, + data: dict = None, + json: dict = None, + headers: dict = None, + **kwargs, + ) -> dict: + """ + Request OmniParse API to parse a document. + + Args: + endpoint (str): API endpoint. + method (str, optional): HTTP method to use. Default is "POST". + files (dict, optional): Files to include in the request. + params (dict, optional): Query string parameters. + data (dict, optional): Form data to include in the request body. + json (dict, optional): JSON data to include in the request body. + headers (dict, optional): HTTP headers to include in the request. + **kwargs: Additional keyword arguments for httpx.AsyncClient.request() + + Returns: + dict: JSON response data. + """ + url = f"{self.base_url}{endpoint}" + method = method.upper() + headers = headers or {} + _headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + headers.update(**_headers) + async with httpx.AsyncClient() as client: + response = await client.request( + url=url, + method=method, + files=files, + params=params, + json=json, + data=data, + headers=headers, + timeout=self.max_timeout, + **kwargs, + ) + response.raise_for_status() + return response.json() + + async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: + """ + Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the document parsing. + """ + self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult: + """ + Parse pdf document. + + Args: + file_input: File path or file byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the pdf parsing. + """ + self.verify_file_ext(file_input, {".pdf"}) + # parse_pdf supports parsing by accepting only the byte data of the file. + file_info = await self.get_file_info(file_input, only_bytes=True) + endpoint = f"{self.parse_document_endpoint}/pdf" + resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) + + async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse audio-type data (supports ".mp3", ".wav", ".aac"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) + + @staticmethod + def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): + """ + Verify the file extension. + + Args: + file_input: File path or file byte data. + allowed_file_extensions: Set of allowed file extensions. + bytes_filename: Filename to use for verification when `file_input` is byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + """ + verify_file_path = None + if isinstance(file_input, (str, Path)): + verify_file_path = str(file_input) + elif isinstance(file_input, bytes) and bytes_filename: + verify_file_path = bytes_filename + + if not verify_file_path: + # Do not verify if only byte data is provided + return + + file_ext = os.path.splitext(verify_file_path)[1].lower() + if file_ext not in allowed_file_extensions: + raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}") + + @staticmethod + async def get_file_info( + file_input: Union[str, bytes, Path], + bytes_filename: str = None, + only_bytes: bool = False, + ) -> Union[bytes, tuple]: + """ + Get file information. + + Args: + file_input: File path or file byte data. + bytes_filename: Filename to use when uploading byte data, useful for determining MIME type. + only_bytes: Whether to return only byte data. Default is False, which returns a tuple. + + Raises: + ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type. + + Notes: + Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types, + the MIME type of the file must be specified when uploading. + + Returns: [bytes, tuple] + Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type). + """ + if isinstance(file_input, (str, Path)): + filename = os.path.basename(str(file_input)) + file_bytes = await aread_bin(file_input) + + if only_bytes: + return file_bytes + + mime_type = mimetypes.guess_type(file_input)[0] + return filename, file_bytes, mime_type + elif isinstance(file_input, bytes): + if only_bytes: + return file_input + if not bytes_filename: + raise ValueError("bytes_filename must be set when passing bytes") + + mime_type = mimetypes.guess_type(bytes_filename)[0] + return bytes_filename, file_input, mime_type + else: + raise ValueError("file_input must be a string (file path) or bytes.") diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index fda19cdba..c922f2cb4 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -40,11 +40,20 @@ TOKEN_COSTS = { "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4o": {"prompt": 0.005, "completion": 0.015}, + "gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006}, + "gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006}, "gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015}, + "gpt-4o-2024-08-06": {"prompt": 0.0025, "completion": 0.01}, + "o1-preview": {"prompt": 0.015, "completion": 0.06}, + "o1-preview-2024-09-12": {"prompt": 0.015, "completion": 0.06}, + "o1-mini": {"prompt": 0.003, "completion": 0.012}, + "o1-mini-2024-09-12": {"prompt": 0.003, "completion": 0.012}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens - "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, + "gemini-1.5-flash": {"prompt": 0.000075, "completion": 0.0003}, + "gemini-1.5-pro": {"prompt": 0.0035, "completion": 0.0105}, + "gemini-1.0-pro": {"prompt": 0.0005, "completion": 0.0015}, "moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens "moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024}, "moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06}, @@ -68,15 +77,20 @@ TOKEN_COSTS = { "llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079}, "openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015}, "openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, + "openai/o1-preview": {"prompt": 0.015, "completion": 0.06}, + "openai/o1-mini": {"prompt": 0.003, "completion": 0.012}, + "anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075}, + "anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015}, + "google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end "deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, "deepseek-coder": {"prompt": 0.00014, "completion": 0.00028}, # For ark model https://www.volcengine.com/docs/82379/1099320 - "doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084}, - "doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084}, - "doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013}, - "doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028}, - "doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028}, - "doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012}, + "doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086}, + "doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086}, + "doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00014}, + "doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00029}, + "doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00029}, + "doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0013}, "llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0}, "llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0}, } @@ -137,8 +151,17 @@ QIANFAN_ENDPOINT_TOKEN_COSTS = { """ DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing Different model has different detail page. Attention, some model are free for a limited time. +Some new model published by Alibaba will be prioritized to be released on the Model Studio instead of the Dashscope. +Token price on Model Studio shows on https://help.aliyun.com/zh/model-studio/getting-started/models#ced16cb6cdfsy """ DASHSCOPE_TOKEN_COSTS = { + "qwen2.5-72b-instruct": {"prompt": 0.00057, "completion": 0.0017}, # per 1k tokens + "qwen2.5-32b-instruct": {"prompt": 0.0005, "completion": 0.001}, + "qwen2.5-14b-instruct": {"prompt": 0.00029, "completion": 0.00086}, + "qwen2.5-7b-instruct": {"prompt": 0.00014, "completion": 0.00029}, + "qwen2.5-3b-instruct": {"prompt": 0.0, "completion": 0.0}, + "qwen2.5-1.5b-instruct": {"prompt": 0.0, "completion": 0.0}, + "qwen2.5-0.5b-instruct": {"prompt": 0.0, "completion": 0.0}, "qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428}, "qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001}, "qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286}, @@ -187,10 +210,26 @@ FIREWORKS_GRADE_TOKEN_COSTS = { "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, } +# https://console.volcengine.com/ark/region:ark+cn-beijing/model +DOUBAO_TOKEN_COSTS = { + "doubao-lite": {"prompt": 0.000043, "completion": 0.000086}, + "doubao-lite-128k": {"prompt": 0.00011, "completion": 0.00014}, + "doubao-pro": {"prompt": 0.00011, "completion": 0.00029}, + "doubao-pro-128k": {"prompt": 0.00071, "completion": 0.0013}, + "doubao-pro-256k": {"prompt": 0.00071, "completion": 0.0013}, +} + # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo TOKEN_MAX = { - "gpt-4o-2024-05-13": 128000, + "o1-preview": 128000, + "o1-preview-2024-09-12": 128000, + "o1-mini": 128000, + "o1-mini-2024-09-12": 128000, "gpt-4o": 128000, + "gpt-4o-2024-05-13": 128000, + "gpt-4o-2024-08-06": 128000, + "gpt-4o-mini-2024-07-18": 128000, + "gpt-4o-mini": 128000, "gpt-4-turbo-2024-04-09": 128000, "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, @@ -212,7 +251,9 @@ TOKEN_MAX = { "text-embedding-ada-002": 8192, "glm-3-turbo": 128000, "glm-4": 128000, - "gemini-pro": 32768, + "gemini-1.5-flash": 1000000, + "gemini-1.5-pro": 2000000, + "gemini-1.0-pro": 32000, "moonshot-v1-8k": 8192, "moonshot-v1-32k": 32768, "moonshot-v1-128k": 128000, @@ -236,6 +277,11 @@ TOKEN_MAX = { "llama3-70b-8192": 8192, "openai/gpt-3.5-turbo-0125": 16385, "openai/gpt-4-turbo-preview": 128000, + "openai/o1-preview": 128000, + "openai/o1-mini": 128000, + "anthropic/claude-3-opus": 200000, + "anthropic/claude-3.5-sonnet": 200000, + "google/gemini-pro-1.5": 4000000, "deepseek-chat": 32768, "deepseek-coder": 16385, "doubao-lite-4k-240515": 4000, @@ -245,6 +291,13 @@ TOKEN_MAX = { "doubao-pro-32k-240515": 32000, "doubao-pro-128k-240515": 128000, # Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20 + "qwen2.5-72b-instruct": 131072, + "qwen2.5-32b-instruct": 131072, + "qwen2.5-14b-instruct": 131072, + "qwen2.5-7b-instruct": 131072, + "qwen2.5-3b-instruct": 32768, + "qwen2.5-1.5b-instruct": 32768, + "qwen2.5-0.5b-instruct": 32768, "qwen2-57b-a14b-instruct": 32768, "qwen2-72b-instruct": 131072, "qwen2-7b-instruct": 32768, @@ -344,11 +397,19 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-0125-preview", + "gpt-4-1106-preview", "gpt-4-turbo", "gpt-4-vision-preview", "gpt-4-1106-vision-preview", - "gpt-4o-2024-05-13", "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "o1-preview", + "o1-preview-2024-09-12", + "o1-mini", + "o1-mini-2024-09-12", }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 diff --git a/requirements.txt b/requirements.txt index 4d8d7f32e..8bf0ee399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -68,9 +68,14 @@ anytree ipywidgets==8.1.1 Pillow imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py -qianfan~=0.3.16 +qianfan~=0.4.4 dashscope~=1.19.3 rank-bm25==0.2.2 # for tool recommendation +jieba==0.42.1 # for tool recommendation +volcengine-python-sdk[ark]~=1.0.94 +# llama-index-vector-stores-elasticsearch~=0.2.5 # Used by `metagpt/memory/longterm_memory.py` +# llama-index-vector-stores-chroma~=0.1.10 # Used by `metagpt/memory/longterm_memory.py` gymnasium==0.29.1 boto3~=1.34.69 spark_ai_python~=0.3.30 +agentops diff --git a/setup.py b/setup.py index 8ba4c8a72..9fecfa766 100644 --- a/setup.py +++ b/setup.py @@ -109,7 +109,7 @@ setup( license="MIT", keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming", packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]), - python_requires=">=3.9", + python_requires=">=3.9, <3.12", install_requires=requirements, extras_require=extras_require, cmdclass={ diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 989e2249c..58a6dd517 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -6,7 +6,7 @@ @File : test_action_node.py """ from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest from pydantic import BaseModel, Field, ValidationError @@ -302,6 +302,19 @@ def test_action_node_from_pydantic_and_print_everything(): assert "tasks" in code, "tasks should be in code" +def test_optional(): + mapping = { + "Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)), + "Task list": (Optional[List[str]], None), + "Plan": (Optional[str], ""), + "Anything UNCLEAR": (Optional[str], None), + } + m = {"Anything UNCLEAR": "a"} + t = ActionNode.create_model_class("test_class_1", mapping) + + t1 = t(**m) + assert t1 + + if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 4760a2db2..b9c9e0f93 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -64,7 +64,7 @@ def is_subset(subset, superset) -> bool: superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}} is_subset(subset, superset) ``` - >>>False + """ for key, value in subset.items(): if key not in superset: diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 8c7a15be2..a10fcbe63 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM from llama_index.core.schema import Document, NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers import SimpleHybridRetriever from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode @@ -37,6 +38,10 @@ class TestSimpleEngine: def mock_get_response_synthesizer(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer") + @pytest.fixture + def mock_get_file_extractor(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor") + def test_from_docs( self, mocker, @@ -44,6 +49,7 @@ class TestSimpleEngine: mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, + mock_get_file_extractor, ): # Mock mock_simple_directory_reader.return_value.load_data.return_value = [ @@ -53,6 +59,8 @@ class TestSimpleEngine: mock_get_retriever.return_value = mocker.MagicMock() mock_get_rankers.return_value = [mocker.MagicMock()] mock_get_response_synthesizer.return_value = mocker.MagicMock() + file_extractor = mocker.MagicMock() + mock_get_file_extractor.return_value = file_extractor # Setup input_dir = "test_dir" @@ -75,7 +83,9 @@ class TestSimpleEngine: ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_simple_directory_reader.assert_called_once_with( + input_dir=input_dir, input_files=input_files, file_extractor=file_extractor + ) mock_get_retriever.assert_called_once() mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) @@ -298,3 +308,17 @@ class TestSimpleEngine: # Assert assert "obj" in node.node.metadata assert node.node.metadata["obj"] == expected_obj + + def test_get_file_extractor(self, mocker): + # mock no omniparse config + mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True) + mock_omniparse_config.base_url = "" + + file_extractor = SimpleEngine._get_file_extractor() + assert file_extractor == {} + + # mock have omniparse config + mock_omniparse_config.base_url = "http://localhost:8000" + file_extractor = SimpleEngine._get_file_extractor() + assert ".pdf" in file_extractor + assert isinstance(file_extractor[".pdf"], OmniParse) diff --git a/tests/metagpt/rag/parser/test_omniparse.py b/tests/metagpt/rag/parser/test_omniparse.py new file mode 100644 index 000000000..d2b533d06 --- /dev/null +++ b/tests/metagpt/rag/parser/test_omniparse.py @@ -0,0 +1,118 @@ +import pytest +from llama_index.core import Document + +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.rag.parsers import OmniParse +from metagpt.rag.schema import ( + OmniParsedResult, + OmniParseOptions, + OmniParseType, + ParseResultType, +) +from metagpt.utils.omniparse_client import OmniParseClient + +# test data +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" + + +class TestOmniParseClient: + parse_client = OmniParseClient() + + @pytest.fixture + def mock_request_parse(self, mocker): + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_parse_pdf(self, mock_request_parse): + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + parse_ret = await self.parse_client.parse_pdf(TEST_PDF) + assert parse_ret == mock_parsed_ret + + @pytest.mark.asyncio + async def test_parse_document(self, mock_request_parse): + mock_content = "#test title\ntest_parse_document" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + with open(TEST_DOCX, "rb") as f: + file_bytes = f.read() + + with pytest.raises(ValueError): + # bytes data must provide bytes_filename + await self.parse_client.parse_document(file_bytes) + + parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx") + assert parse_ret == mock_parsed_ret + + @pytest.mark.asyncio + async def test_parse_video(self, mock_request_parse): + mock_content = "#test title\ntest_parse_video" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + with pytest.raises(ValueError): + # Wrong file extension test + await self.parse_client.parse_video(TEST_DOCX) + + parse_ret = await self.parse_client.parse_video(TEST_VIDEO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + + @pytest.mark.asyncio + async def test_parse_audio(self, mock_request_parse): + mock_content = "#test title\ntest_parse_audio" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + parse_ret = await self.parse_client.parse_audio(TEST_AUDIO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + + +class TestOmniParse: + @pytest.fixture + def mock_omniparse(self): + parser = OmniParse( + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ParseResultType.MD, + max_timeout=120, + num_workers=3, + ) + ) + return parser + + @pytest.fixture + def mock_request_parse(self, mocker): + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_load_data(self, mock_omniparse, mock_request_parse): + # mock + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + # single file + documents = mock_omniparse.load_data(file_path=TEST_PDF) + doc = documents[0] + assert isinstance(doc, Document) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown + + # multi files + file_paths = [TEST_DOCX, TEST_PDF] + mock_omniparse.parse_type = OmniParseType.DOCUMENT + documents = await mock_omniparse.aload_data(file_path=file_paths) + doc = documents[0] + + # assert + assert isinstance(doc, Document) + assert len(documents) == len(file_paths) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py index 8b11e2d4a..47d1fc6de 100644 --- a/tests/metagpt/roles/test_role.py +++ b/tests/metagpt/roles/test_role.py @@ -5,6 +5,7 @@ import pytest from metagpt.provider.human_provider import HumanProvider from metagpt.roles.role import Role +from metagpt.schema import Message, UserMessage def test_role_desc(): @@ -18,5 +19,15 @@ def test_role_human(context): assert isinstance(role.llm, HumanProvider) +@pytest.mark.asyncio +async def test_recovered(): + role = Role(profile="Tester", desc="Tester", recovered=True) + role.put_message(UserMessage(content="2")) + role.latest_observed_msg = Message(content="1") + await role._observe() + await role._observe() + assert role.rc.msg_buffer.empty() + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 4e6ea93b5..3138346d6 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : - +import pytest from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement @@ -55,6 +55,7 @@ def test_environment_serdeser(context): assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK) assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise + assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch def test_environment_serdeser_v2(context): @@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context): assert isinstance(role, ProjectManager) assert isinstance(role.actions[0], WriteTasks) assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks) + assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch def test_environment_serdeser_save(context): @@ -85,3 +87,8 @@ def test_environment_serdeser_save(context): new_env: Environment = Environment(**env_dict, context=context) assert len(new_env.roles) == 1 assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK + assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index aaf7c1935..807849751 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -28,9 +28,9 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import ( def test_roles(context): role_a = RoleA() - assert len(role_a.rc.watch) == 1 + assert len(role_a.rc.watch) == 2 role_b = RoleB() - assert len(role_a.rc.watch) == 1 + assert len(role_a.rc.watch) == 2 assert len(role_b.rc.watch) == 1 role_d = RoleD(actions=[ActionOK()]) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 62ab26d72..84058925e 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -8,9 +8,9 @@ from typing import Optional from pydantic import BaseModel, Field -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.actions.action_node import ActionNode -from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.fix_bug import FixBug from metagpt.roles.role import Role, RoleReactMode serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") @@ -68,7 +68,7 @@ class RoleA(Role): def __init__(self, **kwargs): super(RoleA, self).__init__(**kwargs) self.set_actions([ActionPass]) - self._watch([UserRequirement]) + self._watch([FixBug, UserRequirement]) class RoleB(Role): @@ -93,7 +93,7 @@ class RoleC(Role): def __init__(self, **kwargs): super(RoleC, self).__init__(**kwargs) self.set_actions([ActionOK, ActionRaise]) - self._watch([UserRequirement]) + self._watch([FixBug, UserRequirement]) self.rc.react_mode = RoleReactMode.BY_ORDER self.rc.memory.ignore_id = True diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index 32a017a97..4ced53ce8 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -29,3 +29,7 @@ def div(a: int, b: int = 0): assert new_action.name == "WriteCodeReview" await new_action.run() + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_config.py b/tests/metagpt/test_config.py index 7ce5765cf..797daf5dc 100644 --- a/tests/metagpt/test_config.py +++ b/tests/metagpt/test_config.py @@ -14,8 +14,8 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config def test_config_1(): cfg = Config.default() llm = cfg.get_openai_llm() - assert llm is not None - assert llm.api_type == LLMType.OPENAI + if cfg.llm.api_type == LLMType.OPENAI: + assert llm is not None def test_config_from_dict(): diff --git a/tests/metagpt/test_context.py b/tests/metagpt/test_context.py index f8218c44d..a6daf95cd 100644 --- a/tests/metagpt/test_context.py +++ b/tests/metagpt/test_context.py @@ -53,8 +53,8 @@ def test_context_1(): def test_context_2(): ctx = Context() llm = ctx.config.get_openai_llm() - assert llm is not None - assert llm.api_type == LLMType.OPENAI + if ctx.config.llm.api_type == LLMType.OPENAI: + assert llm is not None kwargs = ctx.kwargs assert kwargs is not None diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index c4262e080..a6b0a43ef 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -114,7 +114,6 @@ class MockLLM(OriginalLLM): raise ValueError( "In current test setting, api call is not allowed, you should properly mock your tests, " "or add expected api response in tests/data/rsp_cache.json. " - f"The prompt you want for api call: {msg_key}" ) # Call the original unmocked method rsp = await ask_func(*args, **kwargs)