update env api schema

This commit is contained in:
better629 2024-02-01 18:18:35 +08:00
parent 52a94470db
commit 9f4ee42079
4 changed files with 44 additions and 9 deletions

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : the environment api store
from typing import Callable
from typing import Any, Callable, Union
from pydantic import BaseModel, Field
@ -18,9 +18,11 @@ class EnvAPIAbstract(BaseModel):
class EnvAPIRegistry(BaseModel):
"""the registry to store environment w&r api/interface"""
registry: dict[str, Callable] = Field(default=dict(), exclude=True)
registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True)
def get(self, api_name: str):
if api_name not in self.registry:
raise ValueError
return self.registry.get(api_name)
def __getitem__(self, api_name: str) -> Callable:
@ -32,6 +34,19 @@ class EnvAPIRegistry(BaseModel):
def __len__(self):
return len(self.registry)
def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]:
"""return func schema without func instance"""
apis = dict()
for func_name, func_schema in self.registry.items():
new_func_schema = dict()
for key, value in func_schema.items():
if key == "func":
continue
new_func_schema[key] = str(value) if as_str else value
new_func_schema = new_func_schema
apis[func_name] = new_func_schema
return apis
class WriteAPIRegistry(EnvAPIRegistry):
"""just as a explicit class name"""

View file

@ -4,7 +4,7 @@
import asyncio
from enum import Enum
from typing import Iterable, Optional, Set, Union
from typing import Any, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -17,7 +17,7 @@ from metagpt.environment.api.env_api import (
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_coroutine_func, is_send_to
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
class EnvType(Enum):
@ -34,13 +34,13 @@ env_read_api_registry = ReadAPIRegistry()
def mark_as_readable(func):
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
env_read_api_registry[func.__name__] = func
env_read_api_registry[func.__name__] = get_function_schema(func)
return func
def mark_as_writeable(func):
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
env_write_api_registry[func.__name__] = func
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
@ -51,17 +51,25 @@ class ExtEnv(BaseModel):
if not rw_api:
raise ValueError(f"{rw_api} not exists")
def get_all_available_apis(self, mode: str = "read") -> list[Any]:
"""get available read/write apis definition"""
assert mode in ["read", "write"]
if mode == "read":
return env_read_api_registry.get_apis()
else:
return env_write_api_registry.get_apis()
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = env_read_api_registry.get(api_name=env_action)
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self)
else:
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = env_read_api_registry.get(api_name=env_action.api_name)
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self, *env_action.args, **env_action.kwargs)
@ -75,7 +83,7 @@ class ExtEnv(BaseModel):
if isinstance(env_action, Message):
self.publish_message(env_action)
elif isinstance(env_action, EnvAPIAbstract):
write_api = env_write_api_registry.get(env_action.api_name)
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
if is_coroutine_func(write_api):
res = await write_api(self, *env_action.args, **env_action.kwargs)

View file

@ -340,6 +340,14 @@ def print_members(module, indent=0):
print(f"{prefix}Method: {name}")
def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]:
sig = inspect.signature(func)
parameters = sig.parameters
return_type = sig.return_annotation
param_schema = {name: parameter.annotation for name, parameter in parameters.items()}
return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func}
def parse_recipient(text):
# FIXME: use ActionNode instead.
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now

View file

@ -40,6 +40,10 @@ async def test_ext_env():
assert len(env_read_api_registry) > 0
assert len(env_write_api_registry) > 0
apis = env.get_all_available_apis(mode="read")
assert len(apis) > 0
assert len(apis["read_api"]) == 3
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
assert env.value == 15