mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 19:06:23 +02:00
Merge branch 'main' into main
This commit is contained in:
commit
d99054ab5e
98 changed files with 1697 additions and 496 deletions
|
|
@ -243,12 +243,19 @@ class ActionNode:
|
|||
"""基于pydantic v2的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
all_fields = set(mapping.keys())
|
||||
required_fields = set()
|
||||
for k, v in mapping.items():
|
||||
type_v, field_info = v
|
||||
if ActionNode.is_optional_type(type_v):
|
||||
continue
|
||||
required_fields.add(k)
|
||||
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing fields: {missing_fields}")
|
||||
|
||||
unrecognized_fields = set(values.keys()) - required_fields
|
||||
unrecognized_fields = set(values.keys()) - all_fields
|
||||
if unrecognized_fields:
|
||||
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
|
||||
return values
|
||||
|
|
@ -850,3 +857,12 @@ class ActionNode:
|
|||
root_node.add_child(child_node)
|
||||
|
||||
return root_node
|
||||
|
||||
@staticmethod
|
||||
def is_optional_type(tp) -> bool:
|
||||
"""Return True if `tp` is `typing.Optional[...]`"""
|
||||
if typing.get_origin(tp) is Union:
|
||||
args = typing.get_args(tp)
|
||||
non_none_types = [arg for arg in args if arg is not type(None)]
|
||||
return len(non_none_types) == 1 and len(args) == 2
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : design_api_an.py
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
|
@ -45,9 +45,10 @@ REFINED_FILE_LIST = ActionNode(
|
|||
example=["main.py", "game.py", "new_feature.py"],
|
||||
)
|
||||
|
||||
# optional,because low success reproduction of class diagram in non py project.
|
||||
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
||||
key="Data structures and interfaces",
|
||||
expected_type=str,
|
||||
expected_type=Optional[str],
|
||||
instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type"
|
||||
" annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. "
|
||||
"The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.",
|
||||
|
|
@ -66,7 +67,7 @@ REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
|||
|
||||
PROGRAM_CALL_FLOW = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
expected_type=Optional[str],
|
||||
instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE "
|
||||
"accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.",
|
||||
example=MMC2,
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@
|
|||
@Author : alexanderwu
|
||||
@File : project_management_an.py
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
REQUIRED_PACKAGES = ActionNode(
|
||||
key="Required packages",
|
||||
expected_type=List[str],
|
||||
instruction="Provide required packages in requirements.txt format.",
|
||||
expected_type=Optional[List[str]],
|
||||
instruction="Provide required third-party packages in requirements.txt format.",
|
||||
example=["flask==1.1.2", "bcrypt==3.2.0"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -161,6 +161,8 @@ class CollectLinks(Action):
|
|||
"""
|
||||
max_results = max(num_results * 2, 6)
|
||||
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
|
||||
if len(results) == 0:
|
||||
return []
|
||||
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
|
||||
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
|
||||
logger.debug(prompt)
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ Language: Please use the same language as the user requirement, but the title an
|
|||
end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Tasks
|
||||
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
{"Required packages": ["无需第三方包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
|
||||
|
||||
## Code Files
|
||||
----- index.html
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
|
|||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.file_parser_config import OmniParseConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel):
|
|||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# omniparse
|
||||
omniparse: OmniParseConfig = OmniParseConfig()
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
|
|
@ -69,6 +73,7 @@ class Config(CLIParams, YamlModel):
|
|||
workspace: WorkspaceConfig = WorkspaceConfig()
|
||||
enable_longterm_memory: bool = False
|
||||
code_review_k_times: int = 2
|
||||
agentops_api_key: str = ""
|
||||
|
||||
# Will be removed in the future
|
||||
metagpt_tti_url: str = ""
|
||||
|
|
|
|||
6
metagpt/configs/file_parser_config.py
Normal file
6
metagpt/configs/file_parser_config.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class OmniParseConfig(YamlModel):
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
|
|
@ -33,7 +33,7 @@ class LLMType(Enum):
|
|||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark"
|
||||
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -90,6 +90,9 @@ class LLMConfig(YamlModel):
|
|||
# Cost Control
|
||||
calc_usage: bool = True
|
||||
|
||||
# For Messages Control
|
||||
use_system_prompt: bool = True
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/1 11:59
|
||||
@Author : alexanderwu
|
||||
@File : const.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for
|
||||
common properties in the Message.
|
||||
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
|
||||
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
|
|
|||
99
metagpt/document_store/milvus_store.py
Normal file
99
metagpt/document_store/milvus_store.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusConnection:
|
||||
"""
|
||||
Args:
|
||||
uri: milvus url
|
||||
token: milvus token
|
||||
"""
|
||||
|
||||
uri: str = None
|
||||
token: str = None
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
def __init__(self, connect: MilvusConnection):
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError:
|
||||
raise Exception("Please install pymilvus first.")
|
||||
if not connect.uri:
|
||||
raise Exception("please check MilvusConnection, uri must be set.")
|
||||
self.client = MilvusClient(uri=connect.uri, token=connect.token)
|
||||
|
||||
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
|
||||
from pymilvus import DataType
|
||||
|
||||
if self.client.has_collection(collection_name=collection_name):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=False,
|
||||
)
|
||||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
enable_dynamic_schema=enable_dynamic_schema,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_filter(key, value) -> str:
|
||||
if isinstance(value, str):
|
||||
filter_expression = f'{key} == "{value}"'
|
||||
else:
|
||||
if isinstance(value, list):
|
||||
filter_expression = f"{key} in {value}"
|
||||
else:
|
||||
filter_expression = f"{key} == {value}"
|
||||
|
||||
return filter_expression
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: List[float],
|
||||
filter: Dict = None,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
|
||||
print(filter_expression)
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=collection_name,
|
||||
data=[query],
|
||||
filter=filter_expression,
|
||||
limit=limit,
|
||||
output_fields=output_fields,
|
||||
)[0]
|
||||
|
||||
return res
|
||||
|
||||
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
|
||||
data = dict()
|
||||
|
||||
for i, id in enumerate(_ids):
|
||||
data["id"] = id
|
||||
data["vector"] = vector[i]
|
||||
data["metadata"] = metadata[i]
|
||||
|
||||
self.client.upsert(collection_name=collection_name, data=data)
|
||||
|
||||
def delete(self, collection_name: str, _ids: List[str]):
|
||||
self.client.delete(collection_name=collection_name, ids=_ids)
|
||||
|
||||
def write(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
@ -266,7 +266,7 @@ class STRole(Role):
|
|||
# We will order our percept based on the distance, with the closest ones
|
||||
# getting priorities.
|
||||
percept_events_list = []
|
||||
# First, we put all events that are occuring in the nearby tiles into the
|
||||
# First, we put all events that are occurring in the nearby tiles into the
|
||||
# percept_events_list
|
||||
for tile in nearby_tiles:
|
||||
tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class Memory(BaseModel):
|
|||
return self.storage[-k:]
|
||||
|
||||
def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
"""find news (previously unseen messages) from the the most recent k memories, from all memories when k=0"""
|
||||
"""find news (previously unseen messages) from the most recent k memories, from all memories when k=0"""
|
||||
already_observed = self.get(k)
|
||||
news: list[Message] = []
|
||||
for i in observed:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,33 @@
|
|||
from openai import AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Provider for volcengine.
|
||||
See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model
|
||||
|
||||
config2.yaml example:
|
||||
```yaml
|
||||
llm:
|
||||
base_url: "https://ark.cn-beijing.volces.com/api/v3"
|
||||
api_type: "ark"
|
||||
endpoint: "ep-2024080514****-d****"
|
||||
api_key: "d47****b-****-****-****-d6e****0fd77"
|
||||
pricing_plan: "doubao-lite"
|
||||
```
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from volcenginesdkarkruntime import AsyncArk
|
||||
from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper
|
||||
from volcenginesdkarkruntime._streaming import AsyncStream
|
||||
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS
|
||||
|
||||
|
||||
@register_provider(LLMType.ARK)
|
||||
|
|
@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM):
|
|||
见:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
aclient: Optional[AsyncArk] = None
|
||||
|
||||
def _init_client(self):
|
||||
"""SDK: https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = (
|
||||
self.config.endpoint or self.config.model
|
||||
) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncArk(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kvs = {
|
||||
"ak": self.config.access_key,
|
||||
"sk": self.config.secret_key,
|
||||
"api_key": self.config.api_key,
|
||||
"base_url": self.config.base_url,
|
||||
}
|
||||
kwargs = {k: v for k, v in kvs.items() if v}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
if proxy_params := self._get_proxy_params():
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs:
|
||||
self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS)
|
||||
if model in self.cost_manager.token_costs:
|
||||
self.pricing_plan = model
|
||||
if self.pricing_plan in self.cost_manager.token_costs:
|
||||
super()._update_costs(usage, self.pricing_plan, local_calc_usage)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
|
||||
stream=True,
|
||||
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
|
||||
extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage
|
||||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
|
|
@ -30,7 +85,7 @@ class ArkLLM(OpenAILLM):
|
|||
collected_messages.append(chunk_message)
|
||||
if chunk.usage:
|
||||
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
usage = chunk.usage
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = {
|
|||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Literal
|
||||
from functools import partial
|
||||
from typing import List, Literal
|
||||
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
|
|
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
|
|||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
||||
|
|
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
|
|||
]
|
||||
logger.info("\n" + "\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model(self, request_body: str) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response
|
||||
|
|
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
response_body = await self.invoke_model(request_body)
|
||||
return response_body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
|
|
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
|
|||
return full_text
|
||||
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
for event in response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
|
||||
stream_response = await self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = await self._get_stream_response_body(stream_response)
|
||||
log_llm_stream("\n")
|
||||
full_text = ("".join(collected_content)).lstrip()
|
||||
return full_text
|
||||
|
|
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
|
|||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
async def _get_stream_response_body(self, stream_response) -> List[str]:
|
||||
def collect_content() -> str:
|
||||
collected_content = []
|
||||
for event in stream_response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
return collected_content
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, collect_content)
|
||||
|
||||
def _get_usage(self, response) -> dict[str, int]:
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
|
|
|
|||
|
|
@ -48,13 +48,17 @@ def build_api_arequest(
|
|||
request_timeout,
|
||||
form,
|
||||
resources,
|
||||
base_address,
|
||||
_,
|
||||
) = _get_protocol_params(kwargs)
|
||||
task_id = kwargs.pop("task_id", None)
|
||||
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
|
||||
if not dashscope.base_http_api_url.endswith("/"):
|
||||
http_url = dashscope.base_http_api_url + "/"
|
||||
if base_address is None:
|
||||
base_address = dashscope.base_http_api_url
|
||||
if not base_address.endswith("/"):
|
||||
http_url = base_address + "/"
|
||||
else:
|
||||
http_url = dashscope.base_http_api_url
|
||||
http_url = base_address
|
||||
|
||||
if is_service:
|
||||
http_url = http_url + SERVICE_API_PATH + "/"
|
||||
|
|
|
|||
|
|
@ -81,7 +81,9 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
if stream and (
|
||||
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
|
||||
):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,11 @@ def register_provider(keys):
|
|||
|
||||
def create_llm_instance(config: LLMConfig) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
return LLM_REGISTRY.get_provider(config.api_type)(config)
|
||||
llm = LLM_REGISTRY.get_provider(config.api_type)(config)
|
||||
if llm.use_system_prompt and not config.use_system_prompt:
|
||||
# for models like o1-series, default openai provider.use_system_prompt is True, but it should be False for o1-*
|
||||
llm.use_system_prompt = config.use_system_prompt
|
||||
return llm
|
||||
|
||||
|
||||
# Registry instance
|
||||
|
|
|
|||
|
|
@ -51,9 +51,17 @@ class OllamaLLM(BaseLLM):
|
|||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
|
|
@ -66,9 +74,17 @@ class OllamaLLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
params=self._const_kwargs(messages, stream=True),
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from metagpt.utils.token_counter import (
|
|||
count_input_tokens,
|
||||
count_output_tokens,
|
||||
get_max_completion_tokens,
|
||||
get_openrouter_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -92,6 +91,7 @@ class OpenAILLM(BaseLLM):
|
|||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
has_finished = False
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
finish_reason = (
|
||||
|
|
@ -99,8 +99,13 @@ class OpenAILLM(BaseLLM):
|
|||
)
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
|
||||
if has_finished:
|
||||
# for oneapi, there has a usage chunk after finish_reason not none chunk
|
||||
if chunk_has_usage:
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
if chunk_has_usage:
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
if isinstance(chunk.usage, CompletionUsage):
|
||||
usage = chunk.usage
|
||||
|
|
@ -109,9 +114,10 @@ class OpenAILLM(BaseLLM):
|
|||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
elif "openrouter.ai" in self.config.base_url:
|
||||
elif "openrouter.ai" in self.config.base_url and chunk_has_usage:
|
||||
# due to it get token cost from api
|
||||
usage = await get_openrouter_tokens(chunk)
|
||||
usage = chunk.usage
|
||||
has_finished = True
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
|
|
@ -132,6 +138,10 @@ class OpenAILLM(BaseLLM):
|
|||
"model": self.model,
|
||||
"timeout": self.get_timeout(timeout),
|
||||
}
|
||||
if "o1-" in self.model:
|
||||
# compatible to openai o1-series
|
||||
kwargs["temperature"] = 1
|
||||
kwargs.pop("max_tokens")
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ class QianFanLLM(BaseLLM):
|
|||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
if self.config.base_url:
|
||||
os.environ.setdefault("QIANFAN_BASE_URL", self.config.base_url)
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
|
|
@ -103,13 +106,13 @@ class QianFanLLM(BaseLLM):
|
|||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
def completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False), request_timeout=timeout)
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
|
|
@ -117,7 +120,7 @@ class QianFanLLM(BaseLLM):
|
|||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout)
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from llama_index.core.llms import LLM
|
|||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
|
|
@ -28,6 +29,7 @@ from llama_index.core.schema import (
|
|||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
|
|
@ -36,6 +38,7 @@ from metagpt.rag.factories import (
|
|||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -44,6 +47,9 @@ from metagpt.rag.schema import (
|
|||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
|
@ -100,7 +106,10 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
file_extractor = cls._get_file_extractor()
|
||||
documents = SimpleDirectoryReader(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
transformations = transformations or cls._default_transformations()
|
||||
|
|
@ -301,3 +310,23 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
@staticmethod
|
||||
def _default_transformations():
|
||||
return [SentenceSplitter()]
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extractor() -> dict[str:BaseReader]:
|
||||
"""
|
||||
Get the file extractor.
|
||||
Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index.
|
||||
|
||||
Returns:
|
||||
dict[file_type: BaseReader]
|
||||
"""
|
||||
file_extractor: dict[str:BaseReader] = {}
|
||||
if config.omniparse.base_url:
|
||||
pdf_parser = OmniParse(
|
||||
api_key=config.omniparse.api_key,
|
||||
base_url=config.omniparse.base_url,
|
||||
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
|
||||
)
|
||||
file_extractor[".pdf"] = pdf_parser
|
||||
|
||||
return file_extractor
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
|||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import (
|
||||
|
|
@ -17,6 +18,7 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchIndexConfig,
|
||||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
MilvusIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
BM25IndexConfig: self._create_bm25,
|
||||
ElasticsearchIndexConfig: self._create_es,
|
||||
ElasticsearchKeywordIndexConfig: self._create_es,
|
||||
MilvusIndexConfig: self._create_milvus
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
|
|||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
|||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
|
@ -20,6 +21,7 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
|||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
|
|
@ -27,6 +29,7 @@ from metagpt.rag.schema import (
|
|||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
MilvusRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -56,6 +59,7 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
ElasticsearchRetrieverConfig: self._create_es_retriever,
|
||||
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
|
||||
MilvusRetrieverConfig: self._create_milvus_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -76,6 +80,11 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
|
||||
config.index = self._build_milvus_index(config, **kwargs)
|
||||
|
||||
return MilvusRetriever(**config.model_dump())
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
|
|
@ -128,6 +137,12 @@ class RetrieverFactory(ConfigBasedFactory):
|
|||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
|
|
|||
3
metagpt/rag/parsers/__init__.py
Normal file
3
metagpt/rag/parsers/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.parsers.omniparse import OmniParse
|
||||
|
||||
__all__ = ["OmniParse"]
|
||||
139
metagpt/rag/parsers/omniparse.py
Normal file
139
metagpt/rag/parsers/omniparse.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import asyncio
|
||||
from fileinput import FileInput
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
|
||||
class OmniParse(BaseReader):
|
||||
"""OmniParse"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: OmniParse Base URL for the API.
|
||||
parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values.
|
||||
"""
|
||||
self.parse_options = parse_options or OmniParseOptions()
|
||||
self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout)
|
||||
|
||||
@property
|
||||
def parse_type(self):
|
||||
return self.parse_options.parse_type
|
||||
|
||||
@property
|
||||
def result_type(self):
|
||||
return self.parse_options.result_type
|
||||
|
||||
@parse_type.setter
|
||||
def parse_type(self, parse_type: Union[str, OmniParseType]):
|
||||
if isinstance(parse_type, str):
|
||||
parse_type = OmniParseType(parse_type)
|
||||
self.parse_options.parse_type = parse_type
|
||||
|
||||
@result_type.setter
|
||||
def result_type(self, result_type: Union[str, ParseResultType]):
|
||||
if isinstance(result_type, str):
|
||||
result_type = ParseResultType(result_type)
|
||||
self.parse_options.result_type = result_type
|
||||
|
||||
async def _aload_data(
|
||||
self,
|
||||
file_path: Union[str, bytes, Path],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
try:
|
||||
if self.parse_type == OmniParseType.PDF:
|
||||
# pdf parse
|
||||
parsed_result = await self.omniparse_client.parse_pdf(file_path)
|
||||
else:
|
||||
# other parse use omniparse_client.parse_document
|
||||
# For compatible byte data, additional filename is required
|
||||
extra_info = extra_info or {}
|
||||
filename = extra_info.get("filename")
|
||||
parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename)
|
||||
|
||||
# Get the specified structured data based on result_type
|
||||
content = getattr(parsed_result, self.result_type)
|
||||
docs = [
|
||||
Document(
|
||||
text=content,
|
||||
metadata=extra_info or {},
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"OMNI Parse Error: {e}")
|
||||
docs = []
|
||||
|
||||
return docs
|
||||
|
||||
async def aload_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls _aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
docs = []
|
||||
if isinstance(file_path, (str, bytes, Path)):
|
||||
# Processing single file
|
||||
docs = await self._aload_data(file_path, extra_info)
|
||||
elif isinstance(file_path, list):
|
||||
# Concurrently process multiple files
|
||||
parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path]
|
||||
doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers)
|
||||
docs = [doc for docs in doc_ret_list for doc in docs]
|
||||
return docs
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_path: Union[List[FileInput], FileInput],
|
||||
extra_info: Optional[dict] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from the input file_path.
|
||||
|
||||
Args:
|
||||
file_path: File path or file byte data.
|
||||
extra_info: Optional dictionary containing additional information.
|
||||
|
||||
Notes:
|
||||
This method ultimately calls aload_data for processing.
|
||||
|
||||
Returns:
|
||||
List[Document]
|
||||
"""
|
||||
NestAsyncio.apply_once() # Ensure compatibility with nested async calls
|
||||
return asyncio.run(self.aload_data(file_path, extra_info))
|
||||
17
metagpt/rag/retrievers/milvus_retriever.py
Normal file
17
metagpt/rag/retrievers/milvus_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Milvus retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class MilvusRetriever(VectorIndexRetriever):
|
||||
"""Milvus retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Milvus automatically saves, so there is no need to implement."""
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Literal, Optional, Union
|
||||
from typing import Any, ClassVar, List, Literal, Optional, Union
|
||||
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
|
|
@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
|
|||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class MilvusRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Milvus-based retrievers."""
|
||||
|
||||
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
token: str = Field(default=None, description="The token for Milvus")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
|
||||
config.embedding.api_type, 1536
|
||||
)
|
||||
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
|
||||
logger.warning(
|
||||
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
|
|
@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
|
|||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
class MilvusIndexConfig(VectorIndexConfig):
|
||||
"""Config for milvus-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
|
||||
token: Optional[str] = Field(default=None, description="The token of the index.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
|
@ -214,3 +254,51 @@ class ObjectNode(TextNode):
|
|||
)
|
||||
|
||||
return metadata.model_dump()
|
||||
|
||||
|
||||
class OmniParseType(str, Enum):
|
||||
"""OmniParseType"""
|
||||
|
||||
PDF = "PDF"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ParseResultType(str, Enum):
|
||||
"""The result type for the parser."""
|
||||
|
||||
TXT = "text"
|
||||
MD = "markdown"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class OmniParseOptions(BaseModel):
|
||||
"""OmniParse Options config"""
|
||||
|
||||
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
|
||||
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")
|
||||
max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests")
|
||||
num_workers: int = Field(
|
||||
default=5,
|
||||
gt=0,
|
||||
lt=10,
|
||||
description="Number of concurrent requests for multiple files",
|
||||
)
|
||||
|
||||
|
||||
class OminParseImage(BaseModel):
|
||||
image: str = Field(default="", description="image str bytes")
|
||||
image_name: str = Field(default="", description="image name")
|
||||
image_info: Optional[dict] = Field(default={}, description="image info")
|
||||
|
||||
|
||||
class OmniParsedResult(BaseModel):
|
||||
markdown: str = Field(default="", description="markdown text")
|
||||
text: str = Field(default="", description="plain text")
|
||||
images: Optional[List[OminParseImage]] = Field(default=[], description="images")
|
||||
metadata: Optional[dict] = Field(default={}, description="metadata")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_markdown(cls, values):
|
||||
if not values.get("markdown"):
|
||||
values["markdown"] = values.get("text")
|
||||
return values
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : architect.py
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.roles.role import Role
|
||||
|
|
|
|||
|
|
@ -80,19 +80,17 @@ class InvoiceOCRAssistant(Role):
|
|||
raise Exception("Invoice file not uploaded")
|
||||
|
||||
resp = await todo.run(file_path)
|
||||
actions = list(self.actions)
|
||||
if len(resp) == 1:
|
||||
# Single file support for questioning based on OCR recognition results
|
||||
self.set_actions([GenerateTable, ReplyQuestion])
|
||||
actions.extend([GenerateTable, ReplyQuestion])
|
||||
self.orc_data = resp[0]
|
||||
else:
|
||||
self.set_actions([GenerateTable])
|
||||
|
||||
self.set_todo(None)
|
||||
actions.append(GenerateTable)
|
||||
self.set_actions(actions)
|
||||
self.rc.max_react_loop = len(self.actions)
|
||||
content = INVOICE_OCR_SUCCESS
|
||||
resp = OCRResults(ocr_result=json.dumps(resp))
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
self.rc.memory.add(msg)
|
||||
return await super().react()
|
||||
elif isinstance(todo, GenerateTable):
|
||||
ocr_results: OCRResults = msg.instruct_content
|
||||
resp = await todo.run(json.loads(ocr_results.ocr_result), self.filename)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@File : project_manager.py
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import WriteTasks
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.roles.role import Role
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
of SummarizeCode.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_NONE
|
||||
|
|
|
|||
|
|
@ -58,7 +58,9 @@ class Researcher(Role):
|
|||
)
|
||||
elif isinstance(todo, WebBrowseAndSummarize):
|
||||
links = instruct_content.links
|
||||
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
|
||||
todos = (
|
||||
todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items() if url
|
||||
)
|
||||
if self.enable_concurrency:
|
||||
summaries = await asyncio.gather(*todos)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -170,7 +170,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self._check_actions()
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self.llm.cost_manager = self.context.cost_manager
|
||||
self._watch(kwargs.pop("watch", [UserRequirement]))
|
||||
if not self.rc.watch:
|
||||
self._watch(kwargs.pop("watch", [UserRequirement]))
|
||||
|
||||
if self.latest_observed_msg:
|
||||
self.recovered = True
|
||||
|
|
@ -421,8 +422,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
"""Prepare new messages for processing from the message buffer and other sources."""
|
||||
# Read unprocessed messages from the msg buffer.
|
||||
news = []
|
||||
if self.recovered:
|
||||
news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
if self.recovered and self.latest_observed_msg:
|
||||
news = self.rc.memory.find_news(observed=[self.latest_observed_msg], k=10)
|
||||
if not news:
|
||||
news = self.rc.msg_buffer.pop_all()
|
||||
# Store the read messages in your own memory to prevent duplicate processing.
|
||||
|
|
|
|||
|
|
@ -1,87 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 12:23
|
||||
@Author : femto Zheng
|
||||
@File : sk_agent.py
|
||||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message filtering.
|
||||
"""
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import Field
|
||||
from semantic_kernel import Kernel
|
||||
from semantic_kernel.planning import SequentialPlanner
|
||||
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
|
||||
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.execute_task import ExecuteTask
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.make_sk_kernel import make_sk_kernel
|
||||
|
||||
|
||||
class SkAgent(Role):
|
||||
"""
|
||||
Represents an SkAgent implemented using semantic kernel
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the SkAgent.
|
||||
profile (str): Role profile, default is 'sk_agent'.
|
||||
goal (str): Goal of the SkAgent.
|
||||
constraints (str): Constraints for the SkAgent.
|
||||
"""
|
||||
|
||||
name: str = "Sunshine"
|
||||
profile: str = "sk_agent"
|
||||
goal: str = "Execute task based on passed in task description"
|
||||
constraints: str = ""
|
||||
|
||||
plan: Plan = Field(default=None, exclude=True)
|
||||
planner_cls: Any = None
|
||||
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
|
||||
import_skill: Callable = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes the Engineer role with given attributes."""
|
||||
super().__init__(**data)
|
||||
self.set_actions([ExecuteTask()])
|
||||
self._watch([UserRequirement])
|
||||
self.kernel = make_sk_kernel()
|
||||
|
||||
# how funny the interface is inconsistent
|
||||
if self.planner_cls == BasicPlanner or self.planner_cls is None:
|
||||
self.planner = BasicPlanner()
|
||||
elif self.planner_cls in [SequentialPlanner, ActionPlanner]:
|
||||
self.planner = self.planner_cls(self.kernel)
|
||||
else:
|
||||
raise Exception(f"Unsupported planner of type {self.planner_cls}")
|
||||
|
||||
self.import_semantic_skill_from_directory = self.kernel.import_semantic_skill_from_directory
|
||||
self.import_skill = self.kernel.import_skill
|
||||
|
||||
async def _think(self) -> None:
|
||||
self._set_state(0)
|
||||
# how funny the interface is inconsistent
|
||||
if isinstance(self.planner, BasicPlanner):
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
|
||||
logger.info(self.plan.generated_plan)
|
||||
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
# how funny the interface is inconsistent
|
||||
result = None
|
||||
if isinstance(self.planner, BasicPlanner):
|
||||
result = await self.planner.execute_plan_async(self.plan, self.kernel)
|
||||
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
|
||||
result = (await self.plan.invoke_async()).result
|
||||
logger.info(result)
|
||||
|
||||
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import agentops
|
||||
import typer
|
||||
|
||||
from metagpt.const import CONFIG_ROOT
|
||||
|
|
@ -38,6 +39,9 @@ def generate_repo(
|
|||
)
|
||||
from metagpt.team import Team
|
||||
|
||||
if config.agentops_api_key != "":
|
||||
agentops.init(config.agentops_api_key, tags=["software_company"])
|
||||
|
||||
config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
|
||||
ctx = Context(config=config)
|
||||
|
||||
|
|
@ -68,6 +72,9 @@ def generate_repo(
|
|||
company.run_project(idea)
|
||||
asyncio.run(company.run(n_round=n_round))
|
||||
|
||||
if config.agentops_api_key != "":
|
||||
agentops.end_session("Success")
|
||||
|
||||
return ctx.repo
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -126,6 +126,9 @@ class Team(BaseModel):
|
|||
self.run_project(idea=idea, send_to=send_to)
|
||||
|
||||
while n_round > 0:
|
||||
if self.env.is_idle:
|
||||
logger.debug("All roles are idle.")
|
||||
break
|
||||
n_round -= 1
|
||||
self._check_balance()
|
||||
await self.env.run()
|
||||
|
|
|
|||
|
|
@ -6,37 +6,15 @@
|
|||
@File : search_engine.py
|
||||
"""
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, Optional, Union, overload
|
||||
from typing import Annotated, Callable, Coroutine, Literal, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from semantic_kernel.skill_definition import sk_function
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SkSearchEngine:
|
||||
"""A search engine class for executing searches.
|
||||
|
||||
Attributes:
|
||||
search_engine: The search engine instance used for executing searches.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.search_engine = SearchEngine(**kwargs)
|
||||
|
||||
@sk_function(
|
||||
description="searches results from Google. Useful when you need to find short "
|
||||
"and succinct answers about a specific topic. Input should be a search query.",
|
||||
name="searchAsync",
|
||||
input_description="search",
|
||||
)
|
||||
async def run(self, query: str) -> str:
|
||||
result = await self.search_engine.run(query)
|
||||
return result
|
||||
|
||||
|
||||
class SearchEngine(BaseModel):
|
||||
"""A model for configuring and executing searches with different search engines.
|
||||
|
||||
|
|
@ -51,7 +29,9 @@ class SearchEngine(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
|
||||
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
|
||||
run_func: Annotated[
|
||||
Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]], Field(exclude=True)
|
||||
] = None
|
||||
api_key: Optional[str] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -87,8 +87,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
if res["error"] == "Google hasn't returned any results for this query.":
|
||||
toret = "No good search result found"
|
||||
else:
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
elif "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Optional, Union, overload
|
||||
from typing import Annotated, Any, Callable, Coroutine, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
|
|
@ -29,7 +29,10 @@ class WebBrowserEngine(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
|
||||
run_func: Annotated[
|
||||
Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]],
|
||||
Field(exclude=True),
|
||||
] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
|
|||
class DiGraphRepository(GraphRepository):
|
||||
"""Graph repository based on DiGraph."""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
def __init__(self, name: str | Path, **kwargs):
|
||||
super().__init__(name=str(name), **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
|
|
@ -112,8 +112,28 @@ class DiGraphRepository(GraphRepository):
|
|||
async def load(self, pathname: str | Path):
|
||||
"""Load a directed graph repository from a JSON file."""
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self.load_json(data)
|
||||
|
||||
def load_json(self, val: str):
|
||||
"""
|
||||
Loads a JSON-encoded string representing a graph structure and updates
|
||||
the internal repository (_repo) with the parsed graph.
|
||||
|
||||
Args:
|
||||
val (str): A JSON-encoded string representing a graph structure.
|
||||
|
||||
Returns:
|
||||
self: Returns the instance of the class with the updated _repo attribute.
|
||||
|
||||
Raises:
|
||||
TypeError: If val is not a valid JSON string or cannot be parsed into
|
||||
a valid graph structure.
|
||||
"""
|
||||
if not val:
|
||||
return self
|
||||
m = json.loads(val)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
|
|
@ -126,9 +146,7 @@ class DiGraphRepository(GraphRepository):
|
|||
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
|
||||
"""
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
graph = DiGraphRepository(name=name, root=root)
|
||||
graph = DiGraphRepository(name=pathname.stem, root=pathname.parent)
|
||||
if pathname.exists():
|
||||
await graph.load(pathname=pathname)
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 12:29
|
||||
@Author : femto Zheng
|
||||
@File : make_sk_kernel.py
|
||||
"""
|
||||
import semantic_kernel as sk
|
||||
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import (
|
||||
AzureChatCompletion,
|
||||
)
|
||||
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import (
|
||||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if llm := config.get_azure_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
|
||||
)
|
||||
elif llm := config.get_openai_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(llm.model, llm.api_key),
|
||||
)
|
||||
|
||||
return kernel
|
||||
239
metagpt/utils/omniparse_client.py
Normal file
239
metagpt/utils/omniparse_client.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
|
||||
from metagpt.rag.schema import OmniParsedResult
|
||||
from metagpt.utils.common import aread_bin
|
||||
|
||||
|
||||
class OmniParseClient:
|
||||
"""
|
||||
OmniParse Server Client
|
||||
This client interacts with the OmniParse server to parse different types of media, documents.
|
||||
|
||||
OmniParse API Documentation: https://docs.cognitivelab.in/api
|
||||
|
||||
Attributes:
|
||||
ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions.
|
||||
ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions.
|
||||
ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions.
|
||||
"""
|
||||
|
||||
ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"}
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"}
|
||||
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"}
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120):
|
||||
"""
|
||||
Args:
|
||||
api_key: Default None, can be used for authentication later.
|
||||
base_url: Base URL for the API.
|
||||
max_timeout: Maximum request timeout in seconds.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.max_timeout = max_timeout
|
||||
|
||||
self.parse_media_endpoint = "/parse_media"
|
||||
self.parse_website_endpoint = "/parse_website"
|
||||
self.parse_document_endpoint = "/parse_document"
|
||||
|
||||
async def _request_parse(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
files: dict = None,
|
||||
params: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
headers: dict = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Request OmniParse API to parse a document.
|
||||
|
||||
Args:
|
||||
endpoint (str): API endpoint.
|
||||
method (str, optional): HTTP method to use. Default is "POST".
|
||||
files (dict, optional): Files to include in the request.
|
||||
params (dict, optional): Query string parameters.
|
||||
data (dict, optional): Form data to include in the request body.
|
||||
json (dict, optional): JSON data to include in the request body.
|
||||
headers (dict, optional): HTTP headers to include in the request.
|
||||
**kwargs: Additional keyword arguments for httpx.AsyncClient.request()
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
method = method.upper()
|
||||
headers = headers or {}
|
||||
_headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
headers.update(**_headers)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
url=url,
|
||||
method=method,
|
||||
files=files,
|
||||
params=params,
|
||||
json=json,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=self.max_timeout,
|
||||
**kwargs,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult:
|
||||
"""
|
||||
Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the document parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult:
|
||||
"""
|
||||
Parse pdf document.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
OmniParsedResult: The result of the pdf parsing.
|
||||
"""
|
||||
self.verify_file_ext(file_input, {".pdf"})
|
||||
# parse_pdf supports parsing by accepting only the byte data of the file.
|
||||
file_info = await self.get_file_info(file_input, only_bytes=True)
|
||||
endpoint = f"{self.parse_document_endpoint}/pdf"
|
||||
resp = await self._request_parse(endpoint=endpoint, files={"file": file_info})
|
||||
data = OmniParsedResult(**resp)
|
||||
return data
|
||||
|
||||
async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info})
|
||||
|
||||
async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict:
|
||||
"""
|
||||
Parse audio-type data (supports ".mp3", ".wav", ".aac").
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
dict: JSON response data.
|
||||
"""
|
||||
self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename)
|
||||
file_info = await self.get_file_info(file_input, bytes_filename)
|
||||
return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info})
|
||||
|
||||
@staticmethod
|
||||
def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None):
|
||||
"""
|
||||
Verify the file extension.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
allowed_file_extensions: Set of allowed file extensions.
|
||||
bytes_filename: Filename to use for verification when `file_input` is byte data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file extension is not allowed.
|
||||
|
||||
Returns:
|
||||
"""
|
||||
verify_file_path = None
|
||||
if isinstance(file_input, (str, Path)):
|
||||
verify_file_path = str(file_input)
|
||||
elif isinstance(file_input, bytes) and bytes_filename:
|
||||
verify_file_path = bytes_filename
|
||||
|
||||
if not verify_file_path:
|
||||
# Do not verify if only byte data is provided
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
|
||||
|
||||
@staticmethod
|
||||
async def get_file_info(
|
||||
file_input: Union[str, bytes, Path],
|
||||
bytes_filename: str = None,
|
||||
only_bytes: bool = False,
|
||||
) -> Union[bytes, tuple]:
|
||||
"""
|
||||
Get file information.
|
||||
|
||||
Args:
|
||||
file_input: File path or file byte data.
|
||||
bytes_filename: Filename to use when uploading byte data, useful for determining MIME type.
|
||||
only_bytes: Whether to return only byte data. Default is False, which returns a tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type.
|
||||
|
||||
Notes:
|
||||
Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types,
|
||||
the MIME type of the file must be specified when uploading.
|
||||
|
||||
Returns: [bytes, tuple]
|
||||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(file_input, (str, Path)):
|
||||
filename = os.path.basename(str(file_input))
|
||||
file_bytes = await aread_bin(file_input)
|
||||
|
||||
if only_bytes:
|
||||
return file_bytes
|
||||
|
||||
mime_type = mimetypes.guess_type(file_input)[0]
|
||||
return filename, file_bytes, mime_type
|
||||
elif isinstance(file_input, bytes):
|
||||
if only_bytes:
|
||||
return file_input
|
||||
if not bytes_filename:
|
||||
raise ValueError("bytes_filename must be set when passing bytes")
|
||||
|
||||
mime_type = mimetypes.guess_type(bytes_filename)[0]
|
||||
return bytes_filename, file_input, mime_type
|
||||
else:
|
||||
raise ValueError("file_input must be a string (file path) or bytes.")
|
||||
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||
import traceback
|
||||
from datetime import timedelta
|
||||
|
||||
import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.logs import logger
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ from multiprocessing import Pipe
|
|||
|
||||
|
||||
class StreamPipe:
|
||||
parent_conn, child_conn = Pipe()
|
||||
finish: bool = False
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
self.parent_conn, self.child_conn = Pipe()
|
||||
self.finish: bool = False
|
||||
|
||||
format_data = {
|
||||
"id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur",
|
||||
|
|
|
|||
|
|
@ -41,11 +41,19 @@ TOKEN_COSTS = {
|
|||
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
|
||||
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
|
||||
"gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006},
|
||||
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
|
||||
"gpt-4o-2024-08-06": {"prompt": 0.0025, "completion": 0.01},
|
||||
"o1-preview": {"prompt": 0.015, "completion": 0.06},
|
||||
"o1-preview-2024-09-12": {"prompt": 0.015, "completion": 0.06},
|
||||
"o1-mini": {"prompt": 0.003, "completion": 0.012},
|
||||
"o1-mini-2024-09-12": {"prompt": 0.003, "completion": 0.012},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
"gemini-1.5-flash": {"prompt": 0.000075, "completion": 0.0003},
|
||||
"gemini-1.5-pro": {"prompt": 0.0035, "completion": 0.0105},
|
||||
"gemini-1.0-pro": {"prompt": 0.0005, "completion": 0.0015},
|
||||
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
|
||||
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
|
||||
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
|
||||
|
|
@ -69,15 +77,20 @@ TOKEN_COSTS = {
|
|||
"llama3-70b-8192": {"prompt": 0.0059, "completion": 0.0079},
|
||||
"openai/gpt-3.5-turbo-0125": {"prompt": 0.0005, "completion": 0.0015},
|
||||
"openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"openai/o1-preview": {"prompt": 0.015, "completion": 0.06},
|
||||
"openai/o1-mini": {"prompt": 0.003, "completion": 0.012},
|
||||
"anthropic/claude-3-opus": {"prompt": 0.015, "completion": 0.075},
|
||||
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"google/gemini-pro-1.5": {"prompt": 0.0025, "completion": 0.0075}, # for openrouter, end
|
||||
"deepseek-chat": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
# For ark model https://www.volcengine.com/docs/82379/1099320
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084},
|
||||
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013},
|
||||
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028},
|
||||
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012},
|
||||
"doubao-lite-4k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
"doubao-lite-32k-240515": {"prompt": 0.000043, "completion": 0.000086},
|
||||
"doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00014},
|
||||
"doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00029},
|
||||
"doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00029},
|
||||
"doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0013},
|
||||
"llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
|
@ -138,8 +151,17 @@ QIANFAN_ENDPOINT_TOKEN_COSTS = {
|
|||
"""
|
||||
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
Some new model published by Alibaba will be prioritized to be released on the Model Studio instead of the Dashscope.
|
||||
Token price on Model Studio shows on https://help.aliyun.com/zh/model-studio/getting-started/models#ced16cb6cdfsy
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen2.5-72b-instruct": {"prompt": 0.00057, "completion": 0.0017}, # per 1k tokens
|
||||
"qwen2.5-32b-instruct": {"prompt": 0.0005, "completion": 0.001},
|
||||
"qwen2.5-14b-instruct": {"prompt": 0.00029, "completion": 0.00086},
|
||||
"qwen2.5-7b-instruct": {"prompt": 0.00014, "completion": 0.00029},
|
||||
"qwen2.5-3b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen2.5-1.5b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen2.5-0.5b-instruct": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen2-72b-instruct": {"prompt": 0.000714, "completion": 0.001428},
|
||||
"qwen2-57b-a14b-instruct": {"prompt": 0.0005, "completion": 0.001},
|
||||
"qwen2-7b-instruct": {"prompt": 0.000143, "completion": 0.000286},
|
||||
|
|
@ -190,16 +212,24 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
|
||||
# https://console.volcengine.com/ark/region:ark+cn-beijing/model
|
||||
DOUBAO_TOKEN_COSTS = {
|
||||
"doubao-lite": {"prompt": 0.0003, "completion": 0.0006},
|
||||
"doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010},
|
||||
"doubao-pro": {"prompt": 0.0008, "completion": 0.0020},
|
||||
"doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090},
|
||||
"doubao-lite": {"prompt": 0.000043, "completion": 0.000086},
|
||||
"doubao-lite-128k": {"prompt": 0.00011, "completion": 0.00014},
|
||||
"doubao-pro": {"prompt": 0.00011, "completion": 0.00029},
|
||||
"doubao-pro-128k": {"prompt": 0.00071, "completion": 0.0013},
|
||||
"doubao-pro-256k": {"prompt": 0.00071, "completion": 0.0013},
|
||||
}
|
||||
|
||||
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
|
||||
TOKEN_MAX = {
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"o1-preview": 128000,
|
||||
"o1-preview-2024-09-12": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o1-mini-2024-09-12": 128000,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"gpt-4o-2024-08-06": 128000,
|
||||
"gpt-4o-mini-2024-07-18": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
|
|
@ -222,7 +252,9 @@ TOKEN_MAX = {
|
|||
"text-embedding-ada-002": 8192,
|
||||
"glm-3-turbo": 128000,
|
||||
"glm-4": 128000,
|
||||
"gemini-pro": 32768,
|
||||
"gemini-1.5-flash": 1000000,
|
||||
"gemini-1.5-pro": 2000000,
|
||||
"gemini-1.0-pro": 32000,
|
||||
"moonshot-v1-8k": 8192,
|
||||
"moonshot-v1-32k": 32768,
|
||||
"moonshot-v1-128k": 128000,
|
||||
|
|
@ -246,6 +278,11 @@ TOKEN_MAX = {
|
|||
"llama3-70b-8192": 8192,
|
||||
"openai/gpt-3.5-turbo-0125": 16385,
|
||||
"openai/gpt-4-turbo-preview": 128000,
|
||||
"openai/o1-preview": 128000,
|
||||
"openai/o1-mini": 128000,
|
||||
"anthropic/claude-3-opus": 200000,
|
||||
"anthropic/claude-3.5-sonnet": 200000,
|
||||
"google/gemini-pro-1.5": 4000000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
"doubao-lite-4k-240515": 4000,
|
||||
|
|
@ -255,6 +292,13 @@ TOKEN_MAX = {
|
|||
"doubao-pro-32k-240515": 32000,
|
||||
"doubao-pro-128k-240515": 128000,
|
||||
# Qwen https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-72b-api-detailes?spm=a2c4g.11186623.0.i20
|
||||
"qwen2.5-72b-instruct": 131072,
|
||||
"qwen2.5-32b-instruct": 131072,
|
||||
"qwen2.5-14b-instruct": 131072,
|
||||
"qwen2.5-7b-instruct": 131072,
|
||||
"qwen2.5-3b-instruct": 32768,
|
||||
"qwen2.5-1.5b-instruct": 32768,
|
||||
"qwen2.5-0.5b-instruct": 32768,
|
||||
"qwen2-57b-a14b-instruct": 32768,
|
||||
"qwen2-72b-instruct": 131072,
|
||||
"qwen2-7b-instruct": 32768,
|
||||
|
|
@ -354,13 +398,19 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-1106-vision-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-5-sonnet-20240620"
|
||||
"gpt-4o-mini-2024-07-18",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
}:
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
tokens_per_name = 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue