From b16315f6c7f7171372975fc76b870caab23f9002 Mon Sep 17 00:00:00 2001 From: geekan Date: Mon, 8 Jan 2024 17:01:51 +0800 Subject: [PATCH] refine code --- metagpt/roles/teacher.py | 10 ++-------- tests/metagpt/provider/test_spark_api.py | 10 ++-------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py index f9583d49b..fb547f56b 100644 --- a/metagpt/roles/teacher.py +++ b/metagpt/roles/teacher.py @@ -11,14 +11,12 @@ import re -import aiofiles - from metagpt.actions import UserRequirement from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, awrite class Teacher(Role): @@ -83,11 +81,7 @@ class Teacher(Role): pathname = self.config.workspace.path / "teaching_plan" pathname.mkdir(exist_ok=True) pathname = pathname / filename - try: - async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer: - await writer.write(content) - except Exception as e: - logger.error(f"Save failed:{e}") + await awrite(pathname, content) logger.info(f"Save to:{pathname}") @staticmethod diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index ee2d02c97..8c6218ac4 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,15 +4,9 @@ import pytest -from metagpt.config import CONFIG +from metagpt.config2 import config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM -CONFIG.spark_appid = "xxx" -CONFIG.spark_api_secret = "xxx" -CONFIG.spark_api_key = "xxx" -CONFIG.domain = "xxxxxx" -CONFIG.spark_url = "xxxx" - prompt_msg = "who are you" resp_content = "I'm Spark" @@ -28,7 +22,7 @@ class MockWebSocketApp(object): def test_get_msg_from_web(mocker): mocker.patch("websocket.WebSocketApp", MockWebSocketApp) - get_msg_from_web = GetMessageFromWeb(text=prompt_msg) + get_msg_from_web = GetMessageFromWeb(prompt_msg, config) assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx" ret = get_msg_from_web.run()