Merge pull request #502 from iorisa/refactor/pre-commit_shenquan

refactor: pre-commit run --all-files
This commit is contained in:
Sirui Hong 2023-11-22 16:41:52 +08:00 committed by GitHub
commit d269d18ed0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
130 changed files with 813 additions and 831 deletions

View file

@ -1,22 +1,22 @@
'''
"""
Filename: MetaGPT/examples/agent_creator.py
Created Date: Tuesday, September 12th 2023, 3:28:37 pm
Author: garylin2099
'''
"""
import re
from metagpt.const import PROJECT_ROOT, WORKSPACE_ROOT
from metagpt.actions import Action
from metagpt.const import PROJECT_ROOT, WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.logs import logger
with open(PROJECT_ROOT / "examples/build_customized_agent.py", "r") as f:
# use official example script to guide AgentCreator
MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read()
class CreateAgent(Action):
class CreateAgent(Action):
PROMPT_TEMPLATE = """
### BACKGROUND
You are using an agent framework called metagpt to write agents capable of different actions,
@ -34,7 +34,6 @@ class CreateAgent(Action):
"""
async def run(self, example: str, instruction: str):
prompt = self.PROMPT_TEMPLATE.format(example=example, instruction=instruction)
# logger.info(prompt)
@ -46,13 +45,14 @@ class CreateAgent(Action):
@staticmethod
def parse_code(rsp):
pattern = r'```python(.*)```'
pattern = r"```python(.*)```"
match = re.search(pattern, rsp, re.DOTALL)
code_text = match.group(1) if match else ""
with open(WORKSPACE_ROOT / "agent_created_agent.py", "w") as f:
f.write(code_text)
return code_text
class AgentCreator(Role):
def __init__(
self,
@ -76,11 +76,11 @@ class AgentCreator(Role):
return msg
if __name__ == "__main__":
import asyncio
async def main():
agent_template = MULTI_ACTION_AGENT_CODE_EXAMPLE
creator = AgentCreator(agent_template=agent_template)

View file

@ -1,21 +1,21 @@
'''
"""
Filename: MetaGPT/examples/build_customized_agent.py
Created Date: Tuesday, September 19th 2023, 6:52:25 pm
Author: garylin2099
'''
"""
import asyncio
import re
import subprocess
import asyncio
import fire
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.logs import logger
class SimpleWriteCode(Action):
PROMPT_TEMPLATE = """
Write a python function that can {instruction} and provide two runnnable test cases.
Return ```python your_code_here ``` with NO other texts,
@ -35,7 +35,6 @@ class SimpleWriteCode(Action):
super().__init__(name, context, llm)
async def run(self, instruction: str):
prompt = self.PROMPT_TEMPLATE.format(instruction=instruction)
rsp = await self._aask(prompt)
@ -46,11 +45,12 @@ class SimpleWriteCode(Action):
@staticmethod
def parse_code(rsp):
pattern = r'```python(.*)```'
pattern = r"```python(.*)```"
match = re.search(pattern, rsp, re.DOTALL)
code_text = match.group(1) if match else rsp
return code_text
class SimpleRunCode(Action):
def __init__(self, name="SimpleRunCode", context=None, llm=None):
super().__init__(name, context, llm)
@ -61,6 +61,7 @@ class SimpleRunCode(Action):
logger.info(f"{code_result=}")
return code_result
class SimpleCoder(Role):
def __init__(
self,
@ -75,7 +76,7 @@ class SimpleCoder(Role):
logger.info(f"{self._setting}: ready to {self._rc.todo}")
todo = self._rc.todo
msg = self._rc.memory.get()[-1] # retrieve the latest memory
msg = self._rc.memory.get()[-1] # retrieve the latest memory
instruction = msg.content
code_text = await SimpleWriteCode().run(instruction)
@ -83,6 +84,7 @@ class SimpleCoder(Role):
return msg
class RunnableCoder(Role):
def __init__(
self,
@ -128,6 +130,7 @@ class RunnableCoder(Role):
await self._act()
return Message(content="All job done", role=self.profile)
def main(msg="write a function that calculates the sum of a list"):
# role = SimpleCoder()
role = RunnableCoder()
@ -135,5 +138,6 @@ def main(msg="write a function that calculates the sum of a list"):
result = asyncio.run(role.run(msg))
logger.info(result)
if __name__ == '__main__':
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,17 +1,19 @@
'''
"""
Filename: MetaGPT/examples/debate.py
Created Date: Tuesday, September 19th 2023, 6:52:25 pm
Author: garylin2099
'''
"""
import asyncio
import platform
import fire
from metagpt.software_company import SoftwareCompany
from metagpt.actions import Action, BossRequirement
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.logs import logger
from metagpt.software_company import SoftwareCompany
class ShoutOut(Action):
"""Action: Shout out loudly in a debate (quarrel)"""
@ -31,7 +33,6 @@ class ShoutOut(Action):
super().__init__(name, context, llm)
async def run(self, context: str, name: str, opponent_name: str):
prompt = self.PROMPT_TEMPLATE.format(context=context, name=name, opponent_name=opponent_name)
# logger.info(prompt)
@ -39,6 +40,7 @@ class ShoutOut(Action):
return rsp
class Trump(Role):
def __init__(
self,
@ -55,7 +57,7 @@ class Trump(Role):
async def _observe(self) -> int:
await super()._observe()
# accept messages sent (from opponent) to self, disregard own messages from the last round
self._rc.news = [msg for msg in self._rc.news if msg.send_to == self.name]
self._rc.news = [msg for msg in self._rc.news if msg.send_to == self.name]
return len(self._rc.news)
async def _act(self) -> Message:
@ -79,6 +81,7 @@ class Trump(Role):
return msg
class Biden(Role):
def __init__(
self,
@ -120,10 +123,12 @@ class Biden(Role):
return msg
async def startup(idea: str, investment: float = 3.0, n_round: int = 5,
code_review: bool = False, run_tests: bool = False):
async def startup(
idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool = False, run_tests: bool = False
):
"""We reuse the startup paradigm for roles to interact with each other.
Now we run a startup of presidents and watch they quarrel. :) """
Now we run a startup of presidents and watch they quarrel. :)"""
company = SoftwareCompany()
company.hire([Biden(), Trump()])
company.invest(investment)
@ -133,7 +138,7 @@ async def startup(idea: str, investment: float = 3.0, n_round: int = 5,
def main(idea: str, investment: float = 3.0, n_round: int = 10):
"""
:param idea: Debate topic, such as "Topic: The U.S. should commit more in climate change fighting"
:param idea: Debate topic, such as "Topic: The U.S. should commit more in climate change fighting"
or "Trump: Climate change is a hoax"
:param investment: contribute a certain dollar amount to watch the debate
:param n_round: maximum rounds of the debate
@ -144,5 +149,5 @@ def main(idea: str, investment: float = 3.0, n_round: int = 10):
asyncio.run(startup(idea, investment, n_round))
if __name__ == '__main__':
if __name__ == "__main__":
fire.Fire(main)

View file

@ -19,19 +19,15 @@ async def main():
Path("../tests/data/invoices/invoice-1.pdf"),
Path("../tests/data/invoices/invoice-2.png"),
Path("../tests/data/invoices/invoice-3.jpg"),
Path("../tests/data/invoices/invoice-4.zip")
Path("../tests/data/invoices/invoice-4.zip"),
]
# The absolute path of the file
absolute_file_paths = [Path.cwd() / path for path in relative_paths]
for path in absolute_file_paths:
role = InvoiceOCRAssistant()
await role.run(Message(
content="Invoicing date",
instruct_content={"file_path": path}
))
await role.run(Message(content="Invoicing date", instruct_content={"file_path": path}))
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -14,11 +14,11 @@ from metagpt.logs import logger
async def main():
llm = LLM()
claude = Claude()
logger.info(await claude.aask('你好,请进行自我介绍'))
logger.info(await llm.aask('hello world'))
logger.info(await llm.aask_batch(['hi', 'write python hello world.']))
logger.info(await claude.aask("你好,请进行自我介绍"))
logger.info(await llm.aask("hello world"))
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
hello_msg = [{'role': 'user', 'content': 'count from 1 to 10. split by newline.'}]
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]
logger.info(await llm.acompletion(hello_msg))
logger.info(await llm.acompletion_batch([hello_msg]))
logger.info(await llm.acompletion_batch_text([hello_msg]))
@ -27,5 +27,5 @@ async def main():
await llm.acompletion_text(hello_msg, stream=True)
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -12,5 +12,5 @@ async def main():
print(f"save report to {RESEARCH_PATH / f'{topic}.md'}.")
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -15,5 +15,5 @@ async def main():
await Searcher().run("What are some good sun protection products?")
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -12,7 +12,7 @@ from metagpt.roles import Sales
async def search():
store = FaissStore(DATA_PATH / 'example.json')
store = FaissStore(DATA_PATH / "example.json")
role = Sales(profile="Sales", store=store)
queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"]
@ -22,5 +22,5 @@ async def search():
logger.info(result)
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(search())

View file

@ -6,11 +6,12 @@ from metagpt.tools import SearchEngineType
async def main():
# Serper API
#await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"])
# await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"])
# SerpAPI
#await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?")
# await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?")
# Google API
await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run("What are the most interesting human facts?")
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,12 +1,13 @@
'''
"""
Filename: MetaGPT/examples/use_off_the_shelf_agent.py
Created Date: Tuesday, September 19th 2023, 6:52:25 pm
Author: garylin2099
'''
"""
import asyncio
from metagpt.roles.product_manager import ProductManager
from metagpt.logs import logger
from metagpt.roles.product_manager import ProductManager
async def main():
msg = "Write a PRD for a snake game"
@ -14,5 +15,6 @@ async def main():
result = await role.run(msg)
logger.info(result.content[:100])
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -16,6 +16,5 @@ async def main():
await role.run(topic)
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main())

View file

@ -23,10 +23,10 @@ class ActionOutput:
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
new_class = create_model(class_name, **mapping)
@validator('*', allow_reuse=True)
@validator("*", allow_reuse=True)
def check_name(v, field):
if field.name not in mapping.keys():
raise ValueError(f'Unrecognized block: {field.name}')
raise ValueError(f"Unrecognized block: {field.name}")
return v
@root_validator(pre=True, allow_reuse=True)
@ -34,10 +34,9 @@ class ActionOutput:
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f'Missing fields: {missing_fields}')
raise ValueError(f"Missing fields: {missing_fields}")
return values
new_class.__validator_check_name = classmethod(check_name)
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
return new_class

View file

@ -10,5 +10,6 @@ from metagpt.actions import Action
class BossRequirement(Action):
"""Boss Requirement without any implementation details"""
async def run(self, *args, **kwargs):
raise NotImplementedError

View file

@ -18,16 +18,13 @@ class AzureTTS(Action):
# Parameters reference: https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles
def synthesize_speech(self, lang, voice, role, text, output_file):
subscription_key = self.config.get('AZURE_TTS_SUBSCRIPTION_KEY')
region = self.config.get('AZURE_TTS_REGION')
speech_config = SpeechConfig(
subscription=subscription_key, region=region)
subscription_key = self.config.get("AZURE_TTS_SUBSCRIPTION_KEY")
region = self.config.get("AZURE_TTS_REGION")
speech_config = SpeechConfig(subscription=subscription_key, region=region)
speech_config.speech_synthesis_voice_name = voice
audio_config = AudioConfig(filename=output_file)
synthesizer = SpeechSynthesizer(
speech_config=speech_config,
audio_config=audio_config)
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
# if voice=="zh-CN-YunxiNeural":
ssml_string = f"""
@ -45,9 +42,4 @@ class AzureTTS(Action):
if __name__ == "__main__":
azure_tts = AzureTTS("azure_tts")
azure_tts.synthesize_speech(
"zh-CN",
"zh-CN-YunxiNeural",
"Boy",
"Hello, I am Kaka",
"output.wav")
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav")

View file

@ -1,5 +1,5 @@
from pathlib import Path
import traceback
from pathlib import Path
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
@ -42,7 +42,7 @@ class CloneFunction(WriteCode):
prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func)
logger.info(f"query for CloneFunction: \n {prompt}")
code = await self.write_code(prompt)
logger.info(f'CloneFunction code is \n {highlight(code)}')
logger.info(f"CloneFunction code is \n {highlight(code)}")
return code
@ -61,5 +61,5 @@ def run_function_script(code_script_path: str, func_name: str, *args, **kwargs):
"""Run function code from script."""
if isinstance(code_script_path, str):
code_path = Path(code_script_path)
code = code_path.read_text(encoding='utf-8')
code = code_path.read_text(encoding="utf-8")
return run_function_code(code, func_name, *args, **kwargs)

View file

@ -7,8 +7,8 @@
"""
import re
from metagpt.logs import logger
from metagpt.actions.action import Action
from metagpt.logs import logger
from metagpt.utils.common import CodeParser
PROMPT_TEMPLATE = """
@ -24,6 +24,8 @@ The message is as follows:
Now you should start rewriting the code:
## file name of the code to rewrite: Write code with triple quoto. Do your best to implement THIS IN ONLY ONE FILE.
"""
class DebugError(Action):
def __init__(self, name="DebugError", context=None, llm=None):
super().__init__(name, context, llm)
@ -33,17 +35,17 @@ class DebugError(Action):
# f"\n\n{error}\n\nPlease try to fix the error in this code."
# fixed_code = await self._aask(prompt)
# return fixed_code
async def run(self, context):
if "PASS" in context:
return "", "the original code works fine, no need to debug"
file_name = re.search("## File To Rewrite:\s*(.+\\.py)", context).group(1)
logger.info(f"Debug and rewrite {file_name}")
prompt = PROMPT_TEMPLATE.format(context=context)
rsp = await self._aask(prompt)
code = CodeParser.parse_code(block="", text=rsp)

View file

@ -13,10 +13,11 @@ class DesignReview(Action):
super().__init__(name, context, llm)
async def run(self, prd, api_design):
prompt = f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed " \
f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements" \
f" of the PRD, and whether it complies with good design practices."
prompt = (
f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed "
f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements"
f" of the PRD, and whether it complies with good design practices."
)
api_review = await self._aask(prompt)
return api_review

View file

@ -17,8 +17,10 @@ Do not add any other explanations, just return a Python string list."""
class DesignFilenames(Action):
def __init__(self, name, context=None, llm=None):
super().__init__(name, context, llm)
self.desc = "Based on the PRD, consider system design, and carry out the basic design of the corresponding " \
"APIs, data structures, and database tables. Please give your design, feedback clearly and in detail."
self.desc = (
"Based on the PRD, consider system design, and carry out the basic design of the corresponding "
"APIs, data structures, and database tables. Please give your design, feedback clearly and in detail."
)
async def run(self, prd):
prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}"
@ -26,4 +28,3 @@ class DesignFilenames(Action):
logger.debug(prompt)
logger.debug(design_filenames)
return design_filenames

View file

@ -6,7 +6,6 @@
@File : detail_mining.py
"""
from metagpt.actions import Action, ActionOutput
from metagpt.logs import logger
PROMPT_TEMPLATE = """
##TOPIC
@ -41,8 +40,8 @@ OUTPUT_MAPPING = {
class DetailMining(Action):
"""This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.
"""
"""This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion."""
def __init__(self, name="", context=None, llm=None):
super().__init__(name, context, llm)

View file

@ -10,8 +10,8 @@
import os
import zipfile
from pathlib import Path
from datetime import datetime
from pathlib import Path
import pandas as pd
from paddleocr import PaddleOCR
@ -19,7 +19,10 @@ from paddleocr import PaddleOCR
from metagpt.actions import Action
from metagpt.const import INVOICE_OCR_TABLE_PATH
from metagpt.logs import logger
from metagpt.prompts.invoice_ocr import EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT
from metagpt.prompts.invoice_ocr import (
EXTRACT_OCR_MAIN_INFO_PROMPT,
REPLY_OCR_QUESTION_PROMPT,
)
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
@ -183,4 +186,3 @@ class ReplyQuestion(Action):
prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language)
resp = await self._aask(prompt=prompt)
return resp

View file

@ -38,4 +38,3 @@ class PrepareInterview(Action):
prompt = PROMPT_TEMPLATE.format(context=context)
question_list = await self._aask_v1(prompt)
return question_list

View file

@ -3,7 +3,6 @@
from __future__ import annotations
import asyncio
import json
from typing import Callable
from pydantic import parse_obj_as
@ -49,7 +48,7 @@ based on the link credibility. If two results have equal credibility, prioritize
ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
"""
WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
WEB_BROWSE_AND_SUMMARIZE_PROMPT = """### Requirements
1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
a comprehensive summary of the text.
@ -58,10 +57,10 @@ a comprehensive summary of the text.
### Reference Information
{content}
'''
"""
CONDUCT_RESEARCH_PROMPT = '''### Reference Information
CONDUCT_RESEARCH_PROMPT = """### Reference Information
{content}
### Requirements
@ -73,11 +72,12 @@ above. The report must meet the following requirements:
- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
- Include all source URLs in APA format at the end of the report.
'''
"""
class CollectLinks(Action):
"""Action class to collect links from a search engine."""
def __init__(
self,
name: str = "",
@ -114,19 +114,24 @@ class CollectLinks(Action):
keywords = OutputParser.extract_struct(keywords, list)
keywords = parse_obj_as(list[str], keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
logger.exception(f'fail to get keywords related to the research topic "{topic}" for {e}')
keywords = [topic]
results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
def gen_msg():
while True:
search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results)
search_results = "\n".join(
f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results)
)
prompt = SUMMARIZE_SEARCH_PROMPT.format(
decomposition_nums=decomposition_nums, search_results=search_results
)
yield prompt
remove = max(results, key=len)
remove.pop()
if len(remove) == 0:
break
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
@ -172,6 +177,7 @@ class CollectLinks(Action):
class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
def __init__(
self,
*args,
@ -214,7 +220,9 @@ class WebBrowseAndSummarize(Action):
for u, content in zip([url, *urls], contents):
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
for prompt in generate_prompt_chunk(
content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp
):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
@ -238,6 +246,7 @@ class WebBrowseAndSummarize(Action):
class ConductResearch(Action):
"""Action class to conduct research and generate a research report."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if CONFIG.model_for_researcher_report:

View file

@ -140,4 +140,3 @@ class SearchAndSummarize(Action):
logger.debug(prompt)
logger.debug(result)
return result

View file

@ -5,13 +5,14 @@
@Author : alexanderwu
@File : write_code.py
"""
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions import WriteDesign
from metagpt.actions.action import Action
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import CodeParser
from tenacity import retry, stop_after_attempt, wait_fixed
PROMPT_TEMPLATE = """
NOTICE
@ -74,9 +75,8 @@ class WriteCode(Action):
async def run(self, context, filename):
prompt = PROMPT_TEMPLATE.format(context=context, filename=filename)
logger.info(f'Writing {filename}..')
logger.info(f"Writing {filename}..")
code = await self.write_code(prompt)
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
# self._save(context, filename, code)
return code

View file

@ -6,11 +6,12 @@
@File : write_code_review.py
"""
from tenacity import retry, stop_after_attempt, wait_fixed
from metagpt.actions.action import Action
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import CodeParser
from tenacity import retry, stop_after_attempt, wait_fixed
PROMPT_TEMPLATE = """
NOTICE
@ -74,9 +75,8 @@ class WriteCodeReview(Action):
async def run(self, context, code, filename):
format_example = FORMAT_EXAMPLE.format(filename=filename)
prompt = PROMPT_TEMPLATE.format(context=context, code=code, filename=filename, format_example=format_example)
logger.info(f'Code review {filename}..')
logger.info(f"Code review {filename}..")
code = await self.write_code(prompt)
# code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING)
# self._save(context, filename, code)
return code

View file

@ -28,7 +28,7 @@ from metagpt.actions.action import Action
from metagpt.utils.common import OutputParser
from metagpt.utils.pycst import merge_docstring
PYTHON_DOCSTRING_SYSTEM = '''### Requirements
PYTHON_DOCSTRING_SYSTEM = """### Requirements
1. Add docstrings to the given code following the {style} style.
2. Replace the function body with an Ellipsis object(...) to reduce output.
3. If the types are already annotated, there is no need to include them in the docstring.
@ -48,7 +48,7 @@ class ExampleError(Exception):
```python
{example}
```
'''
"""
# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
@ -162,7 +162,8 @@ class WriteDocstring(Action):
self.desc = "Write docstring for code."
async def run(
self, code: str,
self,
code: str,
system_text: str = PYTHON_DOCSTRING_SYSTEM,
style: Literal["google", "numpy", "sphinx"] = "google",
) -> str:

View file

@ -25,4 +25,3 @@ class WritePRDReview(Action):
prompt = self.prd_review_prompt_template.format(prd=self.prd)
review = await self._aask(prompt)
return review

View file

@ -10,7 +10,7 @@
from typing import Dict
from metagpt.actions import Action
from metagpt.prompts.tutorial_assistant import DIRECTORY_PROMPT, CONTENT_PROMPT
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
from metagpt.utils.common import OutputParser
@ -65,4 +65,3 @@ class WriteContent(Action):
"""
prompt = CONTENT_PROMPT.format(topic=topic, language=self.language, directory=self.directory)
return await self._aask(prompt=prompt)

View file

@ -46,7 +46,7 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first")
self.openai_api_base = self._get("OPENAI_API_BASE")

View file

@ -41,7 +41,7 @@ class LocalStore(BaseStore, ABC):
self.store = self.write()
def _get_index_and_store_fname(self):
fname = self.raw_data.name.split('.')[0]
fname = self.raw_data.name.split(".")[0]
index_file = self.cache_dir / f"{fname}.index"
store_file = self.cache_dir / f"{fname}.pkl"
return index_file, store_file
@ -53,4 +53,3 @@ class LocalStore(BaseStore, ABC):
@abstractmethod
def _write(self, docs, metadatas):
raise NotImplementedError

View file

@ -10,6 +10,7 @@ import chromadb
class ChromaStore:
"""If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange."""
def __init__(self, name):
client = chromadb.Client()
collection = client.create_collection(name)
@ -22,7 +23,7 @@ class ChromaStore:
query_texts=[query],
n_results=n_results,
where=metadata_filter, # optional filter
where_document=document_filter # optional filter
where_document=document_filter, # optional filter
)
return results

View file

@ -24,20 +24,20 @@ def validate_cols(content_col: str, df: pd.DataFrame):
def read_data(data_path: Path):
suffix = data_path.suffix
if '.xlsx' == suffix:
if ".xlsx" == suffix:
data = pd.read_excel(data_path)
elif '.csv' == suffix:
elif ".csv" == suffix:
data = pd.read_csv(data_path)
elif '.json' == suffix:
elif ".json" == suffix:
data = pd.read_json(data_path)
elif suffix in ('.docx', '.doc'):
data = UnstructuredWordDocumentLoader(str(data_path), mode='elements').load()
elif '.txt' == suffix:
elif suffix in (".docx", ".doc"):
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
elif ".txt" == suffix:
data = TextLoader(str(data_path)).load()
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0)
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
texts = text_splitter.split_documents(data)
data = texts
elif '.pdf' == suffix:
elif ".pdf" == suffix:
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
else:
raise NotImplementedError
@ -45,8 +45,7 @@ def read_data(data_path: Path):
class Document:
def __init__(self, data_path, content_col='content', meta_col='metadata'):
def __init__(self, data_path, content_col="content", meta_col="metadata"):
self.data = read_data(data_path)
if isinstance(self.data, pd.DataFrame):
validate_cols(content_col, self.data)
@ -79,4 +78,3 @@ class Document:
return self._get_docs_and_metadatas_by_langchain()
else:
raise NotImplementedError

View file

@ -20,7 +20,7 @@ from metagpt.logs import logger
class FaissStore(LocalStore):
def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'):
def __init__(self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output"):
self.meta_col = meta_col
self.content_col = content_col
super().__init__(raw_data, cache_dir)
@ -50,7 +50,7 @@ class FaissStore(LocalStore):
pickle.dump(store, f)
store.index = index
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
@ -78,8 +78,8 @@ class FaissStore(LocalStore):
raise NotImplementedError
if __name__ == '__main__':
faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
logger.info(faiss_store.search('Oily Skin Facial Cleanser'))
faiss_store.add([f'Oily Skin Facial Cleanser-{i}' for i in range(3)])
logger.info(faiss_store.search('Oily Skin Facial Cleanser'))
if __name__ == "__main__":
faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))
faiss_store.add([f"Oily Skin Facial Cleanser-{i}" for i in range(3)])
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))

View file

@ -12,12 +12,7 @@ from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connec
from metagpt.document_store.base_store import BaseStore
type_mapping = {
int: DataType.INT64,
str: DataType.VARCHAR,
float: DataType.DOUBLE,
np.ndarray: DataType.FLOAT_VECTOR
}
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
@ -52,17 +47,11 @@ class MilvusStore(BaseStore):
self.collection = None
def _create_collection(self, name, schema):
collection = Collection(
name=name,
schema=schema,
using='default',
shards_num=2,
consistency_level="Strong"
)
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
return collection
def create_collection(self, name, columns):
schema = columns_to_milvus_schema(columns, 'idx')
schema = columns_to_milvus_schema(columns, "idx")
self.collection = self._create_collection(name, schema)
return self.collection
@ -72,7 +61,7 @@ class MilvusStore(BaseStore):
def load_collection(self):
self.collection.load()
def build_index(self, field='emb'):
def build_index(self, field="emb"):
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
def search(self, query: list[list[float]], *args, **kwargs):
@ -85,11 +74,11 @@ class MilvusStore(BaseStore):
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = self.collection.search(
data=query,
anns_field=kwargs.get('field', 'emb'),
anns_field=kwargs.get("field", "emb"),
param=search_params,
limit=10,
expr=None,
consistency_level="Strong"
consistency_level="Strong",
)
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
return results

View file

@ -10,13 +10,14 @@ from metagpt.document_store.base_store import BaseStore
@dataclass
class QdrantConnection:
"""
Args:
url: qdrant url
host: qdrant host
port: qdrant port
memory: qdrant service use memory mode
api_key: qdrant cloud api_key
"""
Args:
url: qdrant url
host: qdrant host
port: qdrant port
memory: qdrant service use memory mode
api_key: qdrant cloud api_key
"""
url: str = None
host: str = None
port: int = None
@ -31,9 +32,7 @@ class QdrantStore(BaseStore):
elif connect.url:
self.client = QdrantClient(url=connect.url, api_key=connect.api_key)
elif connect.host and connect.port:
self.client = QdrantClient(
host=connect.host, port=connect.port, api_key=connect.api_key
)
self.client = QdrantClient(host=connect.host, port=connect.port, api_key=connect.api_key)
else:
raise Exception("please check QdrantConnection.")
@ -58,15 +57,11 @@ class QdrantStore(BaseStore):
try:
self.client.get_collection(collection_name)
if force_recreate:
res = self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
res = self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs)
return res
return True
except: # noqa: E722
return self.client.recreate_collection(
collection_name, vectors_config=vectors_config, **kwargs
)
return self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs)
def has_collection(self, collection_name: str):
try:

View file

@ -17,34 +17,34 @@ from metagpt.schema import Message
class Environment(BaseModel):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
"""
roles: dict[str, Role] = Field(default_factory=dict)
memory: Memory = Field(default_factory=Memory)
history: str = Field(default='')
history: str = Field(default="")
class Config:
arbitrary_types_allowed = True
def add_role(self, role: Role):
"""增加一个在当前环境的角色
Add a role in the current environment
Add a role in the current environment
"""
role.set_env(self)
self.roles[role.profile] = role
def add_roles(self, roles: Iterable[Role]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
Add a batch of characters in the current environment
"""
for role in roles:
self.add_role(role)
def publish_message(self, message: Message):
"""向当前环境发布信息
Post information to the current environment
Post information to the current environment
"""
# self.message_queue.put(message)
self.memory.add(message)
@ -68,12 +68,12 @@ class Environment(BaseModel):
def get_roles(self) -> dict[str, Role]:
"""获得环境内的所有角色
Process all Role runs at once
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> Role:
"""获得环境内的指定角色
get all the environment roles
get all the environment roles
"""
return self.roles.get(name, None)

View file

@ -12,17 +12,17 @@ import metagpt # replace with your module
def print_classes_and_functions(module):
"""FIXME: NOT WORK.. """
"""FIXME: NOT WORK.."""
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj):
print(f'Class: {name}')
print(f"Class: {name}")
elif inspect.isfunction(obj):
print(f'Function: {name}')
print(f"Function: {name}")
else:
print(name)
print(dir(module))
if __name__ == '__main__':
print_classes_and_functions(metagpt)
if __name__ == "__main__":
print_classes_and_functions(metagpt)

View file

@ -12,8 +12,9 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
DEFAULT_LLM = LLM()
CLAUDE_LLM = Claude()
async def ai_func(prompt):
"""使用LLM进行QA
QA with LLMs
"""
QA with LLMs
"""
return await DEFAULT_LLM.aask(prompt)

View file

@ -12,13 +12,15 @@ from loguru import logger as _logger
from metagpt.const import PROJECT_ROOT
def define_log_level(print_level="INFO", logfile_level="DEBUG"):
"""调整日志级别到level之上
Adjust the log level to above level
Adjust the log level to above level
"""
_logger.remove()
_logger.add(sys.stderr, level=print_level)
_logger.add(PROJECT_ROOT / 'logs/log.txt', level=logfile_level)
_logger.add(PROJECT_ROOT / "logs/log.txt", level=logfile_level)
return _logger
logger = define_log_level()

View file

@ -19,8 +19,8 @@ class SkillManager:
def __init__(self):
self._llm = LLM()
self._store = ChromaStore('skill_manager')
self._skills: dict[str: Skill] = {}
self._store = ChromaStore("skill_manager")
self._skills: dict[str:Skill] = {}
def add_skill(self, skill: Skill):
"""
@ -54,7 +54,7 @@ class SkillManager:
:param desc: Skill description
:return: Multiple skills
"""
return self._store.search(desc, n_results=n_results)['ids'][0]
return self._store.search(desc, n_results=n_results)["ids"][0]
def retrieve_skill_scored(self, desc: str, n_results: int = 2) -> dict:
"""
@ -75,6 +75,6 @@ class SkillManager:
logger.info(text)
if __name__ == '__main__':
if __name__ == "__main__":
manager = SkillManager()
manager.generate_skill_desc(Action())

View file

@ -18,7 +18,7 @@ class Manager:
"Product Manager": "Architect",
"Architect": "Engineer",
"Engineer": "QA Engineer",
"QA Engineer": "Product Manager"
"QA Engineer": "Product Manager",
}
self.prompt_template = """
Given the following message:
@ -51,7 +51,7 @@ class Manager:
# chosen_role_name = self.llm.ask(self.prompt_template.format(context))
# FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程
#The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards
# The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards
next_role_profile = self.role_directions[message.role]
# logger.debug(f"{next_role_profile}")
for _, role in roles.items():

View file

@ -68,4 +68,3 @@ class LongTermMemory(Memory):
def clear(self):
super(LongTermMemory, self).clear()
self.memory_storage.clean()

View file

@ -85,4 +85,3 @@ class Memory:
continue
rsp += self.index[action]
return rsp

View file

@ -2,16 +2,16 @@
# -*- coding: utf-8 -*-
# @Desc : the implement of memory storage
from typing import List
from pathlib import Path
from typing import List
from langchain.vectorstores.faiss import FAISS
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.serialize import serialize_message, deserialize_message
from metagpt.document_store.faiss_store import FaissStore
from metagpt.utils.serialize import deserialize_message, serialize_message
class MemoryStorage(FaissStore):
@ -34,7 +34,7 @@ class MemoryStorage(FaissStore):
def recover_memory(self, role_id: str) -> List[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.store = self._load()
@ -51,18 +51,18 @@ class MemoryStorage(FaissStore):
def _get_index_and_store_fname(self):
if not self.role_mem_path:
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
return index_fpath, storage_fpath
def persist(self):
super(MemoryStorage, self).persist()
logger.debug(f'Agent {self.role_id} persist memory into local')
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message) -> bool:
""" add message into memory storage"""
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
@ -79,10 +79,7 @@ class MemoryStorage(FaissStore):
if not self.store:
return []
resp = self.store.similarity_search_with_score(
query=message.content,
k=k
)
resp = self.store.similarity_search_with_score(query=message.content, k=k)
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:
@ -104,4 +101,3 @@ class MemoryStorage(FaissStore):
self.store = None
self._initialized = False

View file

@ -10,7 +10,9 @@
COMMON_PROMPT = "Now I will provide you with the OCR text recognition results for the invoice."
EXTRACT_OCR_MAIN_INFO_PROMPT = COMMON_PROMPT + """
EXTRACT_OCR_MAIN_INFO_PROMPT = (
COMMON_PROMPT
+ """
Please extract the payee, city, total cost, and invoicing date of the invoice.
The OCR data of the invoice are as follows:
@ -22,8 +24,11 @@ Mandatory restrictions are returned according to the following requirements:
2. The returned JSON dictionary must be returned in {language}
3. Mandatory requirement to output in JSON format: {{"收款人":"x","城市":"x","总费用/元":"","开票日期":""}}.
"""
)
REPLY_OCR_QUESTION_PROMPT = COMMON_PROMPT + """
REPLY_OCR_QUESTION_PROMPT = (
COMMON_PROMPT
+ """
Please answer the question: {query}
The OCR data of the invoice are as follows:
@ -34,6 +39,6 @@ Mandatory restrictions are returned according to the following requirements:
2. Enforce restrictions on not returning OCR data sent to you.
3. Return with markdown syntax layout.
"""
)
INVOICE_OCR_SUCCESS = "Successfully completed OCR text recognition invoice."

View file

@ -54,10 +54,12 @@ Conversation history:
{salesperson_name}:
"""
conversation_stages = {'1' : "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.",
'2': "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.",
'3': "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.",
'4': "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.",
'5': "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.",
'6': "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.",
'7': "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits."}
conversation_stages = {
"1": "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.",
"2": "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.",
"3": "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.",
"4": "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.",
"5": "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.",
"6": "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.",
"7": "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits.",
}

View file

@ -12,7 +12,9 @@ You are now a seasoned technical professional in the field of the internet.
We need you to write a technical tutorial with the topic "{topic}".
"""
DIRECTORY_PROMPT = COMMON_PROMPT + """
DIRECTORY_PROMPT = (
COMMON_PROMPT
+ """
Please provide the specific table of contents for this tutorial, strictly following the following requirements:
1. The output must be strictly in the specified language, {language}.
2. Answer strictly in the dictionary format like {{"title": "xxx", "directory": [{{"dir 1": ["sub dir 1", "sub dir 2"]}}, {{"dir 2": ["sub dir 3", "sub dir 4"]}}]}}.
@ -20,8 +22,11 @@ Please provide the specific table of contents for this tutorial, strictly follow
4. Do not have extra spaces or line breaks.
5. Each directory title has practical significance.
"""
)
CONTENT_PROMPT = COMMON_PROMPT + """
CONTENT_PROMPT = (
COMMON_PROMPT
+ """
Now I will give you the module directory titles for the topic.
Please output the detailed principle content of this title in detail.
If there are code examples, please provide them according to standard code specifications.
@ -36,4 +41,5 @@ Strictly limit output according to the following requirements:
3. The output must be strictly in the specified language, {language}.
4. Do not have redundant output, including concluding remarks.
5. Strict requirement not to output the topic "{topic}".
"""
"""
)

View file

@ -32,4 +32,3 @@ class Claude2:
max_tokens_to_sample=1000,
)
return res.completion

View file

@ -12,6 +12,7 @@ from dataclasses import dataclass
@dataclass
class BaseChatbot(ABC):
"""Abstract GPT class"""
mode: str = "API"
@abstractmethod
@ -25,4 +26,3 @@ class BaseChatbot(ABC):
@abstractmethod
def ask_code(self, msgs: list) -> str:
"""Ask GPT multiple questions and get a piece of code"""

View file

@ -14,7 +14,8 @@ from metagpt.provider.base_chatbot import BaseChatbot
class BaseGPTAPI(BaseChatbot):
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
system_prompt = 'You are a helpful assistant.'
system_prompt = "You are a helpful assistant."
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
@ -110,9 +111,8 @@ class BaseGPTAPI(BaseChatbot):
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return '\n'.join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]

View file

@ -110,7 +110,6 @@ class CostManager(metaclass=Singleton):
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
@ -120,7 +119,6 @@ class CostManager(metaclass=Singleton):
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)

View file

@ -14,8 +14,7 @@ import json
import ssl
from time import mktime
from typing import Optional
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
@ -26,9 +25,8 @@ from metagpt.provider.base_gpt_api import BaseGPTAPI
class SparkAPI(BaseGPTAPI):
def __init__(self):
logger.warning('当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。')
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def ask(self, msg: str) -> str:
message = [self._default_system_msg(), self._user_msg(msg)]
@ -49,7 +47,7 @@ class SparkAPI(BaseGPTAPI):
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
# 不支持
logger.error('该功能禁用。')
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)
return w.run()
@ -93,29 +91,26 @@ class GetMessageFromWeb:
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = hmac.new(
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数生成url
url = self.spark_url + '?' + urlencode(v)
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def __init__(self, text):
self.text = text
self.ret = ''
self.ret = ""
self.spark_appid = CONFIG.spark_appid
self.spark_api_secret = CONFIG.spark_api_secret
self.spark_api_key = CONFIG.spark_api_key
@ -124,15 +119,15 @@ class GetMessageFromWeb:
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
code = data["header"]["code"]
if code != 0:
ws.close() # 请求错误则关闭socket
logger.critical(f'回答获取失败,响应信息反序列化之后为: {data}')
logger.critical(f"回答获取失败,响应信息反序列化之后为: {data}")
return
else:
choices = data["payload"]["choices"]
seq = choices["seq"] # 服务端是流式返回seq为返回的数据序号
# seq = choices["seq"] # 服务端是流式返回seq为返回的数据序号
status = choices["status"] # 服务端是流式返回status用于判断信息是否传送完毕
content = choices["text"][0]["content"] # 本次接收到的回答文本
self.ret += content
@ -142,7 +137,7 @@ class GetMessageFromWeb:
# 收到websocket错误的处理
def on_error(self, ws, error):
# on_message方法处理接收到的信息出现任何错误都会调用这个方法
logger.critical(f'通讯连接出错,【错误提示: {error}')
logger.critical(f"通讯连接出错,【错误提示: {error}")
# 收到websocket关闭的处理
def on_close(self, ws, one, two):
@ -150,17 +145,12 @@ class GetMessageFromWeb:
# 处理请求数据
def gen_params(self):
data = {
"header": {
"app_id": self.spark_appid,
"uid": "1234"
},
"header": {"app_id": self.spark_appid, "uid": "1234"},
"parameter": {
"chat": {
# domain为必传参数
"domain": self.domain,
# 以下为可微调,非必传参数
# 注意官方建议temperature和top_k修改一个即可
"max_tokens": 2048, # 默认2048模型回答的tokens的最大长度即允许它输出文本的最长字数
@ -168,11 +158,7 @@ class GetMessageFromWeb:
"top_k": 4, # 取值为[16],默认为4。从k个候选中随机选择一个非等概率
}
},
"payload": {
"message": {
"text": self.text
}
}
"payload": {"message": {"text": self.text}},
}
return data
@ -189,17 +175,12 @@ class GetMessageFromWeb:
return self._run(self.text)
def _run(self, text_list):
ws_param = self.WsParam(
self.spark_appid,
self.spark_api_key,
self.spark_api_secret,
self.spark_url,
text_list)
ws_param = self.WsParam(self.spark_appid, self.spark_api_key, self.spark_api_secret, self.spark_url, text_list)
ws_url = ws_param.create_url()
websocket.enableTrace(False) # 默认禁用 WebSocket 的跟踪功能
ws = websocket.WebSocketApp(ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close,
on_open=self.on_open)
ws = websocket.WebSocketApp(
ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open
)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return self.ret

View file

@ -24,12 +24,5 @@ DESC = """
class CustomerService(Sales):
def __init__(
self,
name="Xiaomei",
profile="Human customer service",
desc=DESC,
store=None
):
def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None):
super().__init__(name, profile, desc=desc, store=store)

View file

@ -9,7 +9,7 @@
import pandas as pd
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
from metagpt.prompts.invoice_ocr import INVOICE_OCR_SUCCESS
from metagpt.roles import Role
from metagpt.schema import Message
@ -107,4 +107,3 @@ class InvoiceOCRAssistant(Role):
break
msg = await self._act()
return msg

View file

@ -23,6 +23,7 @@ SUFFIX = """Let's begin!
Question: {input}
Thoughts: {agent_scratchpad}"""
class PromptString(Enum):
REFLECTION_QUESTIONS = "Here are some statements:\n{memory_descriptions}\n\nBased solely on the information above, what are the 3 most prominent high-level questions we can answer about the topic in the statements?\n\n{format_instructions}"
@ -32,7 +33,7 @@ class PromptString(Enum):
RECENT_ACTIVITY = "Based on the following memory, produce a brief summary of what {full_name} has been up to recently. Do not invent details not explicitly stated in the memory. For any conversation, be sure to mention whether the conversation has concluded or is still ongoing.\n\nMemory: {memory_descriptions}"
MAKE_PLANS = "You are a plan-generating AI. Your job is to assist the character in formulating new plans based on new information. Given the character's information (profile, objectives, recent activities, current plans, and location context) and their current thought process, produce a new set of plans for them. The final plan should comprise at least {time_window} of activities and no more than 5 individual plans. List the plans in the order they should be executed, with each plan detailing its description, location, start time, stop criteria, and maximum duration.\n\nSample plan: {{\"index\": 1, \"description\": \"Cook dinner\", \"location_id\": \"0a3bc22b-36aa-48ab-adb0-18616004caed\",\"start_time\": \"2022-12-12T20:00:00+00:00\",\"max_duration_hrs\": 1.5, \"stop_condition\": \"Dinner is fully prepared\"}}\'\n\nFor each plan, choose the most appropriate location name from this list: {allowed_location_descriptions}\n\n{format_instructions}\n\nAlways prioritize completing any unfinished conversations.\n\nLet's begin!\n\nName: {full_name}\nProfile: {private_bio}\nObjectives: {directives}\nLocation Context: {location_context}\nCurrent Plans: {current_plans}\nRecent Activities: {recent_activity}\nThought Process: {thought_process}\nIt's essential to encourage the character to collaborate with other characters in their plans.\n\n"
MAKE_PLANS = 'You are a plan-generating AI. Your job is to assist the character in formulating new plans based on new information. Given the character\'s information (profile, objectives, recent activities, current plans, and location context) and their current thought process, produce a new set of plans for them. The final plan should comprise at least {time_window} of activities and no more than 5 individual plans. List the plans in the order they should be executed, with each plan detailing its description, location, start time, stop criteria, and maximum duration.\n\nSample plan: {{"index": 1, "description": "Cook dinner", "location_id": "0a3bc22b-36aa-48ab-adb0-18616004caed","start_time": "2022-12-12T20:00:00+00:00","max_duration_hrs": 1.5, "stop_condition": "Dinner is fully prepared"}}\'\n\nFor each plan, choose the most appropriate location name from this list: {allowed_location_descriptions}\n\n{format_instructions}\n\nAlways prioritize completing any unfinished conversations.\n\nLet\'s begin!\n\nName: {full_name}\nProfile: {private_bio}\nObjectives: {directives}\nLocation Context: {location_context}\nCurrent Plans: {current_plans}\nRecent Activities: {recent_activity}\nThought Process: {thought_process}\nIt\'s essential to encourage the character to collaborate with other characters in their plans.\n\n'
EXECUTE_PLAN = "You are a role-playing AI, playing the role of {your_name}, in front of a live audience. Every word you say can be observed by the audience, so make sure you speak often and make it interesting. You cannot interact directly with the audience.\n\nBased on the following context and tools, proceed as if you are {your_name}. Your primary task is to complete the task given below, however, if you are currently in a conversation with another character, you should always complete the conversation first and then proceed with the task. Do not start working while you are engaged in an unfinished conversation. Use your best judgment to determine whether a conversation involves you and whether it is unfinished. You don't need to respond to every message you receive.\n\nThis is a personal profile of your character:\n{your_private_bio}\n\nThis is a description of your location and other nearby characters you can converse with:\n\n{location_context}\n\nThese are some relevant memories:\n```\n{relevant_memories}\n```\n\nThis is some relevant conversation history:\n```\n{conversation_history}\n```\n\nThese are the tools you can use:\n{tools}\n\nYour response should follow the format below:\n\nTask: The task you must complete\nThoughts: What you should think about doing\nAction: The action to take, must be one of these words: [{tool_names}]\nAction Input: Input for the action\nObservation: The result of the action\n... (This Thoughts/Action/Action Input/Observation can be repeated N times)\nThoughts: 'I have completed the task'\nFinal Response: The final response to the task\n\nIf you are not ready with a final response, then you must take an action.\n\nIf you are certain that you cannot complete the task with the tools provided, return 'Final Response: Need Help', however, if you are in conversation with another character, responses like 'I don't know' are valid. During conversations, you should never break character or admit you are an AI.\nIf the task is completed and no specific response is required, return 'Final Response: Done'\nLet's begin!\n\nTask: {input}\n\n{agent_scratchpad}"
@ -40,6 +41,6 @@ class PromptString(Enum):
GOSSIP = "You are {full_name}. \n{memory_descriptions}\n\nBased on the statements above, say a thing or two of interest to others at your location: {other_agent_names}.\nAlways specify their names when referring to others."
HAS_HAPPENED = "Given the descriptions of the observations of the following characters and the events they are awaiting, indicate whether the character has witnessed the event.\n{format_instructions}\n\nExample:\n\nObservations:\nJoe entered the office at 2023-05-04 08:00:00+00:00\nJoe said hi to Sally at 2023-05-04 08:05:00+00:00\nSally said hello to Joe at 2023-05-04 08:05:30+00:00\nRebecca started working at 2023-05-04 08:10:00+00:00\nJoe made some breakfast at 2023-05-04 08:15:00+00:00\n\nAwaiting: Sally responded to Joe\n\nYour response: '{{\"has_happened\": true, \"date_occured\": 2023-05-04 08:05:30+00:00}}'\n\nLet's begin!\n\nObservations:\n{memory_descriptions}\n\nAwaiting: {event_description}\n"
HAS_HAPPENED = 'Given the descriptions of the observations of the following characters and the events they are awaiting, indicate whether the character has witnessed the event.\n{format_instructions}\n\nExample:\n\nObservations:\nJoe entered the office at 2023-05-04 08:00:00+00:00\nJoe said hi to Sally at 2023-05-04 08:05:00+00:00\nSally said hello to Joe at 2023-05-04 08:05:30+00:00\nRebecca started working at 2023-05-04 08:10:00+00:00\nJoe made some breakfast at 2023-05-04 08:15:00+00:00\n\nAwaiting: Sally responded to Joe\n\nYour response: \'{{"has_happened": true, "date_occured": 2023-05-04 08:05:30+00:00}}\'\n\nLet\'s begin!\n\nObservations:\n{memory_descriptions}\n\nAwaiting: {event_description}\n'
OUTPUT_FORMAT = "\n\n(Remember! Make sure your output always adheres to one of the following two formats:\n\nA. If you have completed the task:\nThoughts: 'I have completed the task'\nFinal Response: <str>\n\nB. If you haven't completed the task:\nThoughts: <str>\nAction: <str>\nAction Input: <str>\nObservation: <str>)\n"

View file

@ -11,12 +11,13 @@ from typing import Iterable, Type
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
# from metagpt.environment import Environment
from metagpt.config import CONFIG
from metagpt.actions import Action, ActionOutput
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory, LongTermMemory
from metagpt.memory import LongTermMemory, Memory
from metagpt.schema import Message
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
@ -49,6 +50,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi
class RoleSetting(BaseModel):
"""Role Settings"""
name: str
profile: str
goal: str
@ -64,7 +66,8 @@ class RoleSetting(BaseModel):
class RoleContext(BaseModel):
"""Role Runtime Context"""
env: 'Environment' = Field(default=None)
env: "Environment" = Field(default=None)
memory: Memory = Field(default_factory=Memory)
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
state: int = Field(default=0)
@ -128,7 +131,7 @@ class Role:
logger.debug(self._actions)
self._rc.todo = self._actions[self._rc.state]
def set_env(self, env: 'Environment'):
def set_env(self, env: "Environment"):
"""Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing."""
self._rc.env = env
@ -150,12 +153,13 @@ class Role:
self._set_state(0)
return
prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states),
n_states=len(self._states) - 1)
prompt += STATE_TEMPLATE.format(
history=self._rc.history, states="\n".join(self._states), n_states=len(self._states) - 1
)
next_state = await self._llm.aask(prompt)
logger.debug(f"{prompt=}")
if not next_state.isdigit() or int(next_state) not in range(len(self._states)):
logger.warning(f'Invalid answer of state, {next_state=}')
logger.warning(f"Invalid answer of state, {next_state=}")
next_state = "0"
self._set_state(int(next_state))
@ -168,8 +172,12 @@ class Role:
response = await self._rc.todo.run(self._rc.important_memory)
# logger.info(response)
if isinstance(response, ActionOutput):
msg = Message(content=response.content, instruct_content=response.instruct_content,
role=self.profile, cause_by=type(self._rc.todo))
msg = Message(
content=response.content,
instruct_content=response.instruct_content,
role=self.profile,
cause_by=type(self._rc.todo),
)
else:
msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(msg)
@ -184,15 +192,17 @@ class Role:
env_msgs = self._rc.env.memory.get()
observed = self._rc.env.memory.get_by_actions(self._rc.watch)
self._rc.news = self._rc.memory.find_news(observed) # find news (previously unseen messages) from observed messages
self._rc.news = self._rc.memory.find_news(
observed
) # find news (previously unseen messages) from observed messages
for i in env_msgs:
self.recv(i)
news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
if news_text:
logger.debug(f'{self._setting} observed: {news_text}')
logger.debug(f"{self._setting} observed: {news_text}")
return len(self._rc.news)
def _publish_message(self, msg):

View file

@ -12,16 +12,16 @@ from metagpt.tools import SearchEngineType
class Sales(Role):
def __init__(
self,
name="Xiaomei",
profile="Retail sales guide",
desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
"will answer questions only based on the information in the knowledge base."
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
" I don't know, and I won't tell you that this is from the knowledge base,"
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
"professional guide",
store=None
self,
name="Xiaomei",
profile="Retail sales guide",
desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
"will answer questions only based on the information in the knowledge base."
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
" I don't know, and I won't tell you that this is from the knowledge base,"
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
"professional guide",
store=None,
):
super().__init__(name, profile, desc=desc)
self._set_store(store)
@ -32,4 +32,3 @@ class Sales(Role):
else:
action = SearchAndSummarize()
self._init_actions([action])

View file

@ -15,7 +15,7 @@ from metagpt.tools import SearchEngineType
class Searcher(Role):
"""
Represents a Searcher role responsible for providing search services to users.
Attributes:
name (str): Name of the searcher.
profile (str): Role profile.
@ -23,17 +23,19 @@ class Searcher(Role):
constraints (str): Constraints or limitations for the searcher.
engine (SearchEngineType): The type of search engine to use.
"""
def __init__(self,
name: str = 'Alice',
profile: str = 'Smart Assistant',
goal: str = 'Provide search services for users',
constraints: str = 'Answer is rich and complete',
engine=SearchEngineType.SERPAPI_GOOGLE,
**kwargs) -> None:
def __init__(
self,
name: str = "Alice",
profile: str = "Smart Assistant",
goal: str = "Provide search services for users",
constraints: str = "Answer is rich and complete",
engine=SearchEngineType.SERPAPI_GOOGLE,
**kwargs,
) -> None:
"""
Initializes the Searcher role with given attributes.
Args:
name (str): Name of the searcher.
profile (str): Role profile.
@ -53,10 +55,14 @@ class Searcher(Role):
"""Performs the search action in a single process."""
logger.info(f"{self._setting}: ready to {self._rc.todo}")
response = await self._rc.todo.run(self._rc.memory.get(k=0))
if isinstance(response, ActionOutput):
msg = Message(content=response.content, instruct_content=response.instruct_content,
role=self.profile, cause_by=type(self._rc.todo))
msg = Message(
content=response.content,
instruct_content=response.instruct_content,
role=self.profile,
cause_by=type(self._rc.todo),
)
else:
msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(msg)

View file

@ -9,7 +9,7 @@
from datetime import datetime
from typing import Dict
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
from metagpt.const import TUTORIAL_PATH
from metagpt.logs import logger
from metagpt.roles import Role
@ -110,5 +110,5 @@ class TutorialAssistant(Role):
break
msg = await self._act()
root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
await File.write(root_path, f"{self.main_title}.md", self.total_content.encode('utf-8'))
await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8"))
return msg

View file

@ -23,9 +23,10 @@ class RawMessage(TypedDict):
@dataclass
class Message:
"""list[<role>: <content>]"""
content: str
instruct_content: BaseModel = field(default=None)
role: str = field(default='user') # system / user / assistant
role: str = field(default="user") # system / user / assistant
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
@ -39,45 +40,45 @@ class Message:
return self.__str__()
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content
}
return {"role": self.role, "content": self.content}
@dataclass
class UserMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'user')
super().__init__(content, "user")
@dataclass
class SystemMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'system')
super().__init__(content, "system")
@dataclass
class AIMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'assistant')
super().__init__(content, "assistant")
if __name__ == '__main__':
test_content = 'test_message'
if __name__ == "__main__":
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
logger.info(msgs)

View file

@ -21,6 +21,7 @@ class SoftwareCompany(BaseModel):
Software Company: Possesses a team, SOP (Standard Operating Procedures), and a platform for instant messaging,
dedicated to writing executable code.
"""
environment: Environment = Field(default_factory=Environment)
investment: float = Field(default=10.0)
idea: str = Field(default="")
@ -36,11 +37,11 @@ class SoftwareCompany(BaseModel):
"""Invest company. raise NoMoneyException when exceed max_budget."""
self.investment = investment
CONFIG.max_budget = investment
logger.info(f'Investment: ${investment}.')
logger.info(f"Investment: ${investment}.")
def _check_balance(self):
if CONFIG.total_cost > CONFIG.max_budget:
raise NoMoneyException(CONFIG.total_cost, f'Insufficient funds: {CONFIG.max_budget}')
raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}")
def start_project(self, idea):
"""Start a project from publishing boss requirement."""
@ -59,4 +60,3 @@ class SoftwareCompany(BaseModel):
self._check_balance()
await self.environment.run()
return self.environment.history

View file

@ -1,22 +1,26 @@
import inspect
import re
from typing import List, Callable, Dict
import textwrap
from pathlib import Path
from typing import Callable, Dict, List
import wrapt
import textwrap
import inspect
from interpreter.core.core import Interpreter
from metagpt.logs import logger
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.highlight import highlight
from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script
def extract_python_code(code: str):
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
# Use regular expressions to match comment blocks and related code.
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
matches = re.findall(pattern, code, re.DOTALL)
# Extract the last code block when encountering the same comment.
@ -25,8 +29,8 @@ def extract_python_code(code: str):
unique_comments[comment] = code_block
# concatenate into functional form
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[:code.find("#")]
result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
header_code = code[: code.find("#")]
code = header_code + result_code
logger.info(f"Extract python code: \n {highlight(code)}")
@ -36,6 +40,7 @@ def extract_python_code(code: str):
class OpenCodeInterpreter(object):
"""https://github.com/KillianLucas/open-interpreter"""
def __init__(self, auto_run: bool = True) -> None:
interpreter = Interpreter()
interpreter.auto_run = auto_run
@ -50,15 +55,16 @@ class OpenCodeInterpreter(object):
return self.interpreter.chat(query)
@staticmethod
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
function_format: str = None) -> str:
def extract_function(
query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
) -> str:
"""create a function from query_respond."""
if language not in ('python'):
if language not in ("python"):
raise NotImplementedError(f"Not support to parse language {language}!")
# set function form
if function_format is None:
assert language == 'python', f"Expect python language for default function_format, but got {language}."
assert language == "python", f"Expect python language for default function_format, but got {language}."
function_format = """def {function_name}():\n{code}"""
# Extract the code module in the open-interpreter respond message.
# The query_respond of open-interpreter before v0.1.4 is:
@ -68,25 +74,29 @@ class OpenCodeInterpreter(object):
# "parsed_arguments": {"language": "python", "code": code of first plan}
# ...]
if "function_call" in query_respond[1]:
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and 'language' in item["function_call"]['parsed_arguments']
and item["function_call"]['parsed_arguments']['language'] == language]
code = [
item["function_call"]["parsed_arguments"]["code"]
for item in query_respond
if "function_call" in item
and "parsed_arguments" in item["function_call"]
and "language" in item["function_call"]["parsed_arguments"]
and item["function_call"]["parsed_arguments"]["language"] == language
]
# The query_respond of open-interpreter v0.1.7 is:
# [{'role': 'user', 'message': your query string},
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
# 'code': code of first plan, 'output': output of first plan code},
# ...]
elif "code" in query_respond[1]:
code = [item['code'] for item in query_respond
if "code" in item
and 'language' in item
and item['language'] == language]
code = [
item["code"]
for item in query_respond
if "code" in item and "language" in item and item["language"] == language
]
else:
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
# add indent.
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
indented_code_str = textwrap.indent("\n".join(code), " " * 4)
# Return the code after deduplication.
if language == "python":
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
@ -115,13 +125,13 @@ class OpenInterpreterDecorator(object):
def _have_code(self, rsp: List[Dict]):
# Is there any code generated?
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
def _is_faild_plan(self, rsp: List[Dict]):
# is faild plan?
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
func_code = OpenCodeInterpreter.extract_function(rsp, "function")
# If there is no more than 1 '\n', the plan execution fails.
if isinstance(func_code, str) and func_code.count('\n') <= 1:
if isinstance(func_code, str) and func_code.count("\n") <= 1:
return True
return False
@ -184,4 +194,5 @@ class OpenInterpreterDecorator(object):
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
raise Exception("Could not evaluate Python code", e)
return res
return wrapper(wrapped)

View file

@ -10,8 +10,9 @@ from typing import Union
class GPTPromptGenerator:
"""Using LLM, given an output, request LLM to provide input (supporting instruction, chatbot, and query styles)"""
def __init__(self):
self._generators = {i: getattr(self, f"gen_{i}_style") for i in ['instruction', 'chatbot', 'query']}
self._generators = {i: getattr(self, f"gen_{i}_style") for i in ["instruction", "chatbot", "query"]}
def gen_instruction_style(self, example):
"""Instruction style: Given an output, request LLM to provide input"""
@ -35,7 +36,7 @@ Query: X
Document: {example} What is the detailed query X?
X:"""
def gen(self, example: str, style: str = 'all') -> Union[list[str], str]:
def gen(self, example: str, style: str = "all") -> Union[list[str], str]:
"""
Generate one or multiple outputs using the example, allowing LLM to reply with the corresponding input
@ -43,7 +44,7 @@ X:"""
:param style: (all|instruction|chatbot|query)
:return: Expected LLM input sample (one or multiple)
"""
if style != 'all':
if style != "all":
return self._generators[style](example)
return [f(example) for f in self._generators.values()]

View file

@ -120,11 +120,13 @@ def decode_base64_to_image(img, save_name):
image.save(f"{save_name}.png", pnginfo=pnginfo)
return pnginfo, image
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"

View file

@ -6,7 +6,7 @@
@File : search_engine.py
"""
import importlib
from typing import Callable, Coroutine, Literal, overload, Optional, Union
from typing import Callable, Coroutine, Literal, Optional, Union, overload
from semantic_kernel.skill_definition import sk_function
@ -43,8 +43,8 @@ class SearchEngine:
def __init__(
self,
engine: Optional[SearchEngineType] = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
engine: Optional[SearchEngineType] = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:

View file

@ -29,7 +29,7 @@ class MeilisearchEngine:
def add_documents(self, data_source: DataSource, documents: List[dict]):
index_name = f"{data_source.name}_index"
if index_name not in self.client.get_indexes():
self.client.create_index(uid=index_name, options={'primaryKey': 'id'})
self.client.create_index(uid=index_name, options={"primaryKey": "id"})
index = self.client.get_index(index_name)
index.add_documents(documents)
self.set_index(index)
@ -37,7 +37,7 @@ class MeilisearchEngine:
def search(self, query):
try:
search_results = self._index.search(query)
return search_results['hits']
return search_results["hits"]
except Exception as e:
# Handle MeiliSearch API errors
print(f"MeiliSearch API error: {e}")

View file

@ -6,7 +6,7 @@
@File : translator.py
"""
prompt = '''
prompt = """
# 指令
接下来作为一位拥有20年翻译经验的翻译专家当我给出英文句子或段落时你将提供通顺且具有可读性的{LANG}翻译注意以下要求
1. 确保翻译结果流畅且易于理解
@ -17,11 +17,10 @@ prompt = '''
{ORIGINAL}
# 译文
'''
"""
class Translator:
@classmethod
def translate_prompt(cls, original, lang='中文'):
return prompt.format(LANG=lang, ORIGINAL=original)
def translate_prompt(cls, original, lang="中文"):
return prompt.format(LANG=lang, ORIGINAL=original)

View file

@ -6,7 +6,7 @@ from pathlib import Path
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
ICL_SAMPLE = '''Interface definition:
ICL_SAMPLE = """Interface definition:
```text
Interface Name: Element Tagging
Interface Path: /projects/{project_key}/node-tags
@ -60,20 +60,20 @@ def test_node_tags(project_key, nodes, operations, expected_msg):
# 3. If comments are needed, use Chinese.
# If you understand, please wait for me to give the interface definition and just answer "Understood" to save tokens.
'''
"""
ACT_PROMPT_PREFIX = '''Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type.
ACT_PROMPT_PREFIX = """Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type.
Please output 10 test cases within one `@pytest.mark.parametrize` scope.
```text
'''
"""
YFT_PROMPT_PREFIX = '''Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation,
YFT_PROMPT_PREFIX = """Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation,
authentication and authorization, parameter verification, exception handling, file upload and download.
Please output 10 test cases within one `@pytest.mark.parametrize` scope.
```text
'''
"""
OCR_API_DOC = '''```text
OCR_API_DOC = """```text
Interface Name: OCR recognition
Interface Path: /api/v1/contract/treaty/task/ocr
Method: POST
@ -96,14 +96,20 @@ code integer Yes
message string Yes
data object Yes
```
'''
"""
class UTGenerator:
"""UT Generator: Construct UT through API documentation"""
def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str,
chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None:
def __init__(
self,
swagger_file: str,
ut_py_path: str,
questions_path: str,
chatgpt_method: str = "API",
template_prefix=YFT_PROMPT_PREFIX,
) -> None:
"""Initialize UT Generator
Args:
@ -274,7 +280,7 @@ class UTGenerator:
def gpt_msgs_to_code(self, messages: list) -> str:
"""Choose based on different calling methods"""
result = ''
result = ""
if self.chatgpt_method == "API":
result = GPTAPI().ask_code(msgs=messages)

View file

@ -6,9 +6,10 @@
@File : file.py
@Describe : General file operations.
"""
import aiofiles
from pathlib import Path
import aiofiles
from metagpt.logs import logger
@ -66,10 +67,9 @@ class File:
if not chunk:
break
chunks.append(chunk)
content = b''.join(chunks)
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
except Exception as e:
logger.error(f"Error reading file: {e}")
raise e

View file

@ -1,22 +1,22 @@
# 添加代码语法高亮显示
from pygments import highlight as highlight_
from pygments.formatters import HtmlFormatter, TerminalFormatter
from pygments.lexers import PythonLexer, SqlLexer
from pygments.formatters import TerminalFormatter, HtmlFormatter
def highlight(code: str, language: str = 'python', formatter: str = 'terminal'):
def highlight(code: str, language: str = "python", formatter: str = "terminal"):
# 指定要高亮的语言
if language.lower() == 'python':
if language.lower() == "python":
lexer = PythonLexer()
elif language.lower() == 'sql':
elif language.lower() == "sql":
lexer = SqlLexer()
else:
raise ValueError(f"Unsupported language: {language}")
# 指定输出格式
if formatter.lower() == 'terminal':
if formatter.lower() == "terminal":
formatter = TerminalFormatter()
elif formatter.lower() == 'html':
elif formatter.lower() == "html":
formatter = HtmlFormatter()
else:
raise ValueError(f"Unsupported formatter: {formatter}")

View file

@ -6,9 +6,9 @@
@File : mermaid.py
"""
import base64
import os
from aiohttp import ClientSession,ClientError
from aiohttp import ClientError, ClientSession
from metagpt.logs import logger
@ -29,7 +29,7 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix):
async with session.get(url) as response:
if response.status == 200:
text = await response.content.read()
with open(output_file, 'wb') as f:
with open(output_file, "wb") as f:
f.write(text)
logger.info(f"Generating {output_file}..")
else:

View file

@ -8,10 +8,13 @@
import os
from urllib.parse import urljoin
from playwright.async_api import async_playwright
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,66 +27,72 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes=['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
async with async_playwright() as p:
browser = await p.chromium.launch()
device_scale_factor = 1.0
context = await browser.new_context(
viewport={'width': width, 'height': height},
device_scale_factor=device_scale_factor,
)
viewport={"width": width, "height": height},
device_scale_factor=device_scale_factor,
)
page = await context.new_page()
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.set_viewport_size({'width': width, 'height': height})
await page.set_viewport_size({"width": width, "height": height})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.wait_for_load_state("networkidle")
await page.wait_for_selector("div#container", state="attached")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
#
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -92,16 +101,17 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.set_viewport_size({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height']})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:

View file

@ -7,11 +7,14 @@
"""
import os
from urllib.parse import urljoin
from pyppeteer import launch
from metagpt.logs import logger
from metagpt.config import CONFIG
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048)-> int:
from pyppeteer import launch
from metagpt.config import CONFIG
from metagpt.logs import logger
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
"""
Converts the given Mermaid code to various output formats and saves them to files.
@ -24,15 +27,15 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
Returns:
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
"""
suffixes = ['png', 'svg', 'pdf']
suffixes = ["png", "svg", "pdf"]
__dirname = os.path.dirname(os.path.abspath(__file__))
if CONFIG.pyppeteer_executable_path:
browser = await launch(headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=['--disable-extensions',"--no-sandbox"]
)
browser = await launch(
headless=True,
executablePath=CONFIG.pyppeteer_executable_path,
args=["--disable-extensions", "--no-sandbox"],
)
else:
logger.error("Please set the environment variable:PYPPETEER_EXECUTABLE_PATH.")
return -1
@ -41,50 +44,56 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
async def console_message(msg):
logger.info(msg.text)
page.on('console', console_message)
page.on("console", console_message)
try:
await page.setViewport(viewport={'width': width, 'height': height, 'deviceScaleFactor': device_scale_factor})
await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor})
mermaid_html_path = os.path.abspath(
os.path.join(__dirname, 'index.html'))
mermaid_html_url = urljoin('file:', mermaid_html_path)
mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html"))
mermaid_html_url = urljoin("file:", mermaid_html_path)
await page.goto(mermaid_html_url)
await page.querySelector("div#container")
mermaid_config = {}
# mermaid_config = {}
background_color = "#ffffff"
my_css = ""
# my_css = ""
await page.evaluate(f'document.body.style.background = "{background_color}";')
metadata = await page.evaluate('''async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
const { mermaid, zenuml } = globalThis;
await mermaid.registerExternalDiagrams([zenuml]);
mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
document.getElementById('container').innerHTML = svg;
const svgElement = document.querySelector('svg');
svgElement.style.backgroundColor = backgroundColor;
# metadata = await page.evaluate(
# """async ([definition, mermaidConfig, myCSS, backgroundColor]) => {
# const { mermaid, zenuml } = globalThis;
# await mermaid.registerExternalDiagrams([zenuml]);
# mermaid.initialize({ startOnLoad: false, ...mermaidConfig });
# const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container'));
# document.getElementById('container').innerHTML = svg;
# const svgElement = document.querySelector('svg');
# svgElement.style.backgroundColor = backgroundColor;
#
# if (myCSS) {
# const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
# style.appendChild(document.createTextNode(myCSS));
# svgElement.appendChild(style);
# }
# }""",
# [mermaid_code, mermaid_config, my_css, background_color],
# )
if (myCSS) {
const style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
style.appendChild(document.createTextNode(myCSS));
svgElement.appendChild(style);
}
}''', [mermaid_code, mermaid_config, my_css, background_color])
if 'svg' in suffixes :
svg_xml = await page.evaluate('''() => {
if "svg" in suffixes:
svg_xml = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const xmlSerializer = new XMLSerializer();
return xmlSerializer.serializeToString(svg);
}''')
}"""
)
logger.info(f"Generating {output_file_without_suffix}.svg..")
with open(f'{output_file_without_suffix}.svg', 'wb') as f:
f.write(svg_xml.encode('utf-8'))
with open(f"{output_file_without_suffix}.svg", "wb") as f:
f.write(svg_xml.encode("utf-8"))
if 'png' in suffixes:
clip = await page.evaluate('''() => {
if "png" in suffixes:
clip = await page.evaluate(
"""() => {
const svg = document.querySelector('svg');
const rect = svg.getBoundingClientRect();
return {
@ -93,16 +102,23 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
width: Math.ceil(rect.width),
height: Math.ceil(rect.height)
};
}''')
await page.setViewport({'width': clip['x'] + clip['width'], 'height': clip['y'] + clip['height'], 'deviceScaleFactor': device_scale_factor})
screenshot = await page.screenshot(clip=clip, omit_background=True, scale='device')
}"""
)
await page.setViewport(
{
"width": clip["x"] + clip["width"],
"height": clip["y"] + clip["height"],
"deviceScaleFactor": device_scale_factor,
}
)
screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device")
logger.info(f"Generating {output_file_without_suffix}.png..")
with open(f'{output_file_without_suffix}.png', 'wb') as f:
with open(f"{output_file_without_suffix}.png", "wb") as f:
f.write(screenshot)
if 'pdf' in suffixes:
if "pdf" in suffixes:
pdf_data = await page.pdf(scale=device_scale_factor)
logger.info(f"Generating {output_file_without_suffix}.pdf..")
with open(f'{output_file_without_suffix}.pdf', 'wb') as f:
with open(f"{output_file_without_suffix}.pdf", "wb") as f:
f.write(pdf_data)
return 0
except Exception as e:
@ -110,4 +126,3 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
return -1
finally:
await browser.close()

View file

@ -16,7 +16,7 @@ class WebPage(BaseModel):
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_soup: Optional[BeautifulSoup] = None
_title: Optional[str] = None
@property
@ -24,7 +24,7 @@ class WebPage(BaseModel):
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:

View file

@ -37,12 +37,12 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
if not isinstance(expr, cst.Expr):
return None
val = expr.value
if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
return None
evaluated_value = val.evaluated_value
evaluated_value = val.evaluated_value
if isinstance(evaluated_value, bytes):
return None
@ -56,6 +56,7 @@ class DocstringCollector(cst.CSTVisitor):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(self):
self.stack: list[str] = []
self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
@ -96,6 +97,7 @@ class DocstringTransformer(cst.CSTTransformer):
stack: A list to keep track of the current path in the CST.
docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
"""
def __init__(
self,
docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
@ -125,7 +127,9 @@ class DocstringTransformer(cst.CSTTransformer):
key = tuple(self.stack)
self.stack.pop()
if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
if hasattr(updated_node, "decorators") and any(
(i.decorator.value == "overload") for i in updated_node.decorators
):
return updated_node
statement = self.docstrings.get(key)

View file

@ -8,6 +8,7 @@
import docx
def read_docx(file_path: str) -> list:
"""Open a docx file"""
doc = docx.Document(file_path)

View file

@ -20,4 +20,3 @@ class Singleton(abc.ABCMeta, type):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

View file

@ -1,4 +1,4 @@
# token to separate different code messages in a WriteCode Message content
MSG_SEP = "#*000*#"
MSG_SEP = "#*000*#"
# token to seperate file name and the actual code text in a code message
FILENAME_CODE_SEP = "#*001*#"

View file

@ -3,7 +3,12 @@ from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
def reduce_message_length(
msgs: Generator[str, None, None],
model_name: str,
system_text: str,
reserved: int = 0,
) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
@ -49,9 +54,9 @@ def generate_prompt_chunk(
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
reserved = reserved + count_string_tokens(prompt_template + system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
@ -103,7 +108,7 @@ def decode_unicode_escape(text: str) -> str:
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
def _split_by_count(lst: Sequence, count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0

View file

@ -44,3 +44,4 @@ ta==0.10.2
semantic-kernel==0.3.13.dev0
wrapt==1.15.0
websocket-client==0.58.0

View file

@ -6,14 +6,14 @@
@File : conftest.py
"""
import asyncio
import re
from unittest.mock import Mock
import pytest
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
import asyncio
import re
class Context:

View file

@ -311,12 +311,10 @@ TASKS = [
"添加数据API接受用户输入的文档库对文档库进行索引\n- 使用MeiliSearch连接并添加文档库",
"搜索API接收用户输入的关键词返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据",
"多条件筛选API接收用户选择的筛选条件返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果",
"智能推荐API根据用户的搜索历史记录和搜索行为推荐相关的搜索结果。"
"智能推荐API根据用户的搜索历史记录和搜索行为推荐相关的搜索结果。",
]
TASKS_2 = [
"完成main.py的功能"
]
TASKS_2 = ["完成main.py的功能"]
SEARCH_CODE_SAMPLE = """
import requests
@ -460,7 +458,7 @@ if __name__ == '__main__':
print('No results found.')
'''
MEILI_CODE = '''import meilisearch
MEILI_CODE = """import meilisearch
from typing import List
@ -496,9 +494,9 @@ if __name__ == '__main__':
# 添加文档库到搜索引擎
search_engine.add_documents(books_data_source, documents)
'''
"""
MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
Traceback (most recent call last):
File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in <module>
search_engine.add_documents(books_data_source, documents)
@ -506,7 +504,7 @@ Traceback (most recent call last):
index = self.client.get_or_create_index(index_name)
AttributeError: 'Client' object has no attribute 'get_or_create_index'
Process finished with exit code 1'''
Process finished with exit code 1"""
MEILI_CODE_REFINED = """
"""

View file

@ -9,18 +9,21 @@ from typing import List, Tuple
from metagpt.actions import ActionOutput
t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n",
"Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n",
"Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n",
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."]],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"}
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
@ -45,6 +48,6 @@ def test_create_model_class_with_mapping():
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == '__main__':
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -10,12 +10,7 @@ from metagpt.actions.azure_tts import AzureTTS
def test_azure_tts():
azure_tts = AzureTTS("azure_tts")
azure_tts.synthesize_speech(
"zh-CN",
"zh-CN-YunxiNeural",
"Boy",
"你好,我是卡卡",
"output.wav")
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
# 运行需要先配置 SUBSCRIPTION_KEY
# TODO: 这里如果要检验还要额外加上对应的asr才能确保前后生成是接近一致的但现在还没有

View file

@ -2,7 +2,6 @@ import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
source_code = """
import pandas as pd
import ta
@ -31,14 +30,18 @@ def get_expected_res():
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@ -46,9 +49,9 @@ def get_expected_res():
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_function_code(code, 'stock_indicator', stock_path)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -144,12 +144,12 @@ Engineer
---
'''
@pytest.mark.asyncio
async def test_debug_error():
debug_error = DebugError("debug_error")
file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
assert "class Player" in rewritten_code # rewrite the same class
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
assert "class Player" in rewritten_code # rewrite the same class
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")

View file

@ -10,6 +10,7 @@ import pytest
from metagpt.actions.detail_mining import DetailMining
from metagpt.logs import logger
@pytest.mark.asyncio
async def test_detail_mining():
topic = "如何做一个生日蛋糕"
@ -17,7 +18,6 @@ async def test_detail_mining():
detail_mining = DetailMining("detail_mining")
rsp = await detail_mining.run(topic=topic, record=record)
logger.info(f"{rsp.content=}")
assert '##OUTPUT' in rsp.content
assert '蛋糕' in rsp.content
assert "##OUTPUT" in rsp.content
assert "蛋糕" in rsp.content

View file

@ -8,12 +8,11 @@
"""
import os
from typing import List
import pytest
from pathlib import Path
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
import pytest
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
@pytest.mark.asyncio
@ -22,7 +21,7 @@ from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
]
],
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -35,18 +34,8 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
(
"../../data/invoices/invoice-1.pdf",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": "412.00",
"开票日期": "2023年02月03日"
}
]
),
]
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -59,9 +48,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]):
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
]
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -69,4 +56,3 @@ async def test_reply_question(invoice_path: str, query: dict, expected_result: s
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -4,7 +4,7 @@
#
from tests.metagpt.roles.ui_role import UIDesign
llm_resp= '''
llm_resp = """
# UI Design Description
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
@ -98,12 +98,13 @@ body {
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
'''
"""
def test_ui_design_parse_css():
ui_design_work = UIDesign(name="UI design action")
css = '''
css = """
body {
display: flex;
flex-direction: column;
@ -160,14 +161,14 @@ def test_ui_design_parse_css():
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
'''
assert ui_design_work.parse_css_code(context=llm_resp)==css
"""
assert ui_design_work.parse_css_code(context=llm_resp) == css
def test_ui_design_parse_html():
ui_design_work = UIDesign(name="UI design action")
html = '''
html = """
<!DOCTYPE html>
<html lang="en">
<head>
@ -184,8 +185,5 @@ def test_ui_design_parse_html():
<div class="game-over">Game Over</div>
</body>
</html>
'''
assert ui_design_work.parse_css_code(context=llm_resp)==html
"""
assert ui_design_work.parse_css_code(context=llm_resp) == html

View file

@ -22,13 +22,13 @@ async def test_write_code():
logger.info(code)
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
assert 'def add' in code
assert 'return' in code
assert "def add" in code
assert "return" in code
@pytest.mark.asyncio
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
llm = LLM()
rsp = await llm.aask(prompt)
logger.info(rsp)

View file

@ -2,7 +2,7 @@ import pytest
from metagpt.actions.write_docstring import WriteDocstring
code = '''
code = """
def add_numbers(a: int, b: int):
return a + b
@ -14,7 +14,7 @@ class Person:
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
"""
@pytest.mark.asyncio
@ -25,7 +25,7 @@ class Person:
("numpy", "Parameters"),
("sphinx", ":param name:"),
],
ids=["google", "numpy", "sphinx"]
ids=["google", "numpy", "sphinx"],
)
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)

View file

@ -9,14 +9,11 @@ from typing import Dict
import pytest
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
@pytest.mark.asyncio
@pytest.mark.parametrize(
("language", "topic"),
[("English", "Write a tutorial about Python")]
)
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
async def test_write_directory(language: str, topic: str):
ret = await WriteDirectory(language=language).run(topic=topic)
assert isinstance(ret, dict)
@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})]
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
async def test_write_content(language: str, topic: str, directory: Dict):
ret = await WriteContent(language=language, directory=directory).run(topic=topic)

View file

@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore
def test_chroma_store():
"""FIXMEchroma使用感觉很诡异一用Python就挂测试用例里也是"""
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
document_store = ChromaStore('sample_collection_1')
document_store = ChromaStore("sample_collection_1")
# 使用 write 方法添加多个文档
document_store.write(["This is document1", "This is document2"],
[{"source": "google-docs"}, {"source": "notion"}],
["doc1", "doc2"])
document_store.write(
["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"]
)
# 使用 add 方法添加一个文档
document_store.add("This is document3", {"source": "notion"}, "doc3")

View file

@ -39,11 +39,11 @@ user: 没有了
@pytest.mark.asyncio
async def test_faiss_store_search():
store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
store.add(['油皮洗面奶'])
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
store.add(["油皮洗面奶"])
role = Sales(store=store)
queries = ['油皮洗面奶', '介绍下欧莱雅的']
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
for query in queries:
rsp = await role.run(query)
assert rsp
@ -60,7 +60,10 @@ def customer_service():
async def test_faiss_store_customer_service():
allq = [
# ["我的餐怎么两小时都没到", "退货吧"],
["你好收不到取餐码,麻烦帮我开箱", "14750187158", ]
[
"你好收不到取餐码,麻烦帮我开箱",
"14750187158",
]
]
role = customer_service()
for queries in allq:
@ -71,4 +74,4 @@ async def test_faiss_store_customer_service():
def test_faiss_store_no_file():
with pytest.raises(FileNotFoundError):
FaissStore(DATA_PATH / 'wtf.json')
FaissStore(DATA_PATH / "wtf.json")

View file

@ -5,27 +5,33 @@
@Author : unkn-wn (Leon Yee)
@File : test_lancedb_store.py
"""
from metagpt.document_store.lancedb_store import LanceStore
import pytest
import random
import pytest
from metagpt.document_store.lancedb_store import LanceStore
@pytest
def test_lance_store():
# This simply establishes the connection to the database, so we can drop the table if it exists
store = LanceStore('test')
store = LanceStore("test")
store.drop('test')
store.drop("test")
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
ids=["doc1", "doc2"])
store.write(
data=[[random.random() for _ in range(100)] for _ in range(2)],
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
ids=["doc1", "doc2"],
)
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
result = store.search([random.random() for _ in range(100)], n_results=3)
assert(len(result) == 3)
assert len(result) == 3
store.delete("doc2")
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
assert(len(result) == 1)
result = store.search(
[random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine"
)
assert len(result) == 1

View file

@ -12,7 +12,7 @@ import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float}
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
@ -25,12 +25,12 @@ book_data = [
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop('Book')
milvus_store.create_collection('Book', book_columns)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index('emb')
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field='emb')
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -24,9 +24,7 @@ random.seed(seed_value)
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
points = [
PointStruct(
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
)
PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10})
for idx, vector in enumerate(vectors)
]
@ -57,9 +55,7 @@ def test_milvus_store():
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
@ -68,9 +64,7 @@ def test_milvus_store():
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]

View file

@ -30,7 +30,7 @@ def test_skill_manager():
rsp = manager.retrieve_skill("写测试用例")
logger.info(rsp)
assert rsp[0] == 'WriteTest'
assert rsp[0] == "WriteTest"
rsp = manager.retrieve_skill_scored("写PRD")
logger.info(rsp)

View file

@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
from metagpt.config import CONFIG
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.roles.role import RoleContext
from metagpt.config import CONFIG
from metagpt.memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
def test_ltm_search():
@ -14,25 +14,25 @@ def test_ltm_search():
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
role_id = 'UTUserLtm(Product Manager)'
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[BossRequirement])
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = 'Write a cli snake game'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
idea = "Write a cli snake game"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
news = ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
news = ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
@ -47,8 +47,8 @@ def test_ltm_search():
news = ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = 'Write a Battle City'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a Battle City"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1

View file

@ -4,17 +4,16 @@
from typing import List
from metagpt.actions import BossRequirement, WritePRD
from metagpt.actions.action_output import ActionOutput
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.actions import WritePRD
from metagpt.actions.action_output import ActionOutput
def test_idea_message():
idea = 'Write a cli snake game'
role_id = 'UTUser1(Product Manager)'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
idea = "Write a cli snake game"
role_id = "UTUser1(Product Manager)"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
@ -23,13 +22,13 @@ def test_idea_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
assert len(new_messages) == 0 # similar, return []
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content
@ -38,22 +37,15 @@ def test_idea_message():
def test_actionout_message():
out_mapping = {
'field1': (str, ...),
'field2': (List[str], ...)
}
out_data = {
'field1': 'field1 value',
'field2': ['field2 value1', 'field2 value2']
}
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
role_id = 'UTUser2(Architect)'
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
message = Message(content=content,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD) # WritePRD as test action
role_id = "UTUser2(Architect)"
content = "The boss has requested the creation of a command-line interface (CLI) snake game"
message = Message(
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
@ -62,19 +54,13 @@ def test_actionout_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_conent = 'The request is command-line interface (CLI) snake game'
sim_message = Message(content=sim_conent,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD)
sim_conent = "The request is command-line interface (CLI) snake game"
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
assert len(new_messages) == 0 # similar, return []
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
new_message = Message(content=new_conent,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD)
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content

Some files were not shown because too many files have changed in this diff Show more