diff --git a/metagpt/actions/skill_action.py b/metagpt/actions/skill_action.py index 301cebaab..b68596809 100644 --- a/metagpt/actions/skill_action.py +++ b/metagpt/actions/skill_action.py @@ -29,9 +29,7 @@ class ArgumentsParingAction(Action): @property def prompt(self): - prompt = "You are a function parser. You can convert spoken words into function parameters.\n" - prompt += "\n---\n" - prompt += f"{self.skill.name} function parameters description:\n" + prompt = f"{self.skill.name} function parameters description:\n" for k, v in self.skill.arguments.items(): prompt += f"parameter `{k}`: {v}\n" prompt += "\n---\n" @@ -49,7 +47,10 @@ class ArgumentsParingAction(Action): async def run(self, with_message=None, **kwargs) -> Message: prompt = self.prompt - rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + rsp = await self.llm.aask( + msg=prompt, + system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."], + ) logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}") self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp) self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self) diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index 1678bc8dc..834f07006 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -8,7 +8,6 @@ from typing import Optional from metagpt.actions import Action -from metagpt.context import CONTEXT from metagpt.logs import logger @@ -24,7 +23,7 @@ class WriteTeachingPlanPart(Action): statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, []) statements = [] for p in statement_patterns: - s = self.format_value(p) + s = self.format_value(p, options=self.context.options) statements.append(s) formatter = ( TeachingPlanBlock.PROMPT_TITLE_TEMPLATE @@ -68,21 +67,20 @@ class WriteTeachingPlanPart(Action): return self.topic @staticmethod - def format_value(value): + def format_value(value, options): """Fill parameters inside `value` with `options`.""" if not isinstance(value, str): return value if "{" not in value: return value - # FIXME: 从Context中获取参数,而非从options - merged_opts = CONTEXT.options or {} + opts = {k: v for k, v in options.items() if v is not None} try: - return value.format(**merged_opts) + return value.format(**opts) except KeyError as e: logger.warning(f"Parameter is missing:{e}") - for k, v in merged_opts.items(): + for k, v in opts.items(): value = value.replace("{" + f"{k}" + "}", str(v)) return value diff --git a/metagpt/context.py b/metagpt/context.py index 0ce2f4b40..75dc31ef2 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -7,13 +7,12 @@ """ import os from pathlib import Path -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.configs.llm_config import LLMConfig -from metagpt.const import OPTIONS from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import create_llm_instance from metagpt.utils.cost_manager import CostManager @@ -41,6 +40,16 @@ class AttrDict(BaseModel): else: raise AttributeError(f"No such attribute: {key}") + def set(self, key, val: Any): + self.__dict__[key] = val + + def get(self, key, default: Any = None): + return self.__dict__.get(key, default) + + def remove(self, key): + if key in self.__dict__: + self.__delattr__(key) + class Context(BaseModel): """Env context for MetaGPT""" @@ -58,7 +67,10 @@ class Context(BaseModel): @property def options(self): """Return all key-values""" - return OPTIONS.get() + opts = self.config.model_dump() + for k, v in self.kwargs: + opts[k] = v # None value is allowed to override and disable the value from config. + return opts def new_environ(self): """Return a new os.environ object""" diff --git a/metagpt/learn/text_to_embedding.py b/metagpt/learn/text_to_embedding.py index 6a4342b06..f859ab638 100644 --- a/metagpt/learn/text_to_embedding.py +++ b/metagpt/learn/text_to_embedding.py @@ -6,16 +6,19 @@ @File : text_to_embedding.py @Desc : Text-to-Embedding skill, which provides text-to-embedding functionality. """ - +import metagpt.config2 +from metagpt.config2 import Config from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding -async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs): +async def text_to_embedding(text, model="text-embedding-ada-002", config: Config = metagpt.config2.config): """Text to embedding :param text: The text used for embedding. :param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`. - :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ - return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key) + openai_api_key = config.get_openai_llm().api_key + proxy = config.get_openai_llm().proxy + return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key, proxy=proxy) diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index 8b2cb4473..e2fac7647 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -8,6 +8,7 @@ """ import base64 +import metagpt.config2 from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.llm import LLM @@ -16,27 +17,26 @@ from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image from metagpt.utils.s3 import S3 -async def text_to_image(text, size_type: str = "512x512", model_url="", config: Config = None): +async def text_to_image(text, size_type: str = "512x512", config: Config = metagpt.config2.config): """Text to image :param text: The text used for image conversion. - :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768']. - :param model_url: MetaGPT model url :param config: Config :return: The image data is returned in Base64 encoding. """ image_declaration = "data:image/png;base64," + model_url = config.METAGPT_TEXT_TO_IMAGE_MODEL_URL if model_url: binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url) elif config.get_openai_llm(): - binary_data = await oas3_openai_text_to_image(text, size_type, LLM()) + llm = LLM(llm_config=config.get_openai_llm()) + binary_data = await oas3_openai_text_to_image(text, size_type, llm=llm) else: raise ValueError("Missing necessary parameters.") base64_data = base64.b64encode(binary_data).decode("utf-8") - assert config.s3, "S3 config is required." s3 = S3(config.s3) url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if url: diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 8ffafbd0e..37e56eaff 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,8 +6,8 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ - -from metagpt.config2 import config +import metagpt.config2 +from metagpt.config2 import Config from metagpt.const import BASE64_FORMAT from metagpt.tools.azure_tts import oas3_azsure_tts from metagpt.tools.iflytek_tts import oas3_iflytek_tts @@ -20,12 +20,7 @@ async def text_to_speech( voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl", - subscription_key="", - region="", - iflytek_app_id="", - iflytek_api_key="", - iflytek_api_secret="", - **kwargs, + config: Config = metagpt.config2.config, ): """Text to speech For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` @@ -44,6 +39,8 @@ async def text_to_speech( """ + subscription_key = config.AZURE_TTS_SUBSCRIPTION_KEY + region = config.AZURE_TTS_REGION if subscription_key and region: audio_declaration = "data:audio/wav;base64," base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region) @@ -52,6 +49,10 @@ async def text_to_speech( if url: return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data + + iflytek_app_id = config.IFLYTEK_APP_ID + iflytek_api_key = config.IFLYTEK_API_KEY + iflytek_api_secret = config.IFLYTEK_API_SECRET if iflytek_app_id and iflytek_api_key and iflytek_api_secret: audio_declaration = "data:audio/mp3;base64," base64_data = await oas3_iflytek_tts( diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 8939094ed..1c5315eee 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -65,7 +65,7 @@ class Assistant(Role): prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n" prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n' prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n" - rsp = await self.llm.aask(prompt, []) + rsp = await self.llm.aask(prompt, ["You are an action classifier"]) logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n") return await self._plan(rsp, last_talk=last_talk) @@ -98,9 +98,7 @@ class Assistant(Role): history = self.memory.history_text text = kwargs.get("last_talk") or text self.set_todo( - TalkAction( - context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs - ) + TalkAction(i_context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm) ) return True @@ -110,7 +108,7 @@ class Assistant(Role): if not skill: logger.info(f"skill not found: {text}") return await self.talk_handler(text=last_talk, **kwargs) - action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs) + action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk) await action.run(**kwargs) if action.args is None: return await self.talk_handler(text=last_talk, **kwargs) diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py index d47f4af5b..a40ba69fe 100644 --- a/metagpt/roles/teacher.py +++ b/metagpt/roles/teacher.py @@ -31,11 +31,11 @@ class Teacher(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.name = WriteTeachingPlanPart.format_value(self.name) - self.profile = WriteTeachingPlanPart.format_value(self.profile) - self.goal = WriteTeachingPlanPart.format_value(self.goal) - self.constraints = WriteTeachingPlanPart.format_value(self.constraints) - self.desc = WriteTeachingPlanPart.format_value(self.desc) + self.name = WriteTeachingPlanPart.format_value(self.name, self.context.options) + self.profile = WriteTeachingPlanPart.format_value(self.profile, self.context.options) + self.goal = WriteTeachingPlanPart.format_value(self.goal, self.context.options) + self.constraints = WriteTeachingPlanPart.format_value(self.constraints, self.context.options) + self.desc = WriteTeachingPlanPart.format_value(self.desc, self.context.options) async def _think(self) -> bool: """Everything will be done part by part.""" diff --git a/metagpt/tools/openai_text_to_embedding.py b/metagpt/tools/openai_text_to_embedding.py index 3eb9faac4..e93bfb271 100644 --- a/metagpt/tools/openai_text_to_embedding.py +++ b/metagpt/tools/openai_text_to_embedding.py @@ -13,7 +13,6 @@ import aiohttp import requests from pydantic import BaseModel, Field -from metagpt.config2 import config from metagpt.logs import logger @@ -43,12 +42,12 @@ class ResultEmbedding(BaseModel): class OpenAIText2Embedding: - def __init__(self, openai_api_key): + def __init__(self, api_key: str, proxy: str): """ :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ - self.openai_llm = config.get_openai_llm() - self.openai_api_key = openai_api_key or self.openai_llm.api_key + self.api_key = api_key + self.proxy = proxy async def text_2_embedding(self, text, model="text-embedding-ada-002"): """Text to embedding @@ -58,8 +57,8 @@ class OpenAIText2Embedding: :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ - proxies = {"proxy": self.openai_llm.proxy} if self.openai_llm.proxy else {} - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"} + proxies = {"proxy": self.proxy} if self.proxy else {} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} data = {"input": text, "model": model} url = "https://api.openai.com/v1/embeddings" try: @@ -73,16 +72,14 @@ class OpenAIText2Embedding: # Export -async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""): +async def oas3_openai_text_to_embedding(text, openai_api_key: str, model="text-embedding-ada-002", proxy: str = ""): """Text to embedding :param text: The text used for embedding. :param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`. - :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` + :param config: OpenAI config with API key, For more details, checkout: `https://platform.openai.com/account/api-keys` :return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`. """ if not text: return "" - if not openai_api_key: - openai_api_key = config.get_openai_llm().api_key - return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model) + return await OpenAIText2Embedding(api_key=openai_api_key, proxy=proxy).text_2_embedding(text, model=model) diff --git a/tests/data/demo_project/dependencies.json b/tests/data/demo_project/dependencies.json index cfcf6c165..738e5d9be 100644 --- a/tests/data/demo_project/dependencies.json +++ b/tests/data/demo_project/dependencies.json @@ -1 +1 @@ -{"docs/system_design/20231221155954.json": ["docs/prds/20231221155954.json"], "docs/tasks/20231221155954.json": ["docs/system_design/20231221155954.json"], "game_2048/game.py": ["docs/tasks/20231221155954.json", "docs/system_design/20231221155954.json"], "game_2048/main.py": ["docs/tasks/20231221155954.json", "docs/system_design/20231221155954.json"], "resources/code_summaries/20231221155954.md": ["docs/tasks/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "docs/code_summaries/20231221155954.json": ["docs/tasks/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "tests/test_main.py": ["game_2048/main.py"], "tests/test_game.py": ["game_2048/game.py"], "test_outputs/test_main.py.json": ["game_2048/main.py", "tests/test_main.py"], "test_outputs/test_game.py.json": ["game_2048/game.py", "tests/test_game.py"]} \ No newline at end of file +{"docs/system_design/20231221155954.json": ["docs/prd/20231221155954.json"], "docs/task/20231221155954.json": ["docs/system_design/20231221155954.json"], "game_2048/game.py": ["docs/task/20231221155954.json", "docs/system_design/20231221155954.json"], "game_2048/main.py": ["docs/task/20231221155954.json", "docs/system_design/20231221155954.json"], "resources/code_summary/20231221155954.md": ["docs/task/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "docs/code_summary/20231221155954.json": ["docs/task/20231221155954.json", "game_2048/game.py", "docs/system_design/20231221155954.json", "game_2048/main.py"], "tests/test_main.py": ["game_2048/main.py"], "tests/test_game.py": ["game_2048/game.py"], "test_outputs/test_main.py.json": ["game_2048/main.py", "tests/test_main.py"], "test_outputs/test_game.py.json": ["game_2048/game.py", "tests/test_game.py"]} \ No newline at end of file diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index d8a251dc8..8891960c1 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -28,10 +28,10 @@ async def test_text_to_embedding(mocker): type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy") # Prerequisites - assert config.get_openai_llm() + assert config.get_openai_llm().api_key assert config.get_openai_llm().proxy - v = await text_to_embedding(text="Panda emoji") + v = await text_to_embedding(text="Panda emoji", config=config) assert len(v.data) > 0 diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py index b58ff6580..167a35891 100644 --- a/tests/metagpt/learn/test_text_to_image.py +++ b/tests/metagpt/learn/test_text_to_image.py @@ -29,9 +29,7 @@ async def test_text_to_image(mocker): config = Config.default() assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL - data = await text_to_image( - "Panda emoji", size_type="512x512", model_url=config.METAGPT_TEXT_TO_IMAGE_MODEL_URL, config=config - ) + data = await text_to_image("Panda emoji", size_type="512x512", config=config) assert "base64" in data or "http" in data @@ -54,6 +52,7 @@ async def test_openai_text_to_image(mocker): mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png") config = Config.default() + config.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None assert config.get_openai_llm() data = await text_to_image("Panda emoji", size_type="512x512", config=config) diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py index 41611171c..38e051cc6 100644 --- a/tests/metagpt/learn/test_text_to_speech.py +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -8,43 +8,64 @@ """ import pytest +from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.learn.text_to_speech import text_to_speech +from metagpt.tools.iflytek_tts import IFlyTekTTS +from metagpt.utils.s3 import S3 @pytest.mark.asyncio -async def test_text_to_speech(): +async def test_azure_text_to_speech(mocker): + # mock + config = Config.default() + config.IFLYTEK_API_KEY = None + config.IFLYTEK_API_SECRET = None + config.IFLYTEK_APP_ID = None + mock_result = mocker.Mock() + mock_result.audio_data = b"mock audio data" + mock_result.reason = ResultReason.SynthesizingAudioCompleted + mock_data = mocker.Mock() + mock_data.get.return_value = mock_result + mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data) + mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.wav") + + # Prerequisites + assert not config.IFLYTEK_APP_ID + assert not config.IFLYTEK_API_KEY + assert not config.IFLYTEK_API_SECRET + assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY" + assert config.AZURE_TTS_REGION + + config.copy() + # test azure + data = await text_to_speech("panda emoji", config=config) + assert "base64" in data or "http" in data + + +@pytest.mark.asyncio +async def test_iflytek_text_to_speech(mocker): + # mock + config = Config.default() + config.AZURE_TTS_SUBSCRIPTION_KEY = None + config.AZURE_TTS_REGION = None + mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) + mock_data = mocker.AsyncMock() + mock_data.read.return_value = b"mock iflytek" + mock_reader = mocker.patch("aiofiles.open") + mock_reader.return_value.__aenter__.return_value = mock_data + mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.mp3") + # Prerequisites assert config.IFLYTEK_APP_ID assert config.IFLYTEK_API_KEY assert config.IFLYTEK_API_SECRET - assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY" - assert config.AZURE_TTS_REGION + assert not config.AZURE_TTS_SUBSCRIPTION_KEY or config.AZURE_TTS_SUBSCRIPTION_KEY == "YOUR_API_KEY" + assert not config.AZURE_TTS_REGION - i = config.copy() # test azure - data = await text_to_speech( - "panda emoji", - subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY, - region=i.AZURE_TTS_REGION, - iflytek_api_key=i.IFLYTEK_API_KEY, - iflytek_api_secret=i.IFLYTEK_API_SECRET, - iflytek_app_id=i.IFLYTEK_APP_ID, - ) - assert "base64" in data or "http" in data - - # test iflytek - ## Mock session env - i.AZURE_TTS_SUBSCRIPTION_KEY = "" - data = await text_to_speech( - "panda emoji", - subscription_key=i.AZURE_TTS_SUBSCRIPTION_KEY, - region=i.AZURE_TTS_REGION, - iflytek_api_key=i.IFLYTEK_API_KEY, - iflytek_api_secret=i.IFLYTEK_API_SECRET, - iflytek_app_id=i.IFLYTEK_APP_ID, - ) + data = await text_to_speech("panda emoji", config=config) assert "base64" in data or "http" in data diff --git a/tests/metagpt/roles/test_assistant.py b/tests/metagpt/roles/test_assistant.py index 4ef44d77a..b9740a112 100644 --- a/tests/metagpt/roles/test_assistant.py +++ b/tests/metagpt/roles/test_assistant.py @@ -20,7 +20,10 @@ from metagpt.utils.common import any_to_str @pytest.mark.asyncio -async def test_run(): +async def test_run(mocker): + # mock + mocker.patch("metagpt.learn.text_to_image", return_value="http://mock.com/1.png") + CONTEXT.kwargs.language = "Chinese" class Input(BaseModel): @@ -65,7 +68,7 @@ async def test_run(): "cause_by": any_to_str(SkillAction), }, ] - CONTEXT.kwargs.agent_skills = [ + agent_skills = [ {"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True}, {"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True}, {"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True}, @@ -77,9 +80,11 @@ async def test_run(): for i in inputs: seed = Input(**i) - CONTEXT.kwargs.language = seed.language - CONTEXT.kwargs.agent_description = seed.agent_description role = Assistant(language="Chinese") + role.context.kwargs.language = seed.language + role.context.kwargs.agent_description = seed.agent_description + role.context.kwargs.agent_skills = agent_skills + role.memory = seed.memory # Restore historical conversation content. while True: has_action = await role.think() @@ -112,6 +117,7 @@ async def test_run(): @pytest.mark.asyncio async def test_memory(memory): role = Assistant() + role.context.kwargs.agent_skills = [] role.load_memory(memory) val = role.get_memory() diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index 710e74b8f..17b94828c 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -8,23 +8,25 @@ distribution feature for message handling. """ import json +import uuid from pathlib import Path import pytest from metagpt.actions import WriteCode, WriteTasks from metagpt.const import ( - PRDS_FILE_REPO, + DEFAULT_WORKSPACE_ROOT, REQUIREMENT_FILENAME, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) -from metagpt.context import CONTEXT +from metagpt.context import CONTEXT, Context from metagpt.logs import logger from metagpt.roles.engineer import Engineer from metagpt.schema import CodingContext, Message from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite -from metagpt.utils.git_repository import ChangeType +from metagpt.utils.git_repository import ChangeType, GitRepository +from metagpt.utils.project_repo import ProjectRepo from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages @@ -32,20 +34,18 @@ from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages async def test_engineer(): # Prerequisites rqno = "20231221155954.json" - await CONTEXT.file_repo.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content) - await CONTEXT.file_repo.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content) - await CONTEXT.file_repo.save_file( - rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content - ) - await CONTEXT.file_repo.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content) + project_repo = ProjectRepo(CONTEXT.git_repo) + await project_repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content) + await project_repo.docs.prd.save(rqno, content=MockMessages.prd.content) + await project_repo.docs.system_design.save(rqno, content=MockMessages.system_design.content) + await project_repo.docs.task.save(rqno, content=MockMessages.json_tasks.content) engineer = Engineer() rsp = await engineer.run(Message(content="", cause_by=WriteTasks)) logger.info(rsp) assert rsp.cause_by == any_to_str(WriteCode) - src_file_repo = CONTEXT.git_repo.new_file_repository(CONTEXT.src_workspace) - assert src_file_repo.changed_files + assert project_repo.with_src_path(CONTEXT.src_workspace).srcs.changed_files def test_parse_str(): @@ -114,48 +114,50 @@ def test_todo(): @pytest.mark.asyncio async def test_new_coding_context(): # Prerequisites + context = Context() + context.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") demo_path = Path(__file__).parent / "../../data/demo_project" deps = json.loads(await aread(demo_path / "dependencies.json")) - dependency = await CONTEXT.git_repo.get_dependency() + dependency = await context.git_repo.get_dependency() for k, v in deps.items(): await dependency.update(k, set(v)) data = await aread(demo_path / "system_design.json") rqno = "20231221155954.json" - await awrite(CONTEXT.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data) + await awrite(context.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data) data = await aread(demo_path / "tasks.json") - await awrite(CONTEXT.git_repo.workdir / TASK_FILE_REPO / rqno, data) + await awrite(context.git_repo.workdir / TASK_FILE_REPO / rqno, data) - CONTEXT.src_workspace = Path(CONTEXT.git_repo.workdir) / "game_2048" - src_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=CONTEXT.src_workspace) - task_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=TASK_FILE_REPO) - design_file_repo = CONTEXT.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO) + context.src_workspace = Path(context.git_repo.workdir) / "game_2048" - filename = "game.py" - ctx_doc = await Engineer._new_coding_doc( - filename=filename, - src_file_repo=src_file_repo, - task_file_repo=task_file_repo, - design_file_repo=design_file_repo, - dependency=dependency, - ) - assert ctx_doc - assert ctx_doc.filename == filename - assert ctx_doc.content - ctx = CodingContext.model_validate_json(ctx_doc.content) - assert ctx.filename == filename - assert ctx.design_doc - assert ctx.design_doc.content - assert ctx.task_doc - assert ctx.task_doc.content - assert ctx.code_doc + try: + filename = "game.py" + engineer = Engineer(context=context) + ctx_doc = await engineer._new_coding_doc( + filename=filename, + dependency=dependency, + ) + assert ctx_doc + assert ctx_doc.filename == filename + assert ctx_doc.content + ctx = CodingContext.model_validate_json(ctx_doc.content) + assert ctx.filename == filename + assert ctx.design_doc + assert ctx.design_doc.content + assert ctx.task_doc + assert ctx.task_doc.content + assert ctx.code_doc - CONTEXT.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED}) - CONTEXT.git_repo.commit("mock env") - await src_file_repo.save(filename=filename, content="content") - role = Engineer() - assert not role.code_todos - await role._new_code_actions() - assert role.code_todos + context.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED}) + context.git_repo.commit("mock env") + await ProjectRepo(context.git_repo).with_src_path(context.src_workspace).srcs.save( + filename=filename, content="content" + ) + role = Engineer(context=context) + assert not role.code_todos + await role._new_code_actions() + assert role.code_todos + finally: + context.git_repo.delete_repository() if __name__ == "__main__": diff --git a/tests/metagpt/roles/test_teacher.py b/tests/metagpt/roles/test_teacher.py index 8bd37f482..83a7e382a 100644 --- a/tests/metagpt/roles/test_teacher.py +++ b/tests/metagpt/roles/test_teacher.py @@ -8,15 +8,14 @@ from typing import Dict, Optional import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field -from metagpt.context import CONTEXT +from metagpt.context import Context from metagpt.roles.teacher import Teacher from metagpt.schema import Message @pytest.mark.asyncio -@pytest.mark.skip async def test_init(): class Inputs(BaseModel): name: str @@ -30,6 +29,7 @@ async def test_init(): expect_goal: str expect_constraints: str expect_desc: str + exclude: list = Field(default_factory=list) inputs = [ { @@ -44,6 +44,7 @@ async def test_init(): "kwargs": {}, "desc": "aaa{language}", "expect_desc": "aaa{language}", + "exclude": ["language", "key1", "something_big", "teaching_language"], }, { "name": "Lily{language}", @@ -57,13 +58,21 @@ async def test_init(): "kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"}, "desc": "aaa{language}", "expect_desc": "aaaCN", + "language": "CN", + "teaching_language": "EN", }, ] for i in inputs: seed = Inputs(**i) + context = Context() + for k in seed.exclude: + context.kwargs.set(k, None) + for k, v in seed.kwargs.items(): + context.kwargs.set(k, v) teacher = Teacher( + context=context, name=seed.name, profile=seed.profile, goal=seed.goal, @@ -97,8 +106,6 @@ async def test_new_file_name(): @pytest.mark.asyncio async def test_run(): - CONTEXT.kwargs.language = "Chinese" - CONTEXT.kwargs.teaching_language = "English" lesson = """ UNIT 1 Making New Friends TOPIC 1 Welcome to China! @@ -142,7 +149,10 @@ async def test_run(): 3c Match the big letters with the small ones. Then write them on the lines. """ - teacher = Teacher() + context = Context() + context.kwargs.language = "Chinese" + context.kwargs.teaching_language = "English" + teacher = Teacher(context=context) rsp = await teacher.run(Message(content=lesson)) assert rsp diff --git a/tests/metagpt/tools/test_iflytek_tts.py b/tests/metagpt/tools/test_iflytek_tts.py index 18af0a723..8e4c0cf54 100644 --- a/tests/metagpt/tools/test_iflytek_tts.py +++ b/tests/metagpt/tools/test_iflytek_tts.py @@ -7,12 +7,22 @@ """ import pytest -from metagpt.config2 import config -from metagpt.tools.iflytek_tts import oas3_iflytek_tts +from metagpt.config2 import Config +from metagpt.tools.iflytek_tts import IFlyTekTTS, oas3_iflytek_tts @pytest.mark.asyncio -async def test_tts(): +async def test_iflytek_tts(mocker): + # mock + config = Config.default() + config.AZURE_TTS_SUBSCRIPTION_KEY = None + config.AZURE_TTS_REGION = None + mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) + mock_data = mocker.AsyncMock() + mock_data.read.return_value = b"mock iflytek" + mock_reader = mocker.patch("aiofiles.open") + mock_reader.return_value.__aenter__.return_value = mock_data + # Prerequisites assert config.IFLYTEK_APP_ID assert config.IFLYTEK_API_KEY diff --git a/tests/metagpt/tools/test_openai_text_to_embedding.py b/tests/metagpt/tools/test_openai_text_to_embedding.py index b4e9b3383..047206d48 100644 --- a/tests/metagpt/tools/test_openai_text_to_embedding.py +++ b/tests/metagpt/tools/test_openai_text_to_embedding.py @@ -27,10 +27,13 @@ async def test_embedding(mocker): type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy") # Prerequisites - assert config.get_openai_llm() - assert config.get_openai_llm().proxy + llm_config = config.get_openai_llm() + assert llm_config + assert llm_config.proxy - result = await oas3_openai_text_to_embedding("Panda emoji") + result = await oas3_openai_text_to_embedding( + "Panda emoji", openai_api_key=llm_config.api_key, proxy=llm_config.proxy + ) assert result assert result.model assert len(result.data) > 0 diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 5a6214d17..3f9169ddd 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -39,10 +39,10 @@ async def test_draw(mocker): mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png") # Prerequisites - assert config.get_openai_llm() - assert config.get_openai_llm().proxy + llm_config = config.get_openai_llm() + assert llm_config - binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM()) + binary_data = await oas3_openai_text_to_image("Panda emoji", llm=LLM(llm_config=llm_config)) assert binary_data