Merge branch 'main' into main

This commit is contained in:
better629 2024-10-17 16:25:31 +08:00 committed by GitHub
commit d99054ab5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
98 changed files with 1697 additions and 496 deletions

7
.coveragerc Normal file
View file

@ -0,0 +1,7 @@
[run]
source =
./metagpt/
omit =
*/metagpt/environment/android/*
*/metagpt/ext/android_assistant/*
*/metagpt/ext/werewolf/*

1
.gitattributes vendored
View file

@ -14,6 +14,7 @@
*.ico binary
*.jpeg binary
*.mp3 binary
*.mp4 binary
*.zip binary
*.bin binary

View file

@ -30,7 +30,10 @@ jobs:
cache: 'pip'
- name: Install dependencies
run: |
sh tests/scripts/run_install_deps.sh
python -m pip install --upgrade pip
pip install -e .[test]
npm install -g @mermaid-js/mermaid-cli
playwright install --with-deps
- name: Run reverse proxy script for ssh service
if: contains(github.ref, '-debugger')
continue-on-error: true

View file

@ -27,20 +27,57 @@ jobs:
cache: 'pip'
- name: Install dependencies
run: |
sh tests/scripts/run_install_deps.sh
python -m pip install --upgrade pip
pip install -e .[test]
npm install -g @mermaid-js/mermaid-cli
playwright install --with-deps
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
pytest --continue-on-collection-errors tests/ \
--ignore=tests/metagpt/environment/android_env \
--ignore=tests/metagpt/ext/android_assistant \
--ignore=tests/metagpt/ext/stanford_town \
--ignore=tests/metagpt/provider/test_bedrock_api.py \
--ignore=tests/metagpt/rag/factories/test_embedding.py \
--ignore=tests/metagpt/ext/werewolf/actions/test_experience_operation.py \
--ignore=tests/metagpt/provider/test_openai.py \
--ignore=tests/metagpt/planner/test_action_planner.py \
--ignore=tests/metagpt/planner/test_basic_planner.py \
--ignore=tests/metagpt/actions/test_project_management.py \
--ignore=tests/metagpt/actions/test_write_code.py \
--ignore=tests/metagpt/actions/test_write_code_review.py \
--ignore=tests/metagpt/actions/test_write_prd.py \
--ignore=tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py \
--ignore=tests/metagpt/memory/test_brain_memory.py \
--ignore=tests/metagpt/roles/test_assistant.py \
--ignore=tests/metagpt/roles/test_engineer.py \
--ignore=tests/metagpt/serialize_deserialize/test_write_code_review.py \
--ignore=tests/metagpt/test_environment.py \
--ignore=tests/metagpt/test_llm.py \
--ignore=tests/metagpt/tools/test_metagpt_oas3_api_svc.py \
--ignore=tests/metagpt/tools/test_moderation.py \
--ignore=tests/metagpt/tools/test_search_engine.py \
--ignore=tests/metagpt/tools/test_tool_convert.py \
--ignore=tests/metagpt/tools/test_web_browser_engine_playwright.py \
--ignore=tests/metagpt/utils/test_mermaid.py \
--ignore=tests/metagpt/utils/test_redis.py \
--ignore=tests/metagpt/utils/test_tree.py \
--ignore=tests/metagpt/serialize_deserialize/test_sk_agent.py \
--ignore=tests/metagpt/utils/test_text.py \
--ignore=tests/metagpt/actions/di/test_write_analysis_code.py \
--ignore=tests/metagpt/provider/test_ark.py \
--doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov \
--durations=20 | tee unittest.txt
- name: Show coverage report
run: |
coverage report -m
- name: Show failed tests and overall summary
run: |
grep -E "FAILED tests|ERROR tests|[0-9]+ passed," unittest.txt
failed_count=$(grep -E "FAILED|ERROR" unittest.txt | wc -l)
if [[ "$failed_count" -gt 0 ]]; then
failed_count=$(grep -E "FAILED tests|ERROR tests" unittest.txt | wc -l | tr -d '[:space:]')
if [[ $failed_count -gt 0 ]]; then
echo "$failed_count failed lines found! Task failed."
exit 1
fi

View file

@ -31,7 +31,7 @@ ## News
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems.
🚀 Jan. 16, 2024: Our paper [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
](https://arxiv.org/abs/2308.00352) accepted for **oral presentation (top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
](https://openreview.net/forum?id=VtmBAGCN7o) accepted for **oral presentation (top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
🚀 Jan. 03, 2024: [v0.6.0](https://github.com/geekan/MetaGPT/releases/tag/v0.6.0) released, new features include serialization, upgraded OpenAI package and supported multiple LLM, provided [minimal example for debate](https://github.com/geekan/MetaGPT/blob/main/examples/debate_simple.py) etc.
@ -59,7 +59,7 @@ ## Get Started
### Installation
> Ensure that Python 3.9+ is installed on your system. You can check this by using: `python --version`.
> Ensure that Python 3.9 or later, but less than 3.12, is installed on your system. You can check this by using: `python --version`.
> You can use conda like this: `conda create -n metagpt python=3.9 && conda activate metagpt`
```bash
@ -166,16 +166,15 @@ ## Citation
To stay updated with the latest research and development, follow [@MetaGPT_](https://twitter.com/MetaGPT_) on Twitter.
To cite [MetaGPT](https://arxiv.org/abs/2308.00352) or [Data Interpreter](https://arxiv.org/abs/2402.18679) in publications, please use the following BibTeX entries.
To cite [MetaGPT](https://openreview.net/forum?id=VtmBAGCN7o) or [Data Interpreter](https://arxiv.org/abs/2402.18679) in publications, please use the following BibTeX entries.
```bibtex
@misc{hong2023metagpt,
title={MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework},
author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Ceyao Zhang and Jinlin Wang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and Jürgen Schmidhuber},
year={2023},
eprint={2308.00352},
archivePrefix={arXiv},
primaryClass={cs.AI}
@inproceedings{hong2024metagpt,
title={Meta{GPT}: Meta Programming for A Multi-Agent Collaborative Framework},
author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and J{\"u}rgen Schmidhuber},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=VtmBAGCN7o}
}
@misc{hong2024data,
title={Data Interpreter: An LLM Agent For Data Science},
@ -185,6 +184,4 @@ ## Citation
archivePrefix={arXiv},
primaryClass={cs.AI}
}
```

View file

@ -59,3 +59,27 @@ iflytek_api_key: "YOUR_API_KEY"
iflytek_api_secret: "YOUR_API_SECRET"
metagpt_tti_url: "YOUR_MODEL_URL"
omniparse:
api_key: "YOUR_API_KEY"
base_url: "YOUR_BASE_URL"
models:
# "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
# api_type: "openai" # or azure / ollama / groq etc.
# base_url: "YOUR_BASE_URL"
# api_key: "YOUR_API_KEY"
# proxy: "YOUR_PROXY" # for LLM API requests
# # timeout: 600 # Optional. If set to 0, default value is 300.
# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
# "YOUR_MODEL_NAME_2 or YOUR_API_TYPE_2": # api_type: "openai" # or azure / ollama / groq etc.
# api_type: "openai" # or azure / ollama / groq etc.
# base_url: "YOUR_BASE_URL"
# api_key: "YOUR_API_KEY"
# proxy: "YOUR_PROXY" # for LLM API requests
# # timeout: 600 # Optional. If set to 0, default value is 300.
# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
agentops_api_key: "YOUR_AGENTOPS_API_KEY" # get key from https://app.agentops.ai/settings/projects

View file

@ -2,4 +2,4 @@ llm:
api_type: 'claude' # or anthropic
base_url: 'https://api.anthropic.com'
api_key: 'YOUR_API_KEY'
model: 'claude-3-opus-20240229'
model: 'claude-3-5-sonnet-20240620' # or 'claude-3-opus-20240229'

View file

@ -119,13 +119,12 @@ ## 引用
如果您在研究论文中使用 MetaGPT 或 Data Interpreter请引用我们的工作
```bibtex
@misc{hong2023metagpt,
title={MetaGPT: Meta Programming for Multi-Agent Collaborative Framework},
author={Sirui Hong and Xiawu Zheng and Jonathan Chen and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu},
year={2023},
eprint={2308.00352},
archivePrefix={arXiv},
primaryClass={cs.AI}
@inproceedings{hong2024metagpt,
title={Meta{GPT}: Meta Programming for A Multi-Agent Collaborative Framework},
author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and J{\"u}rgen Schmidhuber},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=VtmBAGCN7o}
}
@misc{hong2024data,
title={Data Interpreter: An LLM Agent For Data Science},

View file

@ -298,13 +298,12 @@ ## 引用
研究論文でMetaGPTやData Interpreterを使用する場合は、以下のように当社の作業を引用してください
```bibtex
@misc{hong2023metagpt,
title={MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework},
author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Ceyao Zhang and Jinlin Wang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and Jürgen Schmidhuber},
year={2023},
eprint={2308.00352},
archivePrefix={arXiv},
primaryClass={cs.AI}
@inproceedings{hong2024metagpt,
title={Meta{GPT}: Meta Programming for A Multi-Agent Collaborative Framework},
author={Sirui Hong and Mingchen Zhuge and Jonathan Chen and Xiawu Zheng and Yuheng Cheng and Jinlin Wang and Ceyao Zhang and Zili Wang and Steven Ka Shing Yau and Zijuan Lin and Liyang Zhou and Chenyu Ran and Lingfeng Xiao and Chenglin Wu and J{\"u}rgen Schmidhuber},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=VtmBAGCN7o}
}
@misc{hong2024data,
title={Data Interpreter: An LLM Agent For Data Science},

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

64
examples/rag/omniparse.py Normal file
View file

@ -0,0 +1,64 @@
import asyncio
from metagpt.config2 import config
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.omniparse_client import OmniParseClient
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
async def omniparse_client_example():
client = OmniParseClient(base_url=config.omniparse.base_url)
# docx
with open(TEST_DOCX, "rb") as f:
file_input = f.read()
document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx")
logger.info(document_parse_ret)
# pdf
pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF)
logger.info(pdf_parse_ret)
# video
video_parse_ret = await client.parse_video(file_input=TEST_VIDEO)
logger.info(video_parse_ret)
# audio
audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO)
logger.info(audio_parse_ret)
async def omniparse_example():
parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
),
)
ret = parser.load_data(file_path=TEST_PDF)
logger.info(ret)
file_paths = [TEST_DOCX, TEST_PDF]
parser.parse_type = OmniParseType.DOCUMENT
ret = await parser.aload_data(file_path=file_paths)
logger.info(ret)
async def main():
await omniparse_client_example()
await omniparse_example()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -2,7 +2,7 @@
import asyncio
from examples.rag_pipeline import DOC_PATH, QUESTION
from examples.rag.rag_pipeline import DOC_PATH, QUESTION
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.roles import Sales

View file

@ -1,82 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/13 12:36
@Author : femto Zheng
@File : sk_agent.py
"""
import asyncio
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
from semantic_kernel.planning import SequentialPlanner
# from semantic_kernel.planning import SequentialPlanner
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from metagpt.actions import UserRequirement
from metagpt.const import SKILL_DIRECTORY
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
from metagpt.tools.search_engine import SkSearchEngine
async def main():
# await basic_planner_example()
# await action_planner_example()
# await sequential_planner_example()
await basic_planner_web_search_example()
async def basic_planner_example():
task = """
Tomorrow is Valentine's day. I need to come up with a few date ideas. She speaks French so write it in French.
Convert the text to uppercase"""
role = SkAgent()
# let's give the agent some skills
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "SummarizeSkill")
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
await role.run(Message(content=task, cause_by=UserRequirement))
async def sequential_planner_example():
task = """
Tomorrow is Valentine's day. I need to come up with a few date ideas. She speaks French so write it in French.
Convert the text to uppercase"""
role = SkAgent(planner_cls=SequentialPlanner)
# let's give the agent some skills
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "SummarizeSkill")
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
await role.run(Message(content=task, cause_by=UserRequirement))
async def basic_planner_web_search_example():
task = """
Question: Who made the 1989 comic book, the film version of which Jon Raymond Polito appeared in?"""
role = SkAgent()
role.import_skill(SkSearchEngine(), "WebSearchSkill")
# role.import_semantic_skill_from_directory(skills_directory, "QASkill")
await role.run(Message(content=task, cause_by=UserRequirement))
async def action_planner_example():
role = SkAgent(planner_cls=ActionPlanner)
# let's give the agent 4 skills
role.import_skill(MathSkill(), "math")
role.import_skill(FileIOSkill(), "fileIO")
role.import_skill(TimeSkill(), "time")
role.import_skill(TextSkill(), "text")
task = "What is the sum of 110 and 990?"
await role.run(Message(content=task, cause_by=UserRequirement)) # it will choose mathskill.Add
if __name__ == "__main__":
asyncio.run(main())

3
examples/ui_with_chainlit/.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
*.chainlit
chainlit.md
.files

View file

@ -0,0 +1,34 @@
# MetaGPT in UI with Chainlit! 🤖
- MetaGPT functionality in UI using Chainlit.
- It also takes a **one line requirement** as input and outputs **user stories / competitive analysis / requirements / data structures / APIs / documents, etc.**, But `everything in UI`.
## Install Chainlit
- Setup initial MetaGPT config from [Main](../../README.md).
```bash
pip install chainlit
```
## Usage
```bash
chainlit run app.py
```
- Now go to: http://localhost:8000
- Select,
- `Create a 2048 game`
- `Write a cli Blackjack Game`
- `Type your own message...`
- It will run a metagpt software company.
## To Setup with own application
- We can change `Environment.run`, `Team.run`, `Role.run`, `Role._act`, `Action.run`.
- In this code, changed `Environment.run`, as it was easier to do.
- We will need to change `metagpt.logs.set_llm_stream_logfunc` to stream messages in UI with Chainlit Message.
- To use at some other place we need to call `chainlit.Message(content="").send()` with content.

View file

View file

@ -0,0 +1,83 @@
import chainlit as cl
from init_setup import ChainlitEnv
from metagpt.roles import (
Architect,
Engineer,
ProductManager,
ProjectManager,
QaEngineer,
)
from metagpt.team import Team
# https://docs.chainlit.io/concepts/starters
@cl.set_chat_profiles
async def chat_profile() -> list[cl.ChatProfile]:
"""Generates a chat profile containing starter messages which can be triggered to run MetaGPT
Returns:
list[chainlit.ChatProfile]: List of Chat Profile
"""
return [
cl.ChatProfile(
name="MetaGPT",
icon="/public/MetaGPT-new-log.jpg",
markdown_description="It takes a **one line requirement** as input and outputs **user stories / competitive analysis / requirements / data structures / APIs / documents, etc.**, But `everything in UI`.",
starters=[
cl.Starter(
label="Create a 2048 Game",
message="Create a 2048 game",
icon="/public/2048.jpg",
),
cl.Starter(
label="Write a cli Blackjack Game",
message="Write a cli Blackjack Game",
icon="/public/blackjack.jpg",
),
],
)
]
# https://docs.chainlit.io/concepts/message
@cl.on_message
async def startup(message: cl.Message) -> None:
"""On Message in UI, Create a MetaGPT software company
Args:
message (chainlit.Message): message by chainlist
"""
idea = message.content
company = Team(env=ChainlitEnv())
# Similar to software_company.py
company.hire(
[
ProductManager(),
Architect(),
ProjectManager(),
Engineer(n_borg=5, use_code_review=True),
QaEngineer(),
]
)
company.invest(investment=3.0)
company.run_project(idea=idea)
await company.run(n_round=5)
workdir = company.env.context.git_repo.workdir
files = company.env.context.git_repo.get_files(workdir)
files = "\n".join([f"{workdir}/{file}" for file in files if not file.startswith(".git")])
await cl.Message(
content=f"""
Codes can be found here:
{files}
---
Total cost: `{company.cost_manager.total_cost}`
"""
).send()

View file

@ -0,0 +1,69 @@
import asyncio
import chainlit as cl
from metagpt.environment import Environment
from metagpt.logs import logger, set_llm_stream_logfunc
from metagpt.roles import Role
from metagpt.utils.common import any_to_name
def log_llm_stream_chainlit(msg):
# Stream the message token into Chainlit UI.
cl.run_sync(chainlit_message.stream_token(msg))
set_llm_stream_logfunc(func=log_llm_stream_chainlit)
class ChainlitEnv(Environment):
"""Chainlit Environment for UI Integration"""
async def run(self, k=1):
"""处理一次所有信息的运行
Process all Role runs at once
"""
for _ in range(k):
futures = []
for role in self.roles.values():
# Call role.run with chainlit configuration
future = self._chainlit_role_run(role=role)
futures.append(future)
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
async def _chainlit_role_run(self, role: Role) -> None:
"""To run the role with chainlit config
Args:
role (Role): metagpt.role.Role
"""
global chainlit_message
chainlit_message = cl.Message(content="")
message = await role.run()
# If message is from role._act() publish to UI.
if message is not None and message.content != "No actions taken yet":
# Convert a message from action node in json format
chainlit_message.content = await self._convert_message_to_markdownjson(message=chainlit_message.content)
# message content from which role and its action...
chainlit_message.content += f"---\n\nAction: `{any_to_name(message.cause_by)}` done by `{role._setting}`."
await chainlit_message.send()
# for clean view in UI
async def _convert_message_to_markdownjson(self, message: str) -> str:
"""If the message is from MetaGPT Action Node output, then
convert it into markdown json for clear view in UI.
Args:
message (str): message by role._act
Returns:
str: message in mardown from
"""
if message.startswith("[CONTENT]"):
return f"```json\n{message}\n```\n"
return message

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

View file

@ -243,12 +243,19 @@ class ActionNode:
"""基于pydantic v2的模型动态生成用来检验结果类型正确性"""
def check_fields(cls, values):
required_fields = set(mapping.keys())
all_fields = set(mapping.keys())
required_fields = set()
for k, v in mapping.items():
type_v, field_info = v
if ActionNode.is_optional_type(type_v):
continue
required_fields.add(k)
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
unrecognized_fields = set(values.keys()) - required_fields
unrecognized_fields = set(values.keys()) - all_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
@ -850,3 +857,12 @@ class ActionNode:
root_node.add_child(child_node)
return root_node
@staticmethod
def is_optional_type(tp) -> bool:
"""Return True if `tp` is `typing.Optional[...]`"""
if typing.get_origin(tp) is Union:
args = typing.get_args(tp)
non_none_types = [arg for arg in args if arg is not type(None)]
return len(non_none_types) == 1 and len(args) == 2
return False

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : design_api_an.py
"""
from typing import List
from typing import List, Optional
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC1, MMC2
@ -45,9 +45,10 @@ REFINED_FILE_LIST = ActionNode(
example=["main.py", "game.py", "new_feature.py"],
)
# optional,because low success reproduction of class diagram in non py project.
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
key="Data structures and interfaces",
expected_type=str,
expected_type=Optional[str],
instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type"
" annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. "
"The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.",
@ -66,7 +67,7 @@ REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
PROGRAM_CALL_FLOW = ActionNode(
key="Program call flow",
expected_type=str,
expected_type=Optional[str],
instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE "
"accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.",
example=MMC2,

View file

@ -5,14 +5,14 @@
@Author : alexanderwu
@File : project_management_an.py
"""
from typing import List
from typing import List, Optional
from metagpt.actions.action_node import ActionNode
REQUIRED_PACKAGES = ActionNode(
key="Required packages",
expected_type=List[str],
instruction="Provide required packages in requirements.txt format.",
expected_type=Optional[List[str]],
instruction="Provide required third-party packages in requirements.txt format.",
example=["flask==1.1.2", "bcrypt==3.2.0"],
)

View file

@ -161,6 +161,8 @@ class CollectLinks(Action):
"""
max_results = max(num_results * 2, 6)
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
if len(results) == 0:
return []
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
logger.debug(prompt)

View file

@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
end", "Anything UNCLEAR": "目前项目要求明确没有不清楚的地方"}
## Tasks
{"Required packages": ["无需Python"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
{"Required packages": ["无需第三方"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Code Files
----- index.html

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel):
# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()
# omniparse
omniparse: OmniParseConfig = OmniParseConfig()
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""
@ -69,6 +73,7 @@ class Config(CLIParams, YamlModel):
workspace: WorkspaceConfig = WorkspaceConfig()
enable_longterm_memory: bool = False
code_review_k_times: int = 2
agentops_api_key: str = ""
# Will be removed in the future
metagpt_tti_url: str = ""

View file

@ -0,0 +1,6 @@
from metagpt.utils.yaml_model import YamlModel
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -33,7 +33,7 @@ class LLMType(Enum):
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
def __missing__(self, key):
return self.OPENAI
@ -90,6 +90,9 @@ class LLMConfig(YamlModel):
# Cost Control
calc_usage: bool = True
# For Messages Control
use_system_prompt: bool = True
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):

View file

@ -1,14 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/1 11:59
@Author : alexanderwu
@File : const.py
@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for
common properties in the Message.
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
"""
import os
from pathlib import Path

View file

@ -0,0 +1,99 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from metagpt.document_store.base_store import BaseStore
@dataclass
class MilvusConnection:
"""
Args:
uri: milvus url
token: milvus token
"""
uri: str = None
token: str = None
class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
try:
from pymilvus import MilvusClient
except ImportError:
raise Exception("Please install pymilvus first.")
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(uri=connect.uri, token=connect.token)
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
from pymilvus import DataType
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema,
)
@staticmethod
def build_filter(key, value) -> str:
if isinstance(value, str):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f"{key} in {value}"
else:
filter_expression = f"{key} == {value}"
return filter_expression
def search(
self,
collection_name: str,
query: List[float],
filter: Dict = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
print(filter_expression)
res = self.client.search(
collection_name=collection_name,
data=[query],
filter=filter_expression,
limit=limit,
output_fields=output_fields,
)[0]
return res
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
data = dict()
for i, id in enumerate(_ids):
data["id"] = id
data["vector"] = vector[i]
data["metadata"] = metadata[i]
self.client.upsert(collection_name=collection_name, data=data)
def delete(self, collection_name: str, _ids: List[str]):
self.client.delete(collection_name=collection_name, ids=_ids)
def write(self, *args, **kwargs):
pass

View file

@ -266,7 +266,7 @@ class STRole(Role):
# We will order our percept based on the distance, with the closest ones
# getting priorities.
percept_events_list = []
# First, we put all events that are occuring in the nearby tiles into the
# First, we put all events that are occurring in the nearby tiles into the
# percept_events_list
for tile in nearby_tiles:
tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))

View file

@ -81,7 +81,7 @@ class Memory(BaseModel):
return self.storage[-k:]
def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""find news (previously unseen messages) from the the most recent k memories, from all memories when k=0"""
"""find news (previously unseen messages) from the most recent k memories, from all memories when k=0"""
already_observed = self.get(k)
news: list[Message] = []
for i in observed:

View file

@ -1,12 +1,33 @@
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provider for volcengine.
See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model
config2.yaml example:
```yaml
llm:
base_url: "https://ark.cn-beijing.volces.com/api/v3"
api_type: "ark"
endpoint: "ep-2024080514****-d****"
api_key: "d47****b-****-****-****-d6e****0fd77"
pricing_plan: "doubao-lite"
```
"""
from typing import Optional, Union
from pydantic import BaseModel
from volcenginesdkarkruntime import AsyncArk
from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper
from volcenginesdkarkruntime._streaming import AsyncStream
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
from metagpt.configs.llm_config import LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS
@register_provider(LLMType.ARK)
@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM):
https://www.volcengine.com/docs/82379/1263482
"""
aclient: Optional[AsyncArk] = None
def _init_client(self):
"""SDK: https://github.com/openai/openai-python#async-usage"""
self.model = (
self.config.endpoint or self.config.model
) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncArk(**kwargs)
def _make_client_kwargs(self) -> dict:
kvs = {
"ak": self.config.access_key,
"sk": self.config.secret_key,
"api_key": self.config.api_key,
"base_url": self.config.base_url,
}
kwargs = {k: v for k, v in kvs.items() if v}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs:
self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS)
if model in self.cost_manager.token_costs:
self.pricing_plan = model
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage
)
usage = None
collected_messages = []
@ -30,7 +85,7 @@ class ArkLLM(OpenAILLM):
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
usage = CompletionUsage(**chunk.usage)
usage = chunk.usage
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)

View file

@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = {
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.

View file

@ -1,5 +1,7 @@
import asyncio
import json
from typing import Literal
from functools import partial
from typing import List, Literal
import boto3
from botocore.eventstream import EventStream
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
logger.warning("Amazon bedrock doesn't support asynchronous now")
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
]
logger.info("\n" + "\n".join(summaries))
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
async def invoke_model(self, request_body: str) -> dict:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
return response
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
response_body = self.invoke_model(request_body)
response_body = await self.invoke_model(request_body)
return response_body
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
return full_text
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
stream_response = await self.invoke_model_with_response_stream(request_body)
collected_content = await self._get_stream_response_body(stream_response)
log_llm_stream("\n")
full_text = ("".join(collected_content)).lstrip()
return full_text
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
response_body = json.loads(response["body"].read())
return response_body
async def _get_stream_response_body(self, stream_response) -> List[str]:
def collect_content() -> str:
collected_content = []
for event in stream_response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
return collected_content
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, collect_content)
def _get_usage(self, response) -> dict[str, int]:
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))

View file

@ -48,13 +48,17 @@ def build_api_arequest(
request_timeout,
form,
resources,
base_address,
_,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
if base_address is None:
base_address = dashscope.base_http_api_url
if not base_address.endswith("/"):
http_url = base_address + "/"
else:
http_url = dashscope.base_http_api_url
http_url = base_address
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"

View file

@ -81,7 +81,9 @@ class GeneralAPIRequestor(APIRequestor):
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)

View file

@ -37,7 +37,11 @@ def register_provider(keys):
def create_llm_instance(config: LLMConfig) -> BaseLLM:
"""get the default llm provider"""
return LLM_REGISTRY.get_provider(config.api_type)(config)
llm = LLM_REGISTRY.get_provider(config.api_type)(config)
if llm.use_system_prompt and not config.use_system_prompt:
# for models like o1-series, default openai provider.use_system_prompt is True, but it should be False for o1-*
llm.use_system_prompt = config.use_system_prompt
return llm
# Registry instance

View file

@ -51,9 +51,17 @@ class OllamaLLM(BaseLLM):
return json.loads(chunk)
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
params=self._const_kwargs(messages),
request_timeout=self.get_timeout(timeout),
)
@ -66,9 +74,17 @@ class OllamaLLM(BaseLLM):
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
stream=True,
params=self._const_kwargs(messages, stream=True),
request_timeout=self.get_timeout(timeout),

View file

@ -37,7 +37,6 @@ from metagpt.utils.token_counter import (
count_input_tokens,
count_output_tokens,
get_max_completion_tokens,
get_openrouter_tokens,
)
@ -92,6 +91,7 @@ class OpenAILLM(BaseLLM):
)
usage = None
collected_messages = []
has_finished = False
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
finish_reason = (
@ -99,8 +99,13 @@ class OpenAILLM(BaseLLM):
)
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
if has_finished:
# for oneapi, there has a usage chunk after finish_reason not none chunk
if chunk_has_usage:
usage = CompletionUsage(**chunk.usage)
if finish_reason:
if hasattr(chunk, "usage") and chunk.usage is not None:
if chunk_has_usage:
# Some services have usage as an attribute of the chunk, such as Fireworks
if isinstance(chunk.usage, CompletionUsage):
usage = chunk.usage
@ -109,9 +114,10 @@ class OpenAILLM(BaseLLM):
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)
elif "openrouter.ai" in self.config.base_url:
elif "openrouter.ai" in self.config.base_url and chunk_has_usage:
# due to it get token cost from api
usage = await get_openrouter_tokens(chunk)
usage = chunk.usage
has_finished = True
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
@ -132,6 +138,10 @@ class OpenAILLM(BaseLLM):
"model": self.model,
"timeout": self.get_timeout(timeout),
}
if "o1-" in self.model:
# compatible to openai o1-series
kwargs["temperature"] = 1
kwargs.pop("max_tokens")
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs

View file

@ -50,6 +50,9 @@ class QianFanLLM(BaseLLM):
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
if self.config.base_url:
os.environ.setdefault("QIANFAN_BASE_URL", self.config.base_url)
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
@ -103,13 +106,13 @@ class QianFanLLM(BaseLLM):
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
def completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
self._update_costs(resp.body.get("usage", {}))
return resp.body
@ -117,7 +120,7 @@ class QianFanLLM(BaseLLM):
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout)
collected_content = []
usage = {}
async for chunk in resp:

View file

@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
@ -28,6 +29,7 @@ from llama_index.core.schema import (
TransformComponent,
)
from metagpt.config2 import config
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
@ -36,6 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -44,6 +47,9 @@ from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.common import import_class
@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
file_extractor = cls._get_file_extractor()
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
cls._fix_document_metadata(documents)
transformations = transformations or cls._default_transformations()
@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor() -> dict[str:BaseReader]:
"""
Get the file extractor.
Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index.
Returns:
dict[file_type: BaseReader]
"""
file_extractor: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor[".pdf"] = pdf_parser
return file_extractor

View file

@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
@ -17,6 +18,7 @@ from metagpt.rag.schema import (
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)
@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
MilvusIndexConfig: self._create_milvus
}
super().__init__(creators)
@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

View file

@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
@ -20,6 +21,7 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
@ -27,6 +29,7 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -56,6 +59,7 @@ class RetrieverFactory(ConfigBasedFactory):
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
}
super().__init__(creators)
@ -76,6 +80,11 @@ class RetrieverFactory(ConfigBasedFactory):
return index.as_retriever()
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)
return MilvusRetriever(**config.model_dump())
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._build_faiss_index(config, **kwargs)
@ -128,6 +137,12 @@ class RetrieverFactory(ConfigBasedFactory):
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

View file

@ -0,0 +1,3 @@
from metagpt.rag.parsers.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -0,0 +1,139 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Optional, Union
from llama_index.core import Document
from llama_index.core.async_utils import run_jobs
from llama_index.core.readers.base import BaseReader
from metagpt.logs import logger
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.omniparse_client import OmniParseClient
class OmniParse(BaseReader):
"""OmniParse"""
def __init__(
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: OmniParse Base URL for the API.
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
"""
self.parse_options = parse_options or OmniParseOptions()
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
@property
def parse_type(self):
return self.parse_options.parse_type
@property
def result_type(self):
return self.parse_options.result_type
@parse_type.setter
def parse_type(self, parse_type: Union[str, OmniParseType]):
if isinstance(parse_type, str):
parse_type = OmniParseType(parse_type)
self.parse_options.parse_type = parse_type
@result_type.setter
def result_type(self, result_type: Union[str, ParseResultType]):
if isinstance(result_type, str):
result_type = ParseResultType(result_type)
self.parse_options.result_type = result_type
async def _aload_data(
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Returns:
List[Document]
"""
try:
if self.parse_type == OmniParseType.PDF:
# pdf parse
parsed_result = await self.omniparse_client.parse_pdf(file_path)
else:
# other parse use omniparse_client.parse_document
# For compatible byte data, additional filename is required
extra_info = extra_info or {}
filename = extra_info.get("filename")
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
# Get the specified structured data based on result_type
content = getattr(parsed_result, self.result_type)
docs = [
Document(
text=content,
metadata=extra_info or {},
)
]
except Exception as e:
logger.error(f"OMNI Parse Error: {e}")
docs = []
return docs
async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls _aload_data for processing.
Returns:
List[Document]
"""
docs = []
if isinstance(file_path, (str, bytes, Path)):
# Processing single file
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
# Concurrently process multiple files
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
docs = [doc for docs in doc_ret_list for doc in docs]
return docs
def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls aload_data for processing.
Returns:
List[Document]
"""
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -0,0 +1,17 @@
"""Milvus retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class MilvusRetriever(VectorIndexRetriever):
"""Milvus retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Milvus automatically saves, so there is no need to implement."""

View file

@ -1,14 +1,14 @@
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, Union
from typing import Any, ClassVar, List, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
@ -214,3 +254,51 @@ class ObjectNode(TextNode):
)
return metadata.model_dump()
class OmniParseType(str, Enum):
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="Number of concurrent requests for multiple files",
)
class OminParseImage(BaseModel):
image: str = Field(default="", description="image str bytes")
image_name: str = Field(default="", description="image name")
image_info: Optional[dict] = Field(default={}, description="image info")
class OmniParsedResult(BaseModel):
markdown: str = Field(default="", description="markdown text")
text: str = Field(default="", description="plain text")
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
metadata: Optional[dict] = Field(default={}, description="metadata")
@model_validator(mode="before")
def set_markdown(cls, values):
if not values.get("markdown"):
values["markdown"] = values.get("text")
return values

View file

@ -6,6 +6,7 @@
@File : architect.py
"""
from metagpt.actions import WritePRD
from metagpt.actions.design_api import WriteDesign
from metagpt.roles.role import Role

View file

@ -80,19 +80,17 @@ class InvoiceOCRAssistant(Role):
raise Exception("Invoice file not uploaded")
resp = await todo.run(file_path)
actions = list(self.actions)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self.set_actions([GenerateTable, ReplyQuestion])
actions.extend([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self.set_actions([GenerateTable])
self.set_todo(None)
actions.append(GenerateTable)
self.set_actions(actions)
self.rc.max_react_loop = len(self.actions)
content = INVOICE_OCR_SUCCESS
resp = OCRResults(ocr_result=json.dumps(resp))
msg = Message(content=content, instruct_content=resp)
self.rc.memory.add(msg)
return await super().react()
elif isinstance(todo, GenerateTable):
ocr_results: OCRResults = msg.instruct_content
resp = await todo.run(json.loads(ocr_results.ocr_result), self.filename)

View file

@ -7,6 +7,7 @@
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
"""
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.roles.role import Role, RoleReactMode

View file

@ -6,6 +6,7 @@
@File : project_manager.py
"""
from metagpt.actions import WriteTasks
from metagpt.actions.design_api import WriteDesign
from metagpt.roles.role import Role

View file

@ -15,6 +15,7 @@
of SummarizeCode.
"""
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import MESSAGE_ROUTE_TO_NONE

View file

@ -58,7 +58,9 @@ class Researcher(Role):
)
elif isinstance(todo, WebBrowseAndSummarize):
links = instruct_content.links
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
todos = (
todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items() if url
)
if self.enable_concurrency:
summaries = await asyncio.gather(*todos)
else:

View file

@ -170,7 +170,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self.llm.cost_manager = self.context.cost_manager
self._watch(kwargs.pop("watch", [UserRequirement]))
if not self.rc.watch:
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
self.recovered = True
@ -421,8 +422,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
news = []
if self.recovered:
news = [self.latest_observed_msg] if self.latest_observed_msg else []
if self.recovered and self.latest_observed_msg:
news = self.rc.memory.find_news(observed=[self.latest_observed_msg], k=10)
if not news:
news = self.rc.msg_buffer.pop_all()
# Store the read messages in your own memory to prevent duplicate processing.

View file

@ -1,87 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/13 12:23
@Author : femto Zheng
@File : sk_agent.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message filtering.
"""
from typing import Any, Callable, Union
from pydantic import Field
from semantic_kernel import Kernel
from semantic_kernel.planning import SequentialPlanner
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
from metagpt.actions import UserRequirement
from metagpt.actions.execute_task import ExecuteTask
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.make_sk_kernel import make_sk_kernel
class SkAgent(Role):
"""
Represents an SkAgent implemented using semantic kernel
Attributes:
name (str): Name of the SkAgent.
profile (str): Role profile, default is 'sk_agent'.
goal (str): Goal of the SkAgent.
constraints (str): Constraints for the SkAgent.
"""
name: str = "Sunshine"
profile: str = "sk_agent"
goal: str = "Execute task based on passed in task description"
constraints: str = ""
plan: Plan = Field(default=None, exclude=True)
planner_cls: Any = None
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
kernel: Kernel = Field(default_factory=Kernel)
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
import_skill: Callable = Field(default=None, exclude=True)
def __init__(self, **data: Any) -> None:
"""Initializes the Engineer role with given attributes."""
super().__init__(**data)
self.set_actions([ExecuteTask()])
self._watch([UserRequirement])
self.kernel = make_sk_kernel()
# how funny the interface is inconsistent
if self.planner_cls == BasicPlanner or self.planner_cls is None:
self.planner = BasicPlanner()
elif self.planner_cls in [SequentialPlanner, ActionPlanner]:
self.planner = self.planner_cls(self.kernel)
else:
raise Exception(f"Unsupported planner of type {self.planner_cls}")
self.import_semantic_skill_from_directory = self.kernel.import_semantic_skill_from_directory
self.import_skill = self.kernel.import_skill
async def _think(self) -> None:
self._set_state(0)
# how funny the interface is inconsistent
if isinstance(self.planner, BasicPlanner):
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
logger.info(self.plan.generated_plan)
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
async def _act(self) -> Message:
# how funny the interface is inconsistent
result = None
if isinstance(self.planner, BasicPlanner):
result = await self.planner.execute_plan_async(self.plan, self.kernel)
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
result = (await self.plan.invoke_async()).result
logger.info(result)
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
self.rc.memory.add(msg)
return msg

View file

@ -4,6 +4,7 @@
import asyncio
from pathlib import Path
import agentops
import typer
from metagpt.const import CONFIG_ROOT
@ -38,6 +39,9 @@ def generate_repo(
)
from metagpt.team import Team
if config.agentops_api_key != "":
agentops.init(config.agentops_api_key, tags=["software_company"])
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
ctx = Context(config=config)
@ -68,6 +72,9 @@ def generate_repo(
company.run_project(idea)
asyncio.run(company.run(n_round=n_round))
if config.agentops_api_key != "":
agentops.end_session("Success")
return ctx.repo

View file

@ -126,6 +126,9 @@ class Team(BaseModel):
self.run_project(idea=idea, send_to=send_to)
while n_round > 0:
if self.env.is_idle:
logger.debug("All roles are idle.")
break
n_round -= 1
self._check_balance()
await self.env.run()

View file

@ -6,37 +6,15 @@
@File : search_engine.py
"""
import importlib
from typing import Callable, Coroutine, Literal, Optional, Union, overload
from typing import Annotated, Callable, Coroutine, Literal, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from semantic_kernel.skill_definition import sk_function
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
class SkSearchEngine:
"""A search engine class for executing searches.
Attributes:
search_engine: The search engine instance used for executing searches.
"""
def __init__(self, **kwargs):
self.search_engine = SearchEngine(**kwargs)
@sk_function(
description="searches results from Google. Useful when you need to find short "
"and succinct answers about a specific topic. Input should be a search query.",
name="searchAsync",
input_description="search",
)
async def run(self, query: str) -> str:
result = await self.search_engine.run(query)
return result
class SearchEngine(BaseModel):
"""A model for configuring and executing searches with different search engines.
@ -51,7 +29,9 @@ class SearchEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
run_func: Annotated[
Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]], Field(exclude=True)
] = None
api_key: Optional[str] = None
proxy: Optional[str] = None

View file

@ -87,8 +87,11 @@ class SerpAPIWrapper(BaseModel):
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
if res["error"] == "Google hasn't returned any results for this query.":
toret = "No good search result found"
else:
raise ValueError(f"Got error from SerpAPI: {res['error']}")
elif "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]

View file

@ -3,9 +3,9 @@
from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, Optional, Union, overload
from typing import Annotated, Any, Callable, Coroutine, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.tools import WebBrowserEngineType
@ -29,7 +29,10 @@ class WebBrowserEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
run_func: Annotated[
Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]],
Field(exclude=True),
] = None
proxy: Optional[str] = None
@model_validator(mode="after")

View file

@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
def __init__(self, name: str | Path, **kwargs):
super().__init__(name=str(name), **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
@ -112,8 +112,28 @@ class DiGraphRepository(GraphRepository):
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self.load_json(data)
def load_json(self, val: str):
"""
Loads a JSON-encoded string representing a graph structure and updates
the internal repository (_repo) with the parsed graph.
Args:
val (str): A JSON-encoded string representing a graph structure.
Returns:
self: Returns the instance of the class with the updated _repo attribute.
Raises:
TypeError: If val is not a valid JSON string or cannot be parsed into
a valid graph structure.
"""
if not val:
return self
m = json.loads(val)
self._repo = networkx.node_link_graph(m)
return self
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
@ -126,9 +146,7 @@ class DiGraphRepository(GraphRepository):
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
if pathname.exists():
await graph.load(pathname=pathname)
return graph

View file

@ -1,32 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/13 12:29
@Author : femto Zheng
@File : make_sk_kernel.py
"""
import semantic_kernel as sk
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import (
AzureChatCompletion,
)
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import (
OpenAIChatCompletion,
)
from metagpt.config2 import config
def make_sk_kernel():
kernel = sk.Kernel()
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
)
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),
)
return kernel

View file

@ -0,0 +1,239 @@
import mimetypes
import os
from pathlib import Path
from typing import Union
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
This client interacts with the OmniParse server to parse different types of media, documents.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
self.max_timeout = max_timeout
self.parse_media_endpoint = "/parse_media"
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
headers = headers or {}
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
file_input: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(file_input, {".pdf"})
# parse_pdf supports parsing by accepting only the byte data of the file.
file_info = await self.get_file_info(file_input, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
Verify the file extension.
Args:
file_input: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `file_input` is byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
"""
verify_file_path = None
if isinstance(file_input, (str, Path)):
verify_file_path = str(file_input)
elif isinstance(file_input, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
@staticmethod
async def get_file_info(
file_input: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
Get file information.
Args:
file_input: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
Notes:
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(file_input, (str, Path)):
filename = os.path.basename(str(file_input))
file_bytes = await aread_bin(file_input)
if only_bytes:
return file_bytes
mime_type = mimetypes.guess_type(file_input)[0]
return filename, file_bytes, mime_type
elif isinstance(file_input, bytes):
if only_bytes:
return file_input
if not bytes_filename:
raise ValueError("bytes_filename must be set when passing bytes")
mime_type = mimetypes.guess_type(bytes_filename)[0]
return bytes_filename, file_input, mime_type
else:
raise ValueError("file_input must be a string (file path) or bytes.")

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import traceback
from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
import redis.asyncio as aioredis
from metagpt.configs.redis_config import RedisConfig
from metagpt.logs import logger

View file

@ -11,8 +11,10 @@ from multiprocessing import Pipe
class StreamPipe:
parent_conn, child_conn = Pipe()
finish: bool = False
def __init__(self, name=None):
self.name = name
self.parent_conn, self.child_conn = Pipe()
self.finish: bool = False
format_data = {
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",

View file

@ -41,11 +41,19 @@ TOKEN_COSTS = {
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-2024-08-06": {"prompt": 0.0025, "completion": 0.01},
"o1-preview": {"prompt": 0.015, "completion": 0.06},
"o1-preview-2024-09-12": {"prompt": 0.015, "completion": 0.06},
"o1-mini": {"prompt": 0.003, "completion": 0.012},
"o1-mini-2024-09-12": {"prompt": 0.003, "completion": 0.012},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
"gemini-1.5-flash": {"prompt": 0.000075, "completion": 0.0003},
"gemini-1.5-pro": {"prompt": 0.0035, "completion": 0.0105},
"gemini-1.0-pro": {"prompt": 0.0005, "completion": 0.0015},
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
@ -69,15 +77,20 @@ TOKEN_COSTS = {
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"openai/o1-preview": {"prompt": 0.015, "completion": 0.06},
"openai/o1-mini": {"prompt": 0.003, "completion": 0.012},
"anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075},
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
"google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
# For ark model https://www.volcengine.com/docs/82379/1099320
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
"doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00014},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0013},
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
}
@ -138,8 +151,17 @@ QIANFAN_ENDPOINT_TOKEN_COSTS = {
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
Some new model published by Alibaba will be prioritized to be released on the Model Studio instead of the Dashscope.
Token price on Model Studio shows on https://help.aliyun.com/zh/model-studio/getting-started/models#ced16cb6cdfsy
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen2.5-72b-instruct": {"prompt": 0.00057, "completion": 0.0017}, # per 1k tokens
"qwen2.5-32b-instruct": {"prompt": 0.0005, "completion": 0.001},
"qwen2.5-14b-instruct": {"prompt": 0.00029, "completion": 0.00086},
"qwen2.5-7b-instruct": {"prompt": 0.00014, "completion": 0.00029},
"qwen2.5-3b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2.5-1.5b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2.5-0.5b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428},
"qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001},
"qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286},
@ -190,16 +212,24 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
# https://console.volcengine.com/ark/region:ark+cn-beijing/model
DOUBAO_TOKEN_COSTS = {
"doubao-lite": {"prompt": 0.0003, "completion": 0.0006},
"doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010},
"doubao-pro": {"prompt": 0.0008, "completion": 0.0020},
"doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090},
"doubao-lite": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-128k": {"prompt": 0.00011, "completion": 0.00014},
"doubao-pro": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-128k": {"prompt": 0.00071, "completion": 0.0013},
"doubao-pro-256k": {"prompt": 0.00071, "completion": 0.0013},
}
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
"gpt-4o-2024-05-13": 128000,
"o1-preview": 128000,
"o1-preview-2024-09-12": 128000,
"o1-mini": 128000,
"o1-mini-2024-09-12": 128000,
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini-2024-07-18": 128000,
"gpt-4o-mini": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
@ -222,7 +252,9 @@ TOKEN_MAX = {
"text-embedding-ada-002": 8192,
"glm-3-turbo": 128000,
"glm-4": 128000,
"gemini-pro": 32768,
"gemini-1.5-flash": 1000000,
"gemini-1.5-pro": 2000000,
"gemini-1.0-pro": 32000,
"moonshot-v1-8k": 8192,
"moonshot-v1-32k": 32768,
"moonshot-v1-128k": 128000,
@ -246,6 +278,11 @@ TOKEN_MAX = {
"llama3-70b-8192": 8192,
"openai/gpt-3.5-turbo-0125": 16385,
"openai/gpt-4-turbo-preview": 128000,
"openai/o1-preview": 128000,
"openai/o1-mini": 128000,
"anthropic/claude-3-opus": 200000,
"anthropic/claude-3.5-sonnet": 200000,
"google/gemini-pro-1.5": 4000000,
"deepseek-chat": 32768,
"deepseek-coder": 16385,
"doubao-lite-4k-240515": 4000,
@ -255,6 +292,13 @@ TOKEN_MAX = {
"doubao-pro-32k-240515": 32000,
"doubao-pro-128k-240515": 128000,
# Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20
"qwen2.5-72b-instruct": 131072,
"qwen2.5-32b-instruct": 131072,
"qwen2.5-14b-instruct": 131072,
"qwen2.5-7b-instruct": 131072,
"qwen2.5-3b-instruct": 32768,
"qwen2.5-1.5b-instruct": 32768,
"qwen2.5-0.5b-instruct": 32768,
"qwen2-57b-a14b-instruct": 32768,
"qwen2-72b-instruct": 131072,
"qwen2-7b-instruct": 32768,
@ -354,13 +398,19 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"claude-3-5-sonnet-20240620"
"gpt-4o-mini-2024-07-18",
"o1-preview",
"o1-preview-2024-09-12",
"o1-mini",
"o1-mini-2024-09-12",
}:
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
tokens_per_name = 1

View file

@ -19,7 +19,7 @@ beautifulsoup4==4.12.3
pandas==2.1.1
pydantic>=2.5.3
#pygame==2.1.3
#pymilvus==2.2.8
# pymilvus==2.4.6
# pytest==7.2.2 # test extras require
python_docx==0.8.11
PyYAML==6.0.1
@ -43,7 +43,9 @@ wrapt==1.15.0
#aiohttp_jinja2
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py
#aioboto3~=12.4.0 # Used by metagpt/utils/s3.py
aioredis~=2.0.1 # Used by metagpt/utils/redis.py
redis~=5.0.0 # Used by metagpt/utils/redis.py
curl-cffi~=0.7.0
httplib2~=0.22.0
websocket-client~=1.8.0
aiofiles==23.2.1
gitpython==3.1.40
@ -66,9 +68,14 @@ anytree
ipywidgets==8.1.1
Pillow
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan~=0.3.16
qianfan~=0.4.4
dashscope~=1.19.3
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
volcengine-python-sdk[ark]~=1.0.94
# llama-index-vector-stores-elasticsearch~=0.2.5 # Used by `metagpt/memory/longterm_memory.py`
# llama-index-vector-stores-chroma~=0.1.10 # Used by `metagpt/memory/longterm_memory.py`
gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30
agentops

View file

@ -43,32 +43,9 @@ extras_require = {
"llama-index-postprocessor-cohere-rerank==0.1.4",
"llama-index-postprocessor-colbert-rerank==0.1.1",
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
# "llama-index-vector-stores-milvus==0.1.23",
"docx2txt==0.8",
],
"android_assistant": [
"pyshine==0.0.9",
"opencv-python==4.6.0.66",
"protobuf<3.20,>=3.9.2",
"modelscope",
"tensorflow==2.9.1; os_name == 'linux'",
"tensorflow==2.9.1; os_name == 'win32'",
"tensorflow-macos==2.9; os_name == 'darwin'",
"keras==2.9.0",
"torch",
"torchvision",
"transformers",
"opencv-python",
"matplotlib",
"pycocotools",
"SentencePiece",
"tf_slim",
"tf_keras",
"pyclipper",
"shapely",
"groundingdino-py",
"datasets==2.18.0",
"clip-openai",
],
}
extras_require["test"] = [
@ -85,6 +62,9 @@ extras_require["test"] = [
"aioboto3~=12.4.0",
"gradio==3.0.0",
"grpcio-status==1.48.2",
"grpcio-tools==1.48.2",
"google-api-core==2.17.1",
"protobuf==3.19.6",
"pylint==3.0.3",
"pybrowsers",
]
@ -93,7 +73,30 @@ extras_require["pyppeteer"] = [
"pyppeteer>=1.0.2"
] # pyppeteer is unmaintained and there are conflicts with dependencies
extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pre-commit~=3.6.0"],)
extras_require["android_assistant"] = [
"pyshine==0.0.9",
"opencv-python==4.6.0.66",
"protobuf<3.20,>=3.9.2",
"modelscope",
"tensorflow==2.9.1; os_name == 'linux'",
"tensorflow==2.9.1; os_name == 'win32'",
"tensorflow-macos==2.9; os_name == 'darwin'",
"keras==2.9.0",
"torch",
"torchvision",
"transformers",
"opencv-python",
"matplotlib",
"pycocotools",
"SentencePiece",
"tf_slim",
"tf_keras",
"pyclipper",
"shapely",
"groundingdino-py",
"datasets==2.18.0",
"clip-openai",
]
setup(
name="metagpt",
@ -107,7 +110,7 @@ setup(
license="MIT",
keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming",
packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]),
python_requires=">=3.9",
python_requires=">=3.9, <3.12",
install_requires=requirements,
extras_require=extras_require,
cmdclass={

View file

@ -0,0 +1,27 @@
llm:
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_gpt-3.5-turbo_BASE_URL"
api_key: "YOUR_gpt-3.5-turbo_API_KEY"
model: "gpt-3.5-turbo" # or gpt-3.5-turbo
# proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
models:
"YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_1_BASE_URL"
api_key: "YOUR_MODEL_1_API_KEY"
# proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
"YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_2_BASE_URL"
api_key: "YOUR_MODEL_2_API_KEY"
proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's

View file

@ -6,7 +6,7 @@
@File : test_action_node.py
"""
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
import pytest
from pydantic import BaseModel, Field, ValidationError
@ -302,6 +302,19 @@ def test_action_node_from_pydantic_and_print_everything():
assert "tasks" in code, "tasks should be in code"
def test_optional():
mapping = {
"Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)),
"Task list": (Optional[List[str]], None),
"Plan": (Optional[str], ""),
"Anything UNCLEAR": (Optional[str], None),
}
m = {"Anything UNCLEAR": "a"}
t = ActionNode.create_model_class("test_class_1", mapping)
t1 = t(**m)
assert t1
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()
pytest.main([__file__, "-s"])

View file

View file

@ -0,0 +1,34 @@
import pytest
from metagpt.actions.talk_action import TalkAction
from metagpt.configs.models_config import ModelsConfig
from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH
from metagpt.utils.common import aread, awrite
@pytest.mark.asyncio
async def test_models_configs(context):
default_model = ModelsConfig.default()
assert default_model is not None
models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml")
assert models
default_models = ModelsConfig.default()
backup = ""
if not default_models.models:
backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml")
test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml")
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data)
try:
action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1")
assert action.private_llm.config.model == "YOUR_MODEL_NAME_1"
assert context.config.llm.model != "YOUR_MODEL_NAME_1"
finally:
if backup:
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,48 @@
import random
import pytest
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
seed_value = 42
random.seed(seed_value)
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
ids = [f"doc_{i}" for i in range(10)]
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)
collection_name = "TestCollection"
milvus_store.create_collection(collection_name, dim=8)
milvus_store.add(collection_name, ids, vectors, metadata)
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
assert len(search_results) > 0
first_result = search_results[0]
assert first_result["id"] == "doc_0"
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"
milvus_store.delete(collection_name, _ids=["doc_0"])
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
assert deleted_results[0]["id"] != "doc_0"
milvus_store.client.drop_collection(collection_name)

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : __init__.py
"""

View file

@ -1,32 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import pytest
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from metagpt.actions import UserRequirement
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_action_planner():
role = SkAgent(planner_cls=ActionPlanner)
# let's give the agent 4 skills
role.import_skill(MathSkill(), "math")
role.import_skill(FileIOSkill(), "fileIO")
role.import_skill(TimeSkill(), "time")
role.import_skill(TextSkill(), "text")
task = "What is the sum of 110 and 990?"
role.put_message(Message(content=task, cause_by=UserRequirement))
await role._observe()
await role._think() # it will choose mathskill.Add
assert "1100" == (await role._act()).content

View file

@ -1,37 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import pytest
from semantic_kernel.core_skills import TextSkill
from metagpt.actions import UserRequirement
from metagpt.const import SKILL_DIRECTORY
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_basic_planner():
task = """
Tomorrow is Valentine's day. I need to come up with a few date ideas. She speaks French so write it in French.
Convert the text to uppercase"""
role = SkAgent()
# let's give the agent some skills
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "SummarizeSkill")
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
role.put_message(Message(content=task, cause_by=UserRequirement))
await role._observe()
await role._think()
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate
assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result
assert "WriterSkill.Translate" in role.plan.generated_plan.result
# assert "SALUT" in (await role._act()).content #content will be some French

View file

@ -64,7 +64,7 @@ def is_subset(subset, superset) -> bool:
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
is_subset(subset, superset)
```
>>>False
"""
for key, value in subset.items():
if key not in superset:

View file

@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -37,6 +38,10 @@ class TestSimpleEngine:
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
def test_from_docs(
self,
mocker,
@ -44,6 +49,7 @@ class TestSimpleEngine:
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
@ -53,6 +59,8 @@ class TestSimpleEngine:
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor
# Setup
input_dir = "test_dir"
@ -75,7 +83,9 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
@ -298,3 +308,17 @@ class TestSimpleEngine:
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj
def test_get_file_extractor(self, mocker):
# mock no omniparse config
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
mock_omniparse_config.base_url = ""
file_extractor = SimpleEngine._get_file_extractor()
assert file_extractor == {}
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor()
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -7,7 +7,7 @@ from metagpt.rag.schema import (
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
FAISSIndexConfig, MilvusIndexConfig,
)
@ -20,6 +20,10 @@ class TestRAGIndexFactory:
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def milvus_config(self):
return MilvusIndexConfig(uri="", collection_name="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@ -65,6 +69,16 @@ class TestRAGIndexFactory:
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
# Mock
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")
# Exec
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)
# Assert
mock_milvus_store.assert_called_once()
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")

View file

@ -5,6 +5,7 @@ from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -12,12 +13,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -41,6 +44,10 @@ class TestRetrieverFactory:
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_milvus_vector_store(self, mocker):
return mocker.MagicMock(spec=MilvusVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@ -91,6 +98,14 @@ class TestRetrieverFactory:
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding):
mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection")
mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store)
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, MilvusRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)

View file

@ -0,0 +1,118 @@
import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.omniparse_client import OmniParseClient
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
class TestOmniParseClient:
parse_client = OmniParseClient()
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, mock_request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_document(self, mock_request_parse):
mock_content = "#test title\ntest_parse_document"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
with open(TEST_DOCX, "rb") as f:
file_bytes = f.read()
with pytest.raises(ValueError):
# bytes data must provide bytes_filename
await self.parse_client.parse_document(file_bytes)
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_video(self, mock_request_parse):
mock_content = "#test title\ntest_parse_video"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
with pytest.raises(ValueError):
# Wrong file extension test
await self.parse_client.parse_video(TEST_DOCX)
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
@pytest.mark.asyncio
async def test_parse_audio(self, mock_request_parse):
mock_content = "#test title\ntest_parse_audio"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
class TestOmniParse:
@pytest.fixture
def mock_omniparse(self):
parser = OmniParse(
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
)
return parser
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):
# mock
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
# single file
documents = mock_omniparse.load_data(file_path=TEST_PDF)
doc = documents[0]
assert isinstance(doc, Document)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
# multi files
file_paths = [TEST_DOCX, TEST_PDF]
mock_omniparse.parse_type = OmniParseType.DOCUMENT
documents = await mock_omniparse.aload_data(file_path=file_paths)
doc = documents[0]
# assert
assert isinstance(doc, Document)
assert len(documents) == len(file_paths)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown

View file

@ -1,3 +1,4 @@
import tempfile
from pathlib import Path
from random import random
from tempfile import TemporaryDirectory
@ -6,6 +7,7 @@ import pytest
from metagpt.actions.research import CollectLinks
from metagpt.roles import researcher
from metagpt.team import Team
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -57,5 +59,13 @@ def test_write_report(mocker, context):
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
@pytest.mark.asyncio
async def test_serialize():
team = Team()
team.hire([researcher.Researcher()])
with tempfile.TemporaryDirectory() as dirname:
team.serialize(Path(dirname) / "team.json")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,6 +5,7 @@ import pytest
from metagpt.provider.human_provider import HumanProvider
from metagpt.roles.role import Role
from metagpt.schema import Message, UserMessage
def test_role_desc():
@ -18,5 +19,15 @@ def test_role_human(context):
assert isinstance(role.llm, HumanProvider)
@pytest.mark.asyncio
async def test_recovered():
role = Role(profile="Tester", desc="Tester", recovered=True)
role.put_message(UserMessage(content="2"))
role.latest_observed_msg = Message(content="1")
await role._observe()
await role._observe()
assert role.rc.msg_buffer.empty()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
@ -55,6 +55,7 @@ def test_environment_serdeser(context):
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
def test_environment_serdeser_v2(context):
@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context):
assert isinstance(role, ProjectManager)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch
def test_environment_serdeser_save(context):
@ -85,3 +87,8 @@ def test_environment_serdeser_save(context):
new_env: Environment = Environment(**env_dict, context=context)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -28,9 +28,9 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_roles(context):
role_a = RoleA()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
role_b = RoleB()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])

View file

@ -8,9 +8,9 @@ from typing import Optional
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.fix_bug import FixBug
from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
@ -68,7 +68,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.set_actions([ActionPass])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
class RoleB(Role):
@ -93,7 +93,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

View file

@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.roles.sk_agent import SkAgent
@pytest.mark.asyncio
async def test_sk_agent_serdeser():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict
assert "planner" in ser_role_dict
new_role = SkAgent(**ser_role_dict)
assert new_role.name == "Sunshine"
assert len(new_role.actions) == 1

View file

@ -29,3 +29,7 @@ def div(a: int, b: int = 0):
assert new_action.name == "WriteCodeReview"
await new_action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,8 +14,8 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config
def test_config_1():
cfg = Config.default()
llm = cfg.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if cfg.llm.api_type == LLMType.OPENAI:
assert llm is not None
def test_config_from_dict():

View file

@ -53,8 +53,8 @@ def test_context_1():
def test_context_2():
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if ctx.config.llm.api_type == LLMType.OPENAI:
assert llm is not None
kwargs = ctx.kwargs
assert kwargs is not None

View file

@ -114,7 +114,6 @@ class MockLLM(OriginalLLM):
raise ValueError(
"In current test setting, api call is not allowed, you should properly mock your tests, "
"or add expected api response in tests/data/rsp_cache.json. "
f"The prompt you want for api call: {msg_key}"
)
# Call the original unmocked method
rsp = await ask_func(*args, **kwargs)