diff --git a/.github/workflows/auto-unittest.yaml b/.github/workflows/fulltest.yaml similarity index 57% rename from .github/workflows/auto-unittest.yaml rename to .github/workflows/fulltest.yaml index a58163b4d..68d3c382f 100644 --- a/.github/workflows/auto-unittest.yaml +++ b/.github/workflows/fulltest.yaml @@ -1,16 +1,19 @@ -name: Auto Unit Tests +name: Unit Tests on: + workflow_dispatch: pull_request_target: push: branches: - 'main' - 'dev' - '*-release' + - '*-debugger' jobs: build: runs-on: ubuntu-latest + environment: unittest strategy: matrix: # python-version: ['3.9', '3.10', '3.11'] @@ -28,10 +31,30 @@ jobs: - name: Install dependencies run: | sh tests/scripts/run_install_deps.sh + - name: Run reverse proxy script for ssh service + if: contains(github.ref, '-debugger') + continue-on-error: true + env: + FPR_SERVER_ADDR: ${{ secrets.FPR_SERVER_ADDR }} + FPR_TOKEN: ${{ secrets.FPR_TOKEN }} + FPR_SSH_REMOTE_PORT: ${{ secrets.FPR_SSH_REMOTE_PORT }} + RSA_PUB: ${{ secrets.RSA_PUB }} + SSH_PORT: ${{ vars.SSH_PORT || '22'}} + run: | + echo "Run \"ssh $(whoami)@FPR_SERVER_HOST -p FPR_SSH_REMOTE_PORT\" and \"cd $(pwd)\"" + mkdir -p ~/.ssh/ + echo $RSA_PUB >> ~/.ssh/authorized_keys + chmod 600 ~/.ssh/authorized_keys + wget https://github.com/fatedier/frp/releases/download/v0.32.1/frp_0.32.1_linux_amd64.tar.gz -O frp.tar.gz + tar xvzf frp.tar.gz -C /opt + mv /opt/frp* /opt/frp + /opt/frp/frpc tcp --server_addr $FPR_SERVER_ADDR --token $FPR_TOKEN --local_port $SSH_PORT --remote_port $FPR_SSH_REMOTE_PORT - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 - mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml + echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml + mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml + echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 68d3c382f..2e7e3ce2b 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -1,19 +1,16 @@ name: Unit Tests on: - workflow_dispatch: pull_request_target: push: branches: - 'main' - 'dev' - '*-release' - - '*-debugger' jobs: build: runs-on: ubuntu-latest - environment: unittest strategy: matrix: # python-version: ['3.9', '3.10', '3.11'] @@ -31,30 +28,10 @@ jobs: - name: Install dependencies run: | sh tests/scripts/run_install_deps.sh - - name: Run reverse proxy script for ssh service - if: contains(github.ref, '-debugger') - continue-on-error: true - env: - FPR_SERVER_ADDR: ${{ secrets.FPR_SERVER_ADDR }} - FPR_TOKEN: ${{ secrets.FPR_TOKEN }} - FPR_SSH_REMOTE_PORT: ${{ secrets.FPR_SSH_REMOTE_PORT }} - RSA_PUB: ${{ secrets.RSA_PUB }} - SSH_PORT: ${{ vars.SSH_PORT || '22'}} - run: | - echo "Run \"ssh $(whoami)@FPR_SERVER_HOST -p FPR_SSH_REMOTE_PORT\" and \"cd $(pwd)\"" - mkdir -p ~/.ssh/ - echo $RSA_PUB >> ~/.ssh/authorized_keys - chmod 600 ~/.ssh/authorized_keys - wget https://github.com/fatedier/frp/releases/download/v0.32.1/frp_0.32.1_linux_amd64.tar.gz -O frp.tar.gz - tar xvzf frp.tar.gz -C /opt - mv /opt/frp* /opt/frp - /opt/frp/frpc tcp --server_addr $FPR_SERVER_ADDR --token $FPR_TOKEN --local_port $SSH_PORT --remote_port $FPR_SSH_REMOTE_PORT - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 - echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml - mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml - echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml + mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/config/config2.yaml.example b/config/config2.yaml.example index bead3c626..8f4a33fc1 100644 --- a/config/config2.yaml.example +++ b/config/config2.yaml.example @@ -11,6 +11,10 @@ search: api_key: "YOUR_API_KEY" cse_id: "YOUR_CSE_ID" +browser: + engine: "playwright" # playwright/selenium + browser_type: "chromium" # playwright: chromium/firefox/webkit; selenium: chrome/firefox/edge/ie + mermaid: engine: "pyppeteer" path: "/Applications/Google Chrome.app" diff --git a/examples/dalle_gpt4v_agent.py b/examples/dalle_gpt4v_agent.py new file mode 100644 index 000000000..28215dba3 --- /dev/null +++ b/examples/dalle_gpt4v_agent.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : use gpt4v to improve prompt and draw image with dall-e-3 + +"""set `model: "gpt-4-vision-preview"` in `config2.yaml` first""" + +import asyncio + +from PIL import Image + +from metagpt.actions.action import Action +from metagpt.logs import logger +from metagpt.roles.role import Role +from metagpt.schema import Message +from metagpt.utils.common import encode_image + + +class GenAndImproveImageAction(Action): + save_image: bool = True + + async def generate_image(self, prompt: str) -> Image: + imgs = await self.llm.gen_image(model="dall-e-3", prompt=prompt) + return imgs[0] + + async def refine_prompt(self, old_prompt: str, image: Image) -> str: + msg = ( + f"You are a creative painter, with the given generated image and old prompt: {old_prompt}, " + f"please refine the prompt and generate new one. Just output the new prompt." + ) + b64_img = encode_image(image) + new_prompt = await self.llm.aask(msg=msg, images=[b64_img]) + return new_prompt + + async def evaluate_images(self, old_prompt: str, images: list[Image]) -> str: + msg = ( + "With the prompt and two generated image, to judge if the second one is better than the first one. " + "If so, just output True else output False" + ) + b64_imgs = [encode_image(img) for img in images] + res = await self.llm.aask(msg=msg, images=b64_imgs) + return res + + async def run(self, messages: list[Message]) -> str: + prompt = messages[-1].content + + old_img: Image = await self.generate_image(prompt) + new_prompt = await self.refine_prompt(old_prompt=prompt, image=old_img) + logger.info(f"original prompt: {prompt}") + logger.info(f"refined prompt: {new_prompt}") + new_img: Image = await self.generate_image(new_prompt) + if self.save_image: + old_img.save("./img_by-dall-e_old.png") + new_img.save("./img_by-dall-e_new.png") + res = await self.evaluate_images(old_prompt=prompt, images=[old_img, new_img]) + opinion = f"The second generated image is better than the first one: {res}" + logger.info(f"evaluate opinion: {opinion}") + return opinion + + +class Painter(Role): + name: str = "MaLiang" + profile: str = "Painter" + goal: str = "to generate fine painting" + + def __init__(self, **data): + super().__init__(**data) + + self.set_actions([GenAndImproveImageAction]) + + +async def main(): + role = Painter() + await role.run(with_message="a girl with flowers") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 219a303c8..1d132eb8a 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -6,9 +6,11 @@ @File : llm_hello_world.py """ import asyncio +from pathlib import Path from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.utils.common import encode_image async def main(): @@ -27,6 +29,12 @@ async def main(): if hasattr(llm, "completion"): logger.info(llm.completion(hello_msg)) + # check if the configured llm supports llm-vision capacity. If not, it will throw a error + invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) + assert "true" in res.lower() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 9406a2965..97b1378ee 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -5,17 +5,20 @@ import asyncio from metagpt.roles import Searcher -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine, SearchEngineType async def main(): question = "What are the most interesting human facts?" + kwargs = {"api_key": "", "cse_id": "", "proxy": None} # Serper API - # await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question) # SerpAPI - await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question) # Google API - # await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question) + # await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question) + # DDG API + await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question) if __name__ == "__main__": diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index a3efb214e..09da4a988 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -413,12 +413,13 @@ class ActionNode: prompt: str, output_class_name: str, output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format timeout=3, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) logger.debug(f"llm raw output:\n{content}") output_class = self.create_model_class(output_class_name, output_data_mapping) @@ -447,13 +448,15 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=3, exclude=None): + async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None): prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + content, scontent = await self._aask_v1( + prompt, class_name, mapping, images=images, schema=schema, timeout=timeout + ) self.content = content self.instruct_content = scontent else: @@ -462,7 +465,17 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]): + async def fill( + self, + context, + llm, + schema="json", + mode="auto", + strgy="simple", + images: Optional[Union[str, list[str]]] = None, + timeout=3, + exclude=[], + ): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -478,6 +491,7 @@ class ActionNode: :param strgy: simple/complex - simple: run only once - complex: run each node + :param images: the list of image url or base64 for gpt4-v :param timeout: Timeout for llm invocation. :param exclude: The keys of ActionNode to exclude. :return: self @@ -488,14 +502,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): if exclude and i.key in exclude: continue - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.model_dump()) cls = self._create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2755628c9..2ebeadb66 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -3,15 +3,15 @@ from __future__ import annotations import asyncio -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union -from pydantic import Field, parse_obj_as +from pydantic import TypeAdapter, model_validator from metagpt.actions import Action from metagpt.config2 import config from metagpt.logs import logger from metagpt.tools.search_engine import SearchEngine -from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType +from metagpt.tools.web_browser_engine import WebBrowserEngine from metagpt.utils.common import OutputParser from metagpt.utils.text import generate_prompt_chunk, reduce_message_length @@ -81,10 +81,16 @@ class CollectLinks(Action): name: str = "CollectLinks" i_context: Optional[str] = None desc: str = "Collect links from a search engine." - - search_engine: SearchEngine = Field(default_factory=SearchEngine) + search_func: Optional[Any] = None + search_engine: Optional[SearchEngine] = None rank_func: Optional[Callable[[list[str]], None]] = None + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.search_engine is None: + self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy) + return self + async def run( self, topic: str, @@ -107,7 +113,7 @@ class CollectLinks(Action): keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) try: keywords = OutputParser.extract_struct(keywords, list) - keywords = parse_obj_as(list[str], keywords) + keywords = TypeAdapter(list[str]).validate_python(keywords) except Exception as e: logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}") keywords = [topic] @@ -133,7 +139,7 @@ class CollectLinks(Action): queries = await self._aask(prompt, [system_text]) try: queries = OutputParser.extract_struct(queries, list) - queries = parse_obj_as(list[str], queries) + queries = TypeAdapter(list[str]).validate_python(queries) except Exception as e: logger.exception(f"fail to break down the research question due to {e}") queries = keywords @@ -178,15 +184,17 @@ class WebBrowseAndSummarize(Action): i_context: Optional[str] = None desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None - web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT + web_browser_engine: Optional[WebBrowserEngine] = None - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.web_browser_engine = WebBrowserEngine( - engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT, - run_func=self.browse_func, - ) + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.web_browser_engine is None: + self.web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + browse_func=self.browse_func, + proxy=self.config.proxy, + ) + return self async def run( self, diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 59b35cd58..7eed7381b 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : search_google.py """ -from typing import Any, Optional +from typing import Optional import pydantic from pydantic import model_validator @@ -13,7 +13,6 @@ from pydantic import model_validator from metagpt.actions import Action from metagpt.logs import logger from metagpt.schema import Message -from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements @@ -105,21 +104,19 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - engine: Optional[SearchEngineType] = None - search_func: Optional[Any] = None search_engine: SearchEngine = None result: str = "" @model_validator(mode="after") - def validate_engine_and_run_func(self): - if self.engine is None: - self.engine = self.config.search_engine - try: - search_engine = SearchEngine(engine=self.engine, run_func=self.search_func) - except pydantic.ValidationError: - search_engine = None + def validate_search_engine(self): + if self.search_engine is None: + try: + config = self.config + search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy) + except pydantic.ValidationError: + search_engine = None - self.search_engine = search_engine + self.search_engine = search_engine return self async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: diff --git a/metagpt/config2.py b/metagpt/config2.py index bbf574818..d983a43c3 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -51,7 +51,7 @@ class Config(CLIParams, YamlModel): proxy: str = "" # Tool Parameters - search: Optional[SearchConfig] = None + search: SearchConfig = SearchConfig() browser: BrowserConfig = BrowserConfig() mermaid: MermaidConfig = MermaidConfig() diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py index 00f918735..2f8024f44 100644 --- a/metagpt/configs/browser_config.py +++ b/metagpt/configs/browser_config.py @@ -15,6 +15,6 @@ class BrowserConfig(YamlModel): """Config for Browser""" engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT - browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome" - driver: Literal["chromium", "firefox", "webkit"] = "chromium" - path: str = "" + browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium" + """If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value + should be either "chrome", "firefox", "edge", or "ie".""" diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py index a8ae918db..af928b02a 100644 --- a/metagpt/configs/search_config.py +++ b/metagpt/configs/search_config.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : search_config.py """ +from typing import Callable, Optional + from metagpt.tools import SearchEngineType from metagpt.utils.yaml_model import YamlModel @@ -12,6 +14,7 @@ from metagpt.utils.yaml_model import YamlModel class SearchConfig(YamlModel): """Config for Search""" - api_key: str - api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO + api_key: str = "" cse_id: str = "" # for google + search_func: Optional[Callable] = None diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index bdf2d0734..59daa692f 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -7,7 +7,7 @@ """ from typing import Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from metagpt.config2 import Config from metagpt.context import Context @@ -17,7 +17,7 @@ from metagpt.provider.base_llm import BaseLLM class ContextMixin(BaseModel): """Mixin class for context and config""" - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") # Pydantic has bug on _private_attr when using inheritance, so we use private_* instead # - https://github.com/pydantic/pydantic/issues/7142 @@ -32,18 +32,17 @@ class ContextMixin(BaseModel): # Env/Role/Action will use this llm as private llm, or use self.context._llm instance private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) - def __init__( - self, - context: Optional[Context] = None, - config: Optional[Config] = None, - llm: Optional[BaseLLM] = None, - **kwargs, - ): - """Initialize with config""" - super().__init__(**kwargs) - self.set_context(context) - self.set_config(config) - self.set_llm(llm) + @model_validator(mode="after") + def validate_context_mixin_extra(self): + self._process_context_mixin_extra() + return self + + def _process_context_mixin_extra(self): + """Process the extra field""" + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) def set(self, k, v, override=False): """Set attribute""" diff --git a/metagpt/environment/README.md b/metagpt/environment/README.md new file mode 100644 index 000000000..9476ac75a --- /dev/null +++ b/metagpt/environment/README.md @@ -0,0 +1,38 @@ +Here is a environment description of MetaGPT env for different situation. +For now, the code only define the environment and still some todos like migrate roles/actions to current version. + +## Function +- Define `ExtEnv`(Base Class) which help users to integrate with external environment like games through apis or construct the game logics. +- Define `Environment`(Base Class) which is the env that MetaGPT directly used. And it includes roles and so on. +- Define the `EnvAPIRegistry` to mark the read/write apis that `ExtEnv` provide observe/step ability. And then, users can call the particular one to get observation from env or feedback to env. + +## Usage + +init environment +``` +android_env = env.create(EnvType.ANDROID) + +assistant = Role(name="Bob", profile="android assistant") +team = Team(investment=10.0, env=android_env, roles=[assistant]) +``` + +observe & step inside role's actions +``` +from metagpt.environment.api.env_api import EnvAPIAbstract + +# get screenshot from ExtEnv +screenshot_path: Path = env.observe( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + +# do a `tap` action on the screen +res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y})) +``` + +## TODO +- add android app operation assistant under `examples/android_assistant` +- migrate roles/actions of werewolf game from old version into current version +- migrate roles/actions of mincraft game from old version into current version +- migrate roles/actions of stanford_town game from old version into current version diff --git a/metagpt/environment/__init__.py b/metagpt/environment/__init__.py new file mode 100644 index 000000000..692672fa7 --- /dev/null +++ b/metagpt/environment/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.environment.base_env import Environment +from metagpt.environment.android_env.android_env import AndroidEnv +from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv +from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv +from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv +from metagpt.environment.software_env.software_env import SoftwareEnv + + +__all__ = ["AndroidEnv", "MincraftExtEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"] diff --git a/metagpt/environment/android_env/__init__.py b/metagpt/environment/android_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/android_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/android_env/android_env.py b/metagpt/environment/android_env/android_env.py new file mode 100644 index 000000000..c27e20541 --- /dev/null +++ b/metagpt/environment/android_env/android_env.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Android Env + +from pydantic import Field + +from metagpt.environment.android_env.android_ext_env import AndroidExtEnv +from metagpt.environment.base_env import Environment + + +class AndroidEnv(Environment, AndroidExtEnv): + rows: int = Field(default=0, description="rows of a grid on the screenshot") + cols: int = Field(default=0, description="cols of a grid on the screenshot") diff --git a/metagpt/environment/android_env/android_ext_env.py b/metagpt/environment/android_env/android_ext_env.py new file mode 100644 index 000000000..b81b2cd26 --- /dev/null +++ b/metagpt/environment/android_env/android_ext_env.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Android external environment to integrate with Android apps + +import subprocess +from pathlib import Path +from typing import Any, Optional + +from pydantic import Field + +from metagpt.environment.android_env.const import ADB_EXEC_FAIL +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable + + +class AndroidExtEnv(ExtEnv): + device_id: Optional[str] = Field(default=None) + screenshot_dir: Optional[Path] = Field(default=None) + xml_dir: Optional[Path] = Field(default=None) + width: int = Field(default=720, description="device screen width") + height: int = Field(default=1080, description="device screen height") + + def __init__(self, **data: Any): + super().__init__(**data) + if data.get("device_id"): + (width, height) = self.device_shape + self.width = data.get("width", width) + self.height = data.get("height", height) + + @property + def adb_prefix_si(self): + """adb cmd prefix with `device_id` and `shell input`""" + return f"adb -s {self.device_id} shell input " + + @property + def adb_prefix_shell(self): + """adb cmd prefix with `device_id` and `shell`""" + return f"adb -s {self.device_id} shell " + + @property + def adb_prefix(self): + """adb cmd prefix with `device_id`""" + return f"adb -s {self.device_id} " + + def execute_adb_with_cmd(self, adb_cmd: str) -> str: + res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + exec_res = ADB_EXEC_FAIL + if not res.returncode: + exec_res = res.stdout.strip() + return exec_res + + @property + def device_shape(self) -> tuple[int, int]: + adb_cmd = f"{self.adb_prefix_shell} wm size" + shape = (0, 0) + shape_res = self.execute_adb_with_cmd(adb_cmd) + if shape_res != ADB_EXEC_FAIL: + shape = tuple(map(int, shape_res.split(": ")[1].split("x"))) + return shape + + def list_devices(self): + adb_cmd = "adb devices" + res = self.execute_adb_with_cmd(adb_cmd) + devices = [] + if res != ADB_EXEC_FAIL: + devices = res.split("\n")[1:] + devices = [device.split()[0] for device in devices] + return devices + + @mark_as_readable + def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path: + """ + ss_name: screenshot file name + local_save_dir: local dir to store image from virtual machine + """ + assert self.screenshot_dir + ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png") + ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}" + ss_res = self.execute_adb_with_cmd(ss_cmd) + + res = ADB_EXEC_FAIL + if ss_res != ADB_EXEC_FAIL: + ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png") + pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + if pull_res != ADB_EXEC_FAIL: + res = ss_local_path + return Path(res) + + @mark_as_readable + def get_xml(self, xml_name: str, local_save_dir: Path) -> Path: + xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml") + dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}" + xml_res = self.execute_adb_with_cmd(dump_cmd) + + res = ADB_EXEC_FAIL + if xml_res != ADB_EXEC_FAIL: + xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml") + pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + if pull_res != ADB_EXEC_FAIL: + res = xml_local_path + return Path(res) + + @mark_as_writeable + def system_back(self) -> str: + adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK" + back_res = self.execute_adb_with_cmd(adb_cmd) + return back_res + + @mark_as_writeable + def system_tap(self, x: int, y: int) -> str: + adb_cmd = f"{self.adb_prefix_si} tap {x} {y}" + tap_res = self.execute_adb_with_cmd(adb_cmd) + return tap_res + + @mark_as_writeable + def user_input(self, input_txt: str) -> str: + input_txt = input_txt.replace(" ", "%s").replace("'", "") + adb_cmd = f"{self.adb_prefix_si} text {input_txt}" + input_res = self.execute_adb_with_cmd(adb_cmd) + return input_res + + @mark_as_writeable + def user_longpress(self, x: int, y: int, duration: int = 500) -> str: + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}" + press_res = self.execute_adb_with_cmd(adb_cmd) + return press_res + + @mark_as_writeable + def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str: + dist_unit = int(self.width / 10) + if dist == "long": + dist_unit *= 3 + elif dist == "medium": + dist_unit *= 2 + + if orient == "up": + offset = 0, -2 * dist_unit + elif orient == "down": + offset = 0, 2 * dist_unit + elif orient == "left": + offset = -1 * dist_unit, 0 + elif orient == "right": + offset = dist_unit, 0 + else: + return ADB_EXEC_FAIL + + duration = 100 if if_quick else 400 + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res + + @mark_as_writeable + def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400): + adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res diff --git a/metagpt/environment/android_env/const.py b/metagpt/environment/android_env/const.py new file mode 100644 index 000000000..8811289bf --- /dev/null +++ b/metagpt/environment/android_env/const.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +# For Android Assistant Agent +ADB_EXEC_FAIL = "FAILED" diff --git a/metagpt/environment/api/__init__.py b/metagpt/environment/api/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/api/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/api/env_api.py b/metagpt/environment/api/env_api.py new file mode 100644 index 000000000..1e6df544d --- /dev/null +++ b/metagpt/environment/api/env_api.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the environment api store + +from typing import Any, Callable, Union + +from pydantic import BaseModel, Field + + +class EnvAPIAbstract(BaseModel): + """api/interface summary description""" + + api_name: str = Field(default="", description="the api function name or id") + args: set = Field(default={}, description="the api function `args` params") + kwargs: dict = Field(default=dict(), description="the api function `kwargs` params") + + +class EnvAPIRegistry(BaseModel): + """the registry to store environment w&r api/interface""" + + registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True) + + def get(self, api_name: str): + if api_name not in self.registry: + raise ValueError + return self.registry.get(api_name) + + def __getitem__(self, api_name: str) -> Callable: + return self.get(api_name) + + def __setitem__(self, api_name: str, func: Callable): + self.registry[api_name] = func + + def __len__(self): + return len(self.registry) + + def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]: + """return func schema without func instance""" + apis = dict() + for func_name, func_schema in self.registry.items(): + new_func_schema = dict() + for key, value in func_schema.items(): + if key == "func": + continue + new_func_schema[key] = str(value) if as_str else value + new_func_schema = new_func_schema + apis[func_name] = new_func_schema + return apis + + +class WriteAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass + + +class ReadAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass diff --git a/metagpt/environment.py b/metagpt/environment/base_env.py similarity index 50% rename from metagpt/environment.py rename to metagpt/environment/base_env.py index 5a2dd339b..9c195b023 100644 --- a/metagpt/environment.py +++ b/metagpt/environment/base_env.py @@ -1,29 +1,101 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 22:12 -@Author : alexanderwu -@File : environment.py -@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116: - 1. Remove the functionality of `Environment` class as a public message buffer. - 2. Standardize the message forwarding behavior of the `Environment` class. - 3. Add the `is_idle` property. -@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing - functionality is to be consolidated into the `Environment` class. -""" +# @Desc : base env of executing environment + import asyncio -from typing import Iterable, Set +from enum import Enum +from typing import TYPE_CHECKING, Any, Iterable, Optional, Set, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from metagpt.context import Context +from metagpt.environment.api.env_api import ( + EnvAPIAbstract, + ReadAPIRegistry, + WriteAPIRegistry, +) from metagpt.logs import logger -from metagpt.roles.role import Role from metagpt.schema import Message -from metagpt.utils.common import is_send_to +from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to + +if TYPE_CHECKING: + from metagpt.roles.role import Role # noqa: F401 -class Environment(BaseModel): +class EnvType(Enum): + ANDROID = "Android" + GYM = "Gym" + WEREWOLF = "Werewolf" + MINCRAFT = "Mincraft" + STANFORDTOWN = "StanfordTown" + + +env_write_api_registry = WriteAPIRegistry() +env_read_api_registry = ReadAPIRegistry() + + +def mark_as_readable(func): + """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" + env_read_api_registry[func.__name__] = get_function_schema(func) + return func + + +def mark_as_writeable(func): + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" + env_write_api_registry[func.__name__] = get_function_schema(func) + return func + + +class ExtEnv(BaseModel): + """External Env to intergate actual game environment""" + + def _check_api_exist(self, rw_api: Optional[str] = None): + if not rw_api: + raise ValueError(f"{rw_api} not exists") + + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + + async def observe(self, env_action: Union[str, EnvAPIAbstract]): + """get observation from particular api of ExtEnv""" + if isinstance(env_action, str): + read_api = env_read_api_registry.get(api_name=env_action)["func"] + self._check_api_exist(read_api) + if is_coroutine_func(read_api): + res = await read_api(self) + else: + res = read_api(self) + elif isinstance(env_action, EnvAPIAbstract): + read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] + self._check_api_exist(read_api) + if is_coroutine_func(read_api): + res = await read_api(self, *env_action.args, **env_action.kwargs) + else: + res = read_api(self, *env_action.args, **env_action.kwargs) + return res + + async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + """execute through particular api of ExtEnv""" + res = None + if isinstance(env_action, Message): + self.publish_message(env_action) + elif isinstance(env_action, EnvAPIAbstract): + write_api = env_write_api_registry.get(env_action.api_name)["func"] + self._check_api_exist(write_api) + if is_coroutine_func(write_api): + res = await write_api(self, *env_action.args, **env_action.kwargs) + else: + res = write_api(self, *env_action.args, **env_action.kwargs) + + return res + + +class Environment(ExtEnv): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles """ @@ -31,8 +103,8 @@ class Environment(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) desc: str = Field(default="") # 环境描述 - roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True) - member_addrs: dict[Role, Set] = Field(default_factory=dict, exclude=True) + roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True) + member_addrs: dict["Role", Set] = Field(default_factory=dict, exclude=True) history: str = "" # For debug context: Context = Field(default_factory=Context, exclude=True) @@ -41,7 +113,7 @@ class Environment(BaseModel): self.add_roles(self.roles.values()) return self - def add_role(self, role: Role): + def add_role(self, role: "Role"): """增加一个在当前环境的角色 Add a role in the current environment """ @@ -49,7 +121,7 @@ class Environment(BaseModel): role.set_env(self) role.context = self.context - def add_roles(self, roles: Iterable[Role]): + def add_roles(self, roles: Iterable["Role"]): """增加一批在当前环境的角色 Add a batch of characters in the current environment """ @@ -95,13 +167,13 @@ class Environment(BaseModel): await asyncio.gather(*futures) logger.debug(f"is idle: {self.is_idle}") - def get_roles(self) -> dict[str, Role]: + def get_roles(self) -> dict[str, "Role"]: """获得环境内的所有角色 Process all Role runs at once """ return self.roles - def get_role(self, name: str) -> Role: + def get_role(self, name: str) -> "Role": """获得环境内的指定角色 get all the environment roles """ @@ -129,3 +201,12 @@ class Environment(BaseModel): def archive(self, auto_archive=True): if auto_archive and self.context.git_repo: self.context.git_repo.archive() + + @classmethod + def model_rebuild(cls, **kwargs): + from metagpt.roles.role import Role # noqa: F401 + + super().model_rebuild(**kwargs) + + +Environment.model_rebuild() diff --git a/metagpt/environment/mincraft_env/__init__.py b/metagpt/environment/mincraft_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/mincraft_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/mincraft_env/const.py b/metagpt/environment/mincraft_env/const.py new file mode 100644 index 000000000..a7222f9cd --- /dev/null +++ b/metagpt/environment/mincraft_env/const.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.const import METAGPT_ROOT + +# For Mincraft Game Agent +MC_CKPT_DIR = METAGPT_ROOT / "data/mincraft/ckpt" +MC_LOG_DIR = METAGPT_ROOT / "logs" +MC_DEFAULT_WARMUP = { + "context": 15, + "biome": 10, + "time": 15, + "nearby_blocks": 0, + "other_blocks": 10, + "nearby_entities": 5, + "health": 15, + "hunger": 15, + "position": 0, + "equipment": 0, + "inventory": 0, + "optional_inventory_items": 7, + "chests": 0, + "completed_tasks": 0, + "failed_tasks": 0, +} +MC_CURRICULUM_OB = [ + "context", + "biome", + "time", + "nearby_blocks", + "other_blocks", + "nearby_entities", + "health", + "hunger", + "position", + "equipment", + "inventory", + "chests", + "completed_tasks", + "failed_tasks", +] +MC_CORE_INVENTORY_ITEMS = r".*_log|.*_planks|stick|crafting_table|furnace" +r"|cobblestone|dirt|coal|.*_pickaxe|.*_sword|.*_axe", # curriculum_agent: only show these items in inventory before optional_inventory_items reached in warm up diff --git a/metagpt/environment/mincraft_env/mincraft_env.py b/metagpt/environment/mincraft_env/mincraft_env.py new file mode 100644 index 000000000..6327aa3f4 --- /dev/null +++ b/metagpt/environment/mincraft_env/mincraft_env.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Mincraft Env +# refs to `voyager voyager.py` + +import json +import re +import time +from typing import Any, Iterable + +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import Chroma +from pydantic import ConfigDict, Field + +from metagpt.config2 import config as CONFIG +from metagpt.environment.base_env import Environment +from metagpt.environment.mincraft_env.const import MC_CKPT_DIR +from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv +from metagpt.logs import logger +from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file + + +class MincraftEnv(Environment, MincraftExtEnv): + """MincraftEnv, including shared memory of cache and infomation between roles""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + event: dict[str, Any] = Field(default_factory=dict) + current_task: str = Field(default="Mine 1 wood log") + task_execution_time: float = Field(default=float) + context: str = Field(default="You can mine one of oak, birch, spruce, jungle, acacia, dark oak, or mangrove logs.") + code: str = Field(default="") + program_code: str = Field(default="") # write in skill/code/*.js + program_name: str = Field(default="") + critique: str = Field(default="") + skills: dict = Field(default_factory=dict) # for skills.json + retrieve_skills: list[str] = Field(default_factory=list) + event_summary: str = Field(default="") + + qa_cache: dict[str, str] = Field(default_factory=dict) + completed_tasks: list[str] = Field(default_factory=list) # Critique things + failed_tasks: list[str] = Field(default_factory=list) + + skill_desp: str = Field(default="") + + chest_memory: dict[str, Any] = Field(default_factory=dict) # eg: {'(1344, 64, 1381)': 'Unknown'} + chest_observation: str = Field(default="") # eg: "Chests: None\n\n" + + runtime_status: bool = False # equal to action execution status: success or failed + + vectordb: Chroma = Field(default_factory=Chroma) + + qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma) + + @property + def progress(self): + # return len(self.completed_tasks) + 10 # Test only + return len(self.completed_tasks) + + @property + def programs(self): + programs = "" + if self.code == "": + return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager + for skill_name, entry in self.skills.items(): + programs += f"{entry['code']}\n\n" + for primitives in load_mc_skills_code(): # TODO add skills_dir + programs += f"{primitives}\n\n" + return programs + + def set_mc_port(self, mc_port): + super().set_mc_port(mc_port) + self.set_mc_resume() + + def set_mc_resume(self): + self.qa_cache_questions_vectordb = Chroma( + collection_name="qa_cache_questions_vectordb", + embedding_function=OpenAIEmbeddings(), + persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb", + ) + + self.vectordb = Chroma( + collection_name="skill_vectordb", + embedding_function=OpenAIEmbeddings(), + persist_directory=f"{MC_CKPT_DIR}/skill/vectordb", + ) + + if CONFIG.resume: + logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action") + self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json") + + logger.info(f"Loading Curriculum Agent from {MC_CKPT_DIR}/curriculum") + self.completed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json") + self.failed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json") + + logger.info(f"Loading Skill Manager from {MC_CKPT_DIR}/skill\033[0m") + self.skills = read_json_file(f"{MC_CKPT_DIR}/skill/skills.json") + + logger.info(f"Loading Qa Cache from {MC_CKPT_DIR}/curriculum\033[0m") + self.qa_cache = read_json_file(f"{MC_CKPT_DIR}/curriculum/qa_cache.json") + + if self.vectordb._collection.count() == 0: + logger.info(self.vectordb._collection.count()) + # Set vdvs for skills & qa_cache + skill_desps = [skill["description"] for program_name, skill in self.skills.items()] + program_names = [program_name for program_name, skill in self.skills.items()] + metadatas = [{"name": program_name} for program_name in program_names] + # add vectordb from file + self.vectordb.add_texts( + texts=skill_desps, + ids=program_names, + metadatas=metadatas, + ) + self.vectordb.persist() + + logger.info(self.qa_cache_questions_vectordb._collection.count()) + if self.qa_cache_questions_vectordb._collection.count() == 0: + questions = [question for question, answer in self.qa_cache.items()] + + self.qa_cache_questions_vectordb.add_texts(texts=questions) + + self.qa_cache_questions_vectordb.persist() + + logger.info( + f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json." + ) + # Check if Skill Manager's vectordb right using + assert self.vectordb._collection.count() == len(self.skills), ( + f"Skill Manager's vectordb is not synced with skills.json.\n" + f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n" + f"Did you set resume=False when initializing the manager?\n" + f"You may need to manually delete the vectordb directory for running from scratch." + ) + + logger.info( + f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json." + ) + assert self.qa_cache_questions_vectordb._collection.count() == len(self.qa_cache), ( + f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n" + f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb " + f"but {len(self.qa_cache)} questions in qa_cache.json.\n" + f"Did you set resume=False when initializing the agent?\n" + f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n" + ) + + def register_roles(self, roles: Iterable["Minecraft"]): + for role in roles: + role.set_memory(self) + + def update_event(self, event: dict): + if self.event == event: + return + self.event = event + self.update_chest_memory(event) + self.update_chest_observation() + # self.event_summary = self.summarize_chatlog(event) + + def update_task(self, task: str): + self.current_task = task + + def update_context(self, context: str): + self.context = context + + def update_program_code(self, program_code: str): + self.program_code = program_code + + def update_code(self, code: str): + self.code = code # action_developer.gen_action_code to HERE + + def update_program_name(self, program_name: str): + self.program_name = program_name + + def update_critique(self, critique: str): + self.critique = critique # critic_agent.check_task_success to HERE + + def append_skill(self, skill: dict): + self.skills[self.program_name] = skill # skill_manager.retrieve_skills to HERE + + def update_retrieve_skills(self, retrieve_skills: list): + self.retrieve_skills = retrieve_skills + + def update_skill_desp(self, skill_desp: str): + self.skill_desp = skill_desp + + async def update_qa_cache(self, qa_cache: dict): + self.qa_cache = qa_cache + + def update_chest_memory(self, events: dict): + """ + Input: events: Dict + Result: self.chest_memory update & save to json + """ + nearbyChests = events[-1][1]["nearbyChests"] + for position, chest in nearbyChests.items(): + if position in self.chest_memory: + if isinstance(chest, dict): + self.chest_memory[position] = chest + if chest == "Invalid": + logger.info(f"Action Developer removing chest {position}: {chest}") + self.chest_memory.pop(position) + else: + if chest != "Invalid": + logger.info(f"Action Developer saving chest {position}: {chest}") + self.chest_memory[position] = chest + + write_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json", self.chest_memory) + + def update_chest_observation(self): + """ + update chest_memory to chest_observation. + Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py + """ + + chests = [] + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) > 0: + chests.append(f"{chest_position}: {chest}") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) == 0: + chests.append(f"{chest_position}: Empty") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, str): + assert chest == "Unknown" + chests.append(f"{chest_position}: Unknown items inside") + assert len(chests) == len(self.chest_memory) + if chests: + chests = "\n".join(chests) + self.chest_observation = f"Chests:\n{chests}\n\n" + else: + self.chest_observation = "Chests: None\n\n" + + def summarize_chatlog(self, events): + def filter_item(message: str): + craft_pattern = r"I cannot make \w+ because I need: (.*)" + craft_pattern2 = r"I cannot make \w+ because there is no crafting table nearby" + mine_pattern = r"I need at least a (.*) to mine \w+!" + if re.match(craft_pattern, message): + self.event_summary = re.match(craft_pattern, message).groups()[0] + elif re.match(craft_pattern2, message): + self.event_summary = "a nearby crafting table" + elif re.match(mine_pattern, message): + self.event_summary = re.match(mine_pattern, message).groups()[0] + else: + self.event_summary = "" + return self.event_summary + + chatlog = set() + for event_type, event in events: + if event_type == "onChat": + item = filter_item(event["onChat"]) + if item: + chatlog.add(item) + self.event_summary = "I also need " + ", ".join(chatlog) + "." if chatlog else "" + + def reset_block_info(self): + # revert all the placing event in the last step + pass + + def update_exploration_progress(self, success: bool): + """ + Split task into completed_tasks or failed_tasks + Args: info = { + "task": self.task, + "success": success, + "conversations": self.conversations, + } + """ + self.runtime_status = success + task = self.current_task + if task.startswith("Deposit useless items into the chest at"): + return + if success: + logger.info(f"Completed task {task}.") + self.completed_tasks.append(task) + else: + logger.info(f"Failed to complete task {task}. Skipping to next task.") + self.failed_tasks.append(task) + # when not success, below to update event! + # revert all the placing event in the last step + blocks = [] + positions = [] + for event_type, event in self.event: + if event_type == "onSave" and event["onSave"].endswith("_placed"): + block = event["onSave"].split("_placed")[0] + position = event["status"]["position"] + blocks.append(block) + positions.append(position) + new_events = self.step( + f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})", + programs=self.programs, + ) + self.event[-1][1]["inventory"] = new_events[-1][1]["inventory"] + self.event[-1][1]["voxels"] = new_events[-1][1]["voxels"] + + self.save_sorted_tasks() + + def save_sorted_tasks(self): + updated_completed_tasks = [] + # record repeated failed tasks + updated_failed_tasks = self.failed_tasks + # dedup but keep order + for task in self.completed_tasks: + if task not in updated_completed_tasks: + updated_completed_tasks.append(task) + + # remove completed tasks from failed tasks + for task in updated_completed_tasks: + while task in updated_failed_tasks: + updated_failed_tasks.remove(task) + + self.completed_tasks = updated_completed_tasks + self.failed_tasks = updated_failed_tasks + + # dump to json + write_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json", self.completed_tasks) + write_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json", self.failed_tasks) + + async def on_event_retrieve(self, *args): + """ + Retrieve Minecraft events. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + self.reset( + options={ + "mode": "soft", + "wait_ticks": 20, + } + ) + # difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful" + difficulty = "peaceful" + + events = self.step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');") + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self.reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to retrieve Minecraft events: {str(e)}") + return events + + async def on_event_execute(self, *args): + """ + Execute Minecraft events. + + This function is used to obtain events from the Minecraft environment. Check the implementation in + the 'voyager/env/bridge.py step()' function to capture events generated within the game. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + events = self.step( + code=self.code, + programs=self.programs, + ) + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self.reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to execute Minecraft events: {str(e)}") + return events diff --git a/metagpt/environment/mincraft_env/mincraft_ext_env.py b/metagpt/environment/mincraft_env/mincraft_ext_env.py new file mode 100644 index 000000000..b86250d8c --- /dev/null +++ b/metagpt/environment/mincraft_env/mincraft_ext_env.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Mincraft external environment to integrate with Mincraft game +# refs to `voyager bridge.py` + +import json +import time +from typing import Optional + +import requests +from pydantic import ConfigDict, Field, model_validator + +from metagpt.environment.base_env import ExtEnv, mark_as_writeable +from metagpt.environment.mincraft_env.const import ( + MC_CKPT_DIR, + MC_CORE_INVENTORY_ITEMS, + MC_CURRICULUM_OB, + MC_DEFAULT_WARMUP, + METAGPT_ROOT, +) +from metagpt.environment.mincraft_env.process_monitor import SubprocessMonitor +from metagpt.logs import logger + + +class MincraftExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + mc_port: Optional[int] = Field(default=None) + server_host: str = Field(default="http://127.0.0.1") + server_port: str = Field(default=3000) + request_timeout: int = Field(default=600) + + mineflayer: Optional[SubprocessMonitor] = Field(default=None, validate_default=True) + + has_reset: bool = Field(default=False) + reset_options: Optional[dict] = Field(default=None) + connected: bool = Field(default=False) + server_paused: bool = Field(default=False) + warm_up: dict = Field(default=dict()) + + @property + def server(self) -> str: + return f"{self.server_host}:{self.server_port}" + + @model_validator(mode="after") + def _post_init_ext_env(self): + if not self.mineflayer: + self.mineflayer = SubprocessMonitor( + commands=[ + "node", + METAGPT_ROOT.joinpath("metagpt", "environment", "mincraft_env", "mineflayer", "index.js"), + str(self.server_port), + ], + name="mineflayer", + ready_match=r"Server started on port (\d+)", + ) + if not self.warm_up: + warm_up = MC_DEFAULT_WARMUP + if "optional_inventory_items" in warm_up: + assert MC_CORE_INVENTORY_ITEMS is not None + # self.core_inv_items_regex = re.compile(MC_CORE_INVENTORY_ITEMS) + self.warm_up["optional_inventory_items"] = warm_up["optional_inventory_items"] + else: + self.warm_up["optional_inventory_items"] = 0 + for key in MC_CURRICULUM_OB: + self.warm_up[key] = warm_up.get(key, MC_DEFAULT_WARMUP[key]) + self.warm_up["nearby_blocks"] = 0 + self.warm_up["inventory"] = 0 + self.warm_up["completed_tasks"] = 0 + self.warm_up["failed_tasks"] = 0 + + # init ckpt sub-forders + MC_CKPT_DIR.joinpath("curriculum/vectordb").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("action").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/code").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("skill/description").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/vectordb").mkdir(exist_ok=True) + + def set_mc_port(self, mc_port: int): + self.mc_port = mc_port + + @mark_as_writeable + def close(self) -> bool: + self.unpause() + if self.connected: + res = requests.post(f"{self.server}/stop") + if res.status_code == 200: + self.connected = False + self.mineflayer.stop() + return not self.connected + + @mark_as_writeable + def check_process(self) -> dict: + retry = 0 + while not self.mineflayer.is_running: + logger.info("Mineflayer process has exited, restarting") + self.mineflayer.run() + if not self.mineflayer.is_running: + if retry > 3: + logger.error("Mineflayer process failed to start") + raise {} + else: + retry += 1 + continue + logger.info(self.mineflayer.ready_line) + res = requests.post( + f"{self.server}/start", + json=self.reset_options, + timeout=self.request_timeout, + ) + if res.status_code != 200: + self.mineflayer.stop() + logger.error(f"Minecraft server reply with code {res.status_code}") + raise {} + return res.json() + + @mark_as_writeable + def reset(self, *, seed=None, options=None) -> dict: + if options is None: + options = {} + if options.get("inventory", {}) and options.get("mode", "hard") != "hard": + logger.error("inventory can only be set when options is hard") + raise {} + + self.reset_options = { + "port": self.mc_port, + "reset": options.get("mode", "hard"), + "inventory": options.get("inventory", {}), + "equipment": options.get("equipment", []), + "spread": options.get("spread", False), + "waitTicks": options.get("wait_ticks", 5), + "position": options.get("position", None), + } + + self.unpause() + self.mineflayer.stop() + time.sleep(1) # wait for mineflayer to exit + + returned_data = self.check_process() + self.has_reset = True + self.connected = True + # All the reset in step will be soft + self.reset_options["reset"] = "soft" + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def step(self, code: str, programs: str = "") -> dict: + if not self.has_reset: + raise RuntimeError("Environment has not been reset yet") + self.check_process() + self.unpause() + data = { + "code": code, + "programs": programs, + } + res = requests.post(f"{self.server}/step", json=data, timeout=self.request_timeout) + if res.status_code != 200: + raise RuntimeError("Failed to step Minecraft server") + returned_data = res.json() + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def pause(self) -> bool: + if self.mineflayer.is_running and not self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = True + return self.server_paused + + @mark_as_writeable + def unpause(self) -> bool: + if self.mineflayer.is_running and self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = False + else: + logger.info(f"mineflayer pause result: {res.json()}") + return self.server_paused diff --git a/metagpt/environment/mincraft_env/mineflayer/.gitignore b/metagpt/environment/mincraft_env/mineflayer/.gitignore new file mode 100644 index 000000000..0fd468410 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/.gitignore @@ -0,0 +1 @@ +!/lib \ No newline at end of file diff --git a/metagpt/environment/mincraft_env/mineflayer/.prettierignore b/metagpt/environment/mincraft_env/mineflayer/.prettierignore new file mode 100644 index 000000000..1b07c39e9 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/.prettierignore @@ -0,0 +1,3 @@ +# Ignore artifacts: +build +coverage \ No newline at end of file diff --git a/metagpt/environment/mincraft_env/mineflayer/.prettierrc.json b/metagpt/environment/mincraft_env/mineflayer/.prettierrc.json new file mode 100644 index 000000000..0a02bcefd --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/.prettierrc.json @@ -0,0 +1,3 @@ +{ + "tabWidth": 4 +} diff --git a/metagpt/environment/mincraft_env/mineflayer/index.js b/metagpt/environment/mincraft_env/mineflayer/index.js new file mode 100644 index 000000000..7fb0a8787 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/index.js @@ -0,0 +1,425 @@ +const fs = require("fs"); +const express = require("express"); +const bodyParser = require("body-parser"); +const mineflayer = require("mineflayer"); + +const skills = require("./lib/skillLoader"); +const { initCounter, getNextTime } = require("./lib/utils"); +const obs = require("./lib/observation/base"); +const OnChat = require("./lib/observation/onChat"); +const OnError = require("./lib/observation/onError"); +const { Voxels, BlockRecords } = require("./lib/observation/voxels"); +const Status = require("./lib/observation/status"); +const Inventory = require("./lib/observation/inventory"); +const OnSave = require("./lib/observation/onSave"); +const Chests = require("./lib/observation/chests"); +const { plugin: tool } = require("mineflayer-tool"); + +let bot = null; + +const app = express(); + +app.use(bodyParser.json({ limit: "50mb" })); +app.use(bodyParser.urlencoded({ limit: "50mb", extended: false })); + +app.post("/start", (req, res) => { + if (bot) onDisconnect("Restarting bot"); + bot = null; + console.log(req.body); + bot = mineflayer.createBot({ + host: "localhost", // minecraft server ip + port: req.body.port, // minecraft server port + username: "bot", + disableChatSigning: true, + checkTimeoutInterval: 60 * 60 * 1000, + }); + bot.once("error", onConnectionFailed); + + // Event subscriptions + bot.waitTicks = req.body.waitTicks; + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + bot.iron_pickaxe = false; + + bot.on("kicked", onDisconnect); + + // mounting will cause physicsTick to stop + bot.on("mount", () => { + bot.dismount(); + }); + + bot.once("spawn", async () => { + bot.removeListener("error", onConnectionFailed); + let itemTicks = 1; + if (req.body.reset === "hard") { + bot.chat("/clear @s"); + bot.chat("/kill @s"); + const inventory = req.body.inventory ? req.body.inventory : {}; + const equipment = req.body.equipment + ? req.body.equipment + : [null, null, null, null, null, null]; + for (let key in inventory) { + bot.chat(`/give @s minecraft:${key} ${inventory[key]}`); + itemTicks += 1; + } + const equipmentNames = [ + "armor.head", + "armor.chest", + "armor.legs", + "armor.feet", + "weapon.mainhand", + "weapon.offhand", + ]; + for (let i = 0; i < 6; i++) { + if (i === 4) continue; + if (equipment[i]) { + bot.chat( + `/item replace entity @s ${equipmentNames[i]} with minecraft:${equipment[i]}` + ); + itemTicks += 1; + } + } + } + + if (req.body.position) { + bot.chat( + `/tp @s ${req.body.position.x} ${req.body.position.y} ${req.body.position.z}` + ); + } + + // if iron_pickaxe is in bot's inventory + if ( + bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.iron_pickaxe = true; + } + + const { pathfinder } = require("mineflayer-pathfinder"); + const tool = require("mineflayer-tool").plugin; + const collectBlock = require("mineflayer-collectblock").plugin; + const pvp = require("mineflayer-pvp").plugin; + const minecraftHawkEye = require("minecrafthawkeye"); + bot.loadPlugin(pathfinder); + bot.loadPlugin(tool); + bot.loadPlugin(collectBlock); + bot.loadPlugin(pvp); + bot.loadPlugin(minecraftHawkEye); + + // bot.collectBlock.movements.digCost = 0; + // bot.collectBlock.movements.placeCost = 0; + + obs.inject(bot, [ + OnChat, + OnError, + Voxels, + Status, + Inventory, + OnSave, + Chests, + BlockRecords, + ]); + skills.inject(bot); + + if (req.body.spread) { + bot.chat(`/spreadplayers ~ ~ 0 300 under 80 false @s`); + await bot.waitForTicks(bot.waitTicks); + } + + await bot.waitForTicks(bot.waitTicks * itemTicks); + res.json(bot.observe()); + + initCounter(bot); + bot.chat("/gamerule keepInventory true"); + bot.chat("/gamerule doDaylightCycle false"); + }); + + function onConnectionFailed(e) { + console.log(e); + bot = null; + res.status(400).json({ error: e }); + } + function onDisconnect(message) { + if (bot.viewer) { + bot.viewer.close(); + } + bot.end(); + console.log(message); + bot = null; + } +}); + +app.post("/step", async (req, res) => { + // import useful package + let response_sent = false; + function otherError(err) { + console.log("Uncaught Error"); + bot.emit("error", handleError(err)); + bot.waitForTicks(bot.waitTicks).then(() => { + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + }); + } + + process.on("uncaughtException", otherError); + + const mcData = require("minecraft-data")(bot.version); + mcData.itemsByName["leather_cap"] = mcData.itemsByName["leather_helmet"]; + mcData.itemsByName["leather_tunic"] = + mcData.itemsByName["leather_chestplate"]; + mcData.itemsByName["leather_pants"] = + mcData.itemsByName["leather_leggings"]; + mcData.itemsByName["leather_boots"] = mcData.itemsByName["leather_boots"]; + mcData.itemsByName["lapis_lazuli_ore"] = mcData.itemsByName["lapis_ore"]; + mcData.blocksByName["lapis_lazuli_ore"] = mcData.blocksByName["lapis_ore"]; + const { + Movements, + goals: { + Goal, + GoalBlock, + GoalNear, + GoalXZ, + GoalNearXZ, + GoalY, + GoalGetToBlock, + GoalLookAtBlock, + GoalBreakBlock, + GoalCompositeAny, + GoalCompositeAll, + GoalInvert, + GoalFollow, + GoalPlaceBlock, + }, + pathfinder, + Move, + ComputedPath, + PartiallyComputedPath, + XZCoordinates, + XYZCoordinates, + SafeBlock, + GoalPlaceBlockOptions, + } = require("mineflayer-pathfinder"); + const { Vec3 } = require("vec3"); + + // Set up pathfinder + const movements = new Movements(bot, mcData); + bot.pathfinder.setMovements(movements); + + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + + function onTick() { + bot.globalTickCounter++; + if (bot.pathfinder.isMoving()) { + bot.stuckTickCounter++; + if (bot.stuckTickCounter >= 100) { + onStuck(1.5); + bot.stuckTickCounter = 0; + } + } + } + + bot.on("physicTick", onTick); + + // initialize fail count + let _craftItemFailCount = 0; + let _killMobFailCount = 0; + let _mineBlockFailCount = 0; + let _placeItemFailCount = 0; + let _smeltItemFailCount = 0; + + // Retrieve array form post bod + const code = req.body.code; + const programs = req.body.programs; + bot.cumulativeObs = []; + await bot.waitForTicks(bot.waitTicks); + const r = await evaluateCode(code, programs); + process.off("uncaughtException", otherError); + if (r !== "success") { + bot.emit("error", handleError(r)); + } + await returnItems(); + // wait for last message + await bot.waitForTicks(bot.waitTicks); + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + bot.removeListener("physicTick", onTick); + + async function evaluateCode(code, programs) { + // Echo the code produced for players to see it. Don't echo when the bot code is already producing dialog or it will double echo + try { + await eval("(async () => {" + programs + "\n" + code + "})()"); + return "success"; + } catch (err) { + return err; + } + } + + function onStuck(posThreshold) { + const currentPos = bot.entity.position; + bot.stuckPosList.push(currentPos); + + // Check if the list is full + if (bot.stuckPosList.length === 5) { + const oldestPos = bot.stuckPosList[0]; + const posDifference = currentPos.distanceTo(oldestPos); + + if (posDifference < posThreshold) { + teleportBot(); // execute the function + } + + // Remove the oldest time from the list + bot.stuckPosList.shift(); + } + } + + function teleportBot() { + const blocks = bot.findBlocks({ + matching: (block) => { + return block.type === 0; + }, + maxDistance: 1, + count: 27, + }); + + if (blocks) { + // console.log(blocks.length); + const randomIndex = Math.floor(Math.random() * blocks.length); + const block = blocks[randomIndex]; + bot.chat(`/tp @s ${block.x} ${block.y} ${block.z}`); + } else { + bot.chat("/tp @s ~ ~1.25 ~"); + } + } + + function returnItems() { + bot.chat("/gamerule doTileDrops false"); + const crafting_table = bot.findBlock({ + matching: mcData.blocksByName.crafting_table.id, + maxDistance: 128, + }); + if (crafting_table) { + bot.chat( + `/setblock ${crafting_table.position.x} ${crafting_table.position.y} ${crafting_table.position.z} air destroy` + ); + bot.chat("/give @s crafting_table"); + } + const furnace = bot.findBlock({ + matching: mcData.blocksByName.furnace.id, + maxDistance: 128, + }); + if (furnace) { + bot.chat( + `/setblock ${furnace.position.x} ${furnace.position.y} ${furnace.position.z} air destroy` + ); + bot.chat("/give @s furnace"); + } + if (bot.inventoryUsed() >= 32) { + // if chest is not in bot's inventory + if (!bot.inventory.items().find((item) => item.name === "chest")) { + bot.chat("/give @s chest"); + } + } + // if iron_pickaxe not in bot's inventory and bot.iron_pickaxe + if ( + bot.iron_pickaxe && + !bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.chat("/give @s iron_pickaxe"); + } + bot.chat("/gamerule doTileDrops true"); + } + + function handleError(err) { + let stack = err.stack; + if (!stack) { + return err; + } + console.log(stack); + const final_line = stack.split("\n")[1]; + const regex = /:(\d+):\d+\)/; + + const programs_length = programs.split("\n").length; + let match_line = null; + for (const line of stack.split("\n")) { + const match = regex.exec(line); + if (match) { + const line_num = parseInt(match[1]); + if (line_num >= programs_length) { + match_line = line_num - programs_length; + break; + } + } + } + if (!match_line) { + return err.message; + } + let f_line = final_line.match( + /\((?.*):(?\d+):(?\d+)\)/ + ); + if (f_line && f_line.groups && fs.existsSync(f_line.groups.file)) { + const { file, line, pos } = f_line.groups; + const f = fs.readFileSync(file, "utf8").split("\n"); + // let filename = file.match(/(?<=node_modules\\)(.*)/)[1]; + let source = file + `:${line}\n${f[line - 1].trim()}\n `; + + const code_source = + "at " + + code.split("\n")[match_line - 1].trim() + + " in your code"; + return source + err.message + "\n" + code_source; + } else if ( + f_line && + f_line.groups && + f_line.groups.file.includes("") + ) { + const { file, line, pos } = f_line.groups; + let source = + "Your code" + + `:${match_line}\n${code.split("\n")[match_line - 1].trim()}\n `; + let code_source = ""; + if (line < programs_length) { + source = + "In your program code: " + + programs.split("\n")[line - 1].trim() + + "\n"; + code_source = `at line ${match_line}:${code + .split("\n") + [match_line - 1].trim()} in your code`; + } + return source + err.message + "\n" + code_source; + } + return err.message; + } +}); + +app.post("/stop", (req, res) => { + bot.end(); + res.json({ + message: "Bot stopped", + }); +}); + +app.post("/pause", (req, res) => { + if (!bot) { + res.status(400).json({ error: "Bot not spawned" }); + return; + } + bot.chat("/pause"); + bot.waitForTicks(bot.waitTicks).then(() => { + res.json({ message: "Success" }); + }); +}); + +// Server listening to PORT 3000 + +const DEFAULT_PORT = 3000; +const PORT = process.argv[2] || DEFAULT_PORT; +app.listen(PORT, () => { + console.log(`Server started on port ${PORT}`); +}); diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/base.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/base.js new file mode 100644 index 000000000..b661a24b5 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/base.js @@ -0,0 +1,45 @@ +class Observation { + constructor(bot) { + if (new.target === Observation) { + throw new TypeError( + "Cannot instantiate abstract class Observation" + ); + } + + this.bot = bot; + this.name = "Observation"; + } + + observe() { + throw new TypeError("Method 'observe()' must be implemented."); + } + + reset() {} +} + +function inject(bot, obs_list) { + bot.obsList = []; + bot.cumulativeObs = []; + bot.eventMemory = {}; + obs_list.forEach((obs) => { + bot.obsList.push(new obs(bot)); + }); + bot.event = function (event_name) { + let result = {}; + bot.obsList.forEach((obs) => { + if (obs.name.startsWith("on") && obs.name !== event_name) { + return; + } + result[obs.name] = obs.observe(); + }); + bot.cumulativeObs.push([event_name, result]); + }; + bot.observe = function () { + bot.event("observe"); + const result = bot.cumulativeObs; + bot.cumulativeObs = []; + return JSON.stringify(result); + }; +} + +module.exports = { Observation, inject }; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/chests.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/chests.js new file mode 100644 index 000000000..842bd171d --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/chests.js @@ -0,0 +1,31 @@ +const { Observation } = require("./base"); + +class Chests extends Observation { + constructor(bot) { + super(bot); + this.name = "nearbyChests"; + this.chestsItems = {}; + bot.on("closeChest", (chestItems, position) => { + this.chestsItems[position] = chestItems; + }); + bot.on("removeChest", (chestPosition) => { + this.chestsItems[chestPosition] = "Invalid"; + }); + } + + observe() { + const chests = this.bot.findBlocks({ + matching: this.bot.registry.blocksByName.chest.id, + maxDistance: 16, + count: 999, + }); + chests.forEach((chest) => { + if (!this.chestsItems.hasOwnProperty(chest)) { + this.chestsItems[chest] = "Unknown"; + } + }); + return this.chestsItems; + } +} + +module.exports = Chests; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/inventory.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/inventory.js new file mode 100644 index 000000000..0645d1bfa --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/inventory.js @@ -0,0 +1,39 @@ +const { Observation } = require("./base"); + +class Inventory extends Observation { + constructor(bot) { + super(bot); + this.name = "inventory"; + } + + observe() { + return listItems(this.bot); + } +} + +function listItems(bot) { + const items = getInventoryItems(bot); + return items.reduce(itemToDict, {}); +} + +function getInventoryItems(bot) { + const inventory = bot.currentWindow || bot.inventory; + return inventory.items(); +} + +function itemToDict(acc, cur) { + if (cur.name && cur.count) { + //if both name and count property are defined + if (acc[cur.name]) { + //if the item is already in the dict + acc[cur.name] += cur.count; + } else { + //if the item is not in the dict + acc[cur.name] = cur.count; + } + } + return acc; +} + +//export modules +module.exports = Inventory; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onChat.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onChat.js new file mode 100644 index 000000000..54b411e2a --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onChat.js @@ -0,0 +1,26 @@ +const Observation = require("./base.js").Observation; + +class onChat extends Observation { + constructor(bot) { + super(bot); + this.name = "onChat"; + this.obs = ""; + bot.on("chatEvent", (username, message) => { + // Save entity status to local variable + if (message.startsWith("/")) { + return; + } + + this.obs += message; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = ""; + return result; + } +} + +module.exports = onChat; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onError.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onError.js new file mode 100644 index 000000000..ac8fed9e5 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onError.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onError extends Observation { + constructor(bot) { + super(bot); + this.name = "onError"; + this.obs = null; + bot.on("error", (err) => { + // Save entity status to local variable + this.obs = err; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onError; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/onSave.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onSave.js new file mode 100644 index 000000000..e5983590f --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/onSave.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onSave extends Observation { + constructor(bot) { + super(bot); + this.name = "onSave"; + this.obs = null; + bot.on("save", (eventName) => { + // Save entity status to local variable + this.obs = eventName; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onSave; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/status.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/status.js new file mode 100644 index 000000000..b031fbcf2 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/status.js @@ -0,0 +1,103 @@ +const Observation = require("./base.js").Observation; + +class Status extends Observation { + constructor(bot) { + super(bot); + this.name = "status"; + } + + observe() { + return { + health: this.bot.health, + food: this.bot.food, + saturation: this.bot.foodSaturation, + oxygen: this.bot.oxygenLevel, + position: this.bot.entity.position, + velocity: this.bot.entity.velocity, + yaw: this.bot.entity.yaw, + pitch: this.bot.entity.pitch, + onGround: this.bot.entity.onGround, + equipment: this.getEquipment(), + name: this.bot.entity.username, + timeSinceOnGround: this.bot.entity.timeSinceOnGround, + isInWater: this.bot.entity.isInWater, + isInLava: this.bot.entity.isInLava, + isInWeb: this.bot.entity.isInWeb, + isCollidedHorizontally: this.bot.entity.isCollidedHorizontally, + isCollidedVertically: this.bot.entity.isCollidedVertically, + biome: this.bot.blockAt(this.bot.entity.position) + ? this.bot.blockAt(this.bot.entity.position).biome.name + : "None", + entities: this.getEntities(), + timeOfDay: this.getTime(), + inventoryUsed: this.bot.inventoryUsed(), + elapsedTime: this.bot.globalTickCounter, + }; + } + + itemToObs(item) { + if (!item) return null; + return item.name; + } + + getTime() { + const timeOfDay = this.bot.time.timeOfDay; + let time = ""; + if (timeOfDay < 1000) { + time = "sunrise"; + } else if (timeOfDay < 6000) { + time = "day"; + } else if (timeOfDay < 12000) { + time = "noon"; + } else if (timeOfDay < 13000) { + time = "sunset"; + } else if (timeOfDay < 18000) { + time = "night"; + } else if (timeOfDay < 22000) { + time = "midnight"; + } else { + time = "sunrise"; + } + return time; + } + + // For each item in equipment, if it exists, return the name of the item + // otherwise return null + getEquipment() { + const slots = this.bot.inventory.slots; + const mainHand = this.bot.heldItem; + return slots + .slice(5, 9) + .concat(mainHand, slots[45]) + .map(this.itemToObs); + } + + getEntities() { + const entities = this.bot.entities; + if (!entities) return {}; + // keep all monsters in one list, keep other mobs in another list + const mobs = {}; + for (const id in entities) { + const entity = entities[id]; + if (!entity.displayName) continue; + if (entity.name === "player" || entity.name === "item") continue; + if (entity.position.distanceTo(this.bot.entity.position) < 32) { + if (!mobs[entity.name]) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } else if ( + mobs[entity.name] > + entity.position.distanceTo(this.bot.entity.position) + ) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } + } + } + return mobs; + } +} + +module.exports = Status; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/observation/voxels.js b/metagpt/environment/mincraft_env/mineflayer/lib/observation/voxels.js new file mode 100644 index 000000000..ecb0c14b7 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/observation/voxels.js @@ -0,0 +1,67 @@ +// Blocks = require("./blocks") +const { Observation } = require("./base"); + +class Voxels extends Observation { + constructor(bot) { + super(bot); + this.name = "voxels"; + } + + observe() { + return Array.from(getSurroundingBlocks(this.bot, 8, 2, 8)); + } +} + +class BlockRecords extends Observation { + constructor(bot) { + super(bot); + this.name = "blockRecords"; + this.records = new Set(); + this.tick = 0; + bot.on("physicsTick", () => { + this.tick++; + if (this.tick >= 100) { + const items = getInventoryItems(this.bot); + getSurroundingBlocks(this.bot, 8, 2, 8).forEach((block) => { + if (!items.has(block)) this.records.add(block); + }); + this.tick = 0; + } + }); + } + + observe() { + return Array.from(this.records); + } + + reset() { + this.records = new Set(); + } +} + +function getSurroundingBlocks(bot, x_distance, y_distance, z_distance) { + const surroundingBlocks = new Set(); + + for (let x = -x_distance; x <= x_distance; x++) { + for (let y = -y_distance; y <= y_distance; y++) { + for (let z = -z_distance; z <= z_distance; z++) { + const block = bot.blockAt(bot.entity.position.offset(x, y, z)); + if (block && block.type !== 0) { + surroundingBlocks.add(block.name); + } + } + } + } + // console.log(surroundingBlocks); + return surroundingBlocks; +} + +function getInventoryItems(bot) { + const items = new Set(); + bot.inventory.items().forEach((item) => { + if (item) items.add(item.name); + }); + return items; +} + +module.exports = { Voxels, BlockRecords }; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/skillLoader.js b/metagpt/environment/mincraft_env/mineflayer/lib/skillLoader.js new file mode 100644 index 000000000..d78cf7820 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/skillLoader.js @@ -0,0 +1,79 @@ +function inject(bot) { + bot._sleep = bot.sleep; + bot.sleep = async (bedBlock) => { + await bot.waitForTicks(20); + await bot._sleep(bedBlock); + await bot.waitForTicks(135); + }; + + bot._fish = bot.fish; + bot.fish = async () => { + if (bot.heldItem?.name !== "fishing_rod") { + bot.chat("I'm not holding a fishing rod!"); + return; + } + let timeout = null; + await Promise.race([ + bot._fish(), + new Promise( + (resolve, reject) => + (timeout = setTimeout(() => { + bot.activateItem(); + reject( + new Error( + "Finishing timeout, make sure you get to and look at a water block!" + ) + ); + }, 60000)) + ), + ]); + clearTimeout(timeout); + await bot.waitForTicks(20); + }; + + bot._consume = bot.consume; + bot.consume = async () => { + // action_count.activateItem++; + await bot._consume(); + await bot.waitForTicks(20); + }; + + bot._useOn = bot.useOn; + bot.useOn = async (entity) => { + if (entity.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the entity first!"); + return; + } + await bot._useOn(entity); + await bot.waitForTicks(20); + }; + + bot._activateBlock = bot.activateBlock; + bot.activateBlock = async (block) => { + if (block.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the block first!"); + return; + } + // action_count.activateBlock++; + await bot._activateBlock(block); + }; + + bot._chat = bot.chat; + bot.chat = (message) => { + // action_count.chat++; + bot.emit("chatEvent", "bot", message); + bot._chat(message); + }; + + bot.inventoryUsed = () => { + return bot.inventory.slots.slice(9, 45).filter((item) => item !== null) + .length; + }; + + bot.save = function (eventName) { + bot.emit("save", eventName); + }; +} + +// export all control_primitives +module.exports = { inject }; diff --git a/metagpt/environment/mincraft_env/mineflayer/lib/utils.js b/metagpt/environment/mincraft_env/mineflayer/lib/utils.js new file mode 100644 index 000000000..68af30796 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/lib/utils.js @@ -0,0 +1,31 @@ +let gameTimeCounter = 0; +let gameTimeList = []; +const initCounter = (bot) => { + gameTimeList = []; + for (let i = 0; i < 13000; i += 1000) { + gameTimeList.push(i); + } + for (let i = 13000; i < 24000; i += 2000) { + gameTimeList.push(i); + } + const timeOfDay = bot.time.timeOfDay; + for (let i = 0; i < gameTimeList.length; i++) { + if (gameTimeList[i] > timeOfDay) { + gameTimeCounter = i - 1; + break; + } + } +}; + +const getNextTime = () => { + gameTimeCounter++; + if (gameTimeCounter >= gameTimeList.length) { + gameTimeCounter = 0; + } + return gameTimeList[gameTimeCounter]; +}; + +module.exports = { + initCounter, + getNextTime, +}; diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/.gitignore b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/.gitignore new file mode 100644 index 000000000..0578fdca3 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/.gitignore @@ -0,0 +1,107 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# Next.js build output +.next + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and *not* Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +lib/ +package-lock.json diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/LICENSE b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/LICENSE new file mode 100644 index 000000000..f2896b56e --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 TheDudeFromCI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/README.md b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/README.md new file mode 100644 index 000000000..555acb761 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/README.md @@ -0,0 +1,89 @@ +

mineflayer-collectblock

+

A small utility plugin for allowing users to collect blocks using a higher level API.

+ +

+ + + + + + +

+ +--- +## This is a modified version to better support Voyager + +## Showcase + +You can see a video of the plugin in action, [here.](https://youtu.be/5T_rcCnNnf4) +The source code of the bot in the video can be seen in the examples folder, [here.](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/examples/collector.js) + +### Description + +This plugin is a wrapper for mineflayer that allows for easier API usage when collecting blocks or item drops. This plugin is designed to reduce some of the boilerplate code based around the act of pathfinding to a block _(handled by_ ***mineflayer-pathfinder***_)_, selecting the best tool to mine that block _(handled by_ ***mineflayer-tool***_)_, actually mining it, then moving to collect the item drops from that block. This plugin allows for all of that basic concept to be wrapped up into a single API function. + +In addition to the usage above, some additional quality of life features are available in this plugin. These include the ability to automatically deposit items into a chest when the bot's inventory is full, collecting new tools from a chest if the bot doesn't currently have a required tool _(also handled by_ ***mineflayer-tool***_)_, and allowing for queueing of multiple blocks or item drops to the collection task, so they can be processed later. + +### Getting Started + +This plugin is built using Node and can be installed using: +```bash +npm install --save mineflayer-collectblock +``` + +### Simple Bot + +The brief description goes here. + +```js +// Create your bot +const mineflayer = require("mineflayer") +const bot = mineflayer.createBot({ + host: 'localhost', + username: 'Player', +}) +let mcData + +// Load collect block +bot.loadPlugin(require('mineflayer-collectblock').plugin) + +async function collectGrass() { + // Find a nearby grass block + const grass = bot.findBlock({ + matching: mcData.blocksByName.grass_block.id, + maxDistance: 64 + }) + + if (grass) { + // If we found one, collect it. + try { + await bot.collectBlock.collect(grass) + collectGrass() // Collect another grass block + } catch (err) { + console.log(err) // Handle errors, if any + } + } +} + +// On spawn, start collecting all nearby grass +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) + collectGrass() +}) +``` + +### Documentation + +[API](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/docs/api.md) + +[Examples](https://github.com/TheDudeFromCI/mineflayer-collectblock/tree/master/examples) + +### License + +This project uses the [MIT](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/LICENSE) license. + +### Contributions + +This project is accepting PRs and Issues. See something you think can be improved? Go for it! Any and all help is highly appreciated! + +For larger changes, it is recommended to discuss these changes in the issues tab before writing any code. It's also preferred to make many smaller PRs than one large one, where applicable. diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/_config.yml b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/_config.yml new file mode 100644 index 000000000..c4192631f --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/docs/api.md b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/docs/api.md new file mode 100644 index 000000000..66d8a3ecc --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/docs/api.md @@ -0,0 +1,52 @@ +# API + +Welcome to the *mineflayer-collectblock* API documentation page. + +## Table of Contents + +- [1. Summary](#1-summary) +- [Properties](#properties) + - [`bot.collectblock.movements: Movements`](#botcollectblockmovements-movements) +- [Functions](#functions) + - [collect](#collect) + - [Options:](#options) + +## 1. Summary + +The collect block plugin is a utility plugin that can be used to help make collecting blocks and item drops very easy, using only a single API call. No need to worry about pathfinding to the block, selecting the right tool, or moving to pick up the item drop after mining. + +## Properties + +### `bot.collectblock.movements: Movements` + +The movements object used by the pathfinder plugin to define the movement configuration. This object is passed to the pathfinder plugin when any API from this plugin is called in order to control how pathfinding should work when collecting the given blocks or item. + +If set to null, the pathfinder plugin movements is not updated. + +Defaults to a new movements object instance. + +## Functions + +### collect + +Usage: `bot.collectblock.collect(target: Collectable | Collectable[], options?: CollectOptions, cb: (err?: Error) => void): void` + +Causes the bot to collect the given block, item drop, or list of those. If the target is a block, the bot will move to the block, mine it, and pick up the item drop. If the target is an item drop, the bot will move to the item drop and pick it up. If the target is a list of collectables, the bot will move from target to target in order of closest to furthest and collect each target in turn. + +#### Options: + + * `append: boolean` + + If true, the target(s) will be appended to the existing target list instead of starting a new task. Defaults to false. + + * `ignoreNoPath: boolean` + + If true, errors will not be thrown when a path to the target block cannot be found. The bot will attempt to choose the best available position it can find, instead. Errors are still thrown if the bot cannot interact with the block from it's final location. Defaults to false. + + * `chestLocations: Vec3[]` + + Gets the list of chest locations to use when storing items after the bot's inventory becomes full. If undefined, it defaults to the chest location list on the bot.collectBlock plugin. + + * `itemFilter: ItemFilter` + + When transferring items to a chest, this filter is used to determine what items are allowed to be moved, and what items aren't allowed to be moved. Defaults to the item filter specified on the bot.collectBlock plugin. \ No newline at end of file diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/collector.js b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/collector.js new file mode 100644 index 000000000..b9bb8faf9 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/collector.js @@ -0,0 +1,70 @@ +/** + * This bot example show how to direct a bot to collect a specific block type + * or a group of nearby blocks of that type. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node collector.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'collector', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + let type = args[1] + if (args.length === 3) type = args[2] + + const blockType = mcData.blocksByName[type] + if (!blockType) { + return + } + + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + bot.chat(`Found ${targets.length} ${type}(s)`) + + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js new file mode 100644 index 000000000..6accac88f --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/oreMiner.js @@ -0,0 +1,59 @@ +/** + * This bot example shows how to collect a vein of ores quickly after only finding a single block. + * This makes it easy to collect a vein of ores or mine a tree without looking for every block in the + * area. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node oreMiner.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'oreMiner', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + const blockType = mcData.blocksByName[args[1]] + if (!blockType) { + bot.chat(`I don't know any blocks named ${args[1]}.`) + return + } + + const block = bot.findBlock({ + matching: blockType.id, + maxDistance: 64 + }) + + if (!block) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = bot.collectBlock.findFromVein(block) + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js new file mode 100644 index 000000000..b6f9971f2 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/examples/storageBot.js @@ -0,0 +1,107 @@ +/** + * This bot example shows how to use the chest filling mechanic of the plugin. + * Simply provide a given storage chest, and the bot will automatically try and + * store it's inventory in that chest when the bot's inventory becomes full. + */ + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node storageBot.js [] []') + process.exit(1) +} + +// Load your libraries +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +// Create your bot +const bot = mineflayer.createBot({ + host: process.argv[2], + port: parseInt(process.argv[3]), + username: process.argv[4] ? process.argv[4] : 'storageBot', + password: process.argv[5] +}) + +// Load the collect block plugin +bot.loadPlugin(collectBlock) + +// Load mcData on login +let mcData +bot.once('login', () => { + mcData = require('minecraft-data')(bot.version) +}) + +// On spawn, try to find any nearby chests and save those as storage locations. +// When the bot's inventory becomes too full, it will empty it's inventory into +// these chests before collecting more resources. If a chest gets full, it moves +// to the next one in order until it's inventory is empty or it runs out of chests. +bot.once('spawn', () => { + bot.collectBlock.chestLocations = bot.findBlocks({ + matching: mcData.blocksByName.chest.id, + maxDistance: 16, + count: 999999 // Get as many chests as we can + }) + + if (bot.collectBlock.chestLocations.length === 0) { + bot.chat("I don't see any chests nearby.") + } else { + for (const chestPos of bot.collectBlock.chestLocations) { + bot.chat(`I found a chest at ${chestPos}`) + } + } +}) + +// Wait for someone to say something +bot.on('chat', async (username, message) => { + // If the player says something start starts with "collect" + // Otherwise, do nothing + const args = message.split(' ') + if (args[0] !== 'collect') return + + // If the player specifies a number, collect that many. Otherwise, default to 1. + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + // If a number was given the item number is the 3rd arg, not the 2nd. + let type = args[1] + if (args.length === 3) type = args[2] + + // Get the id of that block type for this version of Minecraft. + const blockType = mcData.blocksByName[type] + if (!blockType) { + bot.chat(`I don't know any blocks named ${type}.`) + return + } + + // Find all nearby blocks of that type, up to the given count, within 64 blocks. + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + // Complain if we can't find any nearby blocks of that type. + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + // Convert the block position array into a block array to pass to collect block. + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + // Announce what we found. + bot.chat(`Found ${targets.length} ${type}(s)`) + + // Tell the bot to collect all of the given blocks in the block list. + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/package.json b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/package.json new file mode 100644 index 000000000..0f59e7aa6 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/package.json @@ -0,0 +1,44 @@ +{ + "name": "mineflayer-collectblock", + "version": "1.4.1", + "description": "A simple utility plugin for Mineflayer that add a higher level API for collecting blocks.", + "main": "lib/index.js", + "types": "lib/index.d.ts", + "scripts": { + "build": "ts-standard && tsc && require-self", + "clean": "rm -rf lib", + "test": "test" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/TheDudeFromCI/mineflayer-collectblock.git" + }, + "keywords": [ + "mineflayer", + "plugin", + "api", + "utility", + "helper", + "collect" + ], + "author": "TheDudeFromCI", + "license": "MIT", + "bugs": { + "url": "https://github.com/TheDudeFromCI/mineflayer-collectblock/issues" + }, + "homepage": "https://github.com/TheDudeFromCI/mineflayer-collectblock#readme", + "dependencies": { + "mineflayer": "^4.0.0", + "mineflayer-pathfinder": "^2.1.1", + "mineflayer-tool": "^1.1.0" + }, + "devDependencies": { + "@types/node": "^18.6.4", + "require-self": "^0.2.3", + "ts-standard": "^11.0.0", + "typescript": "^4.1.3" + }, + "files": [ + "lib/**/*" + ] +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts new file mode 100644 index 000000000..ae5542ce3 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/BlockVeins.ts @@ -0,0 +1,35 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' + +export function findFromVein (bot: Bot, block: Block, maxBlocks: number, maxDistance: number, floodRadius: number): Block[] { + const targets: Block[] = [] + const open: Block[] = [block] + const type = block.type + const center = block.position + + for (let i = 0; i < maxBlocks; i++) { + const next = open.pop() + if (next == null) break + + targets.push(next) + + for (let x = -floodRadius; x <= floodRadius; x++) { + for (let y = -floodRadius; y <= floodRadius; y++) { + for (let z = -floodRadius; z <= floodRadius; z++) { + const neighborPos = next.position.offset(x, y, z) + if (neighborPos.manhattanDistanceTo(center) > maxDistance) continue + + const neighbor = bot.blockAt(neighborPos) + if (neighbor == null || neighbor.type !== type) continue + + if (targets.includes(neighbor)) continue + if (open.includes(neighbor)) continue + + open.push(neighbor) + } + } + } + } + + return targets +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts new file mode 100644 index 000000000..d2be87822 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/CollectBlock.ts @@ -0,0 +1,451 @@ +import { Bot } from "mineflayer"; +import { Block } from "prismarine-block"; +import { Movements, goals } from "mineflayer-pathfinder"; +import { TemporarySubscriber } from "./TemporarySubscriber"; +import { Entity } from "prismarine-entity"; +import { error } from "./Util"; +import { Vec3 } from "vec3"; +import { emptyInventoryIfFull, ItemFilter } from "./Inventory"; +import { findFromVein } from "./BlockVeins"; +import { Collectable, Targets } from "./Targets"; +import { Item } from "prismarine-item"; +import mcDataLoader from "minecraft-data"; +import { once } from "events"; +import { callbackify } from "util"; + +export type Callback = (err?: Error) => void; + +async function collectAll( + bot: Bot, + options: CollectOptionsFull +): Promise { + let success_count = 0; + while (!options.targets.empty) { + await emptyInventoryIfFull( + bot, + options.chestLocations, + options.itemFilter + ); + const closest = options.targets.getClosest(); + if (closest == null) break; + switch (closest.constructor.name) { + case "Block": { + try { + if (success_count >= options.count) { + break; + } + await bot.tool.equipForBlock( + closest as Block, + equipToolOptions + ); + const goal = new goals.GoalLookAtBlock( + closest.position, + bot.world + ); + await bot.pathfinder.goto(goal); + await mineBlock(bot, closest as Block, options); + success_count++; + // TODO: options.ignoreNoPath + } catch (err) { + // @ts-ignore + // console.log(err.stack) + // bot.pathfinder.stop() + // bot.waitForTicks(10) + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.name === "Invalid block") { + console.log( + `Block ${closest.name} at ${closest.position} is not valid! Skip it!` + ); + } // @ts-ignore + else if (err.name === "Unsafe block") { + console.log( + `${closest.name} at ${closest.position} is not safe to break! Skip it!` + ); + // @ts-ignore + } else if (err.name === "NoItem") { + const properties = + bot.registry.blocksByName[closest.name]; + const leastTool = Object.keys( + properties.harvestTools + )[0]; + const item = bot.registry.items[leastTool]; + bot.chat( + `I need at least a ${item.name} to mine ${closest.name}! Skip it!` + ); + return; + } else if ( + // @ts-ignore + err.name === "NoPath" || + // @ts-ignore + err.name === "Timeout" + ) { + if ( + bot.entity.position.distanceTo( + closest.position + ) < 0.5 + ) { + await mineBlock(bot, closest as Block, options); + break; + } + console.log( + `No path to ${closest.name} at ${closest.position}! Skip it!` + ); + // @ts-ignore + } else if (err.message === "Digging aborted") { + console.log(`Digging aborted! Skip it!`); + } else { + // @ts-ignore + bot.chat(`Error: ${err.message}`); + } + break; + } + throw err; + } + break; + } + case "Entity": { + // Don't collect any entities that are marked as 'invalid' + if (!(closest as Entity).isValid) break; + try { + const tempEvents = new TemporarySubscriber(bot); + const waitForPickup = new Promise( + (resolve, reject) => { + const timeout = setTimeout(() => { + // After 10 seconds, reject the promise + clearTimeout(timeout); + tempEvents.cleanup(); + reject(new Error("Failed to pickup item")); + }, 10000); + tempEvents.subscribeTo( + "entityGone", + (entity: Entity) => { + if (entity === closest) { + clearTimeout(timeout); + tempEvents.cleanup(); + resolve(); + } + } + ); + } + ); + bot.pathfinder.setGoal( + new goals.GoalFollow(closest as Entity, 0) + ); + // await bot.pathfinder.goto(new goals.GoalBlock(closest.position.x, closest.position.y, closest.position.z)) + await waitForPickup; + } catch (err) { + // @ts-ignore + console.log(err.stack); + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.message === "Failed to pickup item") { + bot.chat(`Failed to pickup item! Skip it!`); + } + break; + } + throw err; + } + break; + } + default: { + throw error( + "UnknownType", + `Target ${closest.constructor.name} is not a Block or Entity!` + ); + } + } + options.targets.removeTarget(closest); + } + bot.chat(`Collect finish!`); +} + +const equipToolOptions = { + requireHarvest: true, + getFromChest: false, + maxTools: 2, +}; + +async function mineBlock( + bot: Bot, + block: Block, + options: CollectOptionsFull +): Promise { + if ( + bot.blockAt(block.position)?.type !== block.type || + bot.blockAt(block.position)?.type === 0 + ) { + options.targets.removeTarget(block); + throw error("Invalid block", "Block is not valid!"); + // @ts-expect-error + } else if (!bot.pathfinder.movements.safeToBreak(block)) { + options.targets.removeTarget(block); + throw error("Unsafe block", "Block is not safe to break!"); + } + + await bot.tool.equipForBlock(block, equipToolOptions); + + if (!block.canHarvest(bot.heldItem ? bot.heldItem.type : bot.heldItem)) { + options.targets.removeTarget(block); + throw error("NoItem", "Bot does not have a harvestable tool!"); + } + + const tempEvents = new TemporarySubscriber(bot); + tempEvents.subscribeTo("itemDrop", (entity: Entity) => { + if ( + entity.position.distanceTo(block.position.offset(0.5, 0.5, 0.5)) <= + 0.5 + ) { + options.targets.appendTarget(entity); + } + }); + try { + await bot.dig(block); + // Waiting for items to drop + await new Promise((resolve) => { + let remainingTicks = 10; + tempEvents.subscribeTo("physicTick", () => { + remainingTicks--; + if (remainingTicks <= 0) { + tempEvents.cleanup(); + resolve(); + } + }); + }); + } finally { + tempEvents.cleanup(); + } +} + +/** + * A set of options to apply when collecting the given targets. + */ +export interface CollectOptions { + /** + * If true, the target(s) will be appended to the existing target list instead of + * starting a new task. Defaults to false. + */ + append?: boolean; + + /** + * If true, errors will not be thrown when a path to the target block cannot + * be found. The bot will attempt to choose the best available position it + * can find, instead. Errors are still thrown if the bot cannot interact with + * the block from it's final location. Defaults to false. + */ + ignoreNoPath?: boolean; + + /** + * Gets the list of chest locations to use when storing items after the bot's + * inventory becomes full. If undefined, it defaults to the chest location + * list on the bot.collectBlock plugin. + */ + chestLocations?: Vec3[]; + + /** + * When transferring items to a chest, this filter is used to determine what + * items are allowed to be moved, and what items aren't allowed to be moved. + * Defaults to the item filter specified on the bot.collectBlock plugin. + */ + itemFilter?: ItemFilter; + + /** + * The total number of items to collect + */ + count?: number; +} + +/** + * A version of collect options where all values are assigned. + */ +interface CollectOptionsFull { + append: boolean; + ignoreNoPath: boolean; + chestLocations: Vec3[]; + itemFilter: ItemFilter; + targets: Targets; + count: number; +} + +/** + * The collect block plugin. + */ +export class CollectBlock { + /** + * The bot. + */ + private readonly bot: Bot; + + /** + * The list of active targets being collected. + */ + private readonly targets: Targets; + + /** + * The movements configuration to be sent to the pathfinder plugin. + */ + movements?: Movements; + + /** + * A list of chest locations which the bot is allowed to empty their inventory into + * if it becomes full while the bot is collecting resources. + */ + chestLocations: Vec3[] = []; + + /** + * When collecting items, this filter is used to determine what items should be placed + * into a chest if the bot's inventory becomes full. By default, returns true for all + * items except for tools, weapons, and armor. + * + * @param item - The item stack in the bot's inventory to check. + * + * @returns True if the item should be moved into the chest. False otherwise. + */ + itemFilter: ItemFilter = (item: Item) => { + if (item.name.includes("helmet")) return false; + if (item.name.includes("chestplate")) return false; + if (item.name.includes("leggings")) return false; + if (item.name.includes("boots")) return false; + if (item.name.includes("shield")) return false; + if (item.name.includes("sword")) return false; + if (item.name.includes("pickaxe")) return false; + if (item.name.includes("axe")) return false; + if (item.name.includes("shovel")) return false; + if (item.name.includes("hoe")) return false; + return true; + }; + + /** + * Creates a new instance of the create block plugin. + * + * @param bot - The bot this plugin is acting on. + */ + constructor(bot: Bot) { + this.bot = bot; + this.targets = new Targets(bot); + // @ts-ignore + this.movements = new Movements(bot, mcDataLoader(bot.version)); + } + + /** + * If target is a block: + * Causes the bot to break and collect the target block. + * + * If target is an item drop: + * Causes the bot to collect the item drop. + * + * If target is an array containing items or blocks, preforms the correct action for + * all targets in that array sorting dynamically by distance. + * + * @param target - The block(s) or item(s) to collect. + * @param options - The set of options to use when handling these targets + * @param cb - The callback that is called finished. + */ + async collect( + target: Collectable | Collectable[], + options: CollectOptions | Callback = {}, + cb?: Callback + ): Promise { + if (typeof options === "function") { + cb = options; + options = {}; + } + // @ts-expect-error + if (cb != null) return callbackify(this.collect)(target, options, cb); + + const optionsFull: CollectOptionsFull = { + append: options.append ?? false, + ignoreNoPath: options.ignoreNoPath ?? false, + chestLocations: options.chestLocations ?? this.chestLocations, + itemFilter: options.itemFilter ?? this.itemFilter, + targets: this.targets, + count: options.count ?? Infinity, + }; + + if (this.bot.pathfinder == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-pathfinder plugin to run!" + ); + } + + if (this.bot.tool == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-tool plugin to run!" + ); + } + + if (this.movements != null) { + this.bot.pathfinder.setMovements(this.movements); + } + + if (!optionsFull.append) await this.cancelTask(); + if (Array.isArray(target)) { + this.targets.appendTargets(target); + } else { + this.targets.appendTarget(target); + } + + try { + await collectAll(this.bot, optionsFull); + this.targets.clear(); + } catch (err) { + this.targets.clear(); + // Ignore path stopped error for cancelTask to work properly (imo we shouldn't throw any pathing errors) + // @ts-expect-error + if (err.name !== "PathStopped") throw err; + } finally { + // @ts-expect-error + this.bot.emit("collectBlock_finished"); + } + } + + /** + * Loads all touching blocks of the same type to the given block and returns them as an array. + * This effectively acts as a flood fill algorithm to retrieve blocks in the same ore vein and similar. + * + * @param block - The starting block. + * @param maxBlocks - The maximum number of blocks to look for before stopping. + * @param maxDistance - The max distance from the starting block to look. + * @param floodRadius - The max distance distance from block A to block B to be considered "touching" + */ + findFromVein( + block: Block, + maxBlocks = 100, + maxDistance = 16, + floodRadius = 1 + ): Block[] { + return findFromVein( + this.bot, + block, + maxBlocks, + maxDistance, + floodRadius + ); + } + + /** + * Cancels the current collection task, if still active. + * + * @param cb - The callback to use when the task is stopped. + */ + async cancelTask(cb?: Callback): Promise { + if (this.targets.empty) { + if (cb != null) cb(); + return await Promise.resolve(); + } + this.bot.pathfinder.stop(); + if (cb != null) { + // @ts-expect-error + this.bot.once("collectBlock_finished", cb); + } + await once(this.bot, "collectBlock_finished"); + } +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts new file mode 100644 index 000000000..6a17d0cc5 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Inventory.ts @@ -0,0 +1,87 @@ +import { Bot } from 'mineflayer' +import { Callback } from './CollectBlock' +import { Vec3 } from 'vec3' +import { error } from './Util' +import { Item } from 'prismarine-item' +import { goals } from 'mineflayer-pathfinder' +import { callbackify } from 'util' + +export type ItemFilter = (item: Item) => boolean + +function getClosestChest (bot: Bot, chestLocations: Vec3[]): Vec3 | null { + let chest = null + let distance = 0 + + for (const c of chestLocations) { + const dist = c.distanceTo(bot.entity.position) + if (chest == null || dist < distance) { + chest = c + distance = dist + } + } + + if (chest != null) { + chestLocations.splice(chestLocations.indexOf(chest), 1) + } + + return chest +} + +export async function emptyInventoryIfFull (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventoryIfFull)(bot, chestLocations, cb) + if (bot.inventory.emptySlotCount() > 0) return + return await emptyInventory(bot, chestLocations, itemFilter) +} + +export async function emptyInventory (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventory)(bot, chestLocations, cb) + if (chestLocations.length === 0) { + throw error('NoChests', 'There are no defined chest locations!') + } + + // Shallow clone so we can safely remove chests from the list that are full. + chestLocations = [...chestLocations] + + while (true) { + const chest = getClosestChest(bot, chestLocations) + if (chest == null) { + throw error('NoChests', 'All chests are full.') + } + const hasRemaining = await tryEmptyInventory(bot, chest, itemFilter) + if (!hasRemaining) return + } +} + +async function tryEmptyInventory (bot: Bot, chestLocation: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(tryEmptyInventory)(bot, chestLocation, itemFilter, cb) + await gotoChest(bot, chestLocation) + return await placeItems(bot, chestLocation, itemFilter) +} + +async function gotoChest (bot: Bot, location: Vec3, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(gotoChest)(bot, location) + await bot.pathfinder.goto(new goals.GoalGetToBlock(location.x, location.y, location.z)) +} + +async function placeItems (bot: Bot, chestPos: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(placeItems)(bot, chestPos, itemFilter, cb) + const chestBlock = bot.blockAt(chestPos) + if (chestBlock == null) { + throw error('UnloadedChunk', 'Chest is in an unloaded chunk!') + } + const chest = await bot.openChest(chestBlock) + for (const item of bot.inventory.items()) { + if (!itemFilter(item)) continue + if (chest.firstEmptyContainerSlot() === null) { + // We have items that didn't fit. + return true + } + await chest.deposit(item.type, item.metadata, item.count) + } + return false +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts new file mode 100644 index 000000000..568d07ad9 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Targets.ts @@ -0,0 +1,60 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' +import { Entity } from 'prismarine-entity' + +export type Collectable = Block | Entity + +export class Targets { + private readonly bot: Bot + private targets: Collectable[] = [] + + constructor (bot: Bot) { + this.bot = bot + } + + appendTargets (targets: Collectable[]): void { + for (const target of targets) { + this.appendTarget(target) + } + } + + appendTarget (target: Collectable): void { + if (this.targets.includes(target)) return + this.targets.push(target) + } + + /** + * Gets the closest target to the bot in this list. + * + * @returns The closest target, or null if there are no targets. + */ + getClosest (): Collectable | null { + let closest: Collectable | null = null + let distance: number = 0 + + for (const target of this.targets) { + const dist = target.position.distanceTo(this.bot.entity.position) + + if (closest == null || dist < distance) { + closest = target + distance = dist + } + } + + return closest + } + + get empty (): boolean { + return this.targets.length === 0 + } + + clear (): void { + this.targets.length = 0 + } + + removeTarget (target: Collectable): void { + const index = this.targets.indexOf(target) + if (index < 0) return + this.targets.splice(index, 1) + } +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts new file mode 100644 index 000000000..81fe3bc5a --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TaskQueue.ts @@ -0,0 +1,77 @@ +import type { Callback } from './index' +export type Task = (cb: Callback) => void +export type SyncTask = () => void + +/** + * A simple utility class for queuing up a series of async tasks to execute. + */ +export class TaskQueue { + private tasks: Task[] = [] + + /** + * If true, the task list will stop executing if one of the tasks throws an error. + */ + readonly stopOnError: boolean = true + + /** + * Adds a new async task to this queue. The provided callback should be executed when + * the async task is complete. + * + * @param task - The async task to add. + */ + add (task: Task): void { + this.tasks.push(task) + } + + /** + * Adds a synchronous task toi this queue. + * + * @param task - The sync task to add. + */ + addSync (task: SyncTask): void { + this.add((cb) => { + try { + task() + cb() + } catch (err: any) { + cb(err) + } + }) + } + + /** + * Runs all tasks currently in this queue and empties the queue. + * + * @param cb - The optional callback to be executed when all tasks in this queue have + * finished executing. + */ + runAll (cb?: Callback): void { + const taskList = this.tasks + this.tasks = [] + + let index = -1 + const runNext: () => void = () => { + index++ + if (index >= taskList.length) { + if (cb !== undefined) cb() + return + } + + try { + taskList[index]((err) => { + if (err !== undefined) { + if (cb !== undefined) cb(err) + + if (this.stopOnError) return + } + + runNext() + }) + } catch (err: any) { + if (cb !== undefined) cb(err) + } + } + + runNext() + } +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts new file mode 100644 index 000000000..3f14a607d --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts @@ -0,0 +1,34 @@ +import { Bot } from 'mineflayer' + +class Subscription { + constructor (readonly eventName: string, readonly callback: Function) {} +} + +export class TemporarySubscriber { + private readonly subscriptions: Subscription[] = [] + + constructor (readonly bot: Bot) {} + + /** + * Adds a new temporary event listener to the bot. + * + * @param event - The event to subscribe to. + * @param callback - The function to execute. + */ + subscribeTo (event: string, callback: Function): void { + this.subscriptions.push(new Subscription(event, callback)) + + // @ts-expect-error + this.bot.on(event, callback) + } + + /** + * Removes all attached event listeners from the bot. + */ + cleanup (): void { + for (const sub of this.subscriptions) { + // @ts-expect-error + this.bot.removeListener(sub.eventName, sub.callback) + } + } +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Util.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Util.ts new file mode 100644 index 000000000..ee0f29e0c --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/Util.ts @@ -0,0 +1,13 @@ +/** + * Creates a new error object with the given type and message. + * + * @param type - The error type. + * @param message - The error message. + * + * @returns The error object. + */ +export function error (type: string, message: string): Error { + const e = new Error(message) + e.name = type + return e +} diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/index.ts b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/index.ts new file mode 100644 index 000000000..45c9a8508 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/src/index.ts @@ -0,0 +1,25 @@ +import { Bot } from 'mineflayer' +import { CollectBlock } from './CollectBlock' +import { pathfinder as pathfinderPlugin } from 'mineflayer-pathfinder' +import { plugin as toolPlugin } from 'mineflayer-tool' + +export function plugin (bot: Bot): void { + // @ts-expect-error + bot.collectBlock = new CollectBlock(bot) + + // Load plugins if not loaded manually. + setTimeout(() => loadPathfinderPlugin(bot), 0) + setTimeout(() => loadToolPlugin(bot), 0) +} + +function loadPathfinderPlugin (bot: Bot): void { + if (bot.pathfinder != null) return + bot.loadPlugin(pathfinderPlugin) +} + +function loadToolPlugin (bot: Bot): void { + if (bot.tool != null) return + bot.loadPlugin(toolPlugin) +} + +export { CollectBlock, Callback, CollectOptions } from './CollectBlock' diff --git a/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/tsconfig.json b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/tsconfig.json new file mode 100644 index 000000000..a6076bc0c --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/mineflayer-collectblock/tsconfig.json @@ -0,0 +1,69 @@ +{ + "compilerOptions": { + /* Visit https://aka.ms/tsconfig.json to read more about this file */ + /* Basic Options */ + // "incremental": true, /* Enable incremental compilation */ + "target": "ES2015", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */ + "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */ + // "lib": [], /* Specify library files to be included in the compilation. */ + "allowJs": true, /* Allow javascript files to be compiled. */ + "checkJs": true, /* Report errors in .js files. */ + // "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */ + "declaration": true, + // "declarationMap": true, /* Generates a sourcemap for each corresponding '.d.ts' file. */ + // "sourceMap": true, /* Generates corresponding '.map' file. */ + // "outFile": "./", /* Concatenate and emit output to single file. */ + "outDir": "./lib", + // "rootDir": "./", /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */ + // "composite": true, /* Enable project compilation */ + // "tsBuildInfoFile": "./", /* Specify file to store incremental compilation information */ + // "removeComments": true, /* Do not emit comments to output. */ + // "noEmit": true, /* Do not emit outputs. */ + // "importHelpers": true, /* Import emit helpers from 'tslib'. */ + // "downlevelIteration": true, /* Provide full support for iterables in 'for-of', spread, and destructuring when targeting 'ES5' or 'ES3'. */ + // "isolatedModules": true, /* Transpile each file as a separate module (similar to 'ts.transpileModule'). */ + /* Strict Type-Checking Options */ + "strict": true, /* Enable all strict type-checking options. */ + // "noImplicitAny": true, /* Raise error on expressions and declarations with an implied 'any' type. */ + "strictNullChecks": true, /* Enable strict null checks. */ + // "strictFunctionTypes": true, /* Enable strict checking of function types. */ + // "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */ + // "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */ + // "noImplicitThis": true, /* Raise error on 'this' expressions with an implied 'any' type. */ + "alwaysStrict": true, /* Parse in strict mode and emit "use strict" for each source file. */ + /* Additional Checks */ + "noUnusedLocals": true, /* Report errors on unused locals. */ + // "noUnusedParameters": true, /* Report errors on unused parameters. */ + "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */ + // "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */ + /* Module Resolution Options */ + // "moduleResolution": "node", /* Specify module resolution strategy: 'node' (Node.js) or 'classic' (TypeScript pre-1.6). */ + // "baseUrl": "./", /* Base directory to resolve non-absolute module names. */ + // "paths": {}, /* A series of entries which re-map imports to lookup locations relative to the 'baseUrl'. */ + // "rootDirs": [], /* List of root folders whose combined content represents the structure of the project at runtime. */ + // "typeRoots": [], /* List of folders to include type definitions from. */ + // "types": [], /* Type declaration files to be included in compilation. */ + // "allowSyntheticDefaultImports": true, /* Allow default imports from modules with no default export. This does not affect code emit, just typechecking. */ + "esModuleInterop": true, /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */ + // "preserveSymlinks": true, /* Do not resolve the real path of symlinks. */ + // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ + /* Source Map Options */ + // "sourceRoot": "", /* Specify the location where debugger should locate TypeScript files instead of source locations. */ + // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ + // "inlineSourceMap": true, /* Emit a single file with source maps instead of having a separate file. */ + // "inlineSources": true, /* Emit the source alongside the sourcemaps within a single file; requires '--inlineSourceMap' or '--sourceMap' to be set. */ + /* Experimental Options */ + // "experimentalDecorators": true, /* Enables experimental support for ES7 decorators. */ + // "emitDecoratorMetadata": true, /* Enables experimental support for emitting type metadata for decorators. */ + /* Advanced Options */ + "skipLibCheck": true, /* Skip type checking of declaration files. */ + "forceConsistentCasingInFileNames": true /* Disallow inconsistently-cased references to the same file. */ + }, + "include": [ + "src" + ], + "exclude": [ + "node_modules", + "**/__tests__/*" + ] +} \ No newline at end of file diff --git a/metagpt/environment/mincraft_env/mineflayer/package.json b/metagpt/environment/mincraft_env/mineflayer/package.json new file mode 100644 index 000000000..9e389d268 --- /dev/null +++ b/metagpt/environment/mincraft_env/mineflayer/package.json @@ -0,0 +1,38 @@ +{ + "name": "voyager", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "body-parser": "^1.20.2", + "express": "^4.18.2", + "magic-string": "^0.30.0", + "minecraft-data": "^3.31.0", + "minecrafthawkeye": "^1.3.6", + "mineflayer": "^4.8.1", + "mineflayer-collectblock": "file:mineflayer-collectblock", + "mineflayer-pathfinder": "^2.4.2", + "mineflayer-pvp": "^1.3.2", + "mineflayer-tool": "^1.2.0", + "mocha": "^10.2.0", + "prismarine-biome": "^1.3.0", + "prismarine-block": "=1.16.3", + "prismarine-entity": "^2.2.0", + "prismarine-item": "^1.12.1", + "prismarine-nbt": "^2.2.1", + "prismarine-recipe": "^1.3.1", + "prismarine-viewer": "^1.24.0", + "typescript": "^4.9.5", + "vec3": "^0.1.8", + "graceful-fs": "^4.2.11" + }, + "devDependencies": { + "prettier": "2.8.5" + } +} diff --git a/metagpt/environment/mincraft_env/process_monitor.py b/metagpt/environment/mincraft_env/process_monitor.py new file mode 100644 index 000000000..b62aa6005 --- /dev/null +++ b/metagpt/environment/mincraft_env/process_monitor.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# refs to `voyager process_monitor.py` + +import re +import subprocess +import threading +import warnings +from typing import List + +import psutil + +from metagpt.logs import define_log_level + + +class SubprocessMonitor: + def __init__( + self, + commands: List[str], + name: str, + ready_match: str = r".*", + callback_match: str = r"^(?!x)x$", # regex that will never match + callback: callable = None, + finished_callback: callable = None, + ): + self.commands = commands + self.name = name + self.logger = define_log_level(name=name) + self.process = None + self.ready_match = ready_match + self.ready_event = None + self.ready_line = None + self.callback_match = callback_match + self.callback = callback + self.finished_callback = finished_callback + self.thread = None + + def _start(self): + self.logger.info(f"Starting subprocess with commands: {self.commands}") + + self.process = psutil.Popen( + self.commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + self.logger.info(f"Subprocess {self.name} started with PID {self.process.pid}.") + for line in iter(self.process.stdout.readline, ""): + self.logger.info(line.strip()) + if re.search(self.ready_match, line): + self.ready_line = line + self.logger.info("Subprocess is ready.") + self.ready_event.set() + if re.search(self.callback_match, line): + self.callback() + if not self.ready_event.is_set(): + self.ready_event.set() + warnings.warn(f"Subprocess {self.name} failed to start.") + if self.finished_callback: + self.finished_callback() + + def run(self): + self.ready_event = threading.Event() + self.ready_line = None + self.thread = threading.Thread(target=self._start) + self.thread.start() + self.ready_event.wait() + + def stop(self): + self.logger.info("Stopping subprocess.") + if self.process and self.process.is_running(): + self.process.terminate() + self.process.wait() + + @property + def is_running(self): + if self.process is None: + return False + return self.process.is_running() diff --git a/metagpt/environment/software_env/__init__.py b/metagpt/environment/software_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/software_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/software_env/software_env.py b/metagpt/environment/software_env/software_env.py new file mode 100644 index 000000000..94bc11659 --- /dev/null +++ b/metagpt/environment/software_env/software_env.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Software Env + + +from metagpt.environment.base_env import Environment + + +class SoftwareEnv(Environment): + """a specific alias name""" + + pass diff --git a/metagpt/environment/stanford_town_env/__init__.py b/metagpt/environment/stanford_town_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/stanford_town_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/stanford_town_env/stanford_town_env.py b/metagpt/environment/stanford_town_env/stanford_town_env.py new file mode 100644 index 000000000..8721d6cd1 --- /dev/null +++ b/metagpt/environment/stanford_town_env/stanford_town_env.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG StanfordTown Env + +from metagpt.environment.base_env import Environment +from metagpt.environment.stanford_town_env.stanford_town_ext_env import ( + StanfordTownExtEnv, +) + + +class StanfordTownEnv(Environment, StanfordTownExtEnv): + pass diff --git a/metagpt/environment/stanford_town_env/stanford_town_ext_env.py b/metagpt/environment/stanford_town_env/stanford_town_ext_env.py new file mode 100644 index 000000000..8a9a65965 --- /dev/null +++ b/metagpt/environment/stanford_town_env/stanford_town_ext_env.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The StanfordTown external environment to interate with the web interface +# refs to `generative_agents maze.py` + +import math +from pathlib import Path +from typing import Optional, Tuple + +from pydantic import ConfigDict, Field, model_validator + +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.utils.common import read_csv_to_list, read_json_file + + +class StanfordTownExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + maze_asset_path: Optional[Path] = Field(default=None, description="the path to store maze assets") + maze_width: int = Field(default=140, description="maze map width") + maze_height: int = Field(default=100, description="maze map height") + sq_tile_size: int = Field(default=32, description="the pixel height/width of a tile") + special_constraint: str = Field( + default="", description="a string description of any relevant special constraints " "the world might have" + ) + tiles: list[list[dict]] = Field(default=[]) + address_tiles: dict[str, set] = Field(default=dict()) + collision_maze: list[list] = Field(default=[]) + + @model_validator(mode="before") + @classmethod + def _init_maze(cls, values): + maze_asset_path = values["maze_asset_path"] + assert maze_asset_path + maze_asset_path = Path(maze_asset_path) + + maze_matrix_path = maze_asset_path.joinpath("matrix") + meta_info = read_json_file(maze_matrix_path.joinpath("maze_meta_info.json")) + + maze_width = int(meta_info["maze_width"]) + maze_height = int(meta_info["maze_height"]) + values["maze_width"] = maze_width + values["maze_height"] = maze_height + values["sq_tile_size"] = int(meta_info["sq_tile_size"]) + values["special_constraint"] = meta_info["special_constraint"] + + # READING IN SPECIAL BLOCKS + # Special blocks are those that are colored in the Tiled map. + # Here is an example row for the arena block file: + # e.g, "25331, Double Studio, Studio, Bedroom 2, Painting" + + blocks_folder = maze_matrix_path.joinpath("special_blocks") + + _wb = blocks_folder.joinpath("world_blocks.csv") + wb_rows = read_csv_to_list(_wb, header=False) + wb = wb_rows[0][-1] + + _sb = blocks_folder.joinpath("sector_blocks.csv") + sb_rows = read_csv_to_list(_sb, header=False) + sb_dict = dict() + for i in sb_rows: + sb_dict[i[0]] = i[-1] + + _ab = blocks_folder.joinpath("arena_blocks.csv") + ab_rows = read_csv_to_list(_ab, header=False) + ab_dict = dict() + for i in ab_rows: + ab_dict[i[0]] = i[-1] + + _gob = blocks_folder.joinpath("game_object_blocks.csv") + gob_rows = read_csv_to_list(_gob, header=False) + gob_dict = dict() + for i in gob_rows: + gob_dict[i[0]] = i[-1] + + _slb = blocks_folder.joinpath("spawning_location_blocks.csv") + slb_rows = read_csv_to_list(_slb, header=False) + slb_dict = dict() + for i in slb_rows: + slb_dict[i[0]] = i[-1] + + # [SECTION 3] Reading in the matrices + # This is your typical two dimensional matrices. It's made up of 0s and + # the number that represents the color block from the blocks folder. + maze_folder = maze_matrix_path.joinpath("maze") + + _cm = maze_folder.joinpath("collision_maze.csv") + collision_maze_raw = read_csv_to_list(_cm, header=False)[0] + _sm = maze_folder.joinpath("sector_maze.csv") + sector_maze_raw = read_csv_to_list(_sm, header=False)[0] + _am = maze_folder.joinpath("arena_maze.csv") + arena_maze_raw = read_csv_to_list(_am, header=False)[0] + _gom = maze_folder.joinpath("game_object_maze.csv") + game_object_maze_raw = read_csv_to_list(_gom, header=False)[0] + _slm = maze_folder.joinpath("spawning_location_maze.csv") + spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0] + + # Loading the maze. The mazes are taken directly from the json exports of + # Tiled maps. They should be in csv format. + # Importantly, they are "not" in a 2-d matrix format -- they are single + # row matrices with the length of width x height of the maze. So we need + # to convert here. + # example format: [['0', '0', ... '25309', '0',...], ['0',...]...] + # 25309 is the collision bar number right now. + collision_maze = [] + sector_maze = [] + arena_maze = [] + game_object_maze = [] + spawning_location_maze = [] + for i in range(0, len(collision_maze_raw), maze_width): + tw = maze_width + collision_maze += [collision_maze_raw[i : i + tw]] + sector_maze += [sector_maze_raw[i : i + tw]] + arena_maze += [arena_maze_raw[i : i + tw]] + game_object_maze += [game_object_maze_raw[i : i + tw]] + spawning_location_maze += [spawning_location_maze_raw[i : i + tw]] + values["collision_maze"] = collision_maze + + tiles = [] + for i in range(maze_height): + row = [] + for j in range(maze_width): + tile_details = dict() + tile_details["world"] = wb + + tile_details["sector"] = "" + if sector_maze[i][j] in sb_dict: + tile_details["sector"] = sb_dict[sector_maze[i][j]] + + tile_details["arena"] = "" + if arena_maze[i][j] in ab_dict: + tile_details["arena"] = ab_dict[arena_maze[i][j]] + + tile_details["game_object"] = "" + if game_object_maze[i][j] in gob_dict: + tile_details["game_object"] = gob_dict[game_object_maze[i][j]] + + tile_details["spawning_location"] = "" + if spawning_location_maze[i][j] in slb_dict: + tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]] + + tile_details["collision"] = False + if collision_maze[i][j] != "0": + tile_details["collision"] = True + + tile_details["events"] = set() + + row += [tile_details] + tiles += [row] + values["tiles"] = tiles + + # Each game object occupies an event in the tile. We are setting up the + # default event value here. + for i in range(maze_height): + for j in range(maze_width): + if tiles[i][j]["game_object"]: + object_name = ":".join( + [tiles[i][j]["world"], tiles[i][j]["sector"], tiles[i][j]["arena"], tiles[i][j]["game_object"]] + ) + go_event = (object_name, None, None, None) + tiles[i][j]["events"].add(go_event) + + # Reverse tile access. + # -- given a string address, we return a set of all + # tile coordinates belonging to that address (this is opposite of + # tiles that give you the string address given a coordinate). This is + # an optimization component for finding paths for the personas' movement. + # address_tiles['bedroom-2-a'] == {(58, 9)} + # address_tiles['double studio:recreation:pool table'] + # == {(29, 14), (31, 11), (30, 14), (32, 11), ...}, + address_tiles = dict() + for i in range(maze_height): + for j in range(maze_width): + addresses = [] + if tiles[i][j]["sector"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}' + addresses += [add] + if tiles[i][j]["arena"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}' + addresses += [add] + if tiles[i][j]["game_object"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}:' + add += f'{tiles[i][j]["game_object"]}' + addresses += [add] + if tiles[i][j]["spawning_location"]: + add = f'{tiles[i][j]["spawning_location"]}' + addresses += [add] + + for add in addresses: + if add in address_tiles: + address_tiles[add].add((j, i)) + else: + address_tiles[add] = set([(j, i)]) + values["address_tiles"] = address_tiles + return values + + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: + """ + Turns a pixel coordinate to a tile coordinate. + """ + x = math.ceil(px_coordinate[0] / self.sq_tile_size) + y = math.ceil(px_coordinate[1] / self.sq_tile_size) + return (x, y) + + @mark_as_readable + def get_collision_maze(self) -> list: + return self.collision_maze + + @mark_as_readable + def get_address_tiles(self) -> dict: + return self.address_tiles + + @mark_as_readable + def access_tile(self, tile: tuple[int, int]) -> dict: + """ + Returns the tiles details dictionary that is stored in self.tiles of the + designated x, y location. + + INPUT + tile: The tile coordinate of our interest in (x, y) form. + OUTPUT + The tile detail dictionary for the designated tile. + EXAMPLE OUTPUT + Given (58, 9), + self.tiles[9][58] = {'world': 'double studio', + 'sector': 'double studio', 'arena': 'bedroom 2', + 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', + 'collision': False, + 'events': {('double studio:double studio:bedroom 2:bed', + None, None)}} + """ + x = tile[0] + y = tile[1] + return self.tiles[y][x] + + @mark_as_readable + def get_tile_path(self, tile: tuple[int, int], level: str) -> str: + """ + Get the tile string address given its coordinate. You designate the level + by giving it a string level description. + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + level: world, sector, arena, or game object + OUTPUT + The string address for the tile. + EXAMPLE OUTPUT + Given tile=(58, 9), and level=arena, + "double studio:double studio:bedroom 2" + """ + x = tile[0] + y = tile[1] + tile = self.tiles[y][x] + + path = f"{tile['world']}" + if level == "world": + return path + else: + path += f":{tile['sector']}" + + if level == "sector": + return path + else: + path += f":{tile['arena']}" + + if level == "arena": + return path + else: + path += f":{tile['game_object']}" + + return path + + @mark_as_readable + def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: + """ + Given the current tile and vision_r, return a list of tiles that are + within the radius. Note that this implementation looks at a square + boundary when determining what is within the radius. + i.e., for vision_r, returns x's. + x x x x x + x x x x x + x x P x x + x x x x x + x x x x x + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + vision_r: The radius of the persona's vision. + OUTPUT: + nearby_tiles: a list of tiles that are within the radius. + """ + left_end = 0 + if tile[0] - vision_r > left_end: + left_end = tile[0] - vision_r + + right_end = self.maze_width - 1 + if tile[0] + vision_r + 1 < right_end: + right_end = tile[0] + vision_r + 1 + + bottom_end = self.maze_height - 1 + if tile[1] + vision_r + 1 < bottom_end: + bottom_end = tile[1] + vision_r + 1 + + top_end = 0 + if tile[1] - vision_r > top_end: + top_end = tile[1] - vision_r + + nearby_tiles = [] + for i in range(left_end, right_end): + for j in range(top_end, bottom_end): + nearby_tiles += [(i, j)] + return nearby_tiles + + @mark_as_writeable + def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]): + self.tiles[pt_y][pt_x]["events"].add(event) + + @mark_as_writeable + def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """ + Add an event triple to a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + self.tiles[tile[1]][tile[0]]["events"].add(curr_event) + + @mark_as_writeable + def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """dswaq + Remove an event triple from a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + + @mark_as_writeable + def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + new_event = (event[0], None, None, None) + self.tiles[tile[1]][tile[0]]["events"].add(new_event) + + @mark_as_writeable + def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: + """ + Remove an event triple that has the input subject from a tile. + + INPUT: + subject: "Isabella Rodriguez" + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event[0] == subject: + self.tiles[tile[1]][tile[0]]["events"].remove(event) diff --git a/metagpt/environment/werewolf_env/__init__.py b/metagpt/environment/werewolf_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/metagpt/environment/werewolf_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/werewolf_env/werewolf_env.py b/metagpt/environment/werewolf_env/werewolf_env.py new file mode 100644 index 000000000..d174f322c --- /dev/null +++ b/metagpt/environment/werewolf_env/werewolf_env.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Werewolf Env + +from pydantic import Field + +from metagpt.environment.base_env import Environment +from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv +from metagpt.logs import logger +from metagpt.schema import Message + + +class WerewolfEnv(Environment, WerewolfExtEnv): + timestamp: int = Field(default=0) + + def publish_message(self, message: Message, add_timestamp: bool = True): + """Post information to the current environment""" + logger.debug(f"publish_message: {message.dump()}") + if add_timestamp: + # Because the content of the message may be repeated, for example, killing the same person in two nights + # Therefore, a unique timestamp prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory. + message.content = f"{self.timestamp} | " + message.content + self.memory.add(message) + self.history += f"\n{message}" + + async def run(self, k=1): + """Process all Role runs by order""" + for _ in range(k): + for role in self.roles.values(): + await role.run() + self.timestamp += 1 diff --git a/metagpt/environment/werewolf_env/werewolf_ext_env.py b/metagpt/environment/werewolf_env/werewolf_ext_env.py new file mode 100644 index 000000000..7c4b4c475 --- /dev/null +++ b/metagpt/environment/werewolf_env/werewolf_ext_env.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The werewolf game external environment to integrate with + +import random +from collections import Counter +from enum import Enum +from typing import Callable, Optional + +from pydantic import ConfigDict, Field + +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.logs import logger + + +class RoleState(Enum): + ALIVE = "alive" # the role is alive + KILLED = "killed" # the role is killed by werewolf or voting + POISONED = "poisoned" # the role is killed by posion + SAVED = "saved" # the role is saved by antidote + + +# the ordered rules by the moderator to announce to everyone each step +STEP_INSTRUCTIONS = { + 0: { + "content": "It’s dark, everyone close your eyes. I will talk with you/your team secretly at night.", + "send_to": "Moderator", # for moderator to continuen speaking + "restricted_to": "", + }, + 1: { + "content": "Guard, please open your eyes!", + "send_to": "Moderator", # for moderator to continuen speaking + "restricted_to": "", + }, + 2: { + "content": """Guard, now tell me who you protect tonight? + You only choose one from the following living options please: {living_players}. + Or you can pass. For example: Protect ...""", + "send_to": "Guard", + "restricted_to": "Moderator,Guard", + }, + 3: {"content": "Guard, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 4: {"content": "Werewolves, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 5: { + "content": """Werewolves, I secretly tell you that {werewolf_players} are + all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves. + choose one from the following living options please: + {living_players}. For example: Kill ...""", + "send_to": "Werewolf", + "restricted_to": "Moderator,Werewolf", + }, + 6: {"content": "Werewolves, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 7: {"content": "Witch, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 8: { + "content": """Witch, tonight {player_hunted} has been killed by the werewolves. + You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""", + "send_to": "Witch", + "restricted_to": "Moderator,Witch", + }, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人 + 9: { + "content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players? + Choose one from the following living options: {living_players}. + If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""", + "send_to": "Witch", + "restricted_to": "Moderator,Witch", + }, # + 10: {"content": "Witch, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 11: {"content": "Seer, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 12: { + "content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight? + Choose only one from the following living options:{living_players}.""", + "send_to": "Seer", + "restricted_to": "Moderator,Seer", + }, + 13: {"content": "Seer, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + # The 1-st daytime + 14: { + "content": """It's daytime. Everyone woke up except those who had been killed.""", + "send_to": "Moderator", + "restricted_to": "", + }, + 15: {"content": "{player_current_dead} was killed last night!", "send_to": "Moderator", "restricted_to": ""}, + 16: { + "content": """Living players: {living_players}, now freely talk about the current situation based on your observation and + reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""", + "send_to": "", # send to all to speak in daytime + "restricted_to": "", + }, + 17: { + "content": """Now vote and tell me who you think is the werewolf. Don’t mention your role. + You only choose one from the following living options please: + {living_players}. Say ONLY: I vote to eliminate ...""", + "send_to": "", + "restricted_to": "", + }, + 18: {"content": """{player_current_dead} was eliminated.""", "send_to": "Moderator", "restricted_to": ""}, +} + + +class WerewolfExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + players_state: dict[str, tuple[str, RoleState]] = Field( + default=dict(), description="the player's role type and state by player_name" + ) + + round_idx: int = Field(default=0) # the current round + step_idx: int = Field(default=0) # the current step of current round + eval_step_idx: int = Field(default=0) + per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS)) + + # game global states + game_setup: str = Field(default="", description="game setup including role and its num") + special_role_players: list[str] = Field(default=[]) + winner: Optional[str] = Field(default=None) + win_reason: Optional[str] = Field(default=None) + witch_poison_left: int = Field(default=1) + witch_antidote_left: int = Field(default=1) + + # game current round states, a round is from closing your eyes to the next time you close your eyes + round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result") + round_votes: dict[str, str] = Field( + default=dict(), description="daytime all players vote result, key=voteer, value=voted one" + ) + player_hunted: Optional[str] = Field(default=None) + player_protected: Optional[str] = Field(default=None) + is_hunted_player_saved: bool = Field(default=False) + player_poisoned: Optional[str] = Field(default=None) + player_current_dead: list[str] = Field(default=[]) + + @property + def living_players(self) -> list[str]: + player_names = [] + for name, roletype_state in self.players_state.items(): + if roletype_state[1] in [RoleState.ALIVE, RoleState.SAVED]: + player_names.append(name) + return player_names + + def _role_type_players(self, role_type: str) -> list[str]: + """return player name of particular role type""" + player_names = [] + for name, roletype_state in self.players_state.items(): + if role_type in roletype_state[0]: + player_names.append(name) + return player_names + + @property + def werewolf_players(self) -> list[str]: + player_names = self._role_type_players(role_type="Werewolf") + return player_names + + @property + def villager_players(self) -> list[str]: + player_names = self._role_type_players(role_type="Villager") + return player_names + + def _init_players_state(self, players: list["Role"]): + for play in players: + self.players_state[play.name] = (play.profile, RoleState.ALIVE) + + self.special_role_players = [ + p for p in self.living_players if p not in self.werewolf_players + self.villager_players + ] + + def init_game_setup( + self, + role_uniq_objs: list[object], + num_villager: int = 2, + num_werewolf: int = 2, + shuffle=True, + add_human=False, + use_reflection=True, + use_experience=False, + use_memory_selection=False, + new_experience_version="", + prepare_human_player=Callable, + ) -> tuple[str, list]: + """init players using different roles' num""" + role_objs = [] + for role_obj in role_uniq_objs: + if str(role_obj) == "Villager": + role_objs.extend([role_obj] * num_villager) + elif str(role_obj) == "Werewolf": + role_objs.extend([role_obj] * num_werewolf) + else: + role_objs.append(role_obj) + if shuffle: + random.shuffle(len(role_objs)) + if add_human: + assigned_role_idx = random.randint(0, len(role_objs) - 1) + assigned_role = role_objs[assigned_role_idx] + role_objs[assigned_role_idx] = prepare_human_player(assigned_role) # TODO + + players = [ + role( + name=f"Player{i + 1}", + use_reflection=use_reflection, + use_experience=use_experience, + use_memory_selection=use_memory_selection, + new_experience_version=new_experience_version, + ) + for i, role in enumerate(role_objs) + ] + + if add_human: + logger.info(f"You are assigned {players[assigned_role_idx].name}({players[assigned_role_idx].profile})") + + game_setup = ["Game setup:"] + [f"{player.name}: {player.profile}," for player in players] + self.game_setup = "\n".join(game_setup) + + self._init_players_state(players) # init players state + + return self.game_setup, players + + def _update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED): + for player_name in player_names: + if player_name in self.players_state: + roletype_state = self.players_state[player_name] + self.players_state[player_name] = (roletype_state[0], state) + + def _check_valid_role(self, player: "Role", role_type: str) -> bool: + return True if role_type in str(player) else False + + def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool: + step_idx = self.step_idx % self.per_round_steps + if particular_step > 0 and step_idx != particular_step: # step no + # particular_step = 18, not daytime vote time, ignore + # particular_step = 15, not nighttime hunt time, ignore + return False + if player_name not in self.living_players: + return False + return True + + @mark_as_readable + def curr_step_instruction(self) -> dict: + step_idx = self.step_idx % len(STEP_INSTRUCTIONS) + instruction = STEP_INSTRUCTIONS[step_idx] + self.step_idx += 1 + return instruction + + @mark_as_readable + def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]: + players_state = { + player_name: self.players_state[player_name][1] # only return role state + for player_name in player_names + if player_name in self.players_state + } + return players_state + + @mark_as_writeable + def vote_kill_someone(self, voteer: "Role", player_name: str = None): + """player vote result at daytime + player_name: if it's None, regard as abstaining from voting + """ + if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no + return + + self.round_votes[voteer.name] = player_name + # check if all living players finish voting, then get the dead one + if list(self.round_votes.keys()) == self.living_players: + voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first + voted_all = [item for item in voted_all if item] + self.player_current_dead = Counter(voted_all).most_common()[0][0] + self._update_players_state([self.player_current_dead]) + + @mark_as_writeable + def wolf_kill_someone(self, wolf: "Role", player_name: str): + if not self._check_valid_role(wolf, "Werewolf"): + return + if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no + return + + self.round_hunts[wolf.name] = player_name + living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + # check if all living wolfs finish hunting, then get the hunted one + if list(self.round_hunts.keys()) == living_werewolf: + hunted_all = list(self.round_hunts.values()) + self.player_hunted = Counter(hunted_all).most_common()[0][0] + + @mark_as_writeable + def witch_poison_someone(self, witch: "Role", player_name: str = None): + if not self._check_valid_role(witch, "Witch"): + return + if not self._check_player_continue(player_name): + return + + self._update_players_state([player_name], RoleState.POISONED) + self.player_poisoned = player_name + + @mark_as_writeable + def witch_save_someone(self, witch: "Role", player_name: str = None): + if not self._check_valid_role(witch, "Witch"): + return + if not self._check_player_continue(player_name): + return + + self._update_players_state([player_name], RoleState.SAVED) + self.player_protected = player_name + + @mark_as_writeable + def update_game_states(self, memories: list): + step_idx = self.step_idx % self.per_round_steps + if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx: + return + else: + self.eval_step_idx.append(self.step_idx) # record evaluation, avoid repetitive evaluation at the same step + + if step_idx == 15: # step no + # night ends: after all special roles acted, process the whole night + self.player_current_dead = [] # reset + + if self.player_hunted != self.player_protected and not self.is_hunted_player_saved: + self.player_current_dead.append(self.player_hunted) + if self.player_poisoned: + self.player_current_dead.append(self.player_poisoned) + + self._update_players_state([self.player_current_dead]) + # reset + self.player_hunted = None + self.player_protected = None + self.is_hunted_player_saved = False + self.player_poisoned = None + + # game's termination condition + living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + living_villagers = [p for p in self.villager_players if p in self.living_players] + living_special_roles = [p for p in self.special_role_players if p in self.living_players] + if not living_werewolf: + self.winner = "good guys" + self.win_reason = "werewolves all dead" + elif not living_villagers or not living_special_roles: + self.winner = "werewolf" + self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead" + if self.winner is not None: + self._record_all_experiences() # TODO diff --git a/metagpt/logs.py b/metagpt/logs.py index fb0fdd553..90bac21aa 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -15,14 +15,15 @@ from loguru import logger as _logger from metagpt.const import METAGPT_ROOT -def define_log_level(print_level="INFO", logfile_level="DEBUG"): +def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None): """Adjust the log level to above level""" current_date = datetime.now() formatted_date = current_date.strftime("%Y%m%d") + log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name _logger.remove() _logger.add(sys.stderr, level=print_level) - _logger.add(METAGPT_ROOT / f"logs/{formatted_date}.txt", level=logfile_level) + _logger.add(METAGPT_ROOT / f"logs/{log_name}.txt", level=logfile_level) return _logger diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 23278479c..b144471b5 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -34,8 +34,26 @@ class BaseLLM(ABC): def __init__(self, config: LLMConfig): pass - def _user_msg(self, msg: str) -> dict[str, str]: - return {"role": "user", "content": msg} + def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: + if images: + # as gpt-4v, chat with image + return self._user_msg_with_imgs(msg, images) + else: + return {"role": "user", "content": msg} + + def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]): + """ + images: can be list of http(s) url or base64 + """ + if isinstance(images, str): + images = [images] + content = [{"type": "text", "text": msg}] + for image in images: + # image url or image base64 + url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}" + # it can with multiple-image inputs + content.append({"type": "image_url", "image_url": url}) + return {"role": "user", "content": content} def _assistant_msg(self, msg: str) -> dict[str, str]: return {"role": "assistant", "content": msg} @@ -54,6 +72,7 @@ class BaseLLM(ABC): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -65,7 +84,7 @@ class BaseLLM(ABC): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) logger.debug(message) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 206701efd..4b5393215 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,7 +29,7 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, decode_image from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( @@ -101,7 +101,7 @@ class OpenAILLM(BaseLLM): "messages": messages, "max_tokens": self._get_max_tokens(messages), "n": 1, - "stop": None, + # "stop": None, # default it's None and gpt4-v can't have this one "temperature": 0.3, "model": self.model, "timeout": max(self.config.timeout, timeout), @@ -303,3 +303,24 @@ class OpenAILLM(BaseLLM): async def aspeech_to_text(self, **kwargs): """speech to text""" return await self.aclient.audio.transcriptions.create(**kwargs) + + async def gen_image( + self, + prompt: str, + size: str = "1024x1024", + quality: str = "standard", + model: str = None, + resp_format: str = "url", + ) -> list["Image"]: + """image generate""" + assert resp_format in ["url", "b64_json"] + if not model: + model = self.model + res = await self.aclient.images.generate( + model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format + ) + imgs = [] + for item in res.data: + img_url_or_b64 = item.url if resp_format == "url" else item.b64_json + imgs.append(decode_image(img_url_or_b64)) + return imgs diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 9efcf470e..bcfec708c 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,7 +23,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Iterable, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Iterable, Optional, Set, Type, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -40,6 +40,10 @@ from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.project_repo import ProjectRepo from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +if TYPE_CHECKING: + from metagpt.environment import Environment # noqa: F401 + + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """ CONSTRAINT_TEMPLATE = "the constraint is {constraints}. " @@ -119,11 +123,17 @@ class RoleContext(BaseModel): def history(self) -> list[Message]: return self.memory.get() + @classmethod + def model_rebuild(cls, **kwargs): + from metagpt.environment.base_env import Environment # noqa: F401 + + super().model_rebuild(**kwargs) + class Role(SerializationMixin, ContextMixin, BaseModel): """Role/Agent""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") name: str = "" profile: str = "" @@ -152,28 +162,24 @@ class Role(SerializationMixin, ContextMixin, BaseModel): __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` - def __init__(self, **data: Any): - self.pydantic_rebuild_model() - super().__init__(**data) + @model_validator(mode="after") + def validate_role_extra(self): + self._process_role_extra() + return self + + def _process_role_extra(self): + kwargs = self.model_extra or {} if self.is_human: self.llm = HumanProvider(None) self._check_actions() self.llm.system_prompt = self._get_prefix() - self._watch(data.get("watch") or [UserRequirement]) + self._watch(kwargs.pop("watch", [UserRequirement])) if self.latest_observed_msg: self.recovered = True - @staticmethod - def pydantic_rebuild_model(): - """Rebuild model to avoid `RecursionError: maximum recursion depth exceeded in comparison`""" - from metagpt.environment import Environment - - Environment - Role.model_rebuild() - @property def todo(self) -> Action: """Get action to do""" @@ -594,3 +600,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): if self.actions: return any_to_name(self.actions[0]) return "" + + +RoleContext.model_rebuild() diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 7929ce7fe..bc449b5cd 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -8,12 +8,12 @@ from typing import Optional -from pydantic import Field +from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize, UserRequirement from metagpt.document_store.base_store import BaseStore from metagpt.roles import Role -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine class Sales(Role): @@ -29,14 +29,13 @@ class Sales(Role): store: Optional[BaseStore] = Field(default=None, exclude=True) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._set_store(self.store) - - def _set_store(self, store): - if store: - action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) + @model_validator(mode="after") + def validate_stroe(self): + if self.store: + search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy) + action = SearchAndSummarize(search_engine=search_engine, context=self.context) else: - action = SearchAndSummarize() + action = SearchAndSummarize self.set_actions([action]) self._watch([UserRequirement]) + return self diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 19a73a40e..557c5ae95 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -8,7 +8,9 @@ the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ -from pydantic import Field +from typing import Optional + +from pydantic import Field, model_validator from metagpt.actions import SearchAndSummarize from metagpt.actions.action_node import ActionNode @@ -16,7 +18,7 @@ from metagpt.actions.action_output import ActionOutput from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.tools import SearchEngineType +from metagpt.tools.search_engine import SearchEngine class Searcher(Role): @@ -28,33 +30,22 @@ class Searcher(Role): profile (str): Role profile. goal (str): Goal of the searcher. constraints (str): Constraints or limitations for the searcher. - engine (SearchEngineType): The type of search engine to use. + search_engine (SearchEngine): The search engine to use. """ name: str = Field(default="Alice") profile: str = Field(default="Smart Assistant") goal: str = "Provide search services for users" constraints: str = "Answer is rich and complete" - engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + search_engine: Optional[SearchEngine] = None - def __init__(self, **kwargs) -> None: - """ - Initializes the Searcher role with given attributes. - - Args: - name (str): Name of the searcher. - profile (str): Role profile. - goal (str): Goal of the searcher. - constraints (str): Constraints or limitations for the searcher. - engine (SearchEngineType): The type of search engine to use. - """ - super().__init__(**kwargs) - self.set_actions([SearchAndSummarize(engine=self.engine)]) - - def set_search_func(self, search_func): - """Sets a custom search function for the searcher.""" - action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) - self.set_actions([action]) + @model_validator(mode="after") + def post_root(self): + if self.search_engine: + self.set_actions([SearchAndSummarize(search_engine=self.search_engine, context=self.context)]) + else: + self.set_actions([SearchAndSummarize]) + return self async def _act_sp(self) -> Message: """Performs the search action in a single process.""" diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 0d0db9147..1e540bd0e 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -8,14 +8,23 @@ import importlib from typing import Callable, Coroutine, Literal, Optional, Union, overload +from pydantic import BaseModel, ConfigDict, model_validator from semantic_kernel.skill_definition import sk_function +from metagpt.configs.search_config import SearchConfig +from metagpt.logs import logger from metagpt.tools import SearchEngineType class SkSearchEngine: - def __init__(self): - self.search_engine = SearchEngine() + """A search engine class for executing searches. + + Attributes: + search_engine: The search engine instance used for executing searches. + """ + + def __init__(self, **kwargs): + self.search_engine = SearchEngine(**kwargs) @sk_function( description="searches results from Google. Useful when you need to find short " @@ -28,43 +37,85 @@ class SkSearchEngine: return result -class SearchEngine: - """Class representing a search engine. - - Args: - engine: The search engine type. Defaults to the search engine specified in the config. - run_func: The function to run the search. Defaults to None. +class SearchEngine(BaseModel): + """A model for configuring and executing searches with different search engines. Attributes: - run_func: The function to run the search. - engine: The search engine type. + model_config: Configuration for the model allowing arbitrary types. + engine: The type of search engine to use. + run_func: An optional callable for running the search. If not provided, it will be determined based on the engine. + api_key: An optional API key for the search engine. + proxy: An optional proxy for the search engine requests. """ - def __init__( + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE + run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None + api_key: Optional[str] = None + proxy: Optional[str] = None + + @model_validator(mode="after") + def validate_extra(self): + """Validates extra fields provided to the model and updates the run function accordingly.""" + data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True) + if self.model_extra: + data.update(self.model_extra) + self._process_extra(**data) + return self + + def _process_extra( self, - engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE, - run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None, + run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None, **kwargs, ): - if engine == SearchEngineType.SERPAPI_GOOGLE: + """Processes extra configuration and updates the run function based on the search engine type. + + Args: + run_func: An optional callable for running the search. If not provided, it will be determined based on the engine. + """ + if self.engine == SearchEngineType.SERPAPI_GOOGLE: module = "metagpt.tools.search_engine_serpapi" run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run - elif engine == SearchEngineType.SERPER_GOOGLE: + elif self.engine == SearchEngineType.SERPER_GOOGLE: module = "metagpt.tools.search_engine_serper" run_func = importlib.import_module(module).SerperWrapper(**kwargs).run - elif engine == SearchEngineType.DIRECT_GOOGLE: + elif self.engine == SearchEngineType.DIRECT_GOOGLE: module = "metagpt.tools.search_engine_googleapi" run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run - elif engine == SearchEngineType.DUCK_DUCK_GO: + elif self.engine == SearchEngineType.DUCK_DUCK_GO: module = "metagpt.tools.search_engine_ddg" run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run - elif engine == SearchEngineType.CUSTOM_ENGINE: - pass # run_func = run_func + elif self.engine == SearchEngineType.CUSTOM_ENGINE: + run_func = self.run_func else: raise NotImplementedError - self.engine = engine self.run_func = run_func + @classmethod + def from_search_config(cls, config: SearchConfig, **kwargs): + """Creates a SearchEngine instance from a SearchConfig. + + Args: + config: The search configuration to use for creating the SearchEngine instance. + """ + data = config.model_dump(exclude={"api_type", "search_func"}) + if config.search_func is not None: + data["run_func"] = config.search_func + + return cls(engine=config.api_type, **data, **kwargs) + + @classmethod + def from_search_func( + cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs + ): + """Creates a SearchEngine instance from a custom search function. + + Args: + search_func: A callable that executes the search. + """ + return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs) + @overload def run( self, @@ -83,15 +134,29 @@ class SearchEngine: ) -> list[dict[str, str]]: ... - async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]: + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + ignore_errors: bool = False, + ) -> Union[str, list[dict[str, str]]]: """Run a search query. Args: query: The search query. max_results: The maximum number of results to return. Defaults to 8. as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True. + ignore_errors: Whether to ignore errors during the search. Defaults to False. Returns: The search results as a string or a list of dictionaries. """ - return await self.run_func(query, max_results=max_results, as_string=as_string) + try: + return await self.run_func(query, max_results=max_results, as_string=as_string) + except Exception as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + if not ignore_errors: + raise e + return "" if as_string else [] diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py index 3d004a4ee..1412f20cf 100644 --- a/metagpt/tools/search_engine_ddg.py +++ b/metagpt/tools/search_engine_ddg.py @@ -5,9 +5,9 @@ from __future__ import annotations import asyncio import json from concurrent import futures -from typing import Literal, overload +from typing import Literal, Optional, overload -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict try: from duckduckgo_search import DDGS @@ -18,24 +18,16 @@ except ImportError: ) -class DDGAPIWrapper: - """Wrapper around duckduckgo_search API. +class DDGAPIWrapper(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) - To use this module, you should have the `duckduckgo_search` Python package installed. - """ + loop: Optional[asyncio.AbstractEventLoop] = None + executor: Optional[futures.Executor] = None + proxy: Optional[str] = None - def __init__( - self, - *, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, - ): - kwargs = {} - if config.proxy: - kwargs["proxies"] = config.proxy - self.loop = loop - self.executor = executor - self.ddgs = DDGS(**kwargs) + @property + def ddgs(self): + return DDGS(proxies=self.proxy) @overload def run( diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index 0a8f796cb..66b5ba950 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -4,19 +4,16 @@ from __future__ import annotations import asyncio import json +import warnings from concurrent import futures from typing import Optional from urllib.parse import urlparse import httplib2 -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config -from metagpt.logs import logger +from pydantic import BaseModel, ConfigDict, model_validator try: from googleapiclient.discovery import build - from googleapiclient.errors import HttpError except ImportError: raise ImportError( "To use this module, you should have the `google-api-python-client` Python package installed. " @@ -27,40 +24,41 @@ except ImportError: class GoogleAPIWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - google_api_key: Optional[str] = Field(default=None, validate_default=True) - google_cse_id: Optional[str] = Field(default=None, validate_default=True) + api_key: str + cse_id: str loop: Optional[asyncio.AbstractEventLoop] = None executor: Optional[futures.Executor] = None + proxy: Optional[str] = None - @field_validator("google_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_google_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_google(cls, values: dict) -> dict: + if "google_api_key" in values: + values.setdefault("api_key", values["google_api_key"]) + warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the google_api_key when constructing an object. Alternatively, " - "ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain " + "To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://console.cloud.google.com/apis/credentials." ) - return val - @field_validator("google_cse_id", mode="before") - @classmethod - def check_google_cse_id(cls, val: str): - val = val or config.search.cse_id - if not val: + if "google_cse_id" in values: + values.setdefault("cse_id", values["google_cse_id"]) + warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2) + + if "cse_id" not in values: raise ValueError( - "To use, make sure you provide the google_cse_id when constructing an object. Alternatively, " - "ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain " - "an API key from https://programmablesearchengine.google.com/controlpanel/create." + "To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain " + "the cse_id from https://programmablesearchengine.google.com/controlpanel/create." ) - return val + return values @property def google_api_client(self): - build_kwargs = {"developerKey": self.google_api_key} - if config.proxy: - parse_result = urlparse(config.proxy) + build_kwargs = {"developerKey": self.api_key} + if self.proxy: + parse_result = urlparse(self.proxy) proxy_type = parse_result.scheme if proxy_type == "https": proxy_type = "http" @@ -96,17 +94,11 @@ class GoogleAPIWrapper(BaseModel): """ loop = self.loop or asyncio.get_event_loop() future = loop.run_in_executor( - self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute + self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute ) - try: - result = await future - # Extract the search result items from the response - search_results = result.get("items", []) - - except HttpError as e: - # Handle errors in the API call - logger.exception(f"fail to search {query} for {e}") - search_results = [] + result = await future + # Extract the search result items from the response + search_results = result.get("items", []) focus = focus or ["snippet", "link", "title"] details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 8d27d493d..5744b1b62 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -5,18 +5,17 @@ @Author : alexanderwu @File : search_engine_serpapi.py """ +import warnings from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict, Field, model_validator class SerpAPIWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - search_engine: Any = None #: :meta private: + api_key: str params: dict = Field( default_factory=lambda: { "engine": "google", @@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) - # should add `validate_default=True` to check with default value - serpapi_api_key: Optional[str] = Field(default=None, validate_default=True) aiosession: Optional[aiohttp.ClientSession] = None + proxy: Optional[str] = None - @field_validator("serpapi_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_serpapi_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_serpapi(cls, values: dict) -> dict: + if "serpapi_api_key" in values: + values.setdefault("api_key", values["serpapi_api_key"]) + warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, " - "ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain " - "an API key from https://serpapi.com/." + "To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain" + " an API key from https://serpapi.com/." ) - return val + return values async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" @@ -60,11 +60,11 @@ class SerpAPIWrapper(BaseModel): url, params = construct_url_and_params() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.get(url, params=params) as response: + async with session.get(url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get(url, params=params) as response: + async with self.aiosession.get(url, params=params, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() @@ -73,7 +73,7 @@ class SerpAPIWrapper(BaseModel): def get_params(self, query: str) -> Dict[str, str]: """Get parameters for SerpAPI.""" _params = { - "api_key": self.serpapi_api_key, + "api_key": self.api_key, "q": query, } params = {**self.params, **_params} diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 71ee2f4f9..ba2fb4f93 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -6,33 +6,34 @@ @File : search_engine_serpapi.py """ import json +import warnings from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from metagpt.config2 import config +from pydantic import BaseModel, ConfigDict, Field, model_validator class SerperWrapper(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - search_engine: Any = None #: :meta private: + api_key: str payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10}) - serper_api_key: Optional[str] = Field(default=None, validate_default=True) aiosession: Optional[aiohttp.ClientSession] = None + proxy: Optional[str] = None - @field_validator("serper_api_key", mode="before") + @model_validator(mode="before") @classmethod - def check_serper_api_key(cls, val: str): - val = val or config.search.api_key - if not val: + def validate_serper(cls, values: dict) -> dict: + if "serper_api_key" in values: + values.setdefault("api_key", values["serper_api_key"]) + warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2) + + if "api_key" not in values: raise ValueError( - "To use, make sure you provide the serper_api_key when constructing an object. Alternatively, " - "ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain " + "To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain " "an API key from https://serper.dev/." ) - return val + return values async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" @@ -54,11 +55,11 @@ class SerperWrapper(BaseModel): url, payloads, headers = construct_url_and_payload_and_headers() if not self.aiosession: async with aiohttp.ClientSession() as session: - async with session.post(url, data=payloads, headers=headers) as response: + async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() else: - async with self.aiosession.get.post(url, data=payloads, headers=headers) as response: + async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response: response.raise_for_status() res = await response.json() @@ -76,7 +77,7 @@ class SerperWrapper(BaseModel): return json.dumps(payloads, sort_keys=True) def get_headers(self) -> Dict[str, str]: - headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"} + headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} return headers @staticmethod diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index 411c1604b..01339e51a 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -1,36 +1,95 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - from __future__ import annotations import importlib -from typing import Any, Callable, Coroutine, overload +from typing import Any, Callable, Coroutine, Optional, Union, overload +from pydantic import BaseModel, ConfigDict, model_validator + +from metagpt.configs.browser_config import BrowserConfig from metagpt.tools import WebBrowserEngineType from metagpt.utils.parse_html import WebPage -class WebBrowserEngine: - def __init__( - self, - engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT, - run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None, - ): - if engine is None: - raise NotImplementedError +class WebBrowserEngine(BaseModel): + """Defines a web browser engine configuration for automated browsing and data extraction. - if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT: + This class encapsulates the configuration and operational logic for different web browser engines, + such as Playwright, Selenium, or custom implementations. It provides a unified interface to run + browser automation tasks. + + Attributes: + model_config: Configuration dictionary allowing arbitrary types and extra fields. + engine: The type of web browser engine to use. + run_func: An optional coroutine function to run the browser engine. + proxy: An optional proxy server URL to use with the browser engine. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None + proxy: Optional[str] = None + + @model_validator(mode="after") + def validate_extra(self): + """Validates and processes extra configuration data after model initialization. + + This method is automatically called by Pydantic to validate and process any extra configuration + data provided to the model. It ensures that the extra data is properly integrated into the model's + configuration and operational logic. + + Returns: + The instance itself after processing the extra data. + """ + data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True) + if self.model_extra: + data.update(self.model_extra) + self._process_extra(**data) + return self + + def _process_extra(self, **kwargs): + """Processes extra configuration data to set up the browser engine run function. + + Depending on the specified engine type, this method dynamically imports and configures + the appropriate browser engine wrapper and its run function. + + Args: + **kwargs: Arbitrary keyword arguments representing extra configuration data. + + Raises: + NotImplementedError: If the engine type is not supported. + """ + if self.engine is WebBrowserEngineType.PLAYWRIGHT: module = "metagpt.tools.web_browser_engine_playwright" - run_func = importlib.import_module(module).PlaywrightWrapper().run - elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM: + run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run + elif self.engine is WebBrowserEngineType.SELENIUM: module = "metagpt.tools.web_browser_engine_selenium" - run_func = importlib.import_module(module).SeleniumWrapper().run - elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM: - run_func = run_func + run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run + elif self.engine is WebBrowserEngineType.CUSTOM: + run_func = self.run_func else: raise NotImplementedError self.run_func = run_func - self.engine = engine + + @classmethod + def from_browser_config(cls, config: BrowserConfig, **kwargs): + """Creates a WebBrowserEngine instance from a BrowserConfig object and additional keyword arguments. + + This class method facilitates the creation of a WebBrowserEngine instance by extracting + configuration data from a BrowserConfig object and optionally merging it with additional + keyword arguments. + + Args: + config: A BrowserConfig object containing base configuration data. + **kwargs: Optional additional keyword arguments to override or extend the configuration. + + Returns: + A new instance of WebBrowserEngine configured according to the provided arguments. + """ + data = config.model_dump() + return cls(**data, **kwargs) @overload async def run(self, url: str) -> WebPage: @@ -41,4 +100,16 @@ class WebBrowserEngine: ... async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + """Runs the browser engine to load one or more web pages. + + This method is the implementation of the overloaded run signatures. It delegates the task + of loading web pages to the configured run function, handling either a single URL or multiple URLs. + + Args: + url: The URL of the first web page to load. + *urls: Additional URLs of web pages to load, if any. + + Returns: + A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided. + """ return await self.run_func(url, *urls) diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py index 7c33da923..2df288b1a 100644 --- a/metagpt/tools/web_browser_engine_playwright.py +++ b/metagpt/tools/web_browser_engine_playwright.py @@ -6,15 +6,16 @@ from __future__ import annotations import asyncio import sys from pathlib import Path -from typing import Literal +from typing import Literal, Optional from playwright.async_api import async_playwright +from pydantic import BaseModel, Field, PrivateAttr from metagpt.logs import logger from metagpt.utils.parse_html import WebPage -class PlaywrightWrapper: +class PlaywrightWrapper(BaseModel): """Wrapper around Playwright. To use this module, you should have the `playwright` Python package installed and ensure that @@ -23,28 +24,23 @@ class PlaywrightWrapper: command `playwright install` for the first time. """ - def __init__( - self, - browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium", - launch_kwargs: dict | None = None, - **kwargs, - ) -> None: - from metagpt.config2 import ( - config, # avoid circular import error when importing tools" - ) + browser_type: Literal["chromium", "firefox", "webkit"] = "chromium" + launch_kwargs: dict = Field(default_factory=dict) + proxy: Optional[str] = None + context_kwargs: dict = Field(default_factory=dict) + _has_run_precheck: bool = PrivateAttr(False) - self.browser_type = browser_type - launch_kwargs = launch_kwargs or {} - if config.proxy and "proxy" not in launch_kwargs: + def __init__(self, **kwargs): + super().__init__(**kwargs) + + launch_kwargs = self.launch_kwargs + if self.proxy and "proxy" not in launch_kwargs: args = launch_kwargs.get("args", []) if not any(str.startswith(i, "--proxy-server=") for i in args): - launch_kwargs["proxy"] = {"server": config.proxy} - self.launch_kwargs = launch_kwargs - context_kwargs = {} + launch_kwargs["proxy"] = {"server": self.proxy} + if "ignore_https_errors" in kwargs: - context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"] - self._context_kwargs = context_kwargs - self._has_run_precheck = False + self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"] async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: async with async_playwright() as ap: @@ -58,7 +54,7 @@ class PlaywrightWrapper: return await _scrape(browser, url) async def _scrape(self, browser, url): - context = await browser.new_context(**self._context_kwargs) + context = await browser.new_context(**self.context_kwargs) page = await context.new_page() async with page: try: @@ -78,8 +74,8 @@ class PlaywrightWrapper: executable_path = Path(browser_type.executable_path) if not executable_path.exists() and "executable_path" not in self.launch_kwargs: kwargs = {} - if config.proxy: - kwargs["env"] = {"ALL_PROXY": config.proxy} + if self.proxy: + kwargs["env"] = {"ALL_PROXY": self.proxy} await _install_browsers(self.browser_type, **kwargs) if self._has_run_precheck: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 02dd5c173..3b1682291 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -7,19 +7,19 @@ import asyncio import importlib from concurrent import futures from copy import deepcopy -from typing import Literal +from typing import Callable, Literal, Optional +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait from webdriver_manager.core.download_manager import WDMDownloadManager from webdriver_manager.core.http import WDMHttpClient -from metagpt.config2 import config from metagpt.utils.parse_html import WebPage -class SeleniumWrapper: +class SeleniumWrapper(BaseModel): """Wrapper around Selenium. To use this module, you should check the following: @@ -31,25 +31,28 @@ class SeleniumWrapper: can scrape web pages using the Selenium WebBrowserEngine. """ - def __init__( - self, - browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome", - launch_kwargs: dict | None = None, - *, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, - ) -> None: - self.browser_type = browser_type - launch_kwargs = launch_kwargs or {} - if config.proxy and "proxy-server" not in launch_kwargs: - launch_kwargs["proxy-server"] = config.proxy + model_config = ConfigDict(arbitrary_types_allowed=True) - self.executable_path = launch_kwargs.pop("executable_path", None) - self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()] - self._has_run_precheck = False - self._get_driver = None - self.loop = loop - self.executor = executor + browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome" + launch_kwargs: dict = Field(default_factory=dict) + proxy: Optional[str] = None + loop: Optional[asyncio.AbstractEventLoop] = None + executor: Optional[futures.Executor] = None + _has_run_precheck: bool = PrivateAttr(False) + _get_driver: Optional[Callable] = PrivateAttr(None) + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + if self.proxy and "proxy-server" not in self.launch_kwargs: + self.launch_kwargs["proxy-server"] = self.proxy + + @property + def launch_args(self): + return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"] + + @property + def executable_path(self): + return self.launch_kwargs.get("executable_path") async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: await self._run_precheck() @@ -66,7 +69,9 @@ class SeleniumWrapper: self.loop = self.loop or asyncio.get_event_loop() self._get_driver = await self.loop.run_in_executor( self.executor, - lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path), + lambda: _gen_get_driver_func( + self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy + ), ) self._has_run_precheck = True @@ -92,13 +97,17 @@ _webdriver_manager_types = { class WDMHttpProxyClient(WDMHttpClient): + def __init__(self, proxy: str = None): + super().__init__() + self.proxy = proxy + def get(self, url, **kwargs): - if "proxies" not in kwargs and config.proxy: - kwargs["proxies"] = {"all_proxy": config.proxy} + if "proxies" not in kwargs and self.proxy: + kwargs["proxies"] = {"all_proxy": self.proxy} return super().get(url, **kwargs) -def _gen_get_driver_func(browser_type, *args, executable_path=None): +def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None): WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver") Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service") Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options") @@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): if not executable_path: module_name, type_name = _webdriver_manager_types[browser_type] DriverManager = getattr(importlib.import_module(module_name), type_name) - driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient())) + driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy))) # driver_manager.driver_cache.find_driver(driver_manager.driver)) executable_path = driver_manager.install() diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 9d6a6bb24..2264255dc 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -13,7 +13,9 @@ from __future__ import annotations import ast +import base64 import contextlib +import csv import importlib import inspect import json @@ -23,11 +25,14 @@ import re import sys import traceback import typing +from io import BytesIO from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union import aiofiles import loguru +import requests +from PIL import Image from pydantic_core import to_jsonable_python from tenacity import RetryCallState, RetryError, _utils @@ -339,6 +344,14 @@ def print_members(module, indent=0): print(f"{prefix}Method: {name}") +def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]: + sig = inspect.signature(func) + parameters = sig.parameters + return_type = sig.return_annotation + param_schema = {name: parameter.annotation for name, parameter in parameters.items()} + return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func} + + def parse_recipient(text): # FIXME: use ActionNode instead. pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now @@ -494,6 +507,29 @@ def write_json_file(json_file: str, data: list, encoding: str = None, indent: in json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) +def read_csv_to_list(curr_file: str, header=False, strip_trail=True): + """ + Reads in a csv file to a list of list. If header is True, it returns a + tuple with (header row, all rows) + ARGS: + curr_file: path to the current csv file. + RETURNS: + List of list where the component lists are the rows of the file. + """ + logger.debug(f"start read csv: {curr_file}") + analysis_list = [] + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + if strip_trail: + row = [i.strip() for i in row] + analysis_list += [row] + if not header: + return analysis_list + else: + return analysis_list[0], analysis_list[1:] + + def import_class(class_name: str, module_name: str) -> type: module = importlib.import_module(module_name) a_class = getattr(module, class_name) @@ -602,3 +638,45 @@ def list_files(root: str | Path) -> List[Path]: except Exception as e: logger.error(f"Error: {e}") return files + + +def is_coroutine_func(func: Callable) -> bool: + return inspect.iscoroutinefunction(func) + + +def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]: + """load mincraft skill from js files""" + if not skills_dir: + skills_dir = Path(__file__).parent.absolute() + if skill_names is None: + skill_names = [skill[:-3] for skill in os.listdir(f"{skills_dir}") if skill.endswith(".js")] + skills = [skills_dir.joinpath(f"{skill_name}.js").read_text() for skill_name in skill_names] + return skills + + +def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str: + """encode image from file or PIL.Image into base64""" + if isinstance(image_path_or_pil, Image.Image): + buffer = BytesIO() + image_path_or_pil.save(buffer, format="JPEG") + bytes_data = buffer.getvalue() + else: + if not image_path_or_pil.exists(): + raise FileNotFoundError(f"{image_path_or_pil} not exists") + with open(str(image_path_or_pil), "rb") as image_file: + bytes_data = image_file.read() + return base64.b64encode(bytes_data).decode(encoding) + + +def decode_image(img_url_or_b64: str) -> Image: + """decode image from url or base64 into PIL.Image""" + if img_url_or_b64.startswith("http"): + # image http(s) url + resp = requests.get(img_url_or_b64) + img = Image.open(BytesIO(resp.content)) + else: + # image b64_json + b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64) + img_data = BytesIO(base64.b64decode(b64_data)) + img = Image.open(img_data) + return img diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py index c8b3bc4a4..d3add1171 100644 --- a/metagpt/utils/dependency_file.py +++ b/metagpt/utils/dependency_file.py @@ -60,23 +60,22 @@ class DependencyFile: root = self._filename.parent try: - key = Path(filename).relative_to(root) + key = Path(filename).relative_to(root).as_posix() except ValueError: key = filename - skey = re.sub(r"\\+", "/", str(key)) # Compatible with windows path + key = str(key) if dependencies: relative_paths = [] for i in dependencies: try: - s = str(Path(i).relative_to(root)) + s = str(Path(i).relative_to(root).as_posix()) except ValueError: s = str(i) - s = re.sub(r"\\+", "/", s) # Compatible with windows path relative_paths.append(s) - self._dependencies[skey] = relative_paths - elif skey in self._dependencies: - del self._dependencies[skey] + self._dependencies[key] = relative_paths + elif key in self._dependencies: + del self._dependencies[key] if persist: await self.save() @@ -93,7 +92,7 @@ class DependencyFile: root = self._filename.parent try: - key = Path(filename).relative_to(root) + key = Path(filename).relative_to(root).as_posix() except ValueError: key = filename return set(self._dependencies.get(str(key), {})) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 94506e373..a0fb3b70d 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -29,6 +29,7 @@ TOKEN_COSTS = { "gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens @@ -54,6 +55,7 @@ TOKEN_MAX = { "gpt-4-turbo-preview": 128000, "gpt-4-0125-preview": 128000, "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, @@ -82,6 +84,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-turbo-preview", "gpt-4-0125-preview", "gpt-4-1106-preview", + "gpt-4-vision-preview", "gpt-4-1106-vision-preview", }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> @@ -112,7 +115,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): for message in messages: num_tokens += tokens_per_message for key, value in message.items(): - num_tokens += len(encoding.encode(value)) + content = value + if isinstance(value, list): + # for gpt-4v + for item in value: + if isinstance(item, dict) and item.get("type") in ["text"]: + content = item.get("text", "") + num_tokens += len(encoding.encode(content)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> diff --git a/requirements.txt b/requirements.txt index dff615bdc..6cb25d52b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,4 +65,5 @@ networkx~=3.2.1 google-generativeai==0.3.2 # playwright==1.40.0 # playwright extras require anytree -ipywidgets==8.1.1 \ No newline at end of file +ipywidgets==8.1.1 +Pillow diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 1ec9f4f8d..589282879 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : test_action_node.py """ +from pathlib import Path from typing import List, Tuple import pytest @@ -17,6 +18,7 @@ from metagpt.llm import LLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.team import Team +from metagpt.utils.common import encode_image @pytest.mark.asyncio @@ -241,6 +243,18 @@ def test_create_model_class_with_mapping(): assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] +@pytest.mark.asyncio +async def test_action_node_with_image(): + invoice = ActionNode( + key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False" + ) + + invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) + assert node.instruct_content.invoice + + class ToolDef(BaseModel): tool_name: str = Field(default="a", description="tool name", examples=[]) description: str = Field(default="b", description="tool description", examples=[]) diff --git a/tests/metagpt/actions/test_research.py b/tests/metagpt/actions/test_research.py index 372a1e876..ed83ce58c 100644 --- a/tests/metagpt/actions/test_research.py +++ b/tests/metagpt/actions/test_research.py @@ -28,9 +28,9 @@ async def test_collect_links(mocker, search_engine_mocker, context): return "[1,2]" mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask) - resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), context=context).run( - "The application of MetaGPT" - ) + resp = await research.CollectLinks( + search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), context=context + ).run("The application of MetaGPT") for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]: assert i in resp @@ -50,7 +50,9 @@ async def test_collect_links_with_rank_func(mocker, search_engine_mocker, contex mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask) resp = await research.CollectLinks( - search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func, context=context + search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), + rank_func=rank_func, + context=context, ).run("The application of MetaGPT") for x, y, z in zip(rank_before, rank_after, resp.values()): assert x[::-1] == y diff --git a/tests/metagpt/environment/android_env/__init__.py b/tests/metagpt/environment/android_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/android_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/environment/android_env/test_android_ext_env.py b/tests/metagpt/environment/android_env/test_android_ext_env.py new file mode 100644 index 000000000..c9dfc718b --- /dev/null +++ b/tests/metagpt/environment/android_env/test_android_ext_env.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of AndroidExtEnv + +from pathlib import Path + +from metagpt.environment.android_env.android_ext_env import AndroidExtEnv +from metagpt.environment.android_env.const import ADB_EXEC_FAIL + + +def mock_device_shape(self, adb_cmd: str) -> str: + return "shape: 720x1080" + + +def mock_device_shape_invalid(self, adb_cmd: str) -> str: + return ADB_EXEC_FAIL + + +def mock_list_devices(self, adb_cmd: str) -> str: + return "devices\nemulator-5554" + + +def mock_get_screenshot(self, adb_cmd: str) -> str: + return "screenshot_xxxx-xx-xx" + + +def mock_get_xml(self, adb_cmd: str) -> str: + return "xml_xxxx-xx-xx" + + +def mock_write_read_operation(self, adb_cmd: str) -> str: + return "OK" + + +def test_android_ext_env(mocker): + device_id = "emulator-5554" + mocker.patch( + "metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape + ) + + ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/") + assert ext_env.adb_prefix == f"adb -s {device_id} " + assert ext_env.adb_prefix_shell == f"adb -s {device_id} shell " + assert ext_env.adb_prefix_si == f"adb -s {device_id} shell input " + + assert ext_env.device_shape == (720, 1080) + + mocker.patch( + "metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid + ) + assert ext_env.device_shape == (0, 0) + + mocker.patch( + "metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices + ) + assert ext_env.list_devices() == [device_id] + + mocker.patch( + "metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot + ) + assert ext_env.get_screenshot("screenshot_xxxx-xx-xx", "/data/") == Path("/data/screenshot_xxxx-xx-xx.png") + + mocker.patch("metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml) + assert ext_env.get_xml("xml_xxxx-xx-xx", "/data/") == Path("/data/xml_xxxx-xx-xx.xml") + + mocker.patch( + "metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation + ) + res = "OK" + assert ext_env.system_back() == res + assert ext_env.system_tap(10, 10) == res + assert ext_env.user_input("test_input") == res + assert ext_env.user_longpress(10, 10) == res + assert ext_env.user_swipe(10, 10) == res + assert ext_env.user_swipe_to((10, 10), (20, 20)) == res diff --git a/tests/metagpt/environment/api/__init__.py b/tests/metagpt/environment/api/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/api/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/environment/api/test_env_api.py b/tests/metagpt/environment/api/test_env_api.py new file mode 100644 index 000000000..53f98c0d3 --- /dev/null +++ b/tests/metagpt/environment/api/test_env_api.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.environment.api.env_api import EnvAPIRegistry + + +def test_env_api_registry(): + def test_func(): + pass + + env_api_registry = EnvAPIRegistry() + env_api_registry["test"] = test_func + + env_api_registry.get("test") == test_func diff --git a/tests/metagpt/environment/mincraft_env/__init__.py b/tests/metagpt/environment/mincraft_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/mincraft_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py b/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py new file mode 100644 index 000000000..ad3376141 --- /dev/null +++ b/tests/metagpt/environment/mincraft_env/test_mincraft_ext_env.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of MincraftExtEnv + + +from metagpt.environment.mincraft_env.const import MC_CKPT_DIR +from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv + + +def test_mincraft_ext_env(): + ext_env = MincraftExtEnv() + assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}" + assert MC_CKPT_DIR.joinpath("skill/code").exists() + assert ext_env.warm_up.get("optional_inventory_items") == 7 diff --git a/tests/metagpt/environment/stanford_town_env/__init__.py b/tests/metagpt/environment/stanford_town_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/stanford_town_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py b/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py new file mode 100644 index 000000000..3071f9deb --- /dev/null +++ b/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of StanfordTownExtEnv + +from pathlib import Path + +from metagpt.environment.stanford_town_env.stanford_town_ext_env import ( + StanfordTownExtEnv, +) + +maze_asset_path = ( + Path(__file__).absolute().parent.joinpath("..", "..", "..", "data", "environment", "stanford_town", "the_ville") +) + + +def test_stanford_town_ext_env(): + ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path) + + tile_coord = ext_env.turn_coordinate_to_tile((64, 64)) + assert tile_coord == (2, 2) + + tile = (58, 9) + assert len(ext_env.get_collision_maze()) == 100 + assert len(ext_env.get_address_tiles()) == 306 + assert ext_env.access_tile(tile=tile)["world"] == "the Ville" + assert ext_env.get_tile_path(tile=tile, level="world") == "the Ville" + assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121 + + event = ("double studio:double studio:bedroom 2:bed", None, None, None) + ext_env.add_tiles_event(tile[1], tile[0], event=event) + ext_env.add_event_from_tile(event, tile) + assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1 + + ext_env.turn_event_from_tile_idle(event, tile) + + ext_env.remove_event_from_tile(event, tile) + assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0 + + ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile) + assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0 diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py new file mode 100644 index 000000000..fd73679d8 --- /dev/null +++ b/tests/metagpt/environment/test_base_env.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of ExtEnv&Env + +import pytest + +from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.environment.base_env import ( + Environment, + env_read_api_registry, + env_write_api_registry, + mark_as_readable, + mark_as_writeable, +) + + +class ForTestEnv(Environment): + value: int = 0 + + @mark_as_readable + def read_api_no_param(self): + return self.value + + @mark_as_readable + def read_api(self, a: int, b: int): + return a + b + + @mark_as_writeable + def write_api(self, a: int, b: int): + self.value = a + b + + @mark_as_writeable + async def async_read_api(self, a: int, b: int): + return a + b + + +@pytest.mark.asyncio +async def test_ext_env(): + env = ForTestEnv() + assert len(env_read_api_registry) > 0 + assert len(env_write_api_registry) > 0 + + apis = env.get_all_available_apis(mode="read") + assert len(apis) > 0 + assert len(apis["read_api"]) == 3 + + _ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10})) + assert env.value == 15 + + with pytest.raises(ValueError): + await env.observe("not_exist_api") + + assert await env.observe("read_api_no_param") == 15 + assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10 diff --git a/tests/metagpt/environment/werewolf_env/__init__.py b/tests/metagpt/environment/werewolf_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/werewolf_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py b/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py new file mode 100644 index 000000000..0694c5c3d --- /dev/null +++ b/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of WerewolfExtEnv + +from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv +from metagpt.roles.role import Role + + +class Werewolf(Role): + profile: str = "Werewolf" + + +class Villager(Role): + profile: str = "Villager" + + +class Witch(Role): + profile: str = "Witch" + + +class Guard(Role): + profile: str = "Guard" + + +def test_werewolf_ext_env(): + players_state = { + "Player0": ("Werewolf", RoleState.ALIVE), + "Player1": ("Werewolf", RoleState.ALIVE), + "Player2": ("Villager", RoleState.ALIVE), + "Player3": ("Witch", RoleState.ALIVE), + "Player4": ("Guard", RoleState.ALIVE), + } + ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"]) + + assert len(ext_env.living_players) == 5 + assert len(ext_env.special_role_players) == 2 + assert len(ext_env.werewolf_players) == 2 + + curr_instr = ext_env.curr_step_instruction() + assert ext_env.step_idx == 5 + assert "Werewolves, please open your eyes" in curr_instr["content"] + + # current step_idx = 5 + ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4") + ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4") + ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4") + assert ext_env.player_hunted == "Player4" + assert len(ext_env.living_players) == 5 # hunted but can be saved by witch + + for idx in range(13): + _ = ext_env.curr_step_instruction() + + # current step_idx = 18 + assert ext_env.step_idx == 18 + ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2") + ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3") + ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3") + ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4") + ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2") + assert ext_env.player_current_dead == "Player2" + assert len(ext_env.living_players) == 4 + + player_names = ["Player0", "Player2"] + assert ext_env.get_players_state(player_names) == dict(zip(player_names, [RoleState.ALIVE, RoleState.KILLED])) diff --git a/tests/metagpt/learn/test_google_search.py b/tests/metagpt/learn/test_google_search.py index 7fda6436a..70a146878 100644 --- a/tests/metagpt/learn/test_google_search.py +++ b/tests/metagpt/learn/test_google_search.py @@ -16,6 +16,6 @@ async def test_google_search(search_engine_mocker): result = await google_search( seed.input, engine=SearchEngineType.SERPER_GOOGLE, - serper_api_key="mock-serper-key", + api_key="mock-serper-key", ) assert result != "" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index a49d7e85b..5895c9eb8 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -6,6 +6,7 @@ from openai.types.chat import ( ) from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call import Function +from PIL import Image from metagpt.const import TEST_DATA_PATH from metagpt.llm import LLM @@ -105,3 +106,15 @@ class TestOpenAI: code["language"] == "markdown" else: code["language"] == "python" + + +@pytest.mark.asyncio +async def test_gen_image(): + llm = LLM() + model = "dall-e-3" + prompt = 'a logo with word "MetaGPT"' + images: list[Image] = await llm.gen_image(model=model, prompt=prompt) + assert images[0].size == (1024, 1024) + + images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json") + assert images[0].size == (1024, 1024) diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py index af81777ac..ba05e1296 100644 --- a/tests/metagpt/roles/test_researcher.py +++ b/tests/metagpt/roles/test_researcher.py @@ -36,7 +36,7 @@ async def test_researcher(mocker, search_engine_mocker, context): role = researcher.Researcher(context=context) for i in role.actions: if isinstance(i, CollectLinks): - i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO) + i.search_engine = SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO) await role.run(topic) assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report") diff --git a/tests/metagpt/test_incremental_dev.py b/tests/metagpt/test_incremental_dev.py index 964d4c757..c47397dd7 100644 --- a/tests/metagpt/test_incremental_dev.py +++ b/tests/metagpt/test_incremental_dev.py @@ -6,6 +6,7 @@ @File : test_incremental_dev.py """ import os +import shutil import subprocess import time @@ -45,6 +46,7 @@ PROJECT_NAMES = [ ] +@pytest.mark.skip def test_simple_add_calculator(): result = get_incremental_dev_result(IDEAS[0], PROJECT_NAMES[0]) log_and_check_result(result) @@ -115,10 +117,14 @@ def get_incremental_dev_result(idea, project_name, use_review=True): if not project_path.exists(): # If the project does not exist, extract the project file try: - # Use the tar command to extract the .zip file - subprocess.run(["tar", "-xf", f"{project_path}.zip", "-C", str(project_path.parent)], check=True) + if shutil.which("unzip"): + subprocess.run(["unzip", f"{project_path}.zip", "-d", str(project_path.parent)], check=True) + elif shutil.which("tar"): + subprocess.run(["tar", "-xf", f"{project_path}.zip", "-C", str(project_path.parent)], check=True) + logger.info(f"Extracted project {project_name} successfully.") + except FileNotFoundError as e: + raise FileNotFoundError(f"Neither 'unzip' nor 'tar' command found. Error: {e}") except subprocess.CalledProcessError as e: - # If the extraction fails, throw an exception raise Exception(f"Failed to extract project {project_name}. Error: {e}") check_or_create_base_tag(project_path) diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 966f53a38..a1f03ef7b 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -12,6 +12,7 @@ from typing import Callable import pytest from metagpt.config2 import config +from metagpt.configs.search_config import SearchConfig from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -49,27 +50,34 @@ async def test_search_engine( search_engine_mocker, ): # Prerequisites - search_engine_config = {} + search_engine_config = {"engine": search_engine_type, "run_func": run_func} if search_engine_type is SearchEngineType.SERPAPI_GOOGLE: assert config.search - search_engine_config["serpapi_api_key"] = "mock-serpapi-key" + search_engine_config["api_key"] = "mock-serpapi-key" elif search_engine_type is SearchEngineType.DIRECT_GOOGLE: assert config.search - search_engine_config["google_api_key"] = "mock-google-key" - search_engine_config["google_cse_id"] = "mock-google-cse" + search_engine_config["api_key"] = "mock-google-key" + search_engine_config["cse_id"] = "mock-google-cse" elif search_engine_type is SearchEngineType.SERPER_GOOGLE: assert config.search - search_engine_config["serper_api_key"] = "mock-serper-key" + search_engine_config["api_key"] = "mock-serper-key" - search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config) - rsp = await search_engine.run("metagpt", max_results, as_string) - logger.info(rsp) - if as_string: - assert isinstance(rsp, str) - else: - assert isinstance(rsp, list) - assert len(rsp) <= max_results + async def test(search_engine): + rsp = await search_engine.run("metagpt", max_results, as_string) + logger.info(rsp) + if as_string: + assert isinstance(rsp, str) + else: + assert isinstance(rsp, list) + assert len(rsp) <= max_results + + await test(SearchEngine(**search_engine_config)) + search_engine_config["api_type"] = search_engine_config.pop("engine") + if run_func: + await test(SearchEngine.from_search_func(run_func)) + search_engine_config["search_func"] = search_engine_config.pop("run_func") + await test(SearchEngine.from_search_config(SearchConfig(**search_engine_config))) if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index 0e838a2f8..f35848cf4 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -3,7 +3,6 @@ import pytest -from metagpt.config2 import config from metagpt.tools import web_browser_engine_playwright from metagpt.utils.parse_html import WebPage @@ -19,26 +18,22 @@ from metagpt.utils.parse_html import WebPage ids=["chromium-normal", "firefox-normal", "webkit-normal"], ) async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd): - global_proxy = config.proxy - try: - if use_proxy: - server, proxy_url = await proxy() - config.proxy = proxy_url - browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs) - result = await browser.run(url) - assert isinstance(result, WebPage) - assert "MetaGPT" in result.inner_text + proxy_url = None + if use_proxy: + server, proxy_url = await proxy() + browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs) + result = await browser.run(url) + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text - if urls: - results = await browser.run(url, *urls) - assert isinstance(results, list) - assert len(results) == len(urls) + 1 - assert all(("MetaGPT" in i.inner_text) for i in results) - if use_proxy: - server.close() - assert "Proxy:" in capfd.readouterr().out - finally: - config.proxy = global_proxy + if urls: + results = await browser.run(url, *urls) + assert isinstance(results, list) + assert len(results) == len(urls) + 1 + assert all(("MetaGPT" in i.inner_text) for i in results) + if use_proxy: + server.close() + assert "Proxy:" in capfd.readouterr().out if __name__ == "__main__": diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index 1b1439d29..a88a5d0f4 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -4,7 +4,6 @@ import browsers import pytest -from metagpt.config2 import config from metagpt.tools import web_browser_engine_selenium from metagpt.utils.parse_html import WebPage @@ -40,27 +39,22 @@ from metagpt.utils.parse_html import WebPage async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd): # Prerequisites # firefox, chrome, Microsoft Edge + proxy_url = None + if use_proxy: + server, proxy_url = await proxy() + browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url) + result = await browser.run(url) + assert isinstance(result, WebPage) + assert "MetaGPT" in result.inner_text - global_proxy = config.proxy - try: - if use_proxy: - server, proxy_url = await proxy() - config.proxy = proxy_url - browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type) - result = await browser.run(url) - assert isinstance(result, WebPage) - assert "MetaGPT" in result.inner_text - - if urls: - results = await browser.run(url, *urls) - assert isinstance(results, list) - assert len(results) == len(urls) + 1 - assert all(("MetaGPT" in i.inner_text) for i in results) - if use_proxy: - server.close() - assert "Proxy:" in capfd.readouterr().out - finally: - config.proxy = global_proxy + if urls: + results = await browser.run(url, *urls) + assert isinstance(results, list) + assert len(results) == len(urls) + 1 + assert all(("MetaGPT" in i.inner_text) for i in results) + if use_proxy: + server.close() + assert "Proxy:" in capfd.readouterr().out if __name__ == "__main__": diff --git a/tests/mock/mock_aiohttp.py b/tests/mock/mock_aiohttp.py index 49dcdba79..a7c022a4b 100644 --- a/tests/mock/mock_aiohttp.py +++ b/tests/mock/mock_aiohttp.py @@ -13,7 +13,8 @@ class MockAioResponse: def __init__(self, session, method, url, **kwargs) -> None: fn = self.check_funcs.get((method, url)) - self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}" + _kwargs = {k: v for k, v in kwargs.items() if k != "proxy"} + self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(_kwargs, sort_keys=True)}" self.mng = self.response = None if self.key not in self.rsp_cache: self.mng = origin_request(session, method, url, **kwargs) diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index 8ee580b8a..50e75dabf 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -44,6 +44,7 @@ class MockLLM(OriginalLLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ): @@ -56,7 +57,7 @@ class MockLLM(OriginalLLM): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp @@ -83,6 +84,7 @@ class MockLLM(OriginalLLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -90,7 +92,7 @@ class MockLLM(OriginalLLM): if system_msgs: joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#" msg_key = joined_system_msg + msg_key - rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream) + rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp async def aask_batch(self, msgs: list, timeout=3) -> str: