mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
update env api schema
This commit is contained in:
parent
52a94470db
commit
9f4ee42079
4 changed files with 44 additions and 9 deletions
|
|
@ -2,7 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the environment api store
|
||||
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -18,9 +18,11 @@ class EnvAPIAbstract(BaseModel):
|
|||
class EnvAPIRegistry(BaseModel):
|
||||
"""the registry to store environment w&r api/interface"""
|
||||
|
||||
registry: dict[str, Callable] = Field(default=dict(), exclude=True)
|
||||
registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True)
|
||||
|
||||
def get(self, api_name: str):
|
||||
if api_name not in self.registry:
|
||||
raise ValueError
|
||||
return self.registry.get(api_name)
|
||||
|
||||
def __getitem__(self, api_name: str) -> Callable:
|
||||
|
|
@ -32,6 +34,19 @@ class EnvAPIRegistry(BaseModel):
|
|||
def __len__(self):
|
||||
return len(self.registry)
|
||||
|
||||
def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]:
|
||||
"""return func schema without func instance"""
|
||||
apis = dict()
|
||||
for func_name, func_schema in self.registry.items():
|
||||
new_func_schema = dict()
|
||||
for key, value in func_schema.items():
|
||||
if key == "func":
|
||||
continue
|
||||
new_func_schema[key] = str(value) if as_str else value
|
||||
new_func_schema = new_func_schema
|
||||
apis[func_name] = new_func_schema
|
||||
return apis
|
||||
|
||||
|
||||
class WriteAPIRegistry(EnvAPIRegistry):
|
||||
"""just as a explicit class name"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Iterable, Optional, Set, Union
|
||||
from typing import Any, Iterable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ from metagpt.environment.api.env_api import (
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_coroutine_func, is_send_to
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
|
|
@ -34,13 +34,13 @@ env_read_api_registry = ReadAPIRegistry()
|
|||
|
||||
def mark_as_readable(func):
|
||||
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
env_read_api_registry[func.__name__] = func
|
||||
env_read_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
def mark_as_writeable(func):
|
||||
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
env_write_api_registry[func.__name__] = func
|
||||
env_write_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
|
|
@ -51,17 +51,25 @@ class ExtEnv(BaseModel):
|
|||
if not rw_api:
|
||||
raise ValueError(f"{rw_api} not exists")
|
||||
|
||||
def get_all_available_apis(self, mode: str = "read") -> list[Any]:
|
||||
"""get available read/write apis definition"""
|
||||
assert mode in ["read", "write"]
|
||||
if mode == "read":
|
||||
return env_read_api_registry.get_apis()
|
||||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = env_read_api_registry.get(api_name=env_action)
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self)
|
||||
else:
|
||||
res = read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
|
@ -75,7 +83,7 @@ class ExtEnv(BaseModel):
|
|||
if isinstance(env_action, Message):
|
||||
self.publish_message(env_action)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = env_write_api_registry.get(env_action.api_name)
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
if is_coroutine_func(write_api):
|
||||
res = await write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
|
|
|||
|
|
@ -340,6 +340,14 @@ def print_members(module, indent=0):
|
|||
print(f"{prefix}Method: {name}")
|
||||
|
||||
|
||||
def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]:
|
||||
sig = inspect.signature(func)
|
||||
parameters = sig.parameters
|
||||
return_type = sig.return_annotation
|
||||
param_schema = {name: parameter.annotation for name, parameter in parameters.items()}
|
||||
return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func}
|
||||
|
||||
|
||||
def parse_recipient(text):
|
||||
# FIXME: use ActionNode instead.
|
||||
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
|
||||
|
|
|
|||
|
|
@ -40,6 +40,10 @@ async def test_ext_env():
|
|||
assert len(env_read_api_registry) > 0
|
||||
assert len(env_write_api_registry) > 0
|
||||
|
||||
apis = env.get_all_available_apis(mode="read")
|
||||
assert len(apis) > 0
|
||||
assert len(apis["read_api"]) == 3
|
||||
|
||||
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
assert env.value == 15
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue