mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-10 16:22:37 +02:00
fix ruff check error
This commit is contained in:
parent
192c030281
commit
cb11ec7bc7
19 changed files with 283 additions and 257 deletions
|
|
@ -15,6 +15,7 @@ from metagpt.actions.design_api import WriteDesign
|
|||
from metagpt.actions.design_api_review import DesignReview
|
||||
from metagpt.actions.design_filenames import DesignFilenames
|
||||
from metagpt.actions.project_management import AssignTasks, WriteTasks
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.actions.search_and_summarize import SearchAndSummarize
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
|
|
@ -22,11 +23,11 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""All types of Actions, used for indexing."""
|
||||
|
||||
ADD_REQUIREMENT = BossRequirement
|
||||
WRITE_PRD = WritePRD
|
||||
WRITE_PRD_REVIEW = WritePRDReview
|
||||
|
|
@ -44,3 +45,10 @@ class ActionType(Enum):
|
|||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ActionType",
|
||||
"Action",
|
||||
"ActionOutput",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@
|
|||
@Author : alexanderwu
|
||||
@File : run_code.py
|
||||
"""
|
||||
import traceback
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Tuple
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.logs import logger
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
Role: You are a senior development and qa engineer, your role is summarize the code running result.
|
||||
|
|
@ -55,6 +55,7 @@ standard output: {outs};
|
|||
standard errors: {errs};
|
||||
"""
|
||||
|
||||
|
||||
class RunCode(Action):
|
||||
def __init__(self, name="RunCode", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
|
|
@ -65,7 +66,7 @@ class RunCode(Action):
|
|||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
return namespace.get('result', ""), ""
|
||||
return namespace.get("result", ""), ""
|
||||
except Exception:
|
||||
# If there is an error in the code, return the error message
|
||||
return "", traceback.format_exc()
|
||||
|
|
@ -81,10 +82,12 @@ class RunCode(Action):
|
|||
# Modify the PYTHONPATH environment variable
|
||||
additional_python_paths = [working_directory] + additional_python_paths
|
||||
additional_python_paths = ":".join(additional_python_paths)
|
||||
env['PYTHONPATH'] = additional_python_paths + ':' + env.get('PYTHONPATH', '')
|
||||
env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "")
|
||||
|
||||
# Start the subprocess
|
||||
process = subprocess.Popen(command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
|
||||
process = subprocess.Popen(
|
||||
command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for the process to complete, with a timeout
|
||||
|
|
@ -93,7 +96,7 @@ class RunCode(Action):
|
|||
logger.info("The command did not complete within the given timeout.")
|
||||
process.kill() # Kill the process if it times out
|
||||
stdout, stderr = process.communicate()
|
||||
return stdout.decode('utf-8'), stderr.decode('utf-8')
|
||||
return stdout.decode("utf-8"), stderr.decode("utf-8")
|
||||
|
||||
async def run(
|
||||
self, code, mode="script", code_file_name="", test_code="", test_file_name="", command=[], **kwargs
|
||||
|
|
@ -108,11 +111,13 @@ class RunCode(Action):
|
|||
logger.info(f"{errs=}")
|
||||
|
||||
context = CONTEXT.format(
|
||||
code=code, code_file_name=code_file_name,
|
||||
test_code=test_code, test_file_name=test_file_name,
|
||||
code=code,
|
||||
code_file_name=code_file_name,
|
||||
test_code=test_code,
|
||||
test_file_name=test_file_name,
|
||||
command=" ".join(command),
|
||||
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
|
||||
errs=errs[:10000] # truncate errors to avoid token overflow
|
||||
outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow
|
||||
errs=errs[:10000], # truncate errors to avoid token overflow
|
||||
)
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(context=context)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
@Author : alexanderwu
|
||||
@File : write_test.py
|
||||
"""
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -29,6 +28,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
## {test_file_name}: Write test code with triple quoto. Do your best to implement THIS ONLY ONE FILE.
|
||||
"""
|
||||
|
||||
|
||||
class WriteTest(Action):
|
||||
def __init__(self, name="WriteTest", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
|
|
@ -43,7 +43,7 @@ class WriteTest(Action):
|
|||
code_to_test=code_to_test,
|
||||
test_file_name=test_file_name,
|
||||
source_file_path=source_file_path,
|
||||
workspace=workspace
|
||||
workspace=workspace,
|
||||
)
|
||||
code = await self.write_code(prompt)
|
||||
return code
|
||||
|
|
|
|||
|
|
@ -7,3 +7,5 @@
|
|||
"""
|
||||
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
||||
__all__ = ["FaissStore"]
|
||||
|
|
|
|||
|
|
@ -9,3 +9,8 @@
|
|||
from metagpt.memory.memory import Memory
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Memory",
|
||||
"LongTermMemory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of Long-term memory
|
||||
|
||||
from typing import Iterable, Type
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
|
|
@ -27,10 +25,11 @@ class LongTermMemory(Memory):
|
|||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty')
|
||||
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
|
||||
else:
|
||||
logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages '
|
||||
f'and has recovered them.')
|
||||
logger.warning(
|
||||
f"Agent {role_id} has existed memory storage with {len(messages)} messages " f"and has recovered them."
|
||||
)
|
||||
self.msg_from_recover = True
|
||||
self.add_batch(messages)
|
||||
self.msg_from_recover = False
|
||||
|
|
|
|||
|
|
@ -7,3 +7,6 @@
|
|||
"""
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
|
||||
|
||||
__all__ = ["OpenAIGPTAPI"]
|
||||
|
|
|
|||
|
|
@ -8,10 +8,21 @@
|
|||
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.roles.architect import Architect
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.qa_engineer import QaEngineer
|
||||
from metagpt.roles.seacher import Searcher
|
||||
from metagpt.roles.sales import Sales
|
||||
from metagpt.roles.customer_service import CustomerService
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Role",
|
||||
"Architect",
|
||||
"ProjectManager",
|
||||
"Engineer",
|
||||
"QaEngineer",
|
||||
"Searcher",
|
||||
"Sales",
|
||||
"CustomerService",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,40 +6,44 @@
|
|||
@File : qa_engineer.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
from metagpt.actions import WriteTest, WriteCode, WriteDesign, RunCode, DebugError
|
||||
from metagpt.actions import DebugError, RunCode, WriteCode, WriteDesign, WriteTest
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.utils.common import CodeParser, parse_recipient
|
||||
from metagpt.utils.special_tokens import MSG_SEP, FILENAME_CODE_SEP
|
||||
from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
|
||||
|
||||
|
||||
class QaEngineer(Role):
|
||||
def __init__(self, name="Edward", profile="QaEngineer",
|
||||
goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
|
||||
constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
|
||||
test_round_allowed=5):
|
||||
def __init__(
|
||||
self,
|
||||
name="Edward",
|
||||
profile="QaEngineer",
|
||||
goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
|
||||
constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
|
||||
test_round_allowed=5,
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self._init_actions([WriteTest]) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
|
||||
self._init_actions(
|
||||
[WriteTest]
|
||||
) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
|
||||
self._watch([WriteCode, WriteTest, RunCode, DebugError])
|
||||
self.test_round = 0
|
||||
self.test_round_allowed = test_round_allowed
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse_workspace(cls, system_design_msg: Message) -> str:
|
||||
if not system_design_msg.instruct_content:
|
||||
return system_design_msg.instruct_content.dict().get("Python package name")
|
||||
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
|
||||
|
||||
|
||||
def get_workspace(self, return_proj_dir=True) -> Path:
|
||||
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
|
||||
if not msg:
|
||||
return WORKSPACE_ROOT / 'src'
|
||||
return WORKSPACE_ROOT / "src"
|
||||
workspace = self.parse_workspace(msg)
|
||||
# project directory: workspace/{package_name}, which contains package source code folder, tests folder, resources folder, etc.
|
||||
if return_proj_dir:
|
||||
|
|
@ -48,49 +52,52 @@ class QaEngineer(Role):
|
|||
return WORKSPACE_ROOT / workspace / workspace
|
||||
|
||||
def write_file(self, filename: str, code: str):
|
||||
workspace = self.get_workspace() / 'tests'
|
||||
workspace = self.get_workspace() / "tests"
|
||||
file = workspace / filename
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
file.write_text(code)
|
||||
|
||||
async def _write_test(self, message: Message) -> None:
|
||||
|
||||
code_msgs = message.content.split(MSG_SEP)
|
||||
result_msg_all = []
|
||||
# result_msg_all = []
|
||||
for code_msg in code_msgs:
|
||||
|
||||
# write tests
|
||||
file_name, file_path = code_msg.split(FILENAME_CODE_SEP)
|
||||
code_to_test = open(file_path, "r").read()
|
||||
if "test" in file_name:
|
||||
continue # Engineer might write some test files, skip testing a test file
|
||||
continue # Engineer might write some test files, skip testing a test file
|
||||
test_file_name = "test_" + file_name
|
||||
test_file_path = self.get_workspace() / "tests" / test_file_name
|
||||
logger.info(f'Writing {test_file_name}..')
|
||||
logger.info(f"Writing {test_file_name}..")
|
||||
test_code = await WriteTest().run(
|
||||
code_to_test=code_to_test,
|
||||
test_file_name=test_file_name,
|
||||
# source_file_name=file_name,
|
||||
source_file_path=file_path,
|
||||
workspace=self.get_workspace()
|
||||
workspace=self.get_workspace(),
|
||||
)
|
||||
self.write_file(test_file_name, test_code)
|
||||
|
||||
# prepare context for run tests in next round
|
||||
command = ['python', f'tests/{test_file_name}']
|
||||
command = ["python", f"tests/{test_file_name}"]
|
||||
file_info = {
|
||||
"file_name": file_name, "file_path": str(file_path),
|
||||
"test_file_name": test_file_name, "test_file_path": str(test_file_path),
|
||||
"command": command
|
||||
"file_name": file_name,
|
||||
"file_path": str(file_path),
|
||||
"test_file_name": test_file_name,
|
||||
"test_file_path": str(test_file_path),
|
||||
"command": command,
|
||||
}
|
||||
msg = Message(
|
||||
content=str(file_info), role=self.profile, cause_by=WriteTest,
|
||||
sent_from=self.profile, send_to=self.profile
|
||||
content=str(file_info),
|
||||
role=self.profile,
|
||||
cause_by=WriteTest,
|
||||
sent_from=self.profile,
|
||||
send_to=self.profile,
|
||||
)
|
||||
self._publish_message(msg)
|
||||
|
||||
logger.info(f'Done {self.get_workspace()}/tests generating.')
|
||||
|
||||
|
||||
logger.info(f"Done {self.get_workspace()}/tests generating.")
|
||||
|
||||
async def _run_code(self, msg):
|
||||
file_info = eval(msg.content)
|
||||
development_file_path = file_info["file_path"]
|
||||
|
|
@ -110,17 +117,14 @@ class QaEngineer(Role):
|
|||
test_code=test_code,
|
||||
test_file_name=file_info["test_file_name"],
|
||||
command=file_info["command"],
|
||||
working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
|
||||
additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
|
||||
# import statement inside package code needs this
|
||||
working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here
|
||||
additional_python_paths=[development_code_dir], # workspace/package_name/package_name,
|
||||
# import statement inside package code needs this
|
||||
)
|
||||
|
||||
recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
|
||||
recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself
|
||||
content = str(file_info) + FILENAME_CODE_SEP + result_msg
|
||||
msg = Message(
|
||||
content=content, role=self.profile, cause_by=RunCode,
|
||||
sent_from=self.profile, send_to=recipient
|
||||
)
|
||||
msg = Message(content=content, role=self.profile, cause_by=RunCode, sent_from=self.profile, send_to=recipient)
|
||||
self._publish_message(msg)
|
||||
|
||||
async def _debug_error(self, msg):
|
||||
|
|
@ -128,21 +132,27 @@ class QaEngineer(Role):
|
|||
file_name, code = await DebugError().run(context)
|
||||
if file_name:
|
||||
self.write_file(file_name, code)
|
||||
recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
|
||||
msg = Message(content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient)
|
||||
recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self
|
||||
msg = Message(
|
||||
content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient
|
||||
)
|
||||
self._publish_message(msg)
|
||||
|
||||
|
||||
async def _observe(self) -> int:
|
||||
await super()._observe()
|
||||
self._rc.news = [msg for msg in self._rc.news \
|
||||
if msg.send_to == self.profile] # only relevant msgs count as observed news
|
||||
self._rc.news = [
|
||||
msg for msg in self._rc.news if msg.send_to == self.profile
|
||||
] # only relevant msgs count as observed news
|
||||
return len(self._rc.news)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
if self.test_round > self.test_round_allowed:
|
||||
result_msg = Message(
|
||||
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",
|
||||
role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
|
||||
role=self.profile,
|
||||
cause_by=WriteTest,
|
||||
sent_from=self.profile,
|
||||
send_to="",
|
||||
)
|
||||
return result_msg
|
||||
|
||||
|
|
@ -161,6 +171,9 @@ class QaEngineer(Role):
|
|||
self.test_round += 1
|
||||
result_msg = Message(
|
||||
content=f"Round {self.test_round} of tests done",
|
||||
role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to=""
|
||||
role=self.profile,
|
||||
cause_by=WriteTest,
|
||||
sent_from=self.profile,
|
||||
send_to="",
|
||||
)
|
||||
return result_msg
|
||||
|
|
|
|||
|
|
@ -2,29 +2,27 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import Config
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
||||
config = Config()
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {
|
||||
"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
|
||||
},
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
|
|
@ -36,21 +34,20 @@ payload = {
|
|||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
'enable_hr': False,
|
||||
'hr_scale': 2,
|
||||
'hr_upscaler': 'Latent',
|
||||
'hr_second_pass_steps': 0,
|
||||
'hr_resize_x': 0,
|
||||
'hr_resize_y': 0,
|
||||
'hr_upscale_to_x': 0,
|
||||
'hr_upscale_to_y': 0,
|
||||
'truncate_x': 0,
|
||||
'truncate_y': 0,
|
||||
'applied_old_hires_behavior_to': None,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {}
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
|
@ -60,14 +57,20 @@ class SDEngine:
|
|||
def __init__(self):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.config = Config()
|
||||
self.sd_url = self.config.get('SD_URL')
|
||||
self.sd_url = self.config.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20"):
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
|
|
@ -76,13 +79,13 @@ class SDEngine:
|
|||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
|
|
@ -90,25 +93,26 @@ class SDEngine:
|
|||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json['images']
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
|
@ -124,12 +128,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
|
|
@ -13,3 +13,12 @@ from metagpt.utils.token_counter import (
|
|||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"read_docx",
|
||||
"Singleton",
|
||||
"TOKEN_COSTS",
|
||||
"count_message_tokens",
|
||||
"count_string_tokens",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
@Author : alexanderwu
|
||||
@File : mermaid.py
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height
|
|||
:return: 0 if succed, -1 if failed
|
||||
"""
|
||||
# Write the Mermaid code to a temporary file
|
||||
tmp = Path(f'{output_file_without_suffix}.mmd')
|
||||
tmp.write_text(mermaid_code, encoding='utf-8')
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
|
||||
if check_cmd_exists('mmdc') != 0:
|
||||
logger.warning(
|
||||
"RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
|
||||
if check_cmd_exists("mmdc") != 0:
|
||||
logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
|
||||
return -1
|
||||
|
||||
for suffix in ['pdf', 'svg', 'png']:
|
||||
output_file = f'{output_file_without_suffix}.{suffix}'
|
||||
for suffix in ["pdf", "svg", "png"]:
|
||||
output_file = f"{output_file_without_suffix}.{suffix}"
|
||||
# Call the `mmdc` command to convert the Mermaid code to a PNG
|
||||
logger.info(f"Generating {output_file}..")
|
||||
|
||||
if CONFIG.puppeteer_config:
|
||||
subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o',
|
||||
output_file, '-w', str(width), '-H', str(height)])
|
||||
subprocess.run(
|
||||
[
|
||||
CONFIG.mmdc,
|
||||
"-p",
|
||||
CONFIG.puppeteer_config,
|
||||
"-i",
|
||||
str(tmp),
|
||||
"-o",
|
||||
output_file,
|
||||
"-w",
|
||||
str(width),
|
||||
"-H",
|
||||
str(height),
|
||||
]
|
||||
)
|
||||
else:
|
||||
subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o',
|
||||
output_file, '-w', str(width), '-H', str(height)])
|
||||
subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)])
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -97,7 +108,7 @@ MMC2 = """sequenceDiagram
|
|||
SE-->>M: return summary"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# logger.info(print_members(print_members))
|
||||
mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png')
|
||||
mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png')
|
||||
mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png")
|
||||
mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png")
|
||||
|
|
|
|||
|
|
@ -3,14 +3,11 @@
|
|||
# @Desc : the implement of serialization and deserialization
|
||||
|
||||
import copy
|
||||
from typing import Tuple, List, Type, Union, Dict
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pydantic import create_model
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
||||
|
|
@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
|||
```
|
||||
"""
|
||||
mapping = dict()
|
||||
for field, property in schema['properties'].items():
|
||||
if property['type'] == 'string':
|
||||
for field, property in schema["properties"].items():
|
||||
if property["type"] == "string":
|
||||
mapping[field] = (str, ...)
|
||||
elif property['type'] == 'array' and property['items']['type'] == 'string':
|
||||
elif property["type"] == "array" and property["items"]["type"] == "string":
|
||||
mapping[field] = (List[str], ...)
|
||||
elif property['type'] == 'array' and property['items']['type'] == 'array':
|
||||
elif property["type"] == "array" and property["items"]["type"] == "array":
|
||||
# here only consider the `Tuple[str, str]` situation
|
||||
mapping[field] = (List[Tuple[str, str]], ...)
|
||||
return mapping
|
||||
|
|
@ -53,11 +50,7 @@ def serialize_message(message: Message):
|
|||
schema = ic.schema()
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
|
||||
message_cp.instruct_content = {
|
||||
'class': schema['title'],
|
||||
'mapping': mapping,
|
||||
'value': ic.dict()
|
||||
}
|
||||
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
msg_ser = pickle.dumps(message_cp)
|
||||
|
||||
return msg_ser
|
||||
|
|
@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message:
|
|||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
|
||||
mapping=ic['mapping'])
|
||||
ic_new = ic_obj(**ic['value'])
|
||||
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ target-version = ['py39']
|
|||
[tool.ruff]
|
||||
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
||||
select = ["E", "F"]
|
||||
ignore = ["E501", "E712", "E722", "F821"]
|
||||
ignore = ["E501", "E712", "E722", "F821", "E731"]
|
||||
|
||||
# Allow autofix for all enabled rules (when `--fix`) is provided.
|
||||
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
||||
|
|
|
|||
|
|
@ -6,24 +6,23 @@
|
|||
@File : test_run_code.py
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
from metagpt.actions.run_code import RunCode
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_text():
|
||||
action = RunCode()
|
||||
result, errs = await RunCode.run_text('result = 1 + 1')
|
||||
result, errs = await RunCode.run_text("result = 1 + 1")
|
||||
assert result == 2
|
||||
assert errs == ""
|
||||
|
||||
result, errs = await RunCode.run_text('result = 1 / 0')
|
||||
result, errs = await RunCode.run_text("result = 1 / 0")
|
||||
assert result == ""
|
||||
assert "ZeroDivisionError" in errs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_script():
|
||||
action = RunCode()
|
||||
|
||||
# Successful command
|
||||
out, err = await RunCode.run_script(".", command=["echo", "Hello World"])
|
||||
assert out.strip() == "Hello World"
|
||||
|
|
@ -33,6 +32,7 @@ async def test_run_script():
|
|||
out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"])
|
||||
assert "ZeroDivisionError" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
action = RunCode()
|
||||
|
|
@ -47,10 +47,11 @@ async def test_run():
|
|||
test_file_name="",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[]
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "PASS" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_failure():
|
||||
action = RunCode()
|
||||
|
|
@ -65,6 +66,6 @@ async def test_run_failure():
|
|||
test_file_name="",
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[]
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "FAIL" in result
|
||||
assert "FAIL" in result
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code_review import WriteCodeReview
|
||||
from metagpt.logs import logger
|
||||
from tests.metagpt.actions.mock import SEARCH_CODE_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,11 +18,7 @@ def add(a, b):
|
|||
"""
|
||||
# write_code_review = WriteCodeReview("write_code_review")
|
||||
|
||||
code = await WriteCodeReview().run(
|
||||
context="编写一个从a加b的函数,返回a+b",
|
||||
code=code,
|
||||
filename="math.py"
|
||||
)
|
||||
code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
|
||||
|
||||
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
|
||||
assert isinstance(code, str)
|
||||
|
|
@ -33,6 +27,7 @@ def add(a, b):
|
|||
captured = capfd.readouterr()
|
||||
print(f"输出内容: {captured.out}")
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_write_code_review_directly():
|
||||
# code = SEARCH_CODE_SAMPLE
|
||||
|
|
|
|||
|
|
@ -2,22 +2,19 @@
|
|||
# @Date : 2023/7/15 16:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
import os
|
||||
from importlib import import_module
|
||||
import re
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.roles import ProductManager, Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import Action, ActionOutput, WritePRD
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.software_company import SoftwareCompany
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
PROMPT_TEMPLATE = '''
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
|
||||
|
|
@ -34,9 +31,9 @@ Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD W
|
|||
## CSS Styles (styles.css):Provide as Plain text,use standard css code
|
||||
## Anything UNCLEAR:Provide as Plain text. Make clear here.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
FORMAT_EXAMPLE = '''
|
||||
FORMAT_EXAMPLE = """
|
||||
|
||||
## UI Design Description
|
||||
```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ```
|
||||
|
|
@ -126,7 +123,7 @@ body {
|
|||
## Anything UNCLEAR
|
||||
There are no unclear points.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
OUTPUT_MAPPING = {
|
||||
"UI Design Description": (str, ...),
|
||||
|
|
@ -139,25 +136,25 @@ OUTPUT_MAPPING = {
|
|||
|
||||
def load_engine(func):
|
||||
"""Decorator to load an engine by file name and engine name."""
|
||||
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
file_name, engine_name = func(*args, **kwargs)
|
||||
engine_file = import_module(file_name, package='metagpt')
|
||||
engine_file = import_module(file_name, package="metagpt")
|
||||
ip_module_cls = getattr(engine_file, engine_name)
|
||||
try:
|
||||
engine = ip_module_cls()
|
||||
except:
|
||||
engine = None
|
||||
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def parse(func):
|
||||
"""Decorator to parse information using regex pattern."""
|
||||
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
context, pattern = func(*args, **kwargs)
|
||||
|
|
@ -168,30 +165,30 @@ def parse(func):
|
|||
else:
|
||||
text_info = context
|
||||
logger.info("未找到匹配的内容")
|
||||
|
||||
|
||||
return text_info
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class UIDesign(Action):
|
||||
"""Class representing the UI Design action."""
|
||||
|
||||
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
|
||||
|
||||
|
||||
@parse
|
||||
def parse_requirement(self, context: str):
|
||||
"""Parse UI Design draft from the context using regex."""
|
||||
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
|
||||
|
||||
@parse
|
||||
def parse_ui_elements(self, context: str):
|
||||
"""Parse Selected Elements from the context using regex."""
|
||||
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
|
||||
return context, pattern
|
||||
|
||||
|
||||
@parse
|
||||
def parse_css_code(self, context: str):
|
||||
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
|
||||
|
|
@ -201,7 +198,7 @@ class UIDesign(Action):
|
|||
def parse_html_code(self, context: str):
|
||||
pattern = r"```html.*?\n(.*?)```"
|
||||
return context, pattern
|
||||
|
||||
|
||||
async def draw_icons(self, context, *args, **kwargs):
|
||||
"""Draw icons using SDEngine."""
|
||||
engine = SDEngine()
|
||||
|
|
@ -215,20 +212,20 @@ class UIDesign(Action):
|
|||
prompts_batch.append(prompt)
|
||||
await engine.run_t2i(prompts_batch)
|
||||
logger.info("Finish icon design using StableDiffusion API")
|
||||
|
||||
|
||||
async def _save(self, css_content, html_content):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / 'codes'
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "codes"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Save CSS and HTML content to files
|
||||
css_file_path = save_dir / f"ui_design.css"
|
||||
html_file_path = save_dir / f"ui_design.html"
|
||||
|
||||
with open(css_file_path, 'w') as css_file:
|
||||
css_file_path = save_dir / "ui_design.css"
|
||||
html_file_path = save_dir / "ui_design.html"
|
||||
|
||||
with open(css_file_path, "w") as css_file:
|
||||
css_file.write(css_content)
|
||||
with open(html_file_path, 'w') as html_file:
|
||||
with open(html_file_path, "w") as html_file:
|
||||
html_file.write(html_content)
|
||||
|
||||
|
||||
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
|
||||
"""Run the UI Design action."""
|
||||
# fixme: update prompt (根据需求细化prompt)
|
||||
|
|
@ -249,23 +246,27 @@ class UIDesign(Action):
|
|||
|
||||
class UI(Role):
|
||||
"""Class representing the UI Role."""
|
||||
|
||||
def __init__(self, name="Catherine", profile="UI Design",
|
||||
goal="Finish a workable and good User Interface design based on a product design",
|
||||
constraints="Give clear layout description and use standard icons to finish the design",
|
||||
skills=["SD"]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name="Catherine",
|
||||
profile="UI Design",
|
||||
goal="Finish a workable and good User Interface design based on a product design",
|
||||
constraints="Give clear layout description and use standard icons to finish the design",
|
||||
skills=["SD"],
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self.load_skills(skills)
|
||||
self._init_actions([UIDesign])
|
||||
self._watch([WritePRD])
|
||||
|
||||
|
||||
@load_engine
|
||||
def load_sd_engine(self):
|
||||
"""Load the SDEngine."""
|
||||
file_name = ".tools.sd_engine"
|
||||
engine_name = "SDEngine"
|
||||
return file_name, engine_name
|
||||
|
||||
|
||||
def load_skills(self, skills):
|
||||
"""Load skills for the UI Role."""
|
||||
# todo: 添加其他出图engine
|
||||
|
|
@ -273,4 +274,3 @@ class UI(Role):
|
|||
if skill == "SD":
|
||||
self.sd_engine = self.load_sd_engine()
|
||||
logger.info(f"load skill engine {self.sd_engine}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import web_browser_engine, WebBrowserEngineType
|
||||
|
||||
from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -3,94 +3,64 @@
|
|||
# @Desc : the unittest of serialize
|
||||
|
||||
from typing import List, Tuple
|
||||
import pytest
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
deserialize_message,
|
||||
serialize_message,
|
||||
)
|
||||
|
||||
|
||||
def test_actionoutout_schema_to_mapping():
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
schema = {"title": "test", "type": "object", "properties": {"field": {"title": "field", "type": "string"}}}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (str, ...)
|
||||
assert mapping["field"] == (str, ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
"title": "test",
|
||||
"type": "object",
|
||||
"properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[str], ...)
|
||||
assert mapping["field"] == (List[str], ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'array',
|
||||
'minItems': 2,
|
||||
'maxItems': 2,
|
||||
'items': [
|
||||
{
|
||||
'type': 'string'
|
||||
},
|
||||
{
|
||||
'type': 'string'
|
||||
}
|
||||
]
|
||||
}
|
||||
"title": "test",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"title": "field",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
"items": [{"type": "string"}, {"type": "string"}],
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[Tuple[str, str]], ...)
|
||||
assert mapping["field"] == (List[Tuple[str, str]], ...)
|
||||
|
||||
assert True, True
|
||||
|
||||
|
||||
def test_serialize_and_deserialize_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(content='prd demand',
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
message = Message(
|
||||
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
message_ser = serialize_message(message)
|
||||
|
||||
new_message = deserialize_message(message_ser)
|
||||
assert new_message.content == message.content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field1 == out_data['field1']
|
||||
assert new_message.instruct_content.field1 == out_data["field1"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue