feat: merge geekan:dev

This commit is contained in:
莘权 马 2024-02-02 16:47:52 +08:00
commit dadd09bfb5
105 changed files with 5201 additions and 350 deletions

View file

@ -413,12 +413,13 @@ class ActionNode:
prompt: str,
output_class_name: str,
output_data_mapping: dict,
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=3,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
logger.debug(f"llm raw output:\n{content}")
output_class = self.create_model_class(output_class_name, output_data_mapping)
@ -447,13 +448,15 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=3, exclude=None):
async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
mapping = self.get_mapping(mode, exclude=exclude)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
content, scontent = await self._aask_v1(
prompt, class_name, mapping, images=images, schema=schema, timeout=timeout
)
self.content = content
self.instruct_content = scontent
else:
@ -462,7 +465,17 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]):
async def fill(
self,
context,
llm,
schema="json",
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
timeout=3,
exclude=[],
):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
@ -478,6 +491,7 @@ class ActionNode:
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:param images: the list of image url or base64 for gpt4-v
:param timeout: Timeout for llm invocation.
:param exclude: The keys of ActionNode to exclude.
:return: self
@ -488,14 +502,14 @@ class ActionNode:
schema = self.schema
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
if exclude and i.key in exclude:
continue
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
tmp.update(child.instruct_content.model_dump())
cls = self._create_children_class()
self.instruct_content = cls(**tmp)

View file

@ -3,15 +3,15 @@
from __future__ import annotations
import asyncio
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union
from pydantic import Field, parse_obj_as
from pydantic import TypeAdapter, model_validator
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.tools.web_browser_engine import WebBrowserEngine
from metagpt.utils.common import OutputParser
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
@ -81,10 +81,16 @@ class CollectLinks(Action):
name: str = "CollectLinks"
i_context: Optional[str] = None
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
search_func: Optional[Any] = None
search_engine: Optional[SearchEngine] = None
rank_func: Optional[Callable[[list[str]], None]] = None
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.search_engine is None:
self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy)
return self
async def run(
self,
topic: str,
@ -107,7 +113,7 @@ class CollectLinks(Action):
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
try:
keywords = OutputParser.extract_struct(keywords, list)
keywords = parse_obj_as(list[str], keywords)
keywords = TypeAdapter(list[str]).validate_python(keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}")
keywords = [topic]
@ -133,7 +139,7 @@ class CollectLinks(Action):
queries = await self._aask(prompt, [system_text])
try:
queries = OutputParser.extract_struct(queries, list)
queries = parse_obj_as(list[str], queries)
queries = TypeAdapter(list[str]).validate_python(queries)
except Exception as e:
logger.exception(f"fail to break down the research question due to {e}")
queries = keywords
@ -178,15 +184,17 @@ class WebBrowseAndSummarize(Action):
i_context: Optional[str] = None
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
web_browser_engine: Optional[WebBrowserEngine] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
run_func=self.browse_func,
)
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.web_browser_engine is None:
self.web_browser_engine = WebBrowserEngine.from_browser_config(
self.config.browser,
browse_func=self.browse_func,
proxy=self.config.proxy,
)
return self
async def run(
self,

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : search_google.py
"""
from typing import Any, Optional
from typing import Optional
import pydantic
from pydantic import model_validator
@ -13,7 +13,6 @@ from pydantic import model_validator
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements
@ -105,21 +104,19 @@ You are a member of a professional butler team and will provide helpful suggesti
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
engine: Optional[SearchEngineType] = None
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.engine is None:
self.engine = self.config.search_engine
try:
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
except pydantic.ValidationError:
search_engine = None
def validate_search_engine(self):
if self.search_engine is None:
try:
config = self.config
search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy)
except pydantic.ValidationError:
search_engine = None
self.search_engine = search_engine
self.search_engine = search_engine
return self
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:

View file

@ -51,7 +51,7 @@ class Config(CLIParams, YamlModel):
proxy: str = ""
# Tool Parameters
search: Optional[SearchConfig] = None
search: SearchConfig = SearchConfig()
browser: BrowserConfig = BrowserConfig()
mermaid: MermaidConfig = MermaidConfig()

View file

@ -15,6 +15,6 @@ class BrowserConfig(YamlModel):
"""Config for Browser"""
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
path: str = ""
browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium"
"""If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value
should be either "chrome", "firefox", "edge", or "ie"."""

View file

@ -5,6 +5,8 @@
@Author : alexanderwu
@File : search_config.py
"""
from typing import Callable, Optional
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -12,6 +14,7 @@ from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
api_key: str
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
api_key: str = ""
cse_id: str = "" # for google
search_func: Optional[Callable] = None

View file

@ -7,7 +7,7 @@
"""
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.config2 import Config
from metagpt.context import Context
@ -17,7 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
# - https://github.com/pydantic/pydantic/issues/7142
@ -32,15 +32,18 @@ class ContextMixin(BaseModel):
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
def __init__(
@model_validator(mode="after")
def validate_extra(self):
self._process_extra(**(self.model_extra or {}))
return self
def _process_extra(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
):
"""Initialize with config"""
super().__init__(**kwargs)
"""Process the extra field"""
self.set_context(context)
self.set_config(config)
self.set_llm(llm)

View file

@ -0,0 +1,38 @@
Here is a environment description of MetaGPT env for different situation.
For now, the code only define the environment and still some todos like migrate roles/actions to current version.
## Function
- Define `ExtEnv`(Base Class) which help users to integrate with external environment like games through apis or construct the game logics.
- Define `Environment`(Base Class) which is the env that MetaGPT directly used. And it includes roles and so on.
- Define the `EnvAPIRegistry` to mark the read/write apis that `ExtEnv` provide observe/step ability. And then, users can call the particular one to get observation from env or feedback to env.
## Usage
init environment
```
android_env = env.create(EnvType.ANDROID)
assistant = Role(name="Bob", profile="android assistant")
team = Team(investment=10.0, env=android_env, roles=[assistant])
```
observe & step inside role's actions
```
from metagpt.environment.api.env_api import EnvAPIAbstract
# get screenshot from ExtEnv
screenshot_path: Path = env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
# do a `tap` action on the screen
res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y}))
```
## TODO
- add android app operation assistant under `examples/android_assistant`
- migrate roles/actions of werewolf game from old version into current version
- migrate roles/actions of mincraft game from old version into current version
- migrate roles/actions of stanford_town game from old version into current version

View file

@ -0,0 +1,13 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
from metagpt.environment.software_env.software_env import SoftwareEnv
__all__ = ["AndroidEnv", "MincraftExtEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,13 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : MG Android Env
from pydantic import Field
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
from metagpt.environment.base_env import Environment
class AndroidEnv(Environment, AndroidExtEnv):
rows: int = Field(default=0, description="rows of a grid on the screenshot")
cols: int = Field(default=0, description="cols of a grid on the screenshot")

View file

@ -0,0 +1,157 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps
import subprocess
from pathlib import Path
from typing import Any, Optional
from pydantic import Field
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
class AndroidExtEnv(ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
def __init__(self, **data: Any):
super().__init__(**data)
if data.get("device_id"):
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)
@property
def adb_prefix_si(self):
"""adb cmd prefix with `device_id` and `shell input`"""
return f"adb -s {self.device_id} shell input "
@property
def adb_prefix_shell(self):
"""adb cmd prefix with `device_id` and `shell`"""
return f"adb -s {self.device_id} shell "
@property
def adb_prefix(self):
"""adb cmd prefix with `device_id`"""
return f"adb -s {self.device_id} "
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
exec_res = ADB_EXEC_FAIL
if not res.returncode:
exec_res = res.stdout.strip()
return exec_res
@property
def device_shape(self) -> tuple[int, int]:
adb_cmd = f"{self.adb_prefix_shell} wm size"
shape = (0, 0)
shape_res = self.execute_adb_with_cmd(adb_cmd)
if shape_res != ADB_EXEC_FAIL:
shape = tuple(map(int, shape_res.split(": ")[1].split("x")))
return shape
def list_devices(self):
adb_cmd = "adb devices"
res = self.execute_adb_with_cmd(adb_cmd)
devices = []
if res != ADB_EXEC_FAIL:
devices = res.split("\n")[1:]
devices = [device.split()[0] for device in devices]
return devices
@mark_as_readable
def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path:
"""
ss_name: screenshot file name
local_save_dir: local dir to store image from virtual machine
"""
assert self.screenshot_dir
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
res = ADB_EXEC_FAIL
if ss_res != ADB_EXEC_FAIL:
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
pull_res = self.execute_adb_with_cmd(pull_cmd)
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
return Path(res)
@mark_as_readable
def get_xml(self, xml_name: str, local_save_dir: Path) -> Path:
xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml")
dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}"
xml_res = self.execute_adb_with_cmd(dump_cmd)
res = ADB_EXEC_FAIL
if xml_res != ADB_EXEC_FAIL:
xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml")
pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}"
pull_res = self.execute_adb_with_cmd(pull_cmd)
if pull_res != ADB_EXEC_FAIL:
res = xml_local_path
return Path(res)
@mark_as_writeable
def system_back(self) -> str:
adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK"
back_res = self.execute_adb_with_cmd(adb_cmd)
return back_res
@mark_as_writeable
def system_tap(self, x: int, y: int) -> str:
adb_cmd = f"{self.adb_prefix_si} tap {x} {y}"
tap_res = self.execute_adb_with_cmd(adb_cmd)
return tap_res
@mark_as_writeable
def user_input(self, input_txt: str) -> str:
input_txt = input_txt.replace(" ", "%s").replace("'", "")
adb_cmd = f"{self.adb_prefix_si} text {input_txt}"
input_res = self.execute_adb_with_cmd(adb_cmd)
return input_res
@mark_as_writeable
def user_longpress(self, x: int, y: int, duration: int = 500) -> str:
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}"
press_res = self.execute_adb_with_cmd(adb_cmd)
return press_res
@mark_as_writeable
def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str:
dist_unit = int(self.width / 10)
if dist == "long":
dist_unit *= 3
elif dist == "medium":
dist_unit *= 2
if orient == "up":
offset = 0, -2 * dist_unit
elif orient == "down":
offset = 0, 2 * dist_unit
elif orient == "left":
offset = -1 * dist_unit, 0
elif orient == "right":
offset = dist_unit, 0
else:
return ADB_EXEC_FAIL
duration = 100 if if_quick else 400
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res
@mark_as_writeable
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# For Android Assistant Agent
ADB_EXEC_FAIL = "FAILED"

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the environment api store
from typing import Any, Callable, Union
from pydantic import BaseModel, Field
class EnvAPIAbstract(BaseModel):
"""api/interface summary description"""
api_name: str = Field(default="", description="the api function name or id")
args: set = Field(default={}, description="the api function `args` params")
kwargs: dict = Field(default=dict(), description="the api function `kwargs` params")
class EnvAPIRegistry(BaseModel):
"""the registry to store environment w&r api/interface"""
registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True)
def get(self, api_name: str):
if api_name not in self.registry:
raise ValueError
return self.registry.get(api_name)
def __getitem__(self, api_name: str) -> Callable:
return self.get(api_name)
def __setitem__(self, api_name: str, func: Callable):
self.registry[api_name] = func
def __len__(self):
return len(self.registry)
def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]:
"""return func schema without func instance"""
apis = dict()
for func_name, func_schema in self.registry.items():
new_func_schema = dict()
for key, value in func_schema.items():
if key == "func":
continue
new_func_schema[key] = str(value) if as_str else value
new_func_schema = new_func_schema
apis[func_name] = new_func_schema
return apis
class WriteAPIRegistry(EnvAPIRegistry):
"""just as a explicit class name"""
pass
class ReadAPIRegistry(EnvAPIRegistry):
"""just as a explicit class name"""
pass

View file

@ -1,29 +1,99 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 22:12
@Author : alexanderwu
@File : environment.py
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116:
1. Remove the functionality of `Environment` class as a public message buffer.
2. Standardize the message forwarding behavior of the `Environment` class.
3. Add the `is_idle` property.
@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing
functionality is to be consolidated into the `Environment` class.
"""
# @Desc : base env of executing environment
import asyncio
from typing import Iterable, Set
from enum import Enum
from typing import Any, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_send_to
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
class Environment(BaseModel):
class EnvType(Enum):
ANDROID = "Android"
GYM = "Gym"
WEREWOLF = "Werewolf"
MINCRAFT = "Mincraft"
STANFORDTOWN = "StanfordTown"
env_write_api_registry = WriteAPIRegistry()
env_read_api_registry = ReadAPIRegistry()
def mark_as_readable(func):
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
env_read_api_registry[func.__name__] = get_function_schema(func)
return func
def mark_as_writeable(func):
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
raise ValueError(f"{rw_api} not exists")
def get_all_available_apis(self, mode: str = "read") -> list[Any]:
"""get available read/write apis definition"""
assert mode in ["read", "write"]
if mode == "read":
return env_read_api_registry.get_apis()
else:
return env_write_api_registry.get_apis()
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self)
else:
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self, *env_action.args, **env_action.kwargs)
else:
res = read_api(self, *env_action.args, **env_action.kwargs)
return res
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
"""execute through particular api of ExtEnv"""
res = None
if isinstance(env_action, Message):
self.publish_message(env_action)
elif isinstance(env_action, EnvAPIAbstract):
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
if is_coroutine_func(write_api):
res = await write_api(self, *env_action.args, **env_action.kwargs)
else:
res = write_api(self, *env_action.args, **env_action.kwargs)
return res
class Environment(ExtEnv):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
"""

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.const import METAGPT_ROOT
# For Mincraft Game Agent
MC_CKPT_DIR = METAGPT_ROOT / "data/mincraft/ckpt"
MC_LOG_DIR = METAGPT_ROOT / "logs"
MC_DEFAULT_WARMUP = {
"context": 15,
"biome": 10,
"time": 15,
"nearby_blocks": 0,
"other_blocks": 10,
"nearby_entities": 5,
"health": 15,
"hunger": 15,
"position": 0,
"equipment": 0,
"inventory": 0,
"optional_inventory_items": 7,
"chests": 0,
"completed_tasks": 0,
"failed_tasks": 0,
}
MC_CURRICULUM_OB = [
"context",
"biome",
"time",
"nearby_blocks",
"other_blocks",
"nearby_entities",
"health",
"hunger",
"position",
"equipment",
"inventory",
"chests",
"completed_tasks",
"failed_tasks",
]
MC_CORE_INVENTORY_ITEMS = r".*_log|.*_planks|stick|crafting_table|furnace"
r"|cobblestone|dirt|coal|.*_pickaxe|.*_sword|.*_axe", # curriculum_agent: only show these items in inventory before optional_inventory_items reached in warm up

View file

@ -0,0 +1,391 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : MG Mincraft Env
# refs to `voyager voyager.py`
import json
import re
import time
from typing import Any, Iterable
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
from metagpt.environment.base_env import Environment
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
from metagpt.logs import logger
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
class MincraftEnv(Environment, MincraftExtEnv):
"""MincraftEnv, including shared memory of cache and infomation between roles"""
model_config = ConfigDict(arbitrary_types_allowed=True)
event: dict[str, Any] = Field(default_factory=dict)
current_task: str = Field(default="Mine 1 wood log")
task_execution_time: float = Field(default=float)
context: str = Field(default="You can mine one of oak, birch, spruce, jungle, acacia, dark oak, or mangrove logs.")
code: str = Field(default="")
program_code: str = Field(default="") # write in skill/code/*.js
program_name: str = Field(default="")
critique: str = Field(default="")
skills: dict = Field(default_factory=dict) # for skills.json
retrieve_skills: list[str] = Field(default_factory=list)
event_summary: str = Field(default="")
qa_cache: dict[str, str] = Field(default_factory=dict)
completed_tasks: list[str] = Field(default_factory=list) # Critique things
failed_tasks: list[str] = Field(default_factory=list)
skill_desp: str = Field(default="")
chest_memory: dict[str, Any] = Field(default_factory=dict) # eg: {'(1344, 64, 1381)': 'Unknown'}
chest_observation: str = Field(default="") # eg: "Chests: None\n\n"
runtime_status: bool = False # equal to action execution status: success or failed
vectordb: Chroma = Field(default_factory=Chroma)
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
@property
def progress(self):
# return len(self.completed_tasks) + 10 # Test only
return len(self.completed_tasks)
@property
def programs(self):
programs = ""
if self.code == "":
return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager
for skill_name, entry in self.skills.items():
programs += f"{entry['code']}\n\n"
for primitives in load_mc_skills_code(): # TODO add skills_dir
programs += f"{primitives}\n\n"
return programs
def set_mc_port(self, mc_port):
super().set_mc_port(mc_port)
self.set_mc_resume()
def set_mc_resume(self):
self.qa_cache_questions_vectordb = Chroma(
collection_name="qa_cache_questions_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb",
)
self.vectordb = Chroma(
collection_name="skill_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/skill/vectordb",
)
if CONFIG.resume:
logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action")
self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json")
logger.info(f"Loading Curriculum Agent from {MC_CKPT_DIR}/curriculum")
self.completed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json")
self.failed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json")
logger.info(f"Loading Skill Manager from {MC_CKPT_DIR}/skill\033[0m")
self.skills = read_json_file(f"{MC_CKPT_DIR}/skill/skills.json")
logger.info(f"Loading Qa Cache from {MC_CKPT_DIR}/curriculum\033[0m")
self.qa_cache = read_json_file(f"{MC_CKPT_DIR}/curriculum/qa_cache.json")
if self.vectordb._collection.count() == 0:
logger.info(self.vectordb._collection.count())
# Set vdvs for skills & qa_cache
skill_desps = [skill["description"] for program_name, skill in self.skills.items()]
program_names = [program_name for program_name, skill in self.skills.items()]
metadatas = [{"name": program_name} for program_name in program_names]
# add vectordb from file
self.vectordb.add_texts(
texts=skill_desps,
ids=program_names,
metadatas=metadatas,
)
self.vectordb.persist()
logger.info(self.qa_cache_questions_vectordb._collection.count())
if self.qa_cache_questions_vectordb._collection.count() == 0:
questions = [question for question, answer in self.qa_cache.items()]
self.qa_cache_questions_vectordb.add_texts(texts=questions)
self.qa_cache_questions_vectordb.persist()
logger.info(
f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json."
)
# Check if Skill Manager's vectordb right using
assert self.vectordb._collection.count() == len(self.skills), (
f"Skill Manager's vectordb is not synced with skills.json.\n"
f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n"
f"Did you set resume=False when initializing the manager?\n"
f"You may need to manually delete the vectordb directory for running from scratch."
)
logger.info(
f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json."
)
assert self.qa_cache_questions_vectordb._collection.count() == len(self.qa_cache), (
f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n"
f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb "
f"but {len(self.qa_cache)} questions in qa_cache.json.\n"
f"Did you set resume=False when initializing the agent?\n"
f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n"
)
def register_roles(self, roles: Iterable["Minecraft"]):
for role in roles:
role.set_memory(self)
def update_event(self, event: dict):
if self.event == event:
return
self.event = event
self.update_chest_memory(event)
self.update_chest_observation()
# self.event_summary = self.summarize_chatlog(event)
def update_task(self, task: str):
self.current_task = task
def update_context(self, context: str):
self.context = context
def update_program_code(self, program_code: str):
self.program_code = program_code
def update_code(self, code: str):
self.code = code # action_developer.gen_action_code to HERE
def update_program_name(self, program_name: str):
self.program_name = program_name
def update_critique(self, critique: str):
self.critique = critique # critic_agent.check_task_success to HERE
def append_skill(self, skill: dict):
self.skills[self.program_name] = skill # skill_manager.retrieve_skills to HERE
def update_retrieve_skills(self, retrieve_skills: list):
self.retrieve_skills = retrieve_skills
def update_skill_desp(self, skill_desp: str):
self.skill_desp = skill_desp
async def update_qa_cache(self, qa_cache: dict):
self.qa_cache = qa_cache
def update_chest_memory(self, events: dict):
"""
Input: events: Dict
Result: self.chest_memory update & save to json
"""
nearbyChests = events[-1][1]["nearbyChests"]
for position, chest in nearbyChests.items():
if position in self.chest_memory:
if isinstance(chest, dict):
self.chest_memory[position] = chest
if chest == "Invalid":
logger.info(f"Action Developer removing chest {position}: {chest}")
self.chest_memory.pop(position)
else:
if chest != "Invalid":
logger.info(f"Action Developer saving chest {position}: {chest}")
self.chest_memory[position] = chest
write_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json", self.chest_memory)
def update_chest_observation(self):
"""
update chest_memory to chest_observation.
Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py
"""
chests = []
for chest_position, chest in self.chest_memory.items():
if isinstance(chest, dict) and len(chest) > 0:
chests.append(f"{chest_position}: {chest}")
for chest_position, chest in self.chest_memory.items():
if isinstance(chest, dict) and len(chest) == 0:
chests.append(f"{chest_position}: Empty")
for chest_position, chest in self.chest_memory.items():
if isinstance(chest, str):
assert chest == "Unknown"
chests.append(f"{chest_position}: Unknown items inside")
assert len(chests) == len(self.chest_memory)
if chests:
chests = "\n".join(chests)
self.chest_observation = f"Chests:\n{chests}\n\n"
else:
self.chest_observation = "Chests: None\n\n"
def summarize_chatlog(self, events):
def filter_item(message: str):
craft_pattern = r"I cannot make \w+ because I need: (.*)"
craft_pattern2 = r"I cannot make \w+ because there is no crafting table nearby"
mine_pattern = r"I need at least a (.*) to mine \w+!"
if re.match(craft_pattern, message):
self.event_summary = re.match(craft_pattern, message).groups()[0]
elif re.match(craft_pattern2, message):
self.event_summary = "a nearby crafting table"
elif re.match(mine_pattern, message):
self.event_summary = re.match(mine_pattern, message).groups()[0]
else:
self.event_summary = ""
return self.event_summary
chatlog = set()
for event_type, event in events:
if event_type == "onChat":
item = filter_item(event["onChat"])
if item:
chatlog.add(item)
self.event_summary = "I also need " + ", ".join(chatlog) + "." if chatlog else ""
def reset_block_info(self):
# revert all the placing event in the last step
pass
def update_exploration_progress(self, success: bool):
"""
Split task into completed_tasks or failed_tasks
Args: info = {
"task": self.task,
"success": success,
"conversations": self.conversations,
}
"""
self.runtime_status = success
task = self.current_task
if task.startswith("Deposit useless items into the chest at"):
return
if success:
logger.info(f"Completed task {task}.")
self.completed_tasks.append(task)
else:
logger.info(f"Failed to complete task {task}. Skipping to next task.")
self.failed_tasks.append(task)
# when not success, below to update event!
# revert all the placing event in the last step
blocks = []
positions = []
for event_type, event in self.event:
if event_type == "onSave" and event["onSave"].endswith("_placed"):
block = event["onSave"].split("_placed")[0]
position = event["status"]["position"]
blocks.append(block)
positions.append(position)
new_events = self.step(
f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})",
programs=self.programs,
)
self.event[-1][1]["inventory"] = new_events[-1][1]["inventory"]
self.event[-1][1]["voxels"] = new_events[-1][1]["voxels"]
self.save_sorted_tasks()
def save_sorted_tasks(self):
updated_completed_tasks = []
# record repeated failed tasks
updated_failed_tasks = self.failed_tasks
# dedup but keep order
for task in self.completed_tasks:
if task not in updated_completed_tasks:
updated_completed_tasks.append(task)
# remove completed tasks from failed tasks
for task in updated_completed_tasks:
while task in updated_failed_tasks:
updated_failed_tasks.remove(task)
self.completed_tasks = updated_completed_tasks
self.failed_tasks = updated_failed_tasks
# dump to json
write_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json", self.completed_tasks)
write_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json", self.failed_tasks)
async def on_event_retrieve(self, *args):
"""
Retrieve Minecraft events.
Returns:
list: A list of Minecraft events.
Raises:
Exception: If there is an issue retrieving events.
"""
try:
self.reset(
options={
"mode": "soft",
"wait_ticks": 20,
}
)
# difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful"
difficulty = "peaceful"
events = self.step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');")
self.update_event(events)
return events
except Exception as e:
time.sleep(3) # wait for mineflayer to exit
# reset bot status here
events = self.reset(
options={
"mode": "hard",
"wait_ticks": 20,
"inventory": self.event[-1][1]["inventory"],
"equipment": self.event[-1][1]["status"]["equipment"],
"position": self.event[-1][1]["status"]["position"],
}
)
self.update_event(events)
logger.error(f"Failed to retrieve Minecraft events: {str(e)}")
return events
async def on_event_execute(self, *args):
"""
Execute Minecraft events.
This function is used to obtain events from the Minecraft environment. Check the implementation in
the 'voyager/env/bridge.py step()' function to capture events generated within the game.
Returns:
list: A list of Minecraft events.
Raises:
Exception: If there is an issue retrieving events.
"""
try:
events = self.step(
code=self.code,
programs=self.programs,
)
self.update_event(events)
return events
except Exception as e:
time.sleep(3) # wait for mineflayer to exit
# reset bot status here
events = self.reset(
options={
"mode": "hard",
"wait_ticks": 20,
"inventory": self.event[-1][1]["inventory"],
"equipment": self.event[-1][1]["status"]["equipment"],
"position": self.event[-1][1]["status"]["position"],
}
)
self.update_event(events)
logger.error(f"Failed to execute Minecraft events: {str(e)}")
return events

View file

@ -0,0 +1,180 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Mincraft external environment to integrate with Mincraft game
# refs to `voyager bridge.py`
import json
import time
from typing import Optional
import requests
from pydantic import ConfigDict, Field, model_validator
from metagpt.environment.base_env import ExtEnv, mark_as_writeable
from metagpt.environment.mincraft_env.const import (
MC_CKPT_DIR,
MC_CORE_INVENTORY_ITEMS,
MC_CURRICULUM_OB,
MC_DEFAULT_WARMUP,
METAGPT_ROOT,
)
from metagpt.environment.mincraft_env.process_monitor import SubprocessMonitor
from metagpt.logs import logger
class MincraftExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
mc_port: Optional[int] = Field(default=None)
server_host: str = Field(default="http://127.0.0.1")
server_port: str = Field(default=3000)
request_timeout: int = Field(default=600)
mineflayer: Optional[SubprocessMonitor] = Field(default=None, validate_default=True)
has_reset: bool = Field(default=False)
reset_options: Optional[dict] = Field(default=None)
connected: bool = Field(default=False)
server_paused: bool = Field(default=False)
warm_up: dict = Field(default=dict())
@property
def server(self) -> str:
return f"{self.server_host}:{self.server_port}"
@model_validator(mode="after")
def _post_init_ext_env(self):
if not self.mineflayer:
self.mineflayer = SubprocessMonitor(
commands=[
"node",
METAGPT_ROOT.joinpath("metagpt", "environment", "mincraft_env", "mineflayer", "index.js"),
str(self.server_port),
],
name="mineflayer",
ready_match=r"Server started on port (\d+)",
)
if not self.warm_up:
warm_up = MC_DEFAULT_WARMUP
if "optional_inventory_items" in warm_up:
assert MC_CORE_INVENTORY_ITEMS is not None
# self.core_inv_items_regex = re.compile(MC_CORE_INVENTORY_ITEMS)
self.warm_up["optional_inventory_items"] = warm_up["optional_inventory_items"]
else:
self.warm_up["optional_inventory_items"] = 0
for key in MC_CURRICULUM_OB:
self.warm_up[key] = warm_up.get(key, MC_DEFAULT_WARMUP[key])
self.warm_up["nearby_blocks"] = 0
self.warm_up["inventory"] = 0
self.warm_up["completed_tasks"] = 0
self.warm_up["failed_tasks"] = 0
# init ckpt sub-forders
MC_CKPT_DIR.joinpath("curriculum/vectordb").mkdir(parents=True, exist_ok=True)
MC_CKPT_DIR.joinpath("action").mkdir(exist_ok=True)
MC_CKPT_DIR.joinpath("skill/code").mkdir(parents=True, exist_ok=True)
MC_CKPT_DIR.joinpath("skill/description").mkdir(exist_ok=True)
MC_CKPT_DIR.joinpath("skill/vectordb").mkdir(exist_ok=True)
def set_mc_port(self, mc_port: int):
self.mc_port = mc_port
@mark_as_writeable
def close(self) -> bool:
self.unpause()
if self.connected:
res = requests.post(f"{self.server}/stop")
if res.status_code == 200:
self.connected = False
self.mineflayer.stop()
return not self.connected
@mark_as_writeable
def check_process(self) -> dict:
retry = 0
while not self.mineflayer.is_running:
logger.info("Mineflayer process has exited, restarting")
self.mineflayer.run()
if not self.mineflayer.is_running:
if retry > 3:
logger.error("Mineflayer process failed to start")
raise {}
else:
retry += 1
continue
logger.info(self.mineflayer.ready_line)
res = requests.post(
f"{self.server}/start",
json=self.reset_options,
timeout=self.request_timeout,
)
if res.status_code != 200:
self.mineflayer.stop()
logger.error(f"Minecraft server reply with code {res.status_code}")
raise {}
return res.json()
@mark_as_writeable
def reset(self, *, seed=None, options=None) -> dict:
if options is None:
options = {}
if options.get("inventory", {}) and options.get("mode", "hard") != "hard":
logger.error("inventory can only be set when options is hard")
raise {}
self.reset_options = {
"port": self.mc_port,
"reset": options.get("mode", "hard"),
"inventory": options.get("inventory", {}),
"equipment": options.get("equipment", []),
"spread": options.get("spread", False),
"waitTicks": options.get("wait_ticks", 5),
"position": options.get("position", None),
}
self.unpause()
self.mineflayer.stop()
time.sleep(1) # wait for mineflayer to exit
returned_data = self.check_process()
self.has_reset = True
self.connected = True
# All the reset in step will be soft
self.reset_options["reset"] = "soft"
self.pause()
return json.loads(returned_data)
@mark_as_writeable
def step(self, code: str, programs: str = "") -> dict:
if not self.has_reset:
raise RuntimeError("Environment has not been reset yet")
self.check_process()
self.unpause()
data = {
"code": code,
"programs": programs,
}
res = requests.post(f"{self.server}/step", json=data, timeout=self.request_timeout)
if res.status_code != 200:
raise RuntimeError("Failed to step Minecraft server")
returned_data = res.json()
self.pause()
return json.loads(returned_data)
@mark_as_writeable
def pause(self) -> bool:
if self.mineflayer.is_running and not self.server_paused:
res = requests.post(f"{self.server}/pause")
if res.status_code == 200:
self.server_paused = True
return self.server_paused
@mark_as_writeable
def unpause(self) -> bool:
if self.mineflayer.is_running and self.server_paused:
res = requests.post(f"{self.server}/pause")
if res.status_code == 200:
self.server_paused = False
else:
logger.info(f"mineflayer pause result: {res.json()}")
return self.server_paused

View file

@ -0,0 +1 @@
!/lib

View file

@ -0,0 +1,3 @@
# Ignore artifacts:
build
coverage

View file

@ -0,0 +1,3 @@
{
"tabWidth": 4
}

View file

@ -0,0 +1,425 @@
const fs = require("fs");
const express = require("express");
const bodyParser = require("body-parser");
const mineflayer = require("mineflayer");
const skills = require("./lib/skillLoader");
const { initCounter, getNextTime } = require("./lib/utils");
const obs = require("./lib/observation/base");
const OnChat = require("./lib/observation/onChat");
const OnError = require("./lib/observation/onError");
const { Voxels, BlockRecords } = require("./lib/observation/voxels");
const Status = require("./lib/observation/status");
const Inventory = require("./lib/observation/inventory");
const OnSave = require("./lib/observation/onSave");
const Chests = require("./lib/observation/chests");
const { plugin: tool } = require("mineflayer-tool");
let bot = null;
const app = express();
app.use(bodyParser.json({ limit: "50mb" }));
app.use(bodyParser.urlencoded({ limit: "50mb", extended: false }));
app.post("/start", (req, res) => {
if (bot) onDisconnect("Restarting bot");
bot = null;
console.log(req.body);
bot = mineflayer.createBot({
host: "localhost", // minecraft server ip
port: req.body.port, // minecraft server port
username: "bot",
disableChatSigning: true,
checkTimeoutInterval: 60 * 60 * 1000,
});
bot.once("error", onConnectionFailed);
// Event subscriptions
bot.waitTicks = req.body.waitTicks;
bot.globalTickCounter = 0;
bot.stuckTickCounter = 0;
bot.stuckPosList = [];
bot.iron_pickaxe = false;
bot.on("kicked", onDisconnect);
// mounting will cause physicsTick to stop
bot.on("mount", () => {
bot.dismount();
});
bot.once("spawn", async () => {
bot.removeListener("error", onConnectionFailed);
let itemTicks = 1;
if (req.body.reset === "hard") {
bot.chat("/clear @s");
bot.chat("/kill @s");
const inventory = req.body.inventory ? req.body.inventory : {};
const equipment = req.body.equipment
? req.body.equipment
: [null, null, null, null, null, null];
for (let key in inventory) {
bot.chat(`/give @s minecraft:${key} ${inventory[key]}`);
itemTicks += 1;
}
const equipmentNames = [
"armor.head",
"armor.chest",
"armor.legs",
"armor.feet",
"weapon.mainhand",
"weapon.offhand",
];
for (let i = 0; i < 6; i++) {
if (i === 4) continue;
if (equipment[i]) {
bot.chat(
`/item replace entity @s ${equipmentNames[i]} with minecraft:${equipment[i]}`
);
itemTicks += 1;
}
}
}
if (req.body.position) {
bot.chat(
`/tp @s ${req.body.position.x} ${req.body.position.y} ${req.body.position.z}`
);
}
// if iron_pickaxe is in bot's inventory
if (
bot.inventory.items().find((item) => item.name === "iron_pickaxe")
) {
bot.iron_pickaxe = true;
}
const { pathfinder } = require("mineflayer-pathfinder");
const tool = require("mineflayer-tool").plugin;
const collectBlock = require("mineflayer-collectblock").plugin;
const pvp = require("mineflayer-pvp").plugin;
const minecraftHawkEye = require("minecrafthawkeye");
bot.loadPlugin(pathfinder);
bot.loadPlugin(tool);
bot.loadPlugin(collectBlock);
bot.loadPlugin(pvp);
bot.loadPlugin(minecraftHawkEye);
// bot.collectBlock.movements.digCost = 0;
// bot.collectBlock.movements.placeCost = 0;
obs.inject(bot, [
OnChat,
OnError,
Voxels,
Status,
Inventory,
OnSave,
Chests,
BlockRecords,
]);
skills.inject(bot);
if (req.body.spread) {
bot.chat(`/spreadplayers ~ ~ 0 300 under 80 false @s`);
await bot.waitForTicks(bot.waitTicks);
}
await bot.waitForTicks(bot.waitTicks * itemTicks);
res.json(bot.observe());
initCounter(bot);
bot.chat("/gamerule keepInventory true");
bot.chat("/gamerule doDaylightCycle false");
});
function onConnectionFailed(e) {
console.log(e);
bot = null;
res.status(400).json({ error: e });
}
function onDisconnect(message) {
if (bot.viewer) {
bot.viewer.close();
}
bot.end();
console.log(message);
bot = null;
}
});
app.post("/step", async (req, res) => {
// import useful package
let response_sent = false;
function otherError(err) {
console.log("Uncaught Error");
bot.emit("error", handleError(err));
bot.waitForTicks(bot.waitTicks).then(() => {
if (!response_sent) {
response_sent = true;
res.json(bot.observe());
}
});
}
process.on("uncaughtException", otherError);
const mcData = require("minecraft-data")(bot.version);
mcData.itemsByName["leather_cap"] = mcData.itemsByName["leather_helmet"];
mcData.itemsByName["leather_tunic"] =
mcData.itemsByName["leather_chestplate"];
mcData.itemsByName["leather_pants"] =
mcData.itemsByName["leather_leggings"];
mcData.itemsByName["leather_boots"] = mcData.itemsByName["leather_boots"];
mcData.itemsByName["lapis_lazuli_ore"] = mcData.itemsByName["lapis_ore"];
mcData.blocksByName["lapis_lazuli_ore"] = mcData.blocksByName["lapis_ore"];
const {
Movements,
goals: {
Goal,
GoalBlock,
GoalNear,
GoalXZ,
GoalNearXZ,
GoalY,
GoalGetToBlock,
GoalLookAtBlock,
GoalBreakBlock,
GoalCompositeAny,
GoalCompositeAll,
GoalInvert,
GoalFollow,
GoalPlaceBlock,
},
pathfinder,
Move,
ComputedPath,
PartiallyComputedPath,
XZCoordinates,
XYZCoordinates,
SafeBlock,
GoalPlaceBlockOptions,
} = require("mineflayer-pathfinder");
const { Vec3 } = require("vec3");
// Set up pathfinder
const movements = new Movements(bot, mcData);
bot.pathfinder.setMovements(movements);
bot.globalTickCounter = 0;
bot.stuckTickCounter = 0;
bot.stuckPosList = [];
function onTick() {
bot.globalTickCounter++;
if (bot.pathfinder.isMoving()) {
bot.stuckTickCounter++;
if (bot.stuckTickCounter >= 100) {
onStuck(1.5);
bot.stuckTickCounter = 0;
}
}
}
bot.on("physicTick", onTick);
// initialize fail count
let _craftItemFailCount = 0;
let _killMobFailCount = 0;
let _mineBlockFailCount = 0;
let _placeItemFailCount = 0;
let _smeltItemFailCount = 0;
// Retrieve array form post bod
const code = req.body.code;
const programs = req.body.programs;
bot.cumulativeObs = [];
await bot.waitForTicks(bot.waitTicks);
const r = await evaluateCode(code, programs);
process.off("uncaughtException", otherError);
if (r !== "success") {
bot.emit("error", handleError(r));
}
await returnItems();
// wait for last message
await bot.waitForTicks(bot.waitTicks);
if (!response_sent) {
response_sent = true;
res.json(bot.observe());
}
bot.removeListener("physicTick", onTick);
async function evaluateCode(code, programs) {
// Echo the code produced for players to see it. Don't echo when the bot code is already producing dialog or it will double echo
try {
await eval("(async () => {" + programs + "\n" + code + "})()");
return "success";
} catch (err) {
return err;
}
}
function onStuck(posThreshold) {
const currentPos = bot.entity.position;
bot.stuckPosList.push(currentPos);
// Check if the list is full
if (bot.stuckPosList.length === 5) {
const oldestPos = bot.stuckPosList[0];
const posDifference = currentPos.distanceTo(oldestPos);
if (posDifference < posThreshold) {
teleportBot(); // execute the function
}
// Remove the oldest time from the list
bot.stuckPosList.shift();
}
}
function teleportBot() {
const blocks = bot.findBlocks({
matching: (block) => {
return block.type === 0;
},
maxDistance: 1,
count: 27,
});
if (blocks) {
// console.log(blocks.length);
const randomIndex = Math.floor(Math.random() * blocks.length);
const block = blocks[randomIndex];
bot.chat(`/tp @s ${block.x} ${block.y} ${block.z}`);
} else {
bot.chat("/tp @s ~ ~1.25 ~");
}
}
function returnItems() {
bot.chat("/gamerule doTileDrops false");
const crafting_table = bot.findBlock({
matching: mcData.blocksByName.crafting_table.id,
maxDistance: 128,
});
if (crafting_table) {
bot.chat(
`/setblock ${crafting_table.position.x} ${crafting_table.position.y} ${crafting_table.position.z} air destroy`
);
bot.chat("/give @s crafting_table");
}
const furnace = bot.findBlock({
matching: mcData.blocksByName.furnace.id,
maxDistance: 128,
});
if (furnace) {
bot.chat(
`/setblock ${furnace.position.x} ${furnace.position.y} ${furnace.position.z} air destroy`
);
bot.chat("/give @s furnace");
}
if (bot.inventoryUsed() >= 32) {
// if chest is not in bot's inventory
if (!bot.inventory.items().find((item) => item.name === "chest")) {
bot.chat("/give @s chest");
}
}
// if iron_pickaxe not in bot's inventory and bot.iron_pickaxe
if (
bot.iron_pickaxe &&
!bot.inventory.items().find((item) => item.name === "iron_pickaxe")
) {
bot.chat("/give @s iron_pickaxe");
}
bot.chat("/gamerule doTileDrops true");
}
function handleError(err) {
let stack = err.stack;
if (!stack) {
return err;
}
console.log(stack);
const final_line = stack.split("\n")[1];
const regex = /<anonymous>:(\d+):\d+\)/;
const programs_length = programs.split("\n").length;
let match_line = null;
for (const line of stack.split("\n")) {
const match = regex.exec(line);
if (match) {
const line_num = parseInt(match[1]);
if (line_num >= programs_length) {
match_line = line_num - programs_length;
break;
}
}
}
if (!match_line) {
return err.message;
}
let f_line = final_line.match(
/\((?<file>.*):(?<line>\d+):(?<pos>\d+)\)/
);
if (f_line && f_line.groups && fs.existsSync(f_line.groups.file)) {
const { file, line, pos } = f_line.groups;
const f = fs.readFileSync(file, "utf8").split("\n");
// let filename = file.match(/(?<=node_modules\\)(.*)/)[1];
let source = file + `:${line}\n${f[line - 1].trim()}\n `;
const code_source =
"at " +
code.split("\n")[match_line - 1].trim() +
" in your code";
return source + err.message + "\n" + code_source;
} else if (
f_line &&
f_line.groups &&
f_line.groups.file.includes("<anonymous>")
) {
const { file, line, pos } = f_line.groups;
let source =
"Your code" +
`:${match_line}\n${code.split("\n")[match_line - 1].trim()}\n `;
let code_source = "";
if (line < programs_length) {
source =
"In your program code: " +
programs.split("\n")[line - 1].trim() +
"\n";
code_source = `at line ${match_line}:${code
.split("\n")
[match_line - 1].trim()} in your code`;
}
return source + err.message + "\n" + code_source;
}
return err.message;
}
});
app.post("/stop", (req, res) => {
bot.end();
res.json({
message: "Bot stopped",
});
});
app.post("/pause", (req, res) => {
if (!bot) {
res.status(400).json({ error: "Bot not spawned" });
return;
}
bot.chat("/pause");
bot.waitForTicks(bot.waitTicks).then(() => {
res.json({ message: "Success" });
});
});
// Server listening to PORT 3000
const DEFAULT_PORT = 3000;
const PORT = process.argv[2] || DEFAULT_PORT;
app.listen(PORT, () => {
console.log(`Server started on port ${PORT}`);
});

View file

@ -0,0 +1,45 @@
class Observation {
constructor(bot) {
if (new.target === Observation) {
throw new TypeError(
"Cannot instantiate abstract class Observation"
);
}
this.bot = bot;
this.name = "Observation";
}
observe() {
throw new TypeError("Method 'observe()' must be implemented.");
}
reset() {}
}
function inject(bot, obs_list) {
bot.obsList = [];
bot.cumulativeObs = [];
bot.eventMemory = {};
obs_list.forEach((obs) => {
bot.obsList.push(new obs(bot));
});
bot.event = function (event_name) {
let result = {};
bot.obsList.forEach((obs) => {
if (obs.name.startsWith("on") && obs.name !== event_name) {
return;
}
result[obs.name] = obs.observe();
});
bot.cumulativeObs.push([event_name, result]);
};
bot.observe = function () {
bot.event("observe");
const result = bot.cumulativeObs;
bot.cumulativeObs = [];
return JSON.stringify(result);
};
}
module.exports = { Observation, inject };

View file

@ -0,0 +1,31 @@
const { Observation } = require("./base");
class Chests extends Observation {
constructor(bot) {
super(bot);
this.name = "nearbyChests";
this.chestsItems = {};
bot.on("closeChest", (chestItems, position) => {
this.chestsItems[position] = chestItems;
});
bot.on("removeChest", (chestPosition) => {
this.chestsItems[chestPosition] = "Invalid";
});
}
observe() {
const chests = this.bot.findBlocks({
matching: this.bot.registry.blocksByName.chest.id,
maxDistance: 16,
count: 999,
});
chests.forEach((chest) => {
if (!this.chestsItems.hasOwnProperty(chest)) {
this.chestsItems[chest] = "Unknown";
}
});
return this.chestsItems;
}
}
module.exports = Chests;

View file

@ -0,0 +1,39 @@
const { Observation } = require("./base");
class Inventory extends Observation {
constructor(bot) {
super(bot);
this.name = "inventory";
}
observe() {
return listItems(this.bot);
}
}
function listItems(bot) {
const items = getInventoryItems(bot);
return items.reduce(itemToDict, {});
}
function getInventoryItems(bot) {
const inventory = bot.currentWindow || bot.inventory;
return inventory.items();
}
function itemToDict(acc, cur) {
if (cur.name && cur.count) {
//if both name and count property are defined
if (acc[cur.name]) {
//if the item is already in the dict
acc[cur.name] += cur.count;
} else {
//if the item is not in the dict
acc[cur.name] = cur.count;
}
}
return acc;
}
//export modules
module.exports = Inventory;

View file

@ -0,0 +1,26 @@
const Observation = require("./base.js").Observation;
class onChat extends Observation {
constructor(bot) {
super(bot);
this.name = "onChat";
this.obs = "";
bot.on("chatEvent", (username, message) => {
// Save entity status to local variable
if (message.startsWith("/")) {
return;
}
this.obs += message;
this.bot.event(this.name);
});
}
observe() {
const result = this.obs;
this.obs = "";
return result;
}
}
module.exports = onChat;

View file

@ -0,0 +1,22 @@
const Observation = require("./base.js").Observation;
class onError extends Observation {
constructor(bot) {
super(bot);
this.name = "onError";
this.obs = null;
bot.on("error", (err) => {
// Save entity status to local variable
this.obs = err;
this.bot.event(this.name);
});
}
observe() {
const result = this.obs;
this.obs = null;
return result;
}
}
module.exports = onError;

View file

@ -0,0 +1,22 @@
const Observation = require("./base.js").Observation;
class onSave extends Observation {
constructor(bot) {
super(bot);
this.name = "onSave";
this.obs = null;
bot.on("save", (eventName) => {
// Save entity status to local variable
this.obs = eventName;
this.bot.event(this.name);
});
}
observe() {
const result = this.obs;
this.obs = null;
return result;
}
}
module.exports = onSave;

View file

@ -0,0 +1,103 @@
const Observation = require("./base.js").Observation;
class Status extends Observation {
constructor(bot) {
super(bot);
this.name = "status";
}
observe() {
return {
health: this.bot.health,
food: this.bot.food,
saturation: this.bot.foodSaturation,
oxygen: this.bot.oxygenLevel,
position: this.bot.entity.position,
velocity: this.bot.entity.velocity,
yaw: this.bot.entity.yaw,
pitch: this.bot.entity.pitch,
onGround: this.bot.entity.onGround,
equipment: this.getEquipment(),
name: this.bot.entity.username,
timeSinceOnGround: this.bot.entity.timeSinceOnGround,
isInWater: this.bot.entity.isInWater,
isInLava: this.bot.entity.isInLava,
isInWeb: this.bot.entity.isInWeb,
isCollidedHorizontally: this.bot.entity.isCollidedHorizontally,
isCollidedVertically: this.bot.entity.isCollidedVertically,
biome: this.bot.blockAt(this.bot.entity.position)
? this.bot.blockAt(this.bot.entity.position).biome.name
: "None",
entities: this.getEntities(),
timeOfDay: this.getTime(),
inventoryUsed: this.bot.inventoryUsed(),
elapsedTime: this.bot.globalTickCounter,
};
}
itemToObs(item) {
if (!item) return null;
return item.name;
}
getTime() {
const timeOfDay = this.bot.time.timeOfDay;
let time = "";
if (timeOfDay < 1000) {
time = "sunrise";
} else if (timeOfDay < 6000) {
time = "day";
} else if (timeOfDay < 12000) {
time = "noon";
} else if (timeOfDay < 13000) {
time = "sunset";
} else if (timeOfDay < 18000) {
time = "night";
} else if (timeOfDay < 22000) {
time = "midnight";
} else {
time = "sunrise";
}
return time;
}
// For each item in equipment, if it exists, return the name of the item
// otherwise return null
getEquipment() {
const slots = this.bot.inventory.slots;
const mainHand = this.bot.heldItem;
return slots
.slice(5, 9)
.concat(mainHand, slots[45])
.map(this.itemToObs);
}
getEntities() {
const entities = this.bot.entities;
if (!entities) return {};
// keep all monsters in one list, keep other mobs in another list
const mobs = {};
for (const id in entities) {
const entity = entities[id];
if (!entity.displayName) continue;
if (entity.name === "player" || entity.name === "item") continue;
if (entity.position.distanceTo(this.bot.entity.position) < 32) {
if (!mobs[entity.name]) {
mobs[entity.name] = entity.position.distanceTo(
this.bot.entity.position
);
} else if (
mobs[entity.name] >
entity.position.distanceTo(this.bot.entity.position)
) {
mobs[entity.name] = entity.position.distanceTo(
this.bot.entity.position
);
}
}
}
return mobs;
}
}
module.exports = Status;

View file

@ -0,0 +1,67 @@
// Blocks = require("./blocks")
const { Observation } = require("./base");
class Voxels extends Observation {
constructor(bot) {
super(bot);
this.name = "voxels";
}
observe() {
return Array.from(getSurroundingBlocks(this.bot, 8, 2, 8));
}
}
class BlockRecords extends Observation {
constructor(bot) {
super(bot);
this.name = "blockRecords";
this.records = new Set();
this.tick = 0;
bot.on("physicsTick", () => {
this.tick++;
if (this.tick >= 100) {
const items = getInventoryItems(this.bot);
getSurroundingBlocks(this.bot, 8, 2, 8).forEach((block) => {
if (!items.has(block)) this.records.add(block);
});
this.tick = 0;
}
});
}
observe() {
return Array.from(this.records);
}
reset() {
this.records = new Set();
}
}
function getSurroundingBlocks(bot, x_distance, y_distance, z_distance) {
const surroundingBlocks = new Set();
for (let x = -x_distance; x <= x_distance; x++) {
for (let y = -y_distance; y <= y_distance; y++) {
for (let z = -z_distance; z <= z_distance; z++) {
const block = bot.blockAt(bot.entity.position.offset(x, y, z));
if (block && block.type !== 0) {
surroundingBlocks.add(block.name);
}
}
}
}
// console.log(surroundingBlocks);
return surroundingBlocks;
}
function getInventoryItems(bot) {
const items = new Set();
bot.inventory.items().forEach((item) => {
if (item) items.add(item.name);
});
return items;
}
module.exports = { Voxels, BlockRecords };

View file

@ -0,0 +1,79 @@
function inject(bot) {
bot._sleep = bot.sleep;
bot.sleep = async (bedBlock) => {
await bot.waitForTicks(20);
await bot._sleep(bedBlock);
await bot.waitForTicks(135);
};
bot._fish = bot.fish;
bot.fish = async () => {
if (bot.heldItem?.name !== "fishing_rod") {
bot.chat("I'm not holding a fishing rod!");
return;
}
let timeout = null;
await Promise.race([
bot._fish(),
new Promise(
(resolve, reject) =>
(timeout = setTimeout(() => {
bot.activateItem();
reject(
new Error(
"Finishing timeout, make sure you get to and look at a water block!"
)
);
}, 60000))
),
]);
clearTimeout(timeout);
await bot.waitForTicks(20);
};
bot._consume = bot.consume;
bot.consume = async () => {
// action_count.activateItem++;
await bot._consume();
await bot.waitForTicks(20);
};
bot._useOn = bot.useOn;
bot.useOn = async (entity) => {
if (entity.position.distanceTo(bot.entity.position) > 6) {
bot.chat("Please goto a place near the entity first!");
return;
}
await bot._useOn(entity);
await bot.waitForTicks(20);
};
bot._activateBlock = bot.activateBlock;
bot.activateBlock = async (block) => {
if (block.position.distanceTo(bot.entity.position) > 6) {
bot.chat("Please goto a place near the block first!");
return;
}
// action_count.activateBlock++;
await bot._activateBlock(block);
};
bot._chat = bot.chat;
bot.chat = (message) => {
// action_count.chat++;
bot.emit("chatEvent", "bot", message);
bot._chat(message);
};
bot.inventoryUsed = () => {
return bot.inventory.slots.slice(9, 45).filter((item) => item !== null)
.length;
};
bot.save = function (eventName) {
bot.emit("save", eventName);
};
}
// export all control_primitives
module.exports = { inject };

View file

@ -0,0 +1,31 @@
let gameTimeCounter = 0;
let gameTimeList = [];
const initCounter = (bot) => {
gameTimeList = [];
for (let i = 0; i < 13000; i += 1000) {
gameTimeList.push(i);
}
for (let i = 13000; i < 24000; i += 2000) {
gameTimeList.push(i);
}
const timeOfDay = bot.time.timeOfDay;
for (let i = 0; i < gameTimeList.length; i++) {
if (gameTimeList[i] > timeOfDay) {
gameTimeCounter = i - 1;
break;
}
}
};
const getNextTime = () => {
gameTimeCounter++;
if (gameTimeCounter >= gameTimeList.length) {
gameTimeCounter = 0;
}
return gameTimeList[gameTimeCounter];
};
module.exports = {
initCounter,
getNextTime,
};

View file

@ -0,0 +1,107 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# TypeScript v1 declaration files
typings/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
# parcel-bundler cache (https://parceljs.org/)
.cache
# Next.js build output
.next
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and *not* Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
lib/
package-lock.json

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 TheDudeFromCI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,89 @@
<h1 align="center">mineflayer-collectblock</h1>
<p align="center"><i>A small utility plugin for allowing users to collect blocks using a higher level API.</i></p>
<p align="center">
<img src="https://github.com/TheDudeFromCI/mineflayer-collectblock/workflows/Build/badge.svg" />
<a href="https://www.npmjs.com/package/mineflayer-collectblock"><img src="https://img.shields.io/npm/v/mineflayer-collectblock" /></a>
<img src="https://img.shields.io/github/repo-size/TheDudeFromCI/mineflayer-collectblock" />
<img src="https://img.shields.io/npm/dm/mineflayer-collectblock" />
<img src="https://img.shields.io/github/contributors/TheDudeFromCI/mineflayer-collectblock" />
<img src="https://img.shields.io/github/license/TheDudeFromCI/mineflayer-collectblock" />
</p>
---
## This is a modified version to better support Voyager
## Showcase
You can see a video of the plugin in action, [here.](https://youtu.be/5T_rcCnNnf4)
The source code of the bot in the video can be seen in the examples folder, [here.](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/examples/collector.js)
### Description
This plugin is a wrapper for mineflayer that allows for easier API usage when collecting blocks or item drops. This plugin is designed to reduce some of the boilerplate code based around the act of pathfinding to a block _(handled by_ ***mineflayer-pathfinder***_)_, selecting the best tool to mine that block _(handled by_ ***mineflayer-tool***_)_, actually mining it, then moving to collect the item drops from that block. This plugin allows for all of that basic concept to be wrapped up into a single API function.
In addition to the usage above, some additional quality of life features are available in this plugin. These include the ability to automatically deposit items into a chest when the bot's inventory is full, collecting new tools from a chest if the bot doesn't currently have a required tool _(also handled by_ ***mineflayer-tool***_)_, and allowing for queueing of multiple blocks or item drops to the collection task, so they can be processed later.
### Getting Started
This plugin is built using Node and can be installed using:
```bash
npm install --save mineflayer-collectblock
```
### Simple Bot
The brief description goes here.
```js
// Create your bot
const mineflayer = require("mineflayer")
const bot = mineflayer.createBot({
host: 'localhost',
username: 'Player',
})
let mcData
// Load collect block
bot.loadPlugin(require('mineflayer-collectblock').plugin)
async function collectGrass() {
// Find a nearby grass block
const grass = bot.findBlock({
matching: mcData.blocksByName.grass_block.id,
maxDistance: 64
})
if (grass) {
// If we found one, collect it.
try {
await bot.collectBlock.collect(grass)
collectGrass() // Collect another grass block
} catch (err) {
console.log(err) // Handle errors, if any
}
}
}
// On spawn, start collecting all nearby grass
bot.once('spawn', () => {
mcData = require('minecraft-data')(bot.version)
collectGrass()
})
```
### Documentation
[API](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/docs/api.md)
[Examples](https://github.com/TheDudeFromCI/mineflayer-collectblock/tree/master/examples)
### License
This project uses the [MIT](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/LICENSE) license.
### Contributions
This project is accepting PRs and Issues. See something you think can be improved? Go for it! Any and all help is highly appreciated!
For larger changes, it is recommended to discuss these changes in the issues tab before writing any code. It's also preferred to make many smaller PRs than one large one, where applicable.

View file

@ -0,0 +1 @@
theme: jekyll-theme-cayman

View file

@ -0,0 +1,52 @@
# API <!-- omit in toc -->
Welcome to the *mineflayer-collectblock* API documentation page.
## Table of Contents <!-- omit in toc -->
- [1. Summary](#1-summary)
- [Properties](#properties)
- [`bot.collectblock.movements: Movements`](#botcollectblockmovements-movements)
- [Functions](#functions)
- [collect](#collect)
- [Options:](#options)
## 1. Summary
The collect block plugin is a utility plugin that can be used to help make collecting blocks and item drops very easy, using only a single API call. No need to worry about pathfinding to the block, selecting the right tool, or moving to pick up the item drop after mining.
## Properties
### `bot.collectblock.movements: Movements`
The movements object used by the pathfinder plugin to define the movement configuration. This object is passed to the pathfinder plugin when any API from this plugin is called in order to control how pathfinding should work when collecting the given blocks or item.
If set to null, the pathfinder plugin movements is not updated.
Defaults to a new movements object instance.
## Functions
### collect
Usage: `bot.collectblock.collect(target: Collectable | Collectable[], options?: CollectOptions, cb: (err?: Error) => void): void`
Causes the bot to collect the given block, item drop, or list of those. If the target is a block, the bot will move to the block, mine it, and pick up the item drop. If the target is an item drop, the bot will move to the item drop and pick it up. If the target is a list of collectables, the bot will move from target to target in order of closest to furthest and collect each target in turn.
#### Options:
* `append: boolean`
If true, the target(s) will be appended to the existing target list instead of starting a new task. Defaults to false.
* `ignoreNoPath: boolean`
If true, errors will not be thrown when a path to the target block cannot be found. The bot will attempt to choose the best available position it can find, instead. Errors are still thrown if the bot cannot interact with the block from it's final location. Defaults to false.
* `chestLocations: Vec3[]`
Gets the list of chest locations to use when storing items after the bot's inventory becomes full. If undefined, it defaults to the chest location list on the bot.collectBlock plugin.
* `itemFilter: ItemFilter`
When transferring items to a chest, this filter is used to determine what items are allowed to be moved, and what items aren't allowed to be moved. Defaults to the item filter specified on the bot.collectBlock plugin.

View file

@ -0,0 +1,70 @@
/**
* This bot example show how to direct a bot to collect a specific block type
* or a group of nearby blocks of that type.
*/
const mineflayer = require('mineflayer')
const collectBlock = require('mineflayer-collectblock').plugin
if (process.argv.length < 4 || process.argv.length > 6) {
console.log('Usage : node collector.js <host> <port> [<name>] [<password>]')
process.exit(1)
}
const bot = mineflayer.createBot({
host: process.argv[2],
port: process.argv[3],
username: process.argv[4] || 'collector',
password: process.argv[5]
})
bot.loadPlugin(collectBlock)
let mcData
bot.once('spawn', () => {
mcData = require('minecraft-data')(bot.version)
})
bot.on('chat', async (username, message) => {
const args = message.split(' ')
if (args[0] !== 'collect') return
let count = 1
if (args.length === 3) count = parseInt(args[1])
let type = args[1]
if (args.length === 3) type = args[2]
const blockType = mcData.blocksByName[type]
if (!blockType) {
return
}
const blocks = bot.findBlocks({
matching: blockType.id,
maxDistance: 64,
count: count
})
if (blocks.length === 0) {
bot.chat("I don't see that block nearby.")
return
}
const targets = []
for (let i = 0; i < Math.min(blocks.length, count); i++) {
targets.push(bot.blockAt(blocks[i]))
}
bot.chat(`Found ${targets.length} ${type}(s)`)
try {
await bot.collectBlock.collect(targets)
// All blocks have been collected.
bot.chat('Done')
} catch (err) {
// An error occurred, report it.
bot.chat(err.message)
console.log(err)
}
})

View file

@ -0,0 +1,59 @@
/**
* This bot example shows how to collect a vein of ores quickly after only finding a single block.
* This makes it easy to collect a vein of ores or mine a tree without looking for every block in the
* area.
*/
const mineflayer = require('mineflayer')
const collectBlock = require('mineflayer-collectblock').plugin
if (process.argv.length < 4 || process.argv.length > 6) {
console.log('Usage : node oreMiner.js <host> <port> [<name>] [<password>]')
process.exit(1)
}
const bot = mineflayer.createBot({
host: process.argv[2],
port: process.argv[3],
username: process.argv[4] || 'oreMiner',
password: process.argv[5]
})
bot.loadPlugin(collectBlock)
let mcData
bot.once('spawn', () => {
mcData = require('minecraft-data')(bot.version)
})
bot.on('chat', async (username, message) => {
const args = message.split(' ')
if (args[0] !== 'collect') return
const blockType = mcData.blocksByName[args[1]]
if (!blockType) {
bot.chat(`I don't know any blocks named ${args[1]}.`)
return
}
const block = bot.findBlock({
matching: blockType.id,
maxDistance: 64
})
if (!block) {
bot.chat("I don't see that block nearby.")
return
}
const targets = bot.collectBlock.findFromVein(block)
try {
await bot.collectBlock.collect(targets)
// All blocks have been collected.
bot.chat('Done')
} catch (err) {
// An error occurred, report it.
bot.chat(err.message)
console.log(err)
}
})

View file

@ -0,0 +1,107 @@
/**
* This bot example shows how to use the chest filling mechanic of the plugin.
* Simply provide a given storage chest, and the bot will automatically try and
* store it's inventory in that chest when the bot's inventory becomes full.
*/
if (process.argv.length < 4 || process.argv.length > 6) {
console.log('Usage : node storageBot.js <host> <port> [<name>] [<password>]')
process.exit(1)
}
// Load your libraries
const mineflayer = require('mineflayer')
const collectBlock = require('mineflayer-collectblock').plugin
// Create your bot
const bot = mineflayer.createBot({
host: process.argv[2],
port: parseInt(process.argv[3]),
username: process.argv[4] ? process.argv[4] : 'storageBot',
password: process.argv[5]
})
// Load the collect block plugin
bot.loadPlugin(collectBlock)
// Load mcData on login
let mcData
bot.once('login', () => {
mcData = require('minecraft-data')(bot.version)
})
// On spawn, try to find any nearby chests and save those as storage locations.
// When the bot's inventory becomes too full, it will empty it's inventory into
// these chests before collecting more resources. If a chest gets full, it moves
// to the next one in order until it's inventory is empty or it runs out of chests.
bot.once('spawn', () => {
bot.collectBlock.chestLocations = bot.findBlocks({
matching: mcData.blocksByName.chest.id,
maxDistance: 16,
count: 999999 // Get as many chests as we can
})
if (bot.collectBlock.chestLocations.length === 0) {
bot.chat("I don't see any chests nearby.")
} else {
for (const chestPos of bot.collectBlock.chestLocations) {
bot.chat(`I found a chest at ${chestPos}`)
}
}
})
// Wait for someone to say something
bot.on('chat', async (username, message) => {
// If the player says something start starts with "collect"
// Otherwise, do nothing
const args = message.split(' ')
if (args[0] !== 'collect') return
// If the player specifies a number, collect that many. Otherwise, default to 1.
let count = 1
if (args.length === 3) count = parseInt(args[1])
// If a number was given the item number is the 3rd arg, not the 2nd.
let type = args[1]
if (args.length === 3) type = args[2]
// Get the id of that block type for this version of Minecraft.
const blockType = mcData.blocksByName[type]
if (!blockType) {
bot.chat(`I don't know any blocks named ${type}.`)
return
}
// Find all nearby blocks of that type, up to the given count, within 64 blocks.
const blocks = bot.findBlocks({
matching: blockType.id,
maxDistance: 64,
count: count
})
// Complain if we can't find any nearby blocks of that type.
if (blocks.length === 0) {
bot.chat("I don't see that block nearby.")
return
}
// Convert the block position array into a block array to pass to collect block.
const targets = []
for (let i = 0; i < Math.min(blocks.length, count); i++) {
targets.push(bot.blockAt(blocks[i]))
}
// Announce what we found.
bot.chat(`Found ${targets.length} ${type}(s)`)
// Tell the bot to collect all of the given blocks in the block list.
try {
await bot.collectBlock.collect(targets)
// All blocks have been collected.
bot.chat('Done')
} catch (err) {
// An error occurred, report it.
bot.chat(err.message)
console.log(err)
}
})

View file

@ -0,0 +1,44 @@
{
"name": "mineflayer-collectblock",
"version": "1.4.1",
"description": "A simple utility plugin for Mineflayer that add a higher level API for collecting blocks.",
"main": "lib/index.js",
"types": "lib/index.d.ts",
"scripts": {
"build": "ts-standard && tsc && require-self",
"clean": "rm -rf lib",
"test": "test"
},
"repository": {
"type": "git",
"url": "git+https://github.com/TheDudeFromCI/mineflayer-collectblock.git"
},
"keywords": [
"mineflayer",
"plugin",
"api",
"utility",
"helper",
"collect"
],
"author": "TheDudeFromCI",
"license": "MIT",
"bugs": {
"url": "https://github.com/TheDudeFromCI/mineflayer-collectblock/issues"
},
"homepage": "https://github.com/TheDudeFromCI/mineflayer-collectblock#readme",
"dependencies": {
"mineflayer": "^4.0.0",
"mineflayer-pathfinder": "^2.1.1",
"mineflayer-tool": "^1.1.0"
},
"devDependencies": {
"@types/node": "^18.6.4",
"require-self": "^0.2.3",
"ts-standard": "^11.0.0",
"typescript": "^4.1.3"
},
"files": [
"lib/**/*"
]
}

View file

@ -0,0 +1,35 @@
import { Bot } from 'mineflayer'
import { Block } from 'prismarine-block'
export function findFromVein (bot: Bot, block: Block, maxBlocks: number, maxDistance: number, floodRadius: number): Block[] {
const targets: Block[] = []
const open: Block[] = [block]
const type = block.type
const center = block.position
for (let i = 0; i < maxBlocks; i++) {
const next = open.pop()
if (next == null) break
targets.push(next)
for (let x = -floodRadius; x <= floodRadius; x++) {
for (let y = -floodRadius; y <= floodRadius; y++) {
for (let z = -floodRadius; z <= floodRadius; z++) {
const neighborPos = next.position.offset(x, y, z)
if (neighborPos.manhattanDistanceTo(center) > maxDistance) continue
const neighbor = bot.blockAt(neighborPos)
if (neighbor == null || neighbor.type !== type) continue
if (targets.includes(neighbor)) continue
if (open.includes(neighbor)) continue
open.push(neighbor)
}
}
}
}
return targets
}

View file

@ -0,0 +1,451 @@
import { Bot } from "mineflayer";
import { Block } from "prismarine-block";
import { Movements, goals } from "mineflayer-pathfinder";
import { TemporarySubscriber } from "./TemporarySubscriber";
import { Entity } from "prismarine-entity";
import { error } from "./Util";
import { Vec3 } from "vec3";
import { emptyInventoryIfFull, ItemFilter } from "./Inventory";
import { findFromVein } from "./BlockVeins";
import { Collectable, Targets } from "./Targets";
import { Item } from "prismarine-item";
import mcDataLoader from "minecraft-data";
import { once } from "events";
import { callbackify } from "util";
export type Callback = (err?: Error) => void;
async function collectAll(
bot: Bot,
options: CollectOptionsFull
): Promise<void> {
let success_count = 0;
while (!options.targets.empty) {
await emptyInventoryIfFull(
bot,
options.chestLocations,
options.itemFilter
);
const closest = options.targets.getClosest();
if (closest == null) break;
switch (closest.constructor.name) {
case "Block": {
try {
if (success_count >= options.count) {
break;
}
await bot.tool.equipForBlock(
closest as Block,
equipToolOptions
);
const goal = new goals.GoalLookAtBlock(
closest.position,
bot.world
);
await bot.pathfinder.goto(goal);
await mineBlock(bot, closest as Block, options);
success_count++;
// TODO: options.ignoreNoPath
} catch (err) {
// @ts-ignore
// console.log(err.stack)
// bot.pathfinder.stop()
// bot.waitForTicks(10)
try {
bot.pathfinder.setGoal(null);
} catch (err) {}
if (options.ignoreNoPath) {
// @ts-ignore
if (err.name === "Invalid block") {
console.log(
`Block ${closest.name} at ${closest.position} is not valid! Skip it!`
);
} // @ts-ignore
else if (err.name === "Unsafe block") {
console.log(
`${closest.name} at ${closest.position} is not safe to break! Skip it!`
);
// @ts-ignore
} else if (err.name === "NoItem") {
const properties =
bot.registry.blocksByName[closest.name];
const leastTool = Object.keys(
properties.harvestTools
)[0];
const item = bot.registry.items[leastTool];
bot.chat(
`I need at least a ${item.name} to mine ${closest.name}! Skip it!`
);
return;
} else if (
// @ts-ignore
err.name === "NoPath" ||
// @ts-ignore
err.name === "Timeout"
) {
if (
bot.entity.position.distanceTo(
closest.position
) < 0.5
) {
await mineBlock(bot, closest as Block, options);
break;
}
console.log(
`No path to ${closest.name} at ${closest.position}! Skip it!`
);
// @ts-ignore
} else if (err.message === "Digging aborted") {
console.log(`Digging aborted! Skip it!`);
} else {
// @ts-ignore
bot.chat(`Error: ${err.message}`);
}
break;
}
throw err;
}
break;
}
case "Entity": {
// Don't collect any entities that are marked as 'invalid'
if (!(closest as Entity).isValid) break;
try {
const tempEvents = new TemporarySubscriber(bot);
const waitForPickup = new Promise<void>(
(resolve, reject) => {
const timeout = setTimeout(() => {
// After 10 seconds, reject the promise
clearTimeout(timeout);
tempEvents.cleanup();
reject(new Error("Failed to pickup item"));
}, 10000);
tempEvents.subscribeTo(
"entityGone",
(entity: Entity) => {
if (entity === closest) {
clearTimeout(timeout);
tempEvents.cleanup();
resolve();
}
}
);
}
);
bot.pathfinder.setGoal(
new goals.GoalFollow(closest as Entity, 0)
);
// await bot.pathfinder.goto(new goals.GoalBlock(closest.position.x, closest.position.y, closest.position.z))
await waitForPickup;
} catch (err) {
// @ts-ignore
console.log(err.stack);
try {
bot.pathfinder.setGoal(null);
} catch (err) {}
if (options.ignoreNoPath) {
// @ts-ignore
if (err.message === "Failed to pickup item") {
bot.chat(`Failed to pickup item! Skip it!`);
}
break;
}
throw err;
}
break;
}
default: {
throw error(
"UnknownType",
`Target ${closest.constructor.name} is not a Block or Entity!`
);
}
}
options.targets.removeTarget(closest);
}
bot.chat(`Collect finish!`);
}
const equipToolOptions = {
requireHarvest: true,
getFromChest: false,
maxTools: 2,
};
async function mineBlock(
bot: Bot,
block: Block,
options: CollectOptionsFull
): Promise<void> {
if (
bot.blockAt(block.position)?.type !== block.type ||
bot.blockAt(block.position)?.type === 0
) {
options.targets.removeTarget(block);
throw error("Invalid block", "Block is not valid!");
// @ts-expect-error
} else if (!bot.pathfinder.movements.safeToBreak(block)) {
options.targets.removeTarget(block);
throw error("Unsafe block", "Block is not safe to break!");
}
await bot.tool.equipForBlock(block, equipToolOptions);
if (!block.canHarvest(bot.heldItem ? bot.heldItem.type : bot.heldItem)) {
options.targets.removeTarget(block);
throw error("NoItem", "Bot does not have a harvestable tool!");
}
const tempEvents = new TemporarySubscriber(bot);
tempEvents.subscribeTo("itemDrop", (entity: Entity) => {
if (
entity.position.distanceTo(block.position.offset(0.5, 0.5, 0.5)) <=
0.5
) {
options.targets.appendTarget(entity);
}
});
try {
await bot.dig(block);
// Waiting for items to drop
await new Promise<void>((resolve) => {
let remainingTicks = 10;
tempEvents.subscribeTo("physicTick", () => {
remainingTicks--;
if (remainingTicks <= 0) {
tempEvents.cleanup();
resolve();
}
});
});
} finally {
tempEvents.cleanup();
}
}
/**
* A set of options to apply when collecting the given targets.
*/
export interface CollectOptions {
/**
* If true, the target(s) will be appended to the existing target list instead of
* starting a new task. Defaults to false.
*/
append?: boolean;
/**
* If true, errors will not be thrown when a path to the target block cannot
* be found. The bot will attempt to choose the best available position it
* can find, instead. Errors are still thrown if the bot cannot interact with
* the block from it's final location. Defaults to false.
*/
ignoreNoPath?: boolean;
/**
* Gets the list of chest locations to use when storing items after the bot's
* inventory becomes full. If undefined, it defaults to the chest location
* list on the bot.collectBlock plugin.
*/
chestLocations?: Vec3[];
/**
* When transferring items to a chest, this filter is used to determine what
* items are allowed to be moved, and what items aren't allowed to be moved.
* Defaults to the item filter specified on the bot.collectBlock plugin.
*/
itemFilter?: ItemFilter;
/**
* The total number of items to collect
*/
count?: number;
}
/**
* A version of collect options where all values are assigned.
*/
interface CollectOptionsFull {
append: boolean;
ignoreNoPath: boolean;
chestLocations: Vec3[];
itemFilter: ItemFilter;
targets: Targets;
count: number;
}
/**
* The collect block plugin.
*/
export class CollectBlock {
/**
* The bot.
*/
private readonly bot: Bot;
/**
* The list of active targets being collected.
*/
private readonly targets: Targets;
/**
* The movements configuration to be sent to the pathfinder plugin.
*/
movements?: Movements;
/**
* A list of chest locations which the bot is allowed to empty their inventory into
* if it becomes full while the bot is collecting resources.
*/
chestLocations: Vec3[] = [];
/**
* When collecting items, this filter is used to determine what items should be placed
* into a chest if the bot's inventory becomes full. By default, returns true for all
* items except for tools, weapons, and armor.
*
* @param item - The item stack in the bot's inventory to check.
*
* @returns True if the item should be moved into the chest. False otherwise.
*/
itemFilter: ItemFilter = (item: Item) => {
if (item.name.includes("helmet")) return false;
if (item.name.includes("chestplate")) return false;
if (item.name.includes("leggings")) return false;
if (item.name.includes("boots")) return false;
if (item.name.includes("shield")) return false;
if (item.name.includes("sword")) return false;
if (item.name.includes("pickaxe")) return false;
if (item.name.includes("axe")) return false;
if (item.name.includes("shovel")) return false;
if (item.name.includes("hoe")) return false;
return true;
};
/**
* Creates a new instance of the create block plugin.
*
* @param bot - The bot this plugin is acting on.
*/
constructor(bot: Bot) {
this.bot = bot;
this.targets = new Targets(bot);
// @ts-ignore
this.movements = new Movements(bot, mcDataLoader(bot.version));
}
/**
* If target is a block:
* Causes the bot to break and collect the target block.
*
* If target is an item drop:
* Causes the bot to collect the item drop.
*
* If target is an array containing items or blocks, preforms the correct action for
* all targets in that array sorting dynamically by distance.
*
* @param target - The block(s) or item(s) to collect.
* @param options - The set of options to use when handling these targets
* @param cb - The callback that is called finished.
*/
async collect(
target: Collectable | Collectable[],
options: CollectOptions | Callback = {},
cb?: Callback
): Promise<void> {
if (typeof options === "function") {
cb = options;
options = {};
}
// @ts-expect-error
if (cb != null) return callbackify(this.collect)(target, options, cb);
const optionsFull: CollectOptionsFull = {
append: options.append ?? false,
ignoreNoPath: options.ignoreNoPath ?? false,
chestLocations: options.chestLocations ?? this.chestLocations,
itemFilter: options.itemFilter ?? this.itemFilter,
targets: this.targets,
count: options.count ?? Infinity,
};
if (this.bot.pathfinder == null) {
throw error(
"UnresolvedDependency",
"The mineflayer-collectblock plugin relies on the mineflayer-pathfinder plugin to run!"
);
}
if (this.bot.tool == null) {
throw error(
"UnresolvedDependency",
"The mineflayer-collectblock plugin relies on the mineflayer-tool plugin to run!"
);
}
if (this.movements != null) {
this.bot.pathfinder.setMovements(this.movements);
}
if (!optionsFull.append) await this.cancelTask();
if (Array.isArray(target)) {
this.targets.appendTargets(target);
} else {
this.targets.appendTarget(target);
}
try {
await collectAll(this.bot, optionsFull);
this.targets.clear();
} catch (err) {
this.targets.clear();
// Ignore path stopped error for cancelTask to work properly (imo we shouldn't throw any pathing errors)
// @ts-expect-error
if (err.name !== "PathStopped") throw err;
} finally {
// @ts-expect-error
this.bot.emit("collectBlock_finished");
}
}
/**
* Loads all touching blocks of the same type to the given block and returns them as an array.
* This effectively acts as a flood fill algorithm to retrieve blocks in the same ore vein and similar.
*
* @param block - The starting block.
* @param maxBlocks - The maximum number of blocks to look for before stopping.
* @param maxDistance - The max distance from the starting block to look.
* @param floodRadius - The max distance distance from block A to block B to be considered "touching"
*/
findFromVein(
block: Block,
maxBlocks = 100,
maxDistance = 16,
floodRadius = 1
): Block[] {
return findFromVein(
this.bot,
block,
maxBlocks,
maxDistance,
floodRadius
);
}
/**
* Cancels the current collection task, if still active.
*
* @param cb - The callback to use when the task is stopped.
*/
async cancelTask(cb?: Callback): Promise<void> {
if (this.targets.empty) {
if (cb != null) cb();
return await Promise.resolve();
}
this.bot.pathfinder.stop();
if (cb != null) {
// @ts-expect-error
this.bot.once("collectBlock_finished", cb);
}
await once(this.bot, "collectBlock_finished");
}
}

View file

@ -0,0 +1,87 @@
import { Bot } from 'mineflayer'
import { Callback } from './CollectBlock'
import { Vec3 } from 'vec3'
import { error } from './Util'
import { Item } from 'prismarine-item'
import { goals } from 'mineflayer-pathfinder'
import { callbackify } from 'util'
export type ItemFilter = (item: Item) => boolean
function getClosestChest (bot: Bot, chestLocations: Vec3[]): Vec3 | null {
let chest = null
let distance = 0
for (const c of chestLocations) {
const dist = c.distanceTo(bot.entity.position)
if (chest == null || dist < distance) {
chest = c
distance = dist
}
}
if (chest != null) {
chestLocations.splice(chestLocations.indexOf(chest), 1)
}
return chest
}
export async function emptyInventoryIfFull (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise<void> {
// @ts-expect-error
if (cb != null) return callbackify(emptyInventoryIfFull)(bot, chestLocations, cb)
if (bot.inventory.emptySlotCount() > 0) return
return await emptyInventory(bot, chestLocations, itemFilter)
}
export async function emptyInventory (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise<void> {
// @ts-expect-error
if (cb != null) return callbackify(emptyInventory)(bot, chestLocations, cb)
if (chestLocations.length === 0) {
throw error('NoChests', 'There are no defined chest locations!')
}
// Shallow clone so we can safely remove chests from the list that are full.
chestLocations = [...chestLocations]
while (true) {
const chest = getClosestChest(bot, chestLocations)
if (chest == null) {
throw error('NoChests', 'All chests are full.')
}
const hasRemaining = await tryEmptyInventory(bot, chest, itemFilter)
if (!hasRemaining) return
}
}
async function tryEmptyInventory (bot: Bot, chestLocation: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise<boolean> {
// @ts-expect-error
if (cb != null) return callbackify(tryEmptyInventory)(bot, chestLocation, itemFilter, cb)
await gotoChest(bot, chestLocation)
return await placeItems(bot, chestLocation, itemFilter)
}
async function gotoChest (bot: Bot, location: Vec3, cb?: Callback): Promise<void> {
// @ts-expect-error
if (cb != null) return callbackify(gotoChest)(bot, location)
await bot.pathfinder.goto(new goals.GoalGetToBlock(location.x, location.y, location.z))
}
async function placeItems (bot: Bot, chestPos: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise<boolean> {
// @ts-expect-error
if (cb != null) return callbackify(placeItems)(bot, chestPos, itemFilter, cb)
const chestBlock = bot.blockAt(chestPos)
if (chestBlock == null) {
throw error('UnloadedChunk', 'Chest is in an unloaded chunk!')
}
const chest = await bot.openChest(chestBlock)
for (const item of bot.inventory.items()) {
if (!itemFilter(item)) continue
if (chest.firstEmptyContainerSlot() === null) {
// We have items that didn't fit.
return true
}
await chest.deposit(item.type, item.metadata, item.count)
}
return false
}

View file

@ -0,0 +1,60 @@
import { Bot } from 'mineflayer'
import { Block } from 'prismarine-block'
import { Entity } from 'prismarine-entity'
export type Collectable = Block | Entity
export class Targets {
private readonly bot: Bot
private targets: Collectable[] = []
constructor (bot: Bot) {
this.bot = bot
}
appendTargets (targets: Collectable[]): void {
for (const target of targets) {
this.appendTarget(target)
}
}
appendTarget (target: Collectable): void {
if (this.targets.includes(target)) return
this.targets.push(target)
}
/**
* Gets the closest target to the bot in this list.
*
* @returns The closest target, or null if there are no targets.
*/
getClosest (): Collectable | null {
let closest: Collectable | null = null
let distance: number = 0
for (const target of this.targets) {
const dist = target.position.distanceTo(this.bot.entity.position)
if (closest == null || dist < distance) {
closest = target
distance = dist
}
}
return closest
}
get empty (): boolean {
return this.targets.length === 0
}
clear (): void {
this.targets.length = 0
}
removeTarget (target: Collectable): void {
const index = this.targets.indexOf(target)
if (index < 0) return
this.targets.splice(index, 1)
}
}

View file

@ -0,0 +1,77 @@
import type { Callback } from './index'
export type Task = (cb: Callback) => void
export type SyncTask = () => void
/**
* A simple utility class for queuing up a series of async tasks to execute.
*/
export class TaskQueue {
private tasks: Task[] = []
/**
* If true, the task list will stop executing if one of the tasks throws an error.
*/
readonly stopOnError: boolean = true
/**
* Adds a new async task to this queue. The provided callback should be executed when
* the async task is complete.
*
* @param task - The async task to add.
*/
add (task: Task): void {
this.tasks.push(task)
}
/**
* Adds a synchronous task toi this queue.
*
* @param task - The sync task to add.
*/
addSync (task: SyncTask): void {
this.add((cb) => {
try {
task()
cb()
} catch (err: any) {
cb(err)
}
})
}
/**
* Runs all tasks currently in this queue and empties the queue.
*
* @param cb - The optional callback to be executed when all tasks in this queue have
* finished executing.
*/
runAll (cb?: Callback): void {
const taskList = this.tasks
this.tasks = []
let index = -1
const runNext: () => void = () => {
index++
if (index >= taskList.length) {
if (cb !== undefined) cb()
return
}
try {
taskList[index]((err) => {
if (err !== undefined) {
if (cb !== undefined) cb(err)
if (this.stopOnError) return
}
runNext()
})
} catch (err: any) {
if (cb !== undefined) cb(err)
}
}
runNext()
}
}

View file

@ -0,0 +1,34 @@
import { Bot } from 'mineflayer'
class Subscription {
constructor (readonly eventName: string, readonly callback: Function) {}
}
export class TemporarySubscriber {
private readonly subscriptions: Subscription[] = []
constructor (readonly bot: Bot) {}
/**
* Adds a new temporary event listener to the bot.
*
* @param event - The event to subscribe to.
* @param callback - The function to execute.
*/
subscribeTo (event: string, callback: Function): void {
this.subscriptions.push(new Subscription(event, callback))
// @ts-expect-error
this.bot.on(event, callback)
}
/**
* Removes all attached event listeners from the bot.
*/
cleanup (): void {
for (const sub of this.subscriptions) {
// @ts-expect-error
this.bot.removeListener(sub.eventName, sub.callback)
}
}
}

View file

@ -0,0 +1,13 @@
/**
* Creates a new error object with the given type and message.
*
* @param type - The error type.
* @param message - The error message.
*
* @returns The error object.
*/
export function error (type: string, message: string): Error {
const e = new Error(message)
e.name = type
return e
}

View file

@ -0,0 +1,25 @@
import { Bot } from 'mineflayer'
import { CollectBlock } from './CollectBlock'
import { pathfinder as pathfinderPlugin } from 'mineflayer-pathfinder'
import { plugin as toolPlugin } from 'mineflayer-tool'
export function plugin (bot: Bot): void {
// @ts-expect-error
bot.collectBlock = new CollectBlock(bot)
// Load plugins if not loaded manually.
setTimeout(() => loadPathfinderPlugin(bot), 0)
setTimeout(() => loadToolPlugin(bot), 0)
}
function loadPathfinderPlugin (bot: Bot): void {
if (bot.pathfinder != null) return
bot.loadPlugin(pathfinderPlugin)
}
function loadToolPlugin (bot: Bot): void {
if (bot.tool != null) return
bot.loadPlugin(toolPlugin)
}
export { CollectBlock, Callback, CollectOptions } from './CollectBlock'

View file

@ -0,0 +1,69 @@
{
"compilerOptions": {
/* Visit https://aka.ms/tsconfig.json to read more about this file */
/* Basic Options */
// "incremental": true, /* Enable incremental compilation */
"target": "ES2015", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */
"module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */
// "lib": [], /* Specify library files to be included in the compilation. */
"allowJs": true, /* Allow javascript files to be compiled. */
"checkJs": true, /* Report errors in .js files. */
// "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */
"declaration": true,
// "declarationMap": true, /* Generates a sourcemap for each corresponding '.d.ts' file. */
// "sourceMap": true, /* Generates corresponding '.map' file. */
// "outFile": "./", /* Concatenate and emit output to single file. */
"outDir": "./lib",
// "rootDir": "./", /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */
// "composite": true, /* Enable project compilation */
// "tsBuildInfoFile": "./", /* Specify file to store incremental compilation information */
// "removeComments": true, /* Do not emit comments to output. */
// "noEmit": true, /* Do not emit outputs. */
// "importHelpers": true, /* Import emit helpers from 'tslib'. */
// "downlevelIteration": true, /* Provide full support for iterables in 'for-of', spread, and destructuring when targeting 'ES5' or 'ES3'. */
// "isolatedModules": true, /* Transpile each file as a separate module (similar to 'ts.transpileModule'). */
/* Strict Type-Checking Options */
"strict": true, /* Enable all strict type-checking options. */
// "noImplicitAny": true, /* Raise error on expressions and declarations with an implied 'any' type. */
"strictNullChecks": true, /* Enable strict null checks. */
// "strictFunctionTypes": true, /* Enable strict checking of function types. */
// "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */
// "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */
// "noImplicitThis": true, /* Raise error on 'this' expressions with an implied 'any' type. */
"alwaysStrict": true, /* Parse in strict mode and emit "use strict" for each source file. */
/* Additional Checks */
"noUnusedLocals": true, /* Report errors on unused locals. */
// "noUnusedParameters": true, /* Report errors on unused parameters. */
"noImplicitReturns": true, /* Report error when not all code paths in function return a value. */
// "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */
/* Module Resolution Options */
// "moduleResolution": "node", /* Specify module resolution strategy: 'node' (Node.js) or 'classic' (TypeScript pre-1.6). */
// "baseUrl": "./", /* Base directory to resolve non-absolute module names. */
// "paths": {}, /* A series of entries which re-map imports to lookup locations relative to the 'baseUrl'. */
// "rootDirs": [], /* List of root folders whose combined content represents the structure of the project at runtime. */
// "typeRoots": [], /* List of folders to include type definitions from. */
// "types": [], /* Type declaration files to be included in compilation. */
// "allowSyntheticDefaultImports": true, /* Allow default imports from modules with no default export. This does not affect code emit, just typechecking. */
"esModuleInterop": true, /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */
// "preserveSymlinks": true, /* Do not resolve the real path of symlinks. */
// "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
/* Source Map Options */
// "sourceRoot": "", /* Specify the location where debugger should locate TypeScript files instead of source locations. */
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
// "inlineSourceMap": true, /* Emit a single file with source maps instead of having a separate file. */
// "inlineSources": true, /* Emit the source alongside the sourcemaps within a single file; requires '--inlineSourceMap' or '--sourceMap' to be set. */
/* Experimental Options */
// "experimentalDecorators": true, /* Enables experimental support for ES7 decorators. */
// "emitDecoratorMetadata": true, /* Enables experimental support for emitting type metadata for decorators. */
/* Advanced Options */
"skipLibCheck": true, /* Skip type checking of declaration files. */
"forceConsistentCasingInFileNames": true /* Disallow inconsistently-cased references to the same file. */
},
"include": [
"src"
],
"exclude": [
"node_modules",
"**/__tests__/*"
]
}

View file

@ -0,0 +1,38 @@
{
"name": "voyager",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"dependencies": {
"body-parser": "^1.20.2",
"express": "^4.18.2",
"magic-string": "^0.30.0",
"minecraft-data": "^3.31.0",
"minecrafthawkeye": "^1.3.6",
"mineflayer": "^4.8.1",
"mineflayer-collectblock": "file:mineflayer-collectblock",
"mineflayer-pathfinder": "^2.4.2",
"mineflayer-pvp": "^1.3.2",
"mineflayer-tool": "^1.2.0",
"mocha": "^10.2.0",
"prismarine-biome": "^1.3.0",
"prismarine-block": "=1.16.3",
"prismarine-entity": "^2.2.0",
"prismarine-item": "^1.12.1",
"prismarine-nbt": "^2.2.1",
"prismarine-recipe": "^1.3.1",
"prismarine-viewer": "^1.24.0",
"typescript": "^4.9.5",
"vec3": "^0.1.8",
"graceful-fs": "^4.2.11"
},
"devDependencies": {
"prettier": "2.8.5"
}
}

View file

@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# refs to `voyager process_monitor.py`
import re
import subprocess
import threading
import warnings
from typing import List
import psutil
from metagpt.logs import define_log_level
class SubprocessMonitor:
def __init__(
self,
commands: List[str],
name: str,
ready_match: str = r".*",
callback_match: str = r"^(?!x)x$", # regex that will never match
callback: callable = None,
finished_callback: callable = None,
):
self.commands = commands
self.name = name
self.logger = define_log_level(name=name)
self.process = None
self.ready_match = ready_match
self.ready_event = None
self.ready_line = None
self.callback_match = callback_match
self.callback = callback
self.finished_callback = finished_callback
self.thread = None
def _start(self):
self.logger.info(f"Starting subprocess with commands: {self.commands}")
self.process = psutil.Popen(
self.commands,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
self.logger.info(f"Subprocess {self.name} started with PID {self.process.pid}.")
for line in iter(self.process.stdout.readline, ""):
self.logger.info(line.strip())
if re.search(self.ready_match, line):
self.ready_line = line
self.logger.info("Subprocess is ready.")
self.ready_event.set()
if re.search(self.callback_match, line):
self.callback()
if not self.ready_event.is_set():
self.ready_event.set()
warnings.warn(f"Subprocess {self.name} failed to start.")
if self.finished_callback:
self.finished_callback()
def run(self):
self.ready_event = threading.Event()
self.ready_line = None
self.thread = threading.Thread(target=self._start)
self.thread.start()
self.ready_event.wait()
def stop(self):
self.logger.info("Stopping subprocess.")
if self.process and self.process.is_running():
self.process.terminate()
self.process.wait()
@property
def is_running(self):
if self.process is None:
return False
return self.process.is_running()

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : MG Software Env
from metagpt.environment.base_env import Environment
class SoftwareEnv(Environment):
"""a specific alias name"""
pass

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : MG StanfordTown Env
from metagpt.environment.base_env import Environment
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
StanfordTownExtEnv,
)
class StanfordTownEnv(Environment, StanfordTownExtEnv):
pass

View file

@ -0,0 +1,379 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The StanfordTown external environment to interate with the web interface
# refs to `generative_agents maze.py`
import math
from pathlib import Path
from typing import Optional, Tuple
from pydantic import ConfigDict, Field, model_validator
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.utils.common import read_csv_to_list, read_json_file
class StanfordTownExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
maze_asset_path: Optional[Path] = Field(default=None, description="the path to store maze assets")
maze_width: int = Field(default=140, description="maze map width")
maze_height: int = Field(default=100, description="maze map height")
sq_tile_size: int = Field(default=32, description="the pixel height/width of a tile")
special_constraint: str = Field(
default="", description="a string description of any relevant special constraints " "the world might have"
)
tiles: list[list[dict]] = Field(default=[])
address_tiles: dict[str, set] = Field(default=dict())
collision_maze: list[list] = Field(default=[])
@model_validator(mode="before")
@classmethod
def _init_maze(cls, values):
maze_asset_path = values["maze_asset_path"]
assert maze_asset_path
maze_asset_path = Path(maze_asset_path)
maze_matrix_path = maze_asset_path.joinpath("matrix")
meta_info = read_json_file(maze_matrix_path.joinpath("maze_meta_info.json"))
maze_width = int(meta_info["maze_width"])
maze_height = int(meta_info["maze_height"])
values["maze_width"] = maze_width
values["maze_height"] = maze_height
values["sq_tile_size"] = int(meta_info["sq_tile_size"])
values["special_constraint"] = meta_info["special_constraint"]
# READING IN SPECIAL BLOCKS
# Special blocks are those that are colored in the Tiled map.
# Here is an example row for the arena block file:
# e.g, "25331, Double Studio, Studio, Bedroom 2, Painting"
blocks_folder = maze_matrix_path.joinpath("special_blocks")
_wb = blocks_folder.joinpath("world_blocks.csv")
wb_rows = read_csv_to_list(_wb, header=False)
wb = wb_rows[0][-1]
_sb = blocks_folder.joinpath("sector_blocks.csv")
sb_rows = read_csv_to_list(_sb, header=False)
sb_dict = dict()
for i in sb_rows:
sb_dict[i[0]] = i[-1]
_ab = blocks_folder.joinpath("arena_blocks.csv")
ab_rows = read_csv_to_list(_ab, header=False)
ab_dict = dict()
for i in ab_rows:
ab_dict[i[0]] = i[-1]
_gob = blocks_folder.joinpath("game_object_blocks.csv")
gob_rows = read_csv_to_list(_gob, header=False)
gob_dict = dict()
for i in gob_rows:
gob_dict[i[0]] = i[-1]
_slb = blocks_folder.joinpath("spawning_location_blocks.csv")
slb_rows = read_csv_to_list(_slb, header=False)
slb_dict = dict()
for i in slb_rows:
slb_dict[i[0]] = i[-1]
# [SECTION 3] Reading in the matrices
# This is your typical two dimensional matrices. It's made up of 0s and
# the number that represents the color block from the blocks folder.
maze_folder = maze_matrix_path.joinpath("maze")
_cm = maze_folder.joinpath("collision_maze.csv")
collision_maze_raw = read_csv_to_list(_cm, header=False)[0]
_sm = maze_folder.joinpath("sector_maze.csv")
sector_maze_raw = read_csv_to_list(_sm, header=False)[0]
_am = maze_folder.joinpath("arena_maze.csv")
arena_maze_raw = read_csv_to_list(_am, header=False)[0]
_gom = maze_folder.joinpath("game_object_maze.csv")
game_object_maze_raw = read_csv_to_list(_gom, header=False)[0]
_slm = maze_folder.joinpath("spawning_location_maze.csv")
spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0]
# Loading the maze. The mazes are taken directly from the json exports of
# Tiled maps. They should be in csv format.
# Importantly, they are "not" in a 2-d matrix format -- they are single
# row matrices with the length of width x height of the maze. So we need
# to convert here.
# example format: [['0', '0', ... '25309', '0',...], ['0',...]...]
# 25309 is the collision bar number right now.
collision_maze = []
sector_maze = []
arena_maze = []
game_object_maze = []
spawning_location_maze = []
for i in range(0, len(collision_maze_raw), maze_width):
tw = maze_width
collision_maze += [collision_maze_raw[i : i + tw]]
sector_maze += [sector_maze_raw[i : i + tw]]
arena_maze += [arena_maze_raw[i : i + tw]]
game_object_maze += [game_object_maze_raw[i : i + tw]]
spawning_location_maze += [spawning_location_maze_raw[i : i + tw]]
values["collision_maze"] = collision_maze
tiles = []
for i in range(maze_height):
row = []
for j in range(maze_width):
tile_details = dict()
tile_details["world"] = wb
tile_details["sector"] = ""
if sector_maze[i][j] in sb_dict:
tile_details["sector"] = sb_dict[sector_maze[i][j]]
tile_details["arena"] = ""
if arena_maze[i][j] in ab_dict:
tile_details["arena"] = ab_dict[arena_maze[i][j]]
tile_details["game_object"] = ""
if game_object_maze[i][j] in gob_dict:
tile_details["game_object"] = gob_dict[game_object_maze[i][j]]
tile_details["spawning_location"] = ""
if spawning_location_maze[i][j] in slb_dict:
tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]]
tile_details["collision"] = False
if collision_maze[i][j] != "0":
tile_details["collision"] = True
tile_details["events"] = set()
row += [tile_details]
tiles += [row]
values["tiles"] = tiles
# Each game object occupies an event in the tile. We are setting up the
# default event value here.
for i in range(maze_height):
for j in range(maze_width):
if tiles[i][j]["game_object"]:
object_name = ":".join(
[tiles[i][j]["world"], tiles[i][j]["sector"], tiles[i][j]["arena"], tiles[i][j]["game_object"]]
)
go_event = (object_name, None, None, None)
tiles[i][j]["events"].add(go_event)
# Reverse tile access.
# <address_tiles> -- given a string address, we return a set of all
# tile coordinates belonging to that address (this is opposite of
# tiles that give you the string address given a coordinate). This is
# an optimization component for finding paths for the personas' movement.
# address_tiles['<spawn_loc>bedroom-2-a'] == {(58, 9)}
# address_tiles['double studio:recreation:pool table']
# == {(29, 14), (31, 11), (30, 14), (32, 11), ...},
address_tiles = dict()
for i in range(maze_height):
for j in range(maze_width):
addresses = []
if tiles[i][j]["sector"]:
add = f'{tiles[i][j]["world"]}:'
add += f'{tiles[i][j]["sector"]}'
addresses += [add]
if tiles[i][j]["arena"]:
add = f'{tiles[i][j]["world"]}:'
add += f'{tiles[i][j]["sector"]}:'
add += f'{tiles[i][j]["arena"]}'
addresses += [add]
if tiles[i][j]["game_object"]:
add = f'{tiles[i][j]["world"]}:'
add += f'{tiles[i][j]["sector"]}:'
add += f'{tiles[i][j]["arena"]}:'
add += f'{tiles[i][j]["game_object"]}'
addresses += [add]
if tiles[i][j]["spawning_location"]:
add = f'<spawn_loc>{tiles[i][j]["spawning_location"]}'
addresses += [add]
for add in addresses:
if add in address_tiles:
address_tiles[add].add((j, i))
else:
address_tiles[add] = set([(j, i)])
values["address_tiles"] = address_tiles
return values
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
"""
Turns a pixel coordinate to a tile coordinate.
"""
x = math.ceil(px_coordinate[0] / self.sq_tile_size)
y = math.ceil(px_coordinate[1] / self.sq_tile_size)
return (x, y)
@mark_as_readable
def get_collision_maze(self) -> list:
return self.collision_maze
@mark_as_readable
def get_address_tiles(self) -> dict:
return self.address_tiles
@mark_as_readable
def access_tile(self, tile: tuple[int, int]) -> dict:
"""
Returns the tiles details dictionary that is stored in self.tiles of the
designated x, y location.
INPUT
tile: The tile coordinate of our interest in (x, y) form.
OUTPUT
The tile detail dictionary for the designated tile.
EXAMPLE OUTPUT
Given (58, 9),
self.tiles[9][58] = {'world': 'double studio',
'sector': 'double studio', 'arena': 'bedroom 2',
'game_object': 'bed', 'spawning_location': 'bedroom-2-a',
'collision': False,
'events': {('double studio:double studio:bedroom 2:bed',
None, None)}}
"""
x = tile[0]
y = tile[1]
return self.tiles[y][x]
@mark_as_readable
def get_tile_path(self, tile: tuple[int, int], level: str) -> str:
"""
Get the tile string address given its coordinate. You designate the level
by giving it a string level description.
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
level: world, sector, arena, or game object
OUTPUT
The string address for the tile.
EXAMPLE OUTPUT
Given tile=(58, 9), and level=arena,
"double studio:double studio:bedroom 2"
"""
x = tile[0]
y = tile[1]
tile = self.tiles[y][x]
path = f"{tile['world']}"
if level == "world":
return path
else:
path += f":{tile['sector']}"
if level == "sector":
return path
else:
path += f":{tile['arena']}"
if level == "arena":
return path
else:
path += f":{tile['game_object']}"
return path
@mark_as_readable
def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]:
"""
Given the current tile and vision_r, return a list of tiles that are
within the radius. Note that this implementation looks at a square
boundary when determining what is within the radius.
i.e., for vision_r, returns x's.
x x x x x
x x x x x
x x P x x
x x x x x
x x x x x
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
vision_r: The radius of the persona's vision.
OUTPUT:
nearby_tiles: a list of tiles that are within the radius.
"""
left_end = 0
if tile[0] - vision_r > left_end:
left_end = tile[0] - vision_r
right_end = self.maze_width - 1
if tile[0] + vision_r + 1 < right_end:
right_end = tile[0] + vision_r + 1
bottom_end = self.maze_height - 1
if tile[1] + vision_r + 1 < bottom_end:
bottom_end = tile[1] + vision_r + 1
top_end = 0
if tile[1] - vision_r > top_end:
top_end = tile[1] - vision_r
nearby_tiles = []
for i in range(left_end, right_end):
for j in range(top_end, bottom_end):
nearby_tiles += [(i, j)]
return nearby_tiles
@mark_as_writeable
def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]):
self.tiles[pt_y][pt_x]["events"].add(event)
@mark_as_writeable
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""
Add an event triple to a tile.
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
self.tiles[tile[1]][tile[0]]["events"].add(curr_event)
@mark_as_writeable
def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""dswaq
Remove an event triple from a tile.
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
@mark_as_writeable
def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
new_event = (event[0], None, None, None)
self.tiles[tile[1]][tile[0]]["events"].add(new_event)
@mark_as_writeable
def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None:
"""
Remove an event triple that has the input subject from a tile.
INPUT:
subject: "Isabella Rodriguez"
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event[0] == subject:
self.tiles[tile[1]][tile[0]]["events"].remove(event)

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : MG Werewolf Env
from pydantic import Field
from metagpt.environment.base_env import Environment
from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv
from metagpt.logs import logger
from metagpt.schema import Message
class WerewolfEnv(Environment, WerewolfExtEnv):
timestamp: int = Field(default=0)
def publish_message(self, message: Message, add_timestamp: bool = True):
"""Post information to the current environment"""
logger.debug(f"publish_message: {message.dump()}")
if add_timestamp:
# Because the content of the message may be repeated, for example, killing the same person in two nights
# Therefore, a unique timestamp prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory.
message.content = f"{self.timestamp} | " + message.content
self.memory.add(message)
self.history += f"\n{message}"
async def run(self, k=1):
"""Process all Role runs by order"""
for _ in range(k):
for role in self.roles.values():
await role.run()
self.timestamp += 1

View file

@ -0,0 +1,335 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The werewolf game external environment to integrate with
import random
from collections import Counter
from enum import Enum
from typing import Callable, Optional
from pydantic import ConfigDict, Field
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
class RoleState(Enum):
ALIVE = "alive" # the role is alive
KILLED = "killed" # the role is killed by werewolf or voting
POISONED = "poisoned" # the role is killed by posion
SAVED = "saved" # the role is saved by antidote
# the ordered rules by the moderator to announce to everyone each step
STEP_INSTRUCTIONS = {
0: {
"content": "Its dark, everyone close your eyes. I will talk with you/your team secretly at night.",
"send_to": "Moderator", # for moderator to continuen speaking
"restricted_to": "",
},
1: {
"content": "Guard, please open your eyes!",
"send_to": "Moderator", # for moderator to continuen speaking
"restricted_to": "",
},
2: {
"content": """Guard, now tell me who you protect tonight?
You only choose one from the following living options please: {living_players}.
Or you can pass. For example: Protect ...""",
"send_to": "Guard",
"restricted_to": "Moderator,Guard",
},
3: {"content": "Guard, close your eyes", "send_to": "Moderator", "restricted_to": ""},
4: {"content": "Werewolves, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
5: {
"content": """Werewolves, I secretly tell you that {werewolf_players} are
all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves.
choose one from the following living options please:
{living_players}. For example: Kill ...""",
"send_to": "Werewolf",
"restricted_to": "Moderator,Werewolf",
},
6: {"content": "Werewolves, close your eyes", "send_to": "Moderator", "restricted_to": ""},
7: {"content": "Witch, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
8: {
"content": """Witch, tonight {player_hunted} has been killed by the werewolves.
You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""",
"send_to": "Witch",
"restricted_to": "Moderator,Witch",
}, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人
9: {
"content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players?
Choose one from the following living options: {living_players}.
If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""",
"send_to": "Witch",
"restricted_to": "Moderator,Witch",
}, #
10: {"content": "Witch, close your eyes", "send_to": "Moderator", "restricted_to": ""},
11: {"content": "Seer, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
12: {
"content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight?
Choose only one from the following living options:{living_players}.""",
"send_to": "Seer",
"restricted_to": "Moderator,Seer",
},
13: {"content": "Seer, close your eyes", "send_to": "Moderator", "restricted_to": ""},
# The 1-st daytime
14: {
"content": """It's daytime. Everyone woke up except those who had been killed.""",
"send_to": "Moderator",
"restricted_to": "",
},
15: {"content": "{player_current_dead} was killed last night!", "send_to": "Moderator", "restricted_to": ""},
16: {
"content": """Living players: {living_players}, now freely talk about the current situation based on your observation and
reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""",
"send_to": "", # send to all to speak in daytime
"restricted_to": "",
},
17: {
"content": """Now vote and tell me who you think is the werewolf. Dont mention your role.
You only choose one from the following living options please:
{living_players}. Say ONLY: I vote to eliminate ...""",
"send_to": "",
"restricted_to": "",
},
18: {"content": """{player_current_dead} was eliminated.""", "send_to": "Moderator", "restricted_to": ""},
}
class WerewolfExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
players_state: dict[str, tuple[str, RoleState]] = Field(
default=dict(), description="the player's role type and state by player_name"
)
round_idx: int = Field(default=0) # the current round
step_idx: int = Field(default=0) # the current step of current round
eval_step_idx: int = Field(default=0)
per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS))
# game global states
game_setup: str = Field(default="", description="game setup including role and its num")
special_role_players: list[str] = Field(default=[])
winner: Optional[str] = Field(default=None)
win_reason: Optional[str] = Field(default=None)
witch_poison_left: int = Field(default=1)
witch_antidote_left: int = Field(default=1)
# game current round states, a round is from closing your eyes to the next time you close your eyes
round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result")
round_votes: dict[str, str] = Field(
default=dict(), description="daytime all players vote result, key=voteer, value=voted one"
)
player_hunted: Optional[str] = Field(default=None)
player_protected: Optional[str] = Field(default=None)
is_hunted_player_saved: bool = Field(default=False)
player_poisoned: Optional[str] = Field(default=None)
player_current_dead: list[str] = Field(default=[])
@property
def living_players(self) -> list[str]:
player_names = []
for name, roletype_state in self.players_state.items():
if roletype_state[1] in [RoleState.ALIVE, RoleState.SAVED]:
player_names.append(name)
return player_names
def _role_type_players(self, role_type: str) -> list[str]:
"""return player name of particular role type"""
player_names = []
for name, roletype_state in self.players_state.items():
if role_type in roletype_state[0]:
player_names.append(name)
return player_names
@property
def werewolf_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Werewolf")
return player_names
@property
def villager_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Villager")
return player_names
def _init_players_state(self, players: list["Role"]):
for play in players:
self.players_state[play.name] = (play.profile, RoleState.ALIVE)
self.special_role_players = [
p for p in self.living_players if p not in self.werewolf_players + self.villager_players
]
def init_game_setup(
self,
role_uniq_objs: list[object],
num_villager: int = 2,
num_werewolf: int = 2,
shuffle=True,
add_human=False,
use_reflection=True,
use_experience=False,
use_memory_selection=False,
new_experience_version="",
prepare_human_player=Callable,
) -> tuple[str, list]:
"""init players using different roles' num"""
role_objs = []
for role_obj in role_uniq_objs:
if str(role_obj) == "Villager":
role_objs.extend([role_obj] * num_villager)
elif str(role_obj) == "Werewolf":
role_objs.extend([role_obj] * num_werewolf)
else:
role_objs.append(role_obj)
if shuffle:
random.shuffle(len(role_objs))
if add_human:
assigned_role_idx = random.randint(0, len(role_objs) - 1)
assigned_role = role_objs[assigned_role_idx]
role_objs[assigned_role_idx] = prepare_human_player(assigned_role) # TODO
players = [
role(
name=f"Player{i + 1}",
use_reflection=use_reflection,
use_experience=use_experience,
use_memory_selection=use_memory_selection,
new_experience_version=new_experience_version,
)
for i, role in enumerate(role_objs)
]
if add_human:
logger.info(f"You are assigned {players[assigned_role_idx].name}({players[assigned_role_idx].profile})")
game_setup = ["Game setup:"] + [f"{player.name}: {player.profile}," for player in players]
self.game_setup = "\n".join(game_setup)
self._init_players_state(players) # init players state
return self.game_setup, players
def _update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED):
for player_name in player_names:
if player_name in self.players_state:
roletype_state = self.players_state[player_name]
self.players_state[player_name] = (roletype_state[0], state)
def _check_valid_role(self, player: "Role", role_type: str) -> bool:
return True if role_type in str(player) else False
def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool:
step_idx = self.step_idx % self.per_round_steps
if particular_step > 0 and step_idx != particular_step: # step no
# particular_step = 18, not daytime vote time, ignore
# particular_step = 15, not nighttime hunt time, ignore
return False
if player_name not in self.living_players:
return False
return True
@mark_as_readable
def curr_step_instruction(self) -> dict:
step_idx = self.step_idx % len(STEP_INSTRUCTIONS)
instruction = STEP_INSTRUCTIONS[step_idx]
self.step_idx += 1
return instruction
@mark_as_readable
def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]:
players_state = {
player_name: self.players_state[player_name][1] # only return role state
for player_name in player_names
if player_name in self.players_state
}
return players_state
@mark_as_writeable
def vote_kill_someone(self, voteer: "Role", player_name: str = None):
"""player vote result at daytime
player_name: if it's None, regard as abstaining from voting
"""
if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no
return
self.round_votes[voteer.name] = player_name
# check if all living players finish voting, then get the dead one
if list(self.round_votes.keys()) == self.living_players:
voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first
voted_all = [item for item in voted_all if item]
self.player_current_dead = Counter(voted_all).most_common()[0][0]
self._update_players_state([self.player_current_dead])
@mark_as_writeable
def wolf_kill_someone(self, wolf: "Role", player_name: str):
if not self._check_valid_role(wolf, "Werewolf"):
return
if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no
return
self.round_hunts[wolf.name] = player_name
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
# check if all living wolfs finish hunting, then get the hunted one
if list(self.round_hunts.keys()) == living_werewolf:
hunted_all = list(self.round_hunts.values())
self.player_hunted = Counter(hunted_all).most_common()[0][0]
@mark_as_writeable
def witch_poison_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.POISONED)
self.player_poisoned = player_name
@mark_as_writeable
def witch_save_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.SAVED)
self.player_protected = player_name
@mark_as_writeable
def update_game_states(self, memories: list):
step_idx = self.step_idx % self.per_round_steps
if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx:
return
else:
self.eval_step_idx.append(self.step_idx) # record evaluation, avoid repetitive evaluation at the same step
if step_idx == 15: # step no
# night ends: after all special roles acted, process the whole night
self.player_current_dead = [] # reset
if self.player_hunted != self.player_protected and not self.is_hunted_player_saved:
self.player_current_dead.append(self.player_hunted)
if self.player_poisoned:
self.player_current_dead.append(self.player_poisoned)
self._update_players_state([self.player_current_dead])
# reset
self.player_hunted = None
self.player_protected = None
self.is_hunted_player_saved = False
self.player_poisoned = None
# game's termination condition
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
living_villagers = [p for p in self.villager_players if p in self.living_players]
living_special_roles = [p for p in self.special_role_players if p in self.living_players]
if not living_werewolf:
self.winner = "good guys"
self.win_reason = "werewolves all dead"
elif not living_villagers or not living_special_roles:
self.winner = "werewolf"
self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead"
if self.winner is not None:
self._record_all_experiences() # TODO

View file

@ -15,14 +15,15 @@ from loguru import logger as _logger
from metagpt.const import METAGPT_ROOT
def define_log_level(print_level="INFO", logfile_level="DEBUG"):
def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None):
"""Adjust the log level to above level"""
current_date = datetime.now()
formatted_date = current_date.strftime("%Y%m%d")
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
_logger.remove()
_logger.add(sys.stderr, level=print_level)
_logger.add(METAGPT_ROOT / f"logs/{formatted_date}.txt", level=logfile_level)
_logger.add(METAGPT_ROOT / f"logs/{log_name}.txt", level=logfile_level)
return _logger

View file

@ -39,8 +39,26 @@ class BaseLLM(ABC):
def __init__(self, config: LLMConfig):
pass
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "content": msg}
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]:
if images:
# as gpt-4v, chat with image
return self._user_msg_with_imgs(msg, images)
else:
return {"role": "user", "content": msg}
def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]):
"""
images: can be list of http(s) url or base64
"""
if isinstance(images, str):
images = [images]
content = [{"type": "text", "text": msg}]
for image in images:
# image url or image base64
url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}"
# it can with multiple-image inputs
content.append({"type": "image_url", "image_url": url})
return {"role": "user", "content": content}
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "assistant", "content": msg}
@ -59,6 +77,7 @@ class BaseLLM(ABC):
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
stream=True,
) -> str:
@ -70,7 +89,7 @@ class BaseLLM(ABC):
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
message.append(self._user_msg(msg, images=images))
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp

View file

@ -29,6 +29,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
@ -101,7 +102,7 @@ class OpenAILLM(BaseLLM):
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
"n": 1,
"stop": None,
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": 0.3,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
@ -240,3 +241,24 @@ class OpenAILLM(BaseLLM):
async def aspeech_to_text(self, **kwargs):
"""speech to text"""
return await self.aclient.audio.transcriptions.create(**kwargs)
async def gen_image(
self,
prompt: str,
size: str = "1024x1024",
quality: str = "standard",
model: str = None,
resp_format: str = "url",
) -> list["Image"]:
"""image generate"""
assert resp_format in ["url", "b64_json"]
if not model:
model = self.model
res = await self.aclient.images.generate(
model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format
)
imgs = []
for item in res.data:
img_url_or_b64 = item.url if resp_format == "url" else item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs

View file

@ -8,12 +8,12 @@
from typing import Optional
from pydantic import Field
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize, UserRequirement
from metagpt.document_store.base_store import BaseStore
from metagpt.roles import Role
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
class Sales(Role):
@ -29,14 +29,13 @@ class Sales(Role):
store: Optional[BaseStore] = Field(default=None, exclude=True)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._set_store(self.store)
def _set_store(self, store):
if store:
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
@model_validator(mode="after")
def validate_stroe(self):
if self.store:
search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy)
action = SearchAndSummarize(search_engine=search_engine, context=self.context)
else:
action = SearchAndSummarize()
action = SearchAndSummarize
self.set_actions([action])
self._watch([UserRequirement])
return self

