From 9f4ee420791f1d7bf87ff22d10e46441d93af7da Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 1 Feb 2024 18:18:35 +0800 Subject: [PATCH] update env api schema --- metagpt/environment/api/env_api.py | 19 +++++++++++++++++-- metagpt/environment/base_env.py | 22 +++++++++++++++------- metagpt/utils/common.py | 8 ++++++++ tests/metagpt/environment/test_base_env.py | 4 ++++ 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/metagpt/environment/api/env_api.py b/metagpt/environment/api/env_api.py index 6469e5b4c..1e6df544d 100644 --- a/metagpt/environment/api/env_api.py +++ b/metagpt/environment/api/env_api.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : the environment api store -from typing import Callable +from typing import Any, Callable, Union from pydantic import BaseModel, Field @@ -18,9 +18,11 @@ class EnvAPIAbstract(BaseModel): class EnvAPIRegistry(BaseModel): """the registry to store environment w&r api/interface""" - registry: dict[str, Callable] = Field(default=dict(), exclude=True) + registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True) def get(self, api_name: str): + if api_name not in self.registry: + raise ValueError return self.registry.get(api_name) def __getitem__(self, api_name: str) -> Callable: @@ -32,6 +34,19 @@ class EnvAPIRegistry(BaseModel): def __len__(self): return len(self.registry) + def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]: + """return func schema without func instance""" + apis = dict() + for func_name, func_schema in self.registry.items(): + new_func_schema = dict() + for key, value in func_schema.items(): + if key == "func": + continue + new_func_schema[key] = str(value) if as_str else value + new_func_schema = new_func_schema + apis[func_name] = new_func_schema + return apis + class WriteAPIRegistry(EnvAPIRegistry): """just as a explicit class name""" diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 1bdcfe373..7ba34dfaf 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -4,7 +4,7 @@ import asyncio from enum import Enum -from typing import Iterable, Optional, Set, Union +from typing import Any, Iterable, Optional, Set, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -17,7 +17,7 @@ from metagpt.environment.api.env_api import ( from metagpt.logs import logger from metagpt.roles.role import Role from metagpt.schema import Message -from metagpt.utils.common import is_coroutine_func, is_send_to +from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to class EnvType(Enum): @@ -34,13 +34,13 @@ env_read_api_registry = ReadAPIRegistry() def mark_as_readable(func): """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" - env_read_api_registry[func.__name__] = func + env_read_api_registry[func.__name__] = get_function_schema(func) return func def mark_as_writeable(func): """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" - env_write_api_registry[func.__name__] = func + env_write_api_registry[func.__name__] = get_function_schema(func) return func @@ -51,17 +51,25 @@ class ExtEnv(BaseModel): if not rw_api: raise ValueError(f"{rw_api} not exists") + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + async def observe(self, env_action: Union[str, EnvAPIAbstract]): """get observation from particular api of ExtEnv""" if isinstance(env_action, str): - read_api = env_read_api_registry.get(api_name=env_action) + read_api = env_read_api_registry.get(api_name=env_action)["func"] self._check_api_exist(read_api) if is_coroutine_func(read_api): res = await read_api(self) else: res = read_api(self) elif isinstance(env_action, EnvAPIAbstract): - read_api = env_read_api_registry.get(api_name=env_action.api_name) + read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] self._check_api_exist(read_api) if is_coroutine_func(read_api): res = await read_api(self, *env_action.args, **env_action.kwargs) @@ -75,7 +83,7 @@ class ExtEnv(BaseModel): if isinstance(env_action, Message): self.publish_message(env_action) elif isinstance(env_action, EnvAPIAbstract): - write_api = env_write_api_registry.get(env_action.api_name) + write_api = env_write_api_registry.get(env_action.api_name)["func"] self._check_api_exist(write_api) if is_coroutine_func(write_api): res = await write_api(self, *env_action.args, **env_action.kwargs) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 2e05afa74..d3a922c5f 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -340,6 +340,14 @@ def print_members(module, indent=0): print(f"{prefix}Method: {name}") +def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]: + sig = inspect.signature(func) + parameters = sig.parameters + return_type = sig.return_annotation + param_schema = {name: parameter.annotation for name, parameter in parameters.items()} + return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func} + + def parse_recipient(text): # FIXME: use ActionNode instead. pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py index ce8165f2f..fd73679d8 100644 --- a/tests/metagpt/environment/test_base_env.py +++ b/tests/metagpt/environment/test_base_env.py @@ -40,6 +40,10 @@ async def test_ext_env(): assert len(env_read_api_registry) > 0 assert len(env_write_api_registry) > 0 + apis = env.get_all_available_apis(mode="read") + assert len(apis) > 0 + assert len(apis["read_api"]) == 3 + _ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10})) assert env.value == 15