Merge branch 'main' into main

This commit is contained in:
better629 2024-10-17 16:25:31 +08:00 committed by GitHub
commit d99054ab5e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
98 changed files with 1697 additions and 496 deletions

View file

@ -243,12 +243,19 @@ class ActionNode:
"""基于pydantic v2的模型动态生成用来检验结果类型正确性"""
def check_fields(cls, values):
required_fields = set(mapping.keys())
all_fields = set(mapping.keys())
required_fields = set()
for k, v in mapping.items():
type_v, field_info = v
if ActionNode.is_optional_type(type_v):
continue
required_fields.add(k)
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
unrecognized_fields = set(values.keys()) - required_fields
unrecognized_fields = set(values.keys()) - all_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
@ -850,3 +857,12 @@ class ActionNode:
root_node.add_child(child_node)
return root_node
@staticmethod
def is_optional_type(tp) -> bool:
"""Return True if `tp` is `typing.Optional[...]`"""
if typing.get_origin(tp) is Union:
args = typing.get_args(tp)
non_none_types = [arg for arg in args if arg is not type(None)]
return len(non_none_types) == 1 and len(args) == 2
return False

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : design_api_an.py
"""
from typing import List
from typing import List, Optional
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC1, MMC2
@ -45,9 +45,10 @@ REFINED_FILE_LIST = ActionNode(
example=["main.py", "game.py", "new_feature.py"],
)
# optional,because low success reproduction of class diagram in non py project.
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
key="Data structures and interfaces",
expected_type=str,
expected_type=Optional[str],
instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type"
" annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. "
"The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.",
@ -66,7 +67,7 @@ REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
PROGRAM_CALL_FLOW = ActionNode(
key="Program call flow",
expected_type=str,
expected_type=Optional[str],
instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE "
"accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.",
example=MMC2,

View file

@ -5,14 +5,14 @@
@Author : alexanderwu
@File : project_management_an.py
"""
from typing import List
from typing import List, Optional
from metagpt.actions.action_node import ActionNode
REQUIRED_PACKAGES = ActionNode(
key="Required packages",
expected_type=List[str],
instruction="Provide required packages in requirements.txt format.",
expected_type=Optional[List[str]],
instruction="Provide required third-party packages in requirements.txt format.",
example=["flask==1.1.2", "bcrypt==3.2.0"],
)

View file

@ -161,6 +161,8 @@ class CollectLinks(Action):
"""
max_results = max(num_results * 2, 6)
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
if len(results) == 0:
return []
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
logger.debug(prompt)

View file

@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
end", "Anything UNCLEAR": "目前项目要求明确没有不清楚的地方"}
## Tasks
{"Required packages": ["无需Python"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
{"Required packages": ["无需第三方"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式确保游戏界面美观"], ["main.js", "包含Main类负责初始化游戏和绑定事件"], ["game.js", "包含Game类负责游戏逻辑如开始游戏、移动方块等"], ["storage.js", "包含Storage类用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Code Files
----- index.html

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel):
# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()
# omniparse
omniparse: OmniParseConfig = OmniParseConfig()
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""
@ -69,6 +73,7 @@ class Config(CLIParams, YamlModel):
workspace: WorkspaceConfig = WorkspaceConfig()
enable_longterm_memory: bool = False
code_review_k_times: int = 2
agentops_api_key: str = ""
# Will be removed in the future
metagpt_tti_url: str = ""

View file

@ -0,0 +1,6 @@
from metagpt.utils.yaml_model import YamlModel
class OmniParseConfig(YamlModel):
api_key: str = ""
base_url: str = ""

View file

@ -33,7 +33,7 @@ class LLMType(Enum):
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
def __missing__(self, key):
return self.OPENAI
@ -90,6 +90,9 @@ class LLMConfig(YamlModel):
# Cost Control
calc_usage: bool = True
# For Messages Control
use_system_prompt: bool = True
@field_validator("api_key")
@classmethod
def check_llm_key(cls, v):

View file

@ -1,14 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/1 11:59
@Author : alexanderwu
@File : const.py
@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for
common properties in the Message.
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
"""
import os
from pathlib import Path

View file