View file

@ -8,7 +8,9 @@
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
"""
from pydantic import Field
from typing import Optional
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize
from metagpt.actions.action_node import ActionNode
@ -16,7 +18,7 @@ from metagpt.actions.action_output import ActionOutput
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
class Searcher(Role):
@ -28,33 +30,22 @@ class Searcher(Role):
profile (str): Role profile.
goal (str): Goal of the searcher.
constraints (str): Constraints or limitations for the searcher.
engine (SearchEngineType): The type of search engine to use.
search_engine (SearchEngine): The search engine to use.
"""
name: str = Field(default="Alice")
profile: str = Field(default="Smart Assistant")
goal: str = "Provide search services for users"
constraints: str = "Answer is rich and complete"
engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
search_engine: Optional[SearchEngine] = None
def __init__(self, **kwargs) -> None:
"""
Initializes the Searcher role with given attributes.
Args:
name (str): Name of the searcher.
profile (str): Role profile.
goal (str): Goal of the searcher.
constraints (str): Constraints or limitations for the searcher.
engine (SearchEngineType): The type of search engine to use.
"""
super().__init__(**kwargs)
self.set_actions([SearchAndSummarize(engine=self.engine)])
def set_search_func(self, search_func):
"""Sets a custom search function for the searcher."""
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func)
self.set_actions([action])
@model_validator(mode="after")
def post_root(self):
if self.search_engine:
self.set_actions([SearchAndSummarize(search_engine=self.search_engine, context=self.context)])
else:
self.set_actions([SearchAndSummarize])
return self
async def _act_sp(self) -> Message:
"""Performs the search action in a single process."""

View file

@ -8,14 +8,23 @@
import importlib
from typing import Callable, Coroutine, Literal, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from semantic_kernel.skill_definition import sk_function
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
class SkSearchEngine:
def __init__(self):
self.search_engine = SearchEngine()
"""A search engine class for executing searches.
Attributes:
search_engine: The search engine instance used for executing searches.
"""
def __init__(self, **kwargs):
self.search_engine = SearchEngine(**kwargs)
@sk_function(
description="searches results from Google. Useful when you need to find short "
@ -28,43 +37,85 @@ class SkSearchEngine:
return result
class SearchEngine:
"""Class representing a search engine.
Args:
engine: The search engine type. Defaults to the search engine specified in the config.
run_func: The function to run the search. Defaults to None.
class SearchEngine(BaseModel):
"""A model for configuring and executing searches with different search engines.
Attributes:
run_func: The function to run the search.
engine: The search engine type.
model_config: Configuration for the model allowing arbitrary types.
engine: The type of search engine to use.
run_func: An optional callable for running the search. If not provided, it will be determined based on the engine.
api_key: An optional API key for the search engine.
proxy: An optional proxy for the search engine requests.
"""
def __init__(
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
api_key: Optional[str] = None
proxy: Optional[str] = None
@model_validator(mode="after")
def validate_extra(self):
"""Validates extra fields provided to the model and updates the run function accordingly."""
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
if self.model_extra:
data.update(self.model_extra)
self._process_extra(**data)
return self
def _process_extra(
self,
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None,
**kwargs,
):
if engine == SearchEngineType.SERPAPI_GOOGLE:
"""Processes extra configuration and updates the run function based on the search engine type.
Args:
run_func: An optional callable for running the search. If not provided, it will be determined based on the engine.
"""
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
elif engine == SearchEngineType.SERPER_GOOGLE:
elif self.engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
elif engine == SearchEngineType.DIRECT_GOOGLE:
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
elif engine == SearchEngineType.DUCK_DUCK_GO:
elif self.engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
run_func = self.run_func
else:
raise NotImplementedError
self.engine = engine
self.run_func = run_func
@classmethod
def from_search_config(cls, config: SearchConfig, **kwargs):
"""Creates a SearchEngine instance from a SearchConfig.
Args:
config: The search configuration to use for creating the SearchEngine instance.
"""
data = config.model_dump(exclude={"api_type", "search_func"})
if config.search_func is not None:
data["run_func"] = config.search_func
return cls(engine=config.api_type, **data, **kwargs)
@classmethod
def from_search_func(
cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs
):
"""Creates a SearchEngine instance from a custom search function.
Args:
search_func: A callable that executes the search.
"""
return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs)
@overload
def run(
self,
@ -83,15 +134,29 @@ class SearchEngine:
) -> list[dict[str, str]]:
...
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]:
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
ignore_errors: bool = False,
) -> Union[str, list[dict[str, str]]]:
"""Run a search query.
Args:
query: The search query.
max_results: The maximum number of results to return. Defaults to 8.
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
ignore_errors: Whether to ignore errors during the search. Defaults to False.
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)
try:
return await self.run_func(query, max_results=max_results, as_string=as_string)
except Exception as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
if not ignore_errors:
raise e
return "" if as_string else []

View file

@ -5,9 +5,9 @@ from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload
from typing import Literal, Optional, overload
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict
try:
from duckduckgo_search import DDGS
@ -18,24 +18,16 @@ except ImportError:
)
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
class DDGAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
proxy: Optional[str] = None
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if config.proxy:
kwargs["proxies"] = config.proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@property
def ddgs(self):
return DDGS(proxies=self.proxy)
@overload
def run(

View file

@ -4,19 +4,16 @@ from __future__ import annotations
import asyncio
import json
import warnings
from concurrent import futures
from typing import Optional
from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from metagpt.logs import logger
from pydantic import BaseModel, ConfigDict, model_validator
try:
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError(
"To use this module, you should have the `google-api-python-client` Python package installed. "
@ -27,40 +24,41 @@ except ImportError:
class GoogleAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
google_api_key: Optional[str] = Field(default=None, validate_default=True)
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
api_key: str
cse_id: str
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
proxy: Optional[str] = None
@field_validator("google_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_google(cls, values: dict) -> dict:
if "google_api_key" in values:
values.setdefault("api_key", values["google_api_key"])
warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
"To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://console.cloud.google.com/apis/credentials."
)
return val
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or config.search.cse_id
if not val:
if "google_cse_id" in values:
values.setdefault("cse_id", values["google_cse_id"])
warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2)
if "cse_id" not in values:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
"To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain "
"the cse_id from https://programmablesearchengine.google.com/controlpanel/create."
)
return val
return values
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if config.proxy:
parse_result = urlparse(config.proxy)
build_kwargs = {"developerKey": self.api_key}
if self.proxy:
parse_result = urlparse(self.proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"
@ -96,17 +94,11 @@ class GoogleAPIWrapper(BaseModel):
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute
)
try:
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]

View file

@ -5,18 +5,17 @@
@Author : alexanderwu
@File : search_engine_serpapi.py
"""
import warnings
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict, Field, model_validator
class SerpAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
search_engine: Any = None #: :meta private:
api_key: str
params: dict = Field(
default_factory=lambda: {
"engine": "google",
@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
# should add `validate_default=True` to check with default value
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@field_validator("serpapi_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_serpapi(cls, values: dict) -> dict:
if "serpapi_api_key" in values:
values.setdefault("api_key", values["serpapi_api_key"])
warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
"an API key from https://serpapi.com/."
"To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain"
" an API key from https://serpapi.com/."
)
return val
return values
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
@ -60,11 +60,11 @@ class SerpAPIWrapper(BaseModel):
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
async with session.get(url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params) as response:
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
@ -73,7 +73,7 @@ class SerpAPIWrapper(BaseModel):
def get_params(self, query: str) -> Dict[str, str]:
"""Get parameters for SerpAPI."""
_params = {
"api_key": self.serpapi_api_key,
"api_key": self.api_key,
"q": query,
}
params = {**self.params, **_params}

View file

@ -6,33 +6,34 @@
@File : search_engine_serpapi.py
"""
import json
import warnings
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict, Field, model_validator
class SerperWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
search_engine: Any = None #: :meta private:
api_key: str
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@field_validator("serper_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_serper(cls, values: dict) -> dict:
if "serper_api_key" in values:
values.setdefault("api_key", values["serper_api_key"])
warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://serper.dev/."
)
return val
return values
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
@ -54,11 +55,11 @@ class SerperWrapper(BaseModel):
url, payloads, headers = construct_url_and_payload_and_headers()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers) as response:
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
@ -76,7 +77,7 @@ class SerperWrapper(BaseModel):
return json.dumps(payloads, sort_keys=True)
def get_headers(self) -> Dict[str, str]:
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
return headers
@staticmethod

View file

@ -1,36 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, overload
from typing import Any, Callable, Coroutine, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
if engine is None:
raise NotImplementedError
class WebBrowserEngine(BaseModel):
"""Defines a web browser engine configuration for automated browsing and data extraction.
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
This class encapsulates the configuration and operational logic for different web browser engines,
such as Playwright, Selenium, or custom implementations. It provides a unified interface to run
browser automation tasks.
Attributes:
model_config: Configuration dictionary allowing arbitrary types and extra fields.
engine: The type of web browser engine to use.
run_func: An optional coroutine function to run the browser engine.
proxy: An optional proxy server URL to use with the browser engine.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
proxy: Optional[str] = None
@model_validator(mode="after")
def validate_extra(self):
"""Validates and processes extra configuration data after model initialization.
This method is automatically called by Pydantic to validate and process any extra configuration
data provided to the model. It ensures that the extra data is properly integrated into the model's
configuration and operational logic.
Returns:
The instance itself after processing the extra data.
"""
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
if self.model_extra:
data.update(self.model_extra)
self._process_extra(**data)
return self
def _process_extra(self, **kwargs):
"""Processes extra configuration data to set up the browser engine run function.
Depending on the specified engine type, this method dynamically imports and configures
the appropriate browser engine wrapper and its run function.
Args:
**kwargs: Arbitrary keyword arguments representing extra configuration data.
Raises:
NotImplementedError: If the engine type is not supported.
"""
if self.engine is WebBrowserEngineType.PLAYWRIGHT:
module = "metagpt.tools.web_browser_engine_playwright"
run_func = importlib.import_module(module).PlaywrightWrapper().run
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run
elif self.engine is WebBrowserEngineType.SELENIUM:
module = "metagpt.tools.web_browser_engine_selenium"
run_func = importlib.import_module(module).SeleniumWrapper().run
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
run_func = run_func
run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run
elif self.engine is WebBrowserEngineType.CUSTOM:
run_func = self.run_func
else:
raise NotImplementedError
self.run_func = run_func
self.engine = engine
@classmethod
def from_browser_config(cls, config: BrowserConfig, **kwargs):
"""Creates a WebBrowserEngine instance from a BrowserConfig object and additional keyword arguments.
This class method facilitates the creation of a WebBrowserEngine instance by extracting
configuration data from a BrowserConfig object and optionally merging it with additional
keyword arguments.
Args:
config: A BrowserConfig object containing base configuration data.
**kwargs: Optional additional keyword arguments to override or extend the configuration.
Returns:
A new instance of WebBrowserEngine configured according to the provided arguments.
"""
data = config.model_dump()
return cls(**data, **kwargs)
@overload
async def run(self, url: str) -> WebPage:
@ -41,4 +100,16 @@ class WebBrowserEngine:
...
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
"""Runs the browser engine to load one or more web pages.
This method is the implementation of the overloaded run signatures. It delegates the task
of loading web pages to the configured run function, handling either a single URL or multiple URLs.
Args:
url: The URL of the first web page to load.
*urls: Additional URLs of web pages to load, if any.
Returns:
A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided.
"""
return await self.run_func(url, *urls)

View file

@ -6,16 +6,16 @@ from __future__ import annotations
import asyncio
import sys
from pathlib import Path
from typing import Literal
from typing import Literal, Optional
from playwright.async_api import async_playwright
from pydantic import BaseModel, Field, PrivateAttr
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.parse_html import WebPage
class PlaywrightWrapper:
class PlaywrightWrapper(BaseModel):
"""Wrapper around Playwright.
To use this module, you should have the `playwright` Python package installed and ensure that
@ -24,24 +24,23 @@ class PlaywrightWrapper:
command `playwright install` for the first time.
"""
def __init__(
self,
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
launch_kwargs: dict | None = None,
**kwargs,
) -> None:
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if config.proxy and "proxy" not in launch_kwargs:
browser_type: Literal["chromium", "firefox", "webkit"] = "chromium"
launch_kwargs: dict = Field(default_factory=dict)
proxy: Optional[str] = None
context_kwargs: dict = Field(default_factory=dict)
_has_run_precheck: bool = PrivateAttr(False)
def __init__(self, **kwargs):
super().__init__(**kwargs)
launch_kwargs = self.launch_kwargs
if self.proxy and "proxy" not in launch_kwargs:
args = launch_kwargs.get("args", [])
if not any(str.startswith(i, "--proxy-server=") for i in args):
launch_kwargs["proxy"] = {"server": config.proxy}
self.launch_kwargs = launch_kwargs
context_kwargs = {}
launch_kwargs["proxy"] = {"server": self.proxy}
if "ignore_https_errors" in kwargs:
context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
self._context_kwargs = context_kwargs
self._has_run_precheck = False
self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async with async_playwright() as ap:
@ -55,7 +54,7 @@ class PlaywrightWrapper:
return await _scrape(browser, url)
async def _scrape(self, browser, url):
context = await browser.new_context(**self._context_kwargs)
context = await browser.new_context(**self.context_kwargs)
page = await context.new_page()
async with page:
try:
@ -75,8 +74,8 @@ class PlaywrightWrapper:
executable_path = Path(browser_type.executable_path)
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
kwargs = {}
if config.proxy:
kwargs["env"] = {"ALL_PROXY": config.proxy}
if self.proxy:
kwargs["env"] = {"ALL_PROXY": self.proxy}
await _install_browsers(self.browser_type, **kwargs)
if self._has_run_precheck:

View file

@ -7,19 +7,19 @@ import asyncio
import importlib
from concurrent import futures
from copy import deepcopy
from typing import Literal
from typing import Callable, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.core.download_manager import WDMDownloadManager
from webdriver_manager.core.http import WDMHttpClient
from metagpt.config2 import config
from metagpt.utils.parse_html import WebPage
class SeleniumWrapper:
class SeleniumWrapper(BaseModel):
"""Wrapper around Selenium.
To use this module, you should check the following:
@ -31,25 +31,28 @@ class SeleniumWrapper:
can scrape web pages using the Selenium WebBrowserEngine.
"""
def __init__(
self,
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
launch_kwargs: dict | None = None,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
) -> None:
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if config.proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = config.proxy
model_config = ConfigDict(arbitrary_types_allowed=True)
self.executable_path = launch_kwargs.pop("executable_path", None)
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
self._has_run_precheck = False
self._get_driver = None
self.loop = loop
self.executor = executor
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
launch_kwargs: dict = Field(default_factory=dict)
proxy: Optional[str] = None
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
_has_run_precheck: bool = PrivateAttr(False)
_get_driver: Optional[Callable] = PrivateAttr(None)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
if self.proxy and "proxy-server" not in self.launch_kwargs:
self.launch_kwargs["proxy-server"] = self.proxy
@property
def launch_args(self):
return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"]
@property
def executable_path(self):
return self.launch_kwargs.get("executable_path")
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
await self._run_precheck()
@ -66,7 +69,9 @@ class SeleniumWrapper:
self.loop = self.loop or asyncio.get_event_loop()
self._get_driver = await self.loop.run_in_executor(
self.executor,
lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path),
lambda: _gen_get_driver_func(
self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy
),
)
self._has_run_precheck = True
@ -92,13 +97,17 @@ _webdriver_manager_types = {
class WDMHttpProxyClient(WDMHttpClient):
def __init__(self, proxy: str = None):
super().__init__()
self.proxy = proxy
def get(self, url, **kwargs):
if "proxies" not in kwargs and config.proxy:
kwargs["proxies"] = {"all_proxy": config.proxy}
if "proxies" not in kwargs and self.proxy:
kwargs["proxies"] = {"all_proxy": self.proxy}
return super().get(url, **kwargs)
def _gen_get_driver_func(browser_type, *args, executable_path=None):
def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None):
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options")
@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if not executable_path:
module_name, type_name = _webdriver_manager_types[browser_type]
DriverManager = getattr(importlib.import_module(module_name), type_name)
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy)))
# driver_manager.driver_cache.find_driver(driver_manager.driver))
executable_path = driver_manager.install()

View file

@ -12,7 +12,9 @@
from __future__ import annotations
import ast
import base64
import contextlib
import csv
import importlib
import inspect
import json
@ -22,12 +24,15 @@ import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, List, Tuple, Union
from typing import Any, Callable, List, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
import requests
from PIL import Image
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, RetryError, _utils
@ -336,6 +341,14 @@ def print_members(module, indent=0):
print(f"{prefix}Method: {name}")
def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]:
sig = inspect.signature(func)
parameters = sig.parameters
return_type = sig.return_annotation
param_schema = {name: parameter.annotation for name, parameter in parameters.items()}
return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func}
def parse_recipient(text):
# FIXME: use ActionNode instead.
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
@ -583,6 +596,29 @@ def write_json_file(json_file: str, data: list, encoding=None):
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
"""
Reads in a csv file to a list of list. If header is True, it returns a
tuple with (header row, all rows)
ARGS:
curr_file: path to the current csv file.
RETURNS:
List of list where the component lists are the rows of the file.
"""
logger.debug(f"start read csv: {curr_file}")
analysis_list = []
with open(curr_file) as f_analysis_file:
data_reader = csv.reader(f_analysis_file, delimiter=",")
for count, row in enumerate(data_reader):
if strip_trail:
row = [i.strip() for i in row]
analysis_list += [row]
if not header:
return analysis_list
else:
return analysis_list[0], analysis_list[1:]
def import_class(class_name: str, module_name: str) -> type:
module = importlib.import_module(module_name)
a_class = getattr(module, class_name)
@ -748,3 +784,45 @@ async def awrite_bin(filename: str | Path, data: bytes):
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="wb") as writer:
await writer.write(data)
def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]:
"""load mincraft skill from js files"""
if not skills_dir:
skills_dir = Path(__file__).parent.absolute()
if skill_names is None:
skill_names = [skill[:-3] for skill in os.listdir(f"{skills_dir}") if skill.endswith(".js")]
skills = [skills_dir.joinpath(f"{skill_name}.js").read_text() for skill_name in skill_names]
return skills
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
"""encode image from file or PIL.Image into base64"""
if isinstance(image_path_or_pil, Image.Image):
buffer = BytesIO()
image_path_or_pil.save(buffer, format="JPEG")
bytes_data = buffer.getvalue()
else:
if not image_path_or_pil.exists():
raise FileNotFoundError(f"{image_path_or_pil} not exists")
with open(str(image_path_or_pil), "rb") as image_file:
bytes_data = image_file.read()
return base64.b64encode(bytes_data).decode(encoding)
def decode_image(img_url_or_b64: str) -> Image:
"""decode image from url or base64 into PIL.Image"""
if img_url_or_b64.startswith("http"):
# image http(s) url
resp = requests.get(img_url_or_b64)
img = Image.open(BytesIO(resp.content))
else:
# image b64_json
b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64)
img_data = BytesIO(base64.b64decode(b64_data))
img = Image.open(img_data)
return img

View file

@ -60,23 +60,22 @@ class DependencyFile:
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
key = Path(filename).relative_to(root).as_posix()
except ValueError:
key = filename
skey = re.sub(r"\\+", "/", str(key)) # Compatible with windows path
key = str(key)
if dependencies:
relative_paths = []
for i in dependencies:
try:
s = str(Path(i).relative_to(root))
s = str(Path(i).relative_to(root).as_posix())
except ValueError:
s = str(i)
s = re.sub(r"\\+", "/", s) # Compatible with windows path
relative_paths.append(s)
self._dependencies[skey] = relative_paths
elif skey in self._dependencies:
del self._dependencies[skey]
self._dependencies[key] = relative_paths
elif key in self._dependencies:
del self._dependencies[key]
if persist:
await self.save()
@ -93,7 +92,7 @@ class DependencyFile:
root = self._filename.parent
try:
key = Path(filename).relative_to(root)
key = Path(filename).relative_to(root).as_posix()
except ValueError:
key = filename
return set(self._dependencies.get(str(key), {}))

View file

@ -29,6 +29,7 @@ TOKEN_COSTS = {
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
@ -57,6 +58,7 @@ TOKEN_MAX = {
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768,
@ -85,6 +87,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
}:
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
@ -115,7 +118,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
content = value
if isinstance(value, list):
# for gpt-4v
for item in value:
if isinstance(item, dict) and item.get("type") in ["text"]:
content = item.get("text", "")
num_tokens += len(encoding.encode(content))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>