@ -0,0 +1,99 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from metagpt.document_store.base_store import BaseStore
@dataclass
class MilvusConnection:
"""
Args:
uri: milvus url
token: milvus token
"""
uri: str = None
token: str = None
class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
try:
from pymilvus import MilvusClient
except ImportError:
raise Exception("Please install pymilvus first.")
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(uri=connect.uri, token=connect.token)
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
from pymilvus import DataType
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema,
)
@staticmethod
def build_filter(key, value) -> str:
if isinstance(value, str):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f"{key} in {value}"
else:
filter_expression = f"{key} == {value}"
return filter_expression
def search(
self,
collection_name: str,
query: List[float],
filter: Dict = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
print(filter_expression)
res = self.client.search(
collection_name=collection_name,
data=[query],
filter=filter_expression,
limit=limit,
output_fields=output_fields,
)[0]
return res
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
data = dict()
for i, id in enumerate(_ids):
data["id"] = id
data["vector"] = vector[i]
data["metadata"] = metadata[i]
self.client.upsert(collection_name=collection_name, data=data)
def delete(self, collection_name: str, _ids: List[str]):
self.client.delete(collection_name=collection_name, ids=_ids)
def write(self, *args, **kwargs):
pass

View file

@ -266,7 +266,7 @@ class STRole(Role):
# We will order our percept based on the distance, with the closest ones
# getting priorities.
percept_events_list = []
# First, we put all events that are occuring in the nearby tiles into the
# First, we put all events that are occurring in the nearby tiles into the
# percept_events_list
for tile in nearby_tiles:
tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))

View file

@ -81,7 +81,7 @@ class Memory(BaseModel):
return self.storage[-k:]
def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""find news (previously unseen messages) from the the most recent k memories, from all memories when k=0"""
"""find news (previously unseen messages) from the most recent k memories, from all memories when k=0"""
already_observed = self.get(k)
news: list[Message] = []
for i in observed:

View file

@ -1,12 +1,33 @@
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provider for volcengine.
See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model
config2.yaml example:
```yaml
llm:
base_url: "https://ark.cn-beijing.volces.com/api/v3"
api_type: "ark"
endpoint: "ep-2024080514****-d****"
api_key: "d47****b-****-****-****-d6e****0fd77"
pricing_plan: "doubao-lite"
```
"""
from typing import Optional, Union
from pydantic import BaseModel
from volcenginesdkarkruntime import AsyncArk
from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper
from volcenginesdkarkruntime._streaming import AsyncStream
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
from metagpt.configs.llm_config import LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS
@register_provider(LLMType.ARK)
@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM):
https://www.volcengine.com/docs/82379/1263482
"""
aclient: Optional[AsyncArk] = None
def _init_client(self):
"""SDK: https://github.com/openai/openai-python#async-usage"""
self.model = (
self.config.endpoint or self.config.model
) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncArk(**kwargs)
def _make_client_kwargs(self) -> dict:
kvs = {
"ak": self.config.access_key,
"sk": self.config.secret_key,
"api_key": self.config.api_key,
"base_url": self.config.base_url,
}
kwargs = {k: v for k, v in kvs.items() if v}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs:
self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS)
if model in self.cost_manager.token_costs:
self.pricing_plan = model
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage
)
usage = None
collected_messages = []
@ -30,7 +85,7 @@ class ArkLLM(OpenAILLM):
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
usage = CompletionUsage(**chunk.usage)
usage = chunk.usage
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)

View file

@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = {
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.

View file

@ -1,5 +1,7 @@
import asyncio
import json
from typing import Literal
from functools import partial
from typing import List, Literal
import boto3
from botocore.eventstream import EventStream
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
logger.warning("Amazon bedrock doesn't support asynchronous now")
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
]
logger.info("\n" + "\n".join(summaries))
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
async def invoke_model(self, request_body: str) -> dict:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
return response
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
response_body = self.invoke_model(request_body)
response_body = await self.invoke_model(request_body)
return response_body
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
return full_text
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
stream_response = await self.invoke_model_with_response_stream(request_body)
collected_content = await self._get_stream_response_body(stream_response)
log_llm_stream("\n")
full_text = ("".join(collected_content)).lstrip()
return full_text
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
response_body = json.loads(response["body"].read())
return response_body
async def _get_stream_response_body(self, stream_response) -> List[str]:
def collect_content() -> str:
collected_content = []
for event in stream_response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
return collected_content
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, collect_content)
def _get_usage(self, response) -> dict[str, int]:
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))

View file

@ -48,13 +48,17 @@ def build_api_arequest(
request_timeout,
form,
resources,
base_address,
_,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
if base_address is None:
base_address = dashscope.base_http_api_url
if not base_address.endswith("/"):
http_url = base_address + "/"
else:
http_url = dashscope.base_http_api_url
http_url = base_address
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"

View file

@ -81,7 +81,9 @@ class GeneralAPIRequestor(APIRequestor):
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
if stream and (
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)

View file

@ -37,7 +37,11 @@ def register_provider(keys):
def create_llm_instance(config: LLMConfig) -> BaseLLM:
"""get the default llm provider"""
return LLM_REGISTRY.get_provider(config.api_type)(config)
llm = LLM_REGISTRY.get_provider(config.api_type)(config)
if llm.use_system_prompt and not config.use_system_prompt:
# for models like o1-series, default openai provider.use_system_prompt is True, but it should be False for o1-*
llm.use_system_prompt = config.use_system_prompt
return llm
# Registry instance

View file

@ -51,9 +51,17 @@ class OllamaLLM(BaseLLM):
return json.loads(chunk)
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
params=self._const_kwargs(messages),
request_timeout=self.get_timeout(timeout),
)
@ -66,9 +74,17 @@ class OllamaLLM(BaseLLM):
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
headers = (
None
if not self.config.api_key or self.config.api_key == "sk-"
else {
"Authorization": f"Bearer {self.config.api_key}",
}
)
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
headers=headers,
stream=True,
params=self._const_kwargs(messages, stream=True),
request_timeout=self.get_timeout(timeout),

View file

@ -37,7 +37,6 @@ from metagpt.utils.token_counter import (
count_input_tokens,
count_output_tokens,
get_max_completion_tokens,
get_openrouter_tokens,
)
@ -92,6 +91,7 @@ class OpenAILLM(BaseLLM):
)
usage = None
collected_messages = []
has_finished = False
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
finish_reason = (
@ -99,8 +99,13 @@ class OpenAILLM(BaseLLM):
)
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
if has_finished:
# for oneapi, there has a usage chunk after finish_reason not none chunk
if chunk_has_usage:
usage = CompletionUsage(**chunk.usage)
if finish_reason:
if hasattr(chunk, "usage") and chunk.usage is not None:
if chunk_has_usage:
# Some services have usage as an attribute of the chunk, such as Fireworks
if isinstance(chunk.usage, CompletionUsage):
usage = chunk.usage
@ -109,9 +114,10 @@ class OpenAILLM(BaseLLM):
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)
elif "openrouter.ai" in self.config.base_url:
elif "openrouter.ai" in self.config.base_url and chunk_has_usage:
# due to it get token cost from api
usage = await get_openrouter_tokens(chunk)
usage = chunk.usage
has_finished = True
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
@ -132,6 +138,10 @@ class OpenAILLM(BaseLLM):
"model": self.model,
"timeout": self.get_timeout(timeout),
}
if "o1-" in self.model:
# compatible to openai o1-series
kwargs["temperature"] = 1
kwargs.pop("max_tokens")
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs

View file

@ -50,6 +50,9 @@ class QianFanLLM(BaseLLM):
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
if self.config.base_url:
os.environ.setdefault("QIANFAN_BASE_URL", self.config.base_url)
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
@ -103,13 +106,13 @@ class QianFanLLM(BaseLLM):
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
def completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
self._update_costs(resp.body.get("usage", {}))
return resp.body
@ -117,7 +120,7 @@ class QianFanLLM(BaseLLM):
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout)
collected_content = []
usage = {}
async for chunk in resp:

View file

@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
@ -28,6 +29,7 @@ from llama_index.core.schema import (
TransformComponent,
)
from metagpt.config2 import config
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
@ -36,6 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -44,6 +47,9 @@ from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.common import import_class
@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
file_extractor = cls._get_file_extractor()
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
cls._fix_document_metadata(documents)
transformations = transformations or cls._default_transformations()
@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor() -> dict[str:BaseReader]:
"""
Get the file extractor.
Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index.
Returns:
dict[file_type: BaseReader]
"""
file_extractor: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor[".pdf"] = pdf_parser
return file_extractor

View file

@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
@ -17,6 +18,7 @@ from metagpt.rag.schema import (
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)
@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
MilvusIndexConfig: self._create_milvus
}
super().__init__(creators)
@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

View file

@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
@ -20,6 +21,7 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
@ -27,6 +29,7 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -56,6 +59,7 @@ class RetrieverFactory(ConfigBasedFactory):
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
}
super().__init__(creators)
@ -76,6 +80,11 @@ class RetrieverFactory(ConfigBasedFactory):
return index.as_retriever()
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)
return MilvusRetriever(**config.model_dump())
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._build_faiss_index(config, **kwargs)
@ -128,6 +137,12 @@ class RetrieverFactory(ConfigBasedFactory):
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

View file

@ -0,0 +1,3 @@
from metagpt.rag.parsers.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -0,0 +1,139 @@
import asyncio
from fileinput import FileInput
from pathlib import Path
from typing import List, Optional, Union
from llama_index.core import Document
from llama_index.core.async_utils import run_jobs
from llama_index.core.readers.base import BaseReader
from metagpt.logs import logger
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.omniparse_client import OmniParseClient
class OmniParse(BaseReader):
"""OmniParse"""
def __init__(
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: OmniParse Base URL for the API.
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
"""
self.parse_options = parse_options or OmniParseOptions()
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
@property
def parse_type(self):
return self.parse_options.parse_type
@property
def result_type(self):
return self.parse_options.result_type
@parse_type.setter
def parse_type(self, parse_type: Union[str, OmniParseType]):
if isinstance(parse_type, str):
parse_type = OmniParseType(parse_type)
self.parse_options.parse_type = parse_type
@result_type.setter
def result_type(self, result_type: Union[str, ParseResultType]):
if isinstance(result_type, str):
result_type = ParseResultType(result_type)
self.parse_options.result_type = result_type
async def _aload_data(
self,
file_path: Union[str, bytes, Path],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Returns:
List[Document]
"""
try:
if self.parse_type == OmniParseType.PDF:
# pdf parse
parsed_result = await self.omniparse_client.parse_pdf(file_path)
else:
# other parse use omniparse_client.parse_document
# For compatible byte data, additional filename is required
extra_info = extra_info or {}
filename = extra_info.get("filename")
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
# Get the specified structured data based on result_type
content = getattr(parsed_result, self.result_type)
docs = [
Document(
text=content,
metadata=extra_info or {},
)
]
except Exception as e:
logger.error(f"OMNI Parse Error: {e}")
docs = []
return docs
async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls _aload_data for processing.
Returns:
List[Document]
"""
docs = []
if isinstance(file_path, (str, bytes, Path)):
# Processing single file
docs = await self._aload_data(file_path, extra_info)
elif isinstance(file_path, list):
# Concurrently process multiple files
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
docs = [doc for docs in doc_ret_list for doc in docs]
return docs
def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""
Load data from the input file_path.
Args:
file_path: File path or file byte data.
extra_info: Optional dictionary containing additional information.
Notes:
This method ultimately calls aload_data for processing.
Returns:
List[Document]
"""
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
return asyncio.run(self.aload_data(file_path, extra_info))

View file

@ -0,0 +1,17 @@
"""Milvus retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class MilvusRetriever(VectorIndexRetriever):
"""Milvus retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Milvus automatically saves, so there is no need to implement."""

View file

@ -1,14 +1,14 @@
"""RAG schemas."""
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, Union
from typing import Any, ClassVar, List, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
_no_embedding: bool = PrivateAttr(default=True)
class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)
return self
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)
class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
@ -214,3 +254,51 @@ class ObjectNode(TextNode):
)
return metadata.model_dump()
class OmniParseType(str, Enum):
"""OmniParseType"""
PDF = "PDF"
DOCUMENT = "DOCUMENT"
class ParseResultType(str, Enum):
"""The result type for the parser."""
TXT = "text"
MD = "markdown"
JSON = "json"
class OmniParseOptions(BaseModel):
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
num_workers: int = Field(
default=5,
gt=0,
lt=10,
description="Number of concurrent requests for multiple files",
)
class OminParseImage(BaseModel):
image: str = Field(default="", description="image str bytes")
image_name: str = Field(default="", description="image name")
image_info: Optional[dict] = Field(default={}, description="image info")
class OmniParsedResult(BaseModel):
markdown: str = Field(default="", description="markdown text")
text: str = Field(default="", description="plain text")
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
metadata: Optional[dict] = Field(default={}, description="metadata")
@model_validator(mode="before")
def set_markdown(cls, values):
if not values.get("markdown"):
values["markdown"] = values.get("text")
return values

View file

@ -6,6 +6,7 @@
@File : architect.py
"""
from metagpt.actions import WritePRD
from metagpt.actions.design_api import WriteDesign
from metagpt.roles.role import Role

View file

@ -80,19 +80,17 @@ class InvoiceOCRAssistant(Role):
raise Exception("Invoice file not uploaded")
resp = await todo.run(file_path)
actions = list(self.actions)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self.set_actions([GenerateTable, ReplyQuestion])
actions.extend([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self.set_actions([GenerateTable])
self.set_todo(None)
actions.append(GenerateTable)
self.set_actions(actions)
self.rc.max_react_loop = len(self.actions)
content = INVOICE_OCR_SUCCESS
resp = OCRResults(ocr_result=json.dumps(resp))
msg = Message(content=content, instruct_content=resp)
self.rc.memory.add(msg)
return await super().react()
elif isinstance(todo, GenerateTable):
ocr_results: OCRResults = msg.instruct_content
resp = await todo.run(json.loads(ocr_results.ocr_result), self.filename)

View file

@ -7,6 +7,7 @@
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
"""
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.roles.role import Role, RoleReactMode

View file

@ -6,6 +6,7 @@
@File : project_manager.py
"""
from metagpt.actions import WriteTasks
from metagpt.actions.design_api import WriteDesign
from metagpt.roles.role import Role

View file

@ -15,6 +15,7 @@
of SummarizeCode.
"""
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import MESSAGE_ROUTE_TO_NONE

View file

@ -58,7 +58,9 @@ class Researcher(Role):
)
elif isinstance(todo, WebBrowseAndSummarize):
links = instruct_content.links
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
todos = (
todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items() if url
)
if self.enable_concurrency:
summaries = await asyncio.gather(*todos)
else:

View file

@ -170,7 +170,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self.llm.cost_manager = self.context.cost_manager
self._watch(kwargs.pop("watch", [UserRequirement]))
if not self.rc.watch:
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
self.recovered = True
@ -421,8 +422,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
news = []
if self.recovered:
news = [self.latest_observed_msg] if self.latest_observed_msg else []
if self.recovered and self.latest_observed_msg:
news = self.rc.memory.find_news(observed=[self.latest_observed_msg], k=10)
if not news:
news = self.rc.msg_buffer.pop_all()
# Store the read messages in your own memory to prevent duplicate processing.

View file

@ -1,87 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/13 12:23
@Author : femto Zheng
@File : sk_agent.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message filtering.
"""
from typing import Any, Callable, Union
from pydantic import Field
from semantic_kernel import Kernel
from semantic_kernel.planning import SequentialPlanner
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
from metagpt.actions import UserRequirement
from metagpt.actions.execute_task import ExecuteTask
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.make_sk_kernel import make_sk_kernel
class SkAgent(Role):
"""
Represents an SkAgent implemented using semantic kernel
Attributes:
name (str): Name of the SkAgent.
profile (str): Role profile, default is 'sk_agent'.
goal (str): Goal of the SkAgent.
constraints (str): Constraints for the SkAgent.
"""
name: str = "Sunshine"
profile: str = "sk_agent"
goal: str = "Execute task based on passed in task description"
constraints: str = ""
plan: Plan = Field(default=None, exclude=True)
planner_cls: Any = None
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
kernel: Kernel = Field(default_factory=Kernel)
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
import_skill: Callable = Field(default=None, exclude=True)
def __init__(self, **data: Any) -> None:
"""Initializes the Engineer role with given attributes."""
super().__init__(**data)
self.set_actions([ExecuteTask()])
self._watch([UserRequirement])
self.kernel = make_sk_kernel()
# how funny the interface is inconsistent
if self.planner_cls == BasicPlanner or self.planner_cls is None:
self.planner = BasicPlanner()
elif self.planner_cls in [SequentialPlanner, ActionPlanner]:
self.planner = self.planner_cls(self.kernel)
else:
raise Exception(f"Unsupported planner of type {self.planner_cls}")
self.import_semantic_skill_from_directory = self.kernel.import_semantic_skill_from_directory
self.import_skill = self.kernel.import_skill
async def _think(self) -> None:
self._set_state(0)
# how funny the interface is inconsistent
if isinstance(self.planner, BasicPlanner):
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
logger.info(self.plan.generated_plan)
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
async def _act(self) -> Message:
# how funny the interface is inconsistent
result = None
if isinstance(self.planner, BasicPlanner):
result = await self.planner.execute_plan_async(self.plan, self.kernel)
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
result = (await self.plan.invoke_async()).result
logger.info(result)
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
self.rc.memory.add(msg)
return msg

View file

@ -4,6 +4,7 @@
import asyncio
from pathlib import Path
import agentops
import typer
from metagpt.const import CONFIG_ROOT
@ -38,6 +39,9 @@ def generate_repo(
)
from metagpt.team import Team
if config.agentops_api_key != "":
agentops.init(config.agentops_api_key, tags=["software_company"])
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
ctx = Context(config=config)
@ -68,6 +72,9 @@ def generate_repo(
company.run_project(idea)
asyncio.run(company.run(n_round=n_round))
if config.agentops_api_key != "":
agentops.end_session("Success")
return ctx.repo

View file

@ -126,6 +126,9 @@ class Team(BaseModel):
self.run_project(idea=idea, send_to=send_to)
while n_round > 0:
if self.env.is_idle:
logger.debug("All roles are idle.")
break
n_round -= 1
self._check_balance()
await self.env.run()

View file

@ -6,37 +6,15 @@
@File : search_engine.py
"""
import importlib
from typing import Callable, Coroutine, Literal, Optional, Union, overload
from typing import Annotated, Callable, Coroutine, Literal, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from semantic_kernel.skill_definition import sk_function
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
class SkSearchEngine:
"""A search engine class for executing searches.
Attributes:
search_engine: The search engine instance used for executing searches.
"""
def __init__(self, **kwargs):
self.search_engine = SearchEngine(**kwargs)
@sk_function(
description="searches results from Google. Useful when you need to find short "
"and succinct answers about a specific topic. Input should be a search query.",
name="searchAsync",
input_description="search",
)
async def run(self, query: str) -> str:
result = await self.search_engine.run(query)
return result
class SearchEngine(BaseModel):
"""A model for configuring and executing searches with different search engines.
@ -51,7 +29,9 @@ class SearchEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
run_func: Annotated[
Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]], Field(exclude=True)
] = None
api_key: Optional[str] = None
proxy: Optional[str] = None

View file

@ -87,8 +87,11 @@ class SerpAPIWrapper(BaseModel):
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
if res["error"] == "Google hasn't returned any results for this query.":
toret = "No good search result found"
else:
raise ValueError(f"Got error from SerpAPI: {res['error']}")
elif "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]

View file

@ -3,9 +3,9 @@
from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, Optional, Union, overload
from typing import Annotated, Any, Callable, Coroutine, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.tools import WebBrowserEngineType
@ -29,7 +29,10 @@ class WebBrowserEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
run_func: Annotated[
Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]],
Field(exclude=True),
] = None
proxy: Optional[str] = None
@model_validator(mode="after")

View file

@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
def __init__(self, name: str | Path, **kwargs):
super().__init__(name=str(name), **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
@ -112,8 +112,28 @@ class DiGraphRepository(GraphRepository):
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self.load_json(data)
def load_json(self, val: str):
"""
Loads a JSON-encoded string representing a graph structure and updates
the internal repository (_repo) with the parsed graph.
Args:
val (str): A JSON-encoded string representing a graph structure.
Returns:
self: Returns the instance of the class with the updated _repo attribute.
Raises:
TypeError: If val is not a valid JSON string or cannot be parsed into
a valid graph structure.
"""
if not val:
return self
m = json.loads(val)
self._repo = networkx.node_link_graph(m)
return self
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
@ -126,9 +146,7 @@ class DiGraphRepository(GraphRepository):
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
graph = DiGraphRepository(name=name, root=root)
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
if pathname.exists():
await graph.load(pathname=pathname)
return graph

View file

@ -1,32 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/13 12:29
@Author : femto Zheng
@File : make_sk_kernel.py
"""
import semantic_kernel as sk
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import (
AzureChatCompletion,
)
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import (
OpenAIChatCompletion,
)
from metagpt.config2 import config
def make_sk_kernel():
kernel = sk.Kernel()
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
)
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),
)
return kernel

View file

@ -0,0 +1,239 @@
import mimetypes
import os
from pathlib import Path
from typing import Union
import httpx
from metagpt.rag.schema import OmniParsedResult
from metagpt.utils.common import aread_bin
class OmniParseClient:
"""
OmniParse Server Client
This client interacts with the OmniParse server to parse different types of media, documents.
OmniParse API Documentation: https://docs.cognitivelab.in/api
Attributes:
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
"""
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
"""
Args:
api_key: Default None, can be used for authentication later.
base_url: Base URL for the API.
max_timeout: Maximum request timeout in seconds.
"""
self.api_key = api_key
self.base_url = base_url
self.max_timeout = max_timeout
self.parse_media_endpoint = "/parse_media"
self.parse_website_endpoint = "/parse_website"
self.parse_document_endpoint = "/parse_document"
async def _request_parse(
self,
endpoint: str,
method: str = "POST",
files: dict = None,
params: dict = None,
data: dict = None,
json: dict = None,
headers: dict = None,
**kwargs,
) -> dict:
"""
Request OmniParse API to parse a document.
Args:
endpoint (str): API endpoint.
method (str, optional): HTTP method to use. Default is "POST".
files (dict, optional): Files to include in the request.
params (dict, optional): Query string parameters.
data (dict, optional): Form data to include in the request body.
json (dict, optional): JSON data to include in the request body.
headers (dict, optional): HTTP headers to include in the request.
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
Returns:
dict: JSON response data.
"""
url = f"{self.base_url}{endpoint}"
method = method.upper()
headers = headers or {}
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
headers.update(**_headers)
async with httpx.AsyncClient() as client:
response = await client.request(
url=url,
method=method,
files=files,
params=params,
json=json,
data=data,
headers=headers,
timeout=self.max_timeout,
**kwargs,
)
response.raise_for_status()
return response.json()
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
"""
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the document parsing.
"""
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
"""
Parse pdf document.
Args:
file_input: File path or file byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
OmniParsedResult: The result of the pdf parsing.
"""
self.verify_file_ext(file_input, {".pdf"})
# parse_pdf supports parsing by accepting only the byte data of the file.
file_info = await self.get_file_info(file_input, only_bytes=True)
endpoint = f"{self.parse_document_endpoint}/pdf"
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
data = OmniParsedResult(**resp)
return data
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
"""
Parse audio-type data (supports ".mp3", ".wav", ".aac").
Args:
file_input: File path or file byte data.
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
Raises:
ValueError: If the file extension is not allowed.
Returns:
dict: JSON response data.
"""
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
file_info = await self.get_file_info(file_input, bytes_filename)
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
@staticmethod
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
"""
Verify the file extension.
Args:
file_input: File path or file byte data.
allowed_file_extensions: Set of allowed file extensions.
bytes_filename: Filename to use for verification when `file_input` is byte data.
Raises:
ValueError: If the file extension is not allowed.
Returns:
"""
verify_file_path = None
if isinstance(file_input, (str, Path)):
verify_file_path = str(file_input)
elif isinstance(file_input, bytes) and bytes_filename:
verify_file_path = bytes_filename
if not verify_file_path:
# Do not verify if only byte data is provided
return
file_ext = os.path.splitext(verify_file_path)[1].lower()
if file_ext not in allowed_file_extensions:
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
@staticmethod
async def get_file_info(
file_input: Union[str, bytes, Path],
bytes_filename: str = None,
only_bytes: bool = False,
) -> Union[bytes, tuple]:
"""
Get file information.
Args:
file_input: File path or file byte data.
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
Raises:
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
Notes:
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
the MIME type of the file must be specified when uploading.
Returns: [bytes, tuple]
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
"""
if isinstance(file_input, (str, Path)):
filename = os.path.basename(str(file_input))
file_bytes = await aread_bin(file_input)
if only_bytes:
return file_bytes
mime_type = mimetypes.guess_type(file_input)[0]
return filename, file_bytes, mime_type
elif isinstance(file_input, bytes):
if only_bytes:
return file_input
if not bytes_filename:
raise ValueError("bytes_filename must be set when passing bytes")
mime_type = mimetypes.guess_type(bytes_filename)[0]
return bytes_filename, file_input, mime_type
else:
raise ValueError("file_input must be a string (file path) or bytes.")

View file

@ -10,7 +10,7 @@ from __future__ import annotations
import traceback
from datetime import timedelta
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
import redis.asyncio as aioredis
from metagpt.configs.redis_config import RedisConfig
from metagpt.logs import logger

View file

@ -11,8 +11,10 @@ from multiprocessing import Pipe
class StreamPipe:
parent_conn, child_conn = Pipe()
finish: bool = False
def __init__(self, name=None):
self.name = name
self.parent_conn, self.child_conn = Pipe()
self.finish: bool = False
format_data = {
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",

View file

@ -41,11 +41,19 @@ TOKEN_COSTS = {
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-2024-08-06": {"prompt": 0.0025, "completion": 0.01},
"o1-preview": {"prompt": 0.015, "completion": 0.06},
"o1-preview-2024-09-12": {"prompt": 0.015, "completion": 0.06},
"o1-mini": {"prompt": 0.003, "completion": 0.012},
"o1-mini-2024-09-12": {"prompt": 0.003, "completion": 0.012},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
"gemini-1.5-flash": {"prompt": 0.000075, "completion": 0.0003},
"gemini-1.5-pro": {"prompt": 0.0035, "completion": 0.0105},
"gemini-1.0-pro": {"prompt": 0.0005, "completion": 0.0015},
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
@ -69,15 +77,20 @@ TOKEN_COSTS = {
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"openai/o1-preview": {"prompt": 0.015, "completion": 0.06},
"openai/o1-mini": {"prompt": 0.003, "completion": 0.012},
"anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075},
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
"google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
# For ark model https://www.volcengine.com/docs/82379/1099320
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
"doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00014},
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0013},
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
}
@ -138,8 +151,17 @@ QIANFAN_ENDPOINT_TOKEN_COSTS = {
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
Some new model published by Alibaba will be prioritized to be released on the Model Studio instead of the Dashscope.
Token price on Model Studio shows on https://help.aliyun.com/zh/model-studio/getting-started/models#ced16cb6cdfsy
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen2.5-72b-instruct": {"prompt": 0.00057, "completion": 0.0017}, # per 1k tokens
"qwen2.5-32b-instruct": {"prompt": 0.0005, "completion": 0.001},
"qwen2.5-14b-instruct": {"prompt": 0.00029, "completion": 0.00086},
"qwen2.5-7b-instruct": {"prompt": 0.00014, "completion": 0.00029},
"qwen2.5-3b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2.5-1.5b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2.5-0.5b-instruct": {"prompt": 0.0, "completion": 0.0},
"qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428},
"qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001},
"qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286},
@ -190,16 +212,24 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
# https://console.volcengine.com/ark/region:ark+cn-beijing/model
DOUBAO_TOKEN_COSTS = {
"doubao-lite": {"prompt": 0.0003, "completion": 0.0006},
"doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010},
"doubao-pro": {"prompt": 0.0008, "completion": 0.0020},
"doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090},
"doubao-lite": {"prompt": 0.000043, "completion": 0.000086},
"doubao-lite-128k": {"prompt": 0.00011, "completion": 0.00014},
"doubao-pro": {"prompt": 0.00011, "completion": 0.00029},
"doubao-pro-128k": {"prompt": 0.00071, "completion": 0.0013},
"doubao-pro-256k": {"prompt": 0.00071, "completion": 0.0013},
}
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
"gpt-4o-2024-05-13": 128000,
"o1-preview": 128000,
"o1-preview-2024-09-12": 128000,
"o1-mini": 128000,
"o1-mini-2024-09-12": 128000,
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini-2024-07-18": 128000,
"gpt-4o-mini": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
@ -222,7 +252,9 @@ TOKEN_MAX = {
"text-embedding-ada-002": 8192,
"glm-3-turbo": 128000,
"glm-4": 128000,
"gemini-pro": 32768,
"gemini-1.5-flash": 1000000,
"gemini-1.5-pro": 2000000,
"gemini-1.0-pro": 32000,
"moonshot-v1-8k": 8192,
"moonshot-v1-32k": 32768,
"moonshot-v1-128k": 128000,
@ -246,6 +278,11 @@ TOKEN_MAX = {
"llama3-70b-8192": 8192,
"openai/gpt-3.5-turbo-0125": 16385,
"openai/gpt-4-turbo-preview": 128000,
"openai/o1-preview": 128000,
"openai/o1-mini": 128000,
"anthropic/claude-3-opus": 200000,
"anthropic/claude-3.5-sonnet": 200000,
"google/gemini-pro-1.5": 4000000,
"deepseek-chat": 32768,
"deepseek-coder": 16385,
"doubao-lite-4k-240515": 4000,
@ -255,6 +292,13 @@ TOKEN_MAX = {
"doubao-pro-32k-240515": 32000,
"doubao-pro-128k-240515": 128000,
# Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20
"qwen2.5-72b-instruct": 131072,
"qwen2.5-32b-instruct": 131072,
"qwen2.5-14b-instruct": 131072,
"qwen2.5-7b-instruct": 131072,
"qwen2.5-3b-instruct": 32768,
"qwen2.5-1.5b-instruct": 32768,
"qwen2.5-0.5b-instruct": 32768,
"qwen2-57b-a14b-instruct": 32768,
"qwen2-72b-instruct": 131072,
"qwen2-7b-instruct": 32768,
@ -354,13 +398,19 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"claude-3-5-sonnet-20240620"
"gpt-4o-mini-2024-07-18",
"o1-preview",
"o1-preview-2024-09-12",
"o1-mini",
"o1-mini-2024-09-12",
}:
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
tokens_per_name = 1