mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 03:46:23 +02:00
feat: merge main
This commit is contained in:
commit
8a92fa0f36
404 changed files with 20076 additions and 1163 deletions
|
|
@ -8,12 +8,14 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
from metagpt.context_mixin import ContextMixin
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.schema import (
|
||||
CodePlanAndChangeContext,
|
||||
CodeSummarizeContext,
|
||||
|
|
@ -35,6 +37,19 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
# The model name or API type of LLM of the `models` in the `config2.yaml`;
|
||||
# Using `None` to use the `llm` configuration in the `config2.yaml`.
|
||||
llm_name_or_type: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
@classmethod
|
||||
def _update_private_llm(cls, data: Any) -> Any:
|
||||
config = ModelsConfig.default().get(data.llm_name_or_type)
|
||||
if config:
|
||||
llm = create_llm_instance(config)
|
||||
llm.cost_manager = data.llm.cost_manager
|
||||
data.llm = llm
|
||||
return data
|
||||
|
||||
@property
|
||||
def repo(self) -> ProjectRepo:
|
||||
|
|
|
|||
|
|
@ -457,7 +457,6 @@ class ActionNode:
|
|||
self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None
|
||||
):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode, exclude=exclude)
|
||||
class_name = f"{self.key}_AN"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : design_api_an.py
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
|
@ -45,9 +45,10 @@ REFINED_FILE_LIST = ActionNode(
|
|||
example=["main.py", "game.py", "new_feature.py"],
|
||||
)
|
||||
|
||||
# optional,because low success reproduction of class diagram in non py project.
|
||||
DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
||||
key="Data structures and interfaces",
|
||||
expected_type=str,
|
||||
expected_type=Optional[str],
|
||||
instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type"
|
||||
" annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. "
|
||||
"The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.",
|
||||
|
|
@ -66,7 +67,7 @@ REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode(
|
|||
|
||||
PROGRAM_CALL_FLOW = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
expected_type=Optional[str],
|
||||
instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE "
|
||||
"accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.",
|
||||
example=MMC2,
|
||||
|
|
|
|||
|
|
@ -486,7 +486,7 @@ class RebuildSequenceView(Action):
|
|||
Returns:
|
||||
List[str]: A list of participants extracted from the sequence diagram.
|
||||
"""
|
||||
pattern = r"participant ([a-zA-Z\.0-9_]+)"
|
||||
pattern = r"participant ([\w\.]+)"
|
||||
matches = re.findall(pattern, mermaid_sequence_diagram)
|
||||
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
|
||||
return matches
|
||||
|
|
|
|||
|
|
@ -161,6 +161,8 @@ class CollectLinks(Action):
|
|||
"""
|
||||
max_results = max(num_results * 2, 6)
|
||||
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
|
||||
if len(results) == 0:
|
||||
return []
|
||||
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
|
||||
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
|
||||
logger.debug(prompt)
|
||||
|
|
|
|||
|
|
@ -128,6 +128,9 @@ CODE_PLAN_AND_CHANGE_CONTEXT = """
|
|||
## User New Requirements
|
||||
{requirement}
|
||||
|
||||
## Issue
|
||||
{issue}
|
||||
|
||||
## PRD
|
||||
{prd}
|
||||
|
||||
|
|
@ -211,7 +214,8 @@ class WriteCodePlanAndChange(Action):
|
|||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=self.i_context.requirement,
|
||||
requirement=f"```text\n{self.i_context.requirement}\n```",
|
||||
issue=f"```text\n{self.i_context.issue}\n```",
|
||||
prd=prd_doc.content,
|
||||
design=design_doc.content,
|
||||
task=task_doc.content,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
# Context
|
||||
{context}
|
||||
|
||||
-----
|
||||
|
||||
## Code to be Reviewed: {filename}
|
||||
```Code
|
||||
{code}
|
||||
|
|
@ -38,7 +40,8 @@ EXAMPLE_AND_INSTRUCTION = """
|
|||
{format_example}
|
||||
|
||||
|
||||
# Instruction: Based on the actual code situation, follow one of the "Format example". Return only 1 file under review.
|
||||
# Instruction: Based on the actual code, follow one of the "Code Review Format example".
|
||||
- Note the code filename should be `{filename}`. Return the only ONE file `{filename}` under review.
|
||||
|
||||
## Code Review: Ordered List. Based on the "Code to be Reviewed", provide key, clear, concise, and specific answer. If any answer is no, explain how to fix it step by step.
|
||||
1. Is the code implemented as per the requirements? If not, how to achieve it? Analyse it step by step.
|
||||
|
|
@ -56,7 +59,9 @@ LGTM/LBTM
|
|||
"""
|
||||
|
||||
FORMAT_EXAMPLE = """
|
||||
# Format example 1
|
||||
-----
|
||||
|
||||
# Code Review Format example 1
|
||||
## Code Review: {filename}
|
||||
1. No, we should fix the logic of class A due to ...
|
||||
2. ...
|
||||
|
|
@ -92,7 +97,9 @@ FORMAT_EXAMPLE = """
|
|||
## Code Review Result
|
||||
LBTM
|
||||
|
||||
# Format example 2
|
||||
-----
|
||||
|
||||
# Code Review Format example 2
|
||||
## Code Review: {filename}
|
||||
1. Yes.
|
||||
2. Yes.
|
||||
|
|
@ -106,10 +113,12 @@ pass
|
|||
|
||||
## Code Review Result
|
||||
LGTM
|
||||
|
||||
-----
|
||||
"""
|
||||
|
||||
REWRITE_CODE_TEMPLATE = """
|
||||
# Instruction: rewrite code based on the Code Review and Actions
|
||||
# Instruction: rewrite the `{filename}` based on the Code Review and Actions
|
||||
## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} with triple quotes. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes.
|
||||
```Code
|
||||
## {filename}
|
||||
|
|
@ -169,6 +178,7 @@ class WriteCodeReview(Action):
|
|||
)
|
||||
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
|
||||
format_example=format_example,
|
||||
filename=self.i_context.code_doc.filename,
|
||||
)
|
||||
len1 = len(iterative_code) if iterative_code else 0
|
||||
len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0
|
||||
|
|
|
|||
|
|
@ -133,10 +133,10 @@ REQUIREMENT_ANALYSIS = ActionNode(
|
|||
REFINED_REQUIREMENT_ANALYSIS = ActionNode(
|
||||
key="Refined Requirement Analysis",
|
||||
expected_type=List[str],
|
||||
instruction="Review and refine the existing requirement analysis to align with the evolving needs of the project "
|
||||
instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project "
|
||||
"due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements "
|
||||
"required for the refined project scope.",
|
||||
example=["Require add/update/modify ..."],
|
||||
example=["Require add ...", "Require modify ..."],
|
||||
)
|
||||
|
||||
REQUIREMENT_POOL = ActionNode(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.file_parser_config import OmniParseConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -47,6 +49,12 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# omniparse
|
||||
omniparse: OmniParseConfig = OmniParseConfig()
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
|
|
@ -65,6 +73,7 @@ class Config(CLIParams, YamlModel):
|
|||
workspace: WorkspaceConfig = WorkspaceConfig()
|
||||
enable_longterm_memory: bool = False
|
||||
code_review_k_times: int = 2
|
||||
agentops_api_key: str = ""
|
||||
|
||||
# Will be removed in the future
|
||||
metagpt_tti_url: str = ""
|
||||
|
|
@ -75,6 +84,7 @@ class Config(CLIParams, YamlModel):
|
|||
iflytek_api_key: str = ""
|
||||
azure_tts_subscription_key: str = ""
|
||||
azure_tts_region: str = ""
|
||||
_extra: dict = dict() # extra config dict
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
|
|
@ -127,6 +137,14 @@ class Config(CLIParams, YamlModel):
|
|||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
@property
|
||||
def extra(self):
|
||||
return self._extra
|
||||
|
||||
@extra.setter
|
||||
def extra(self, value: dict):
|
||||
self._extra = value
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
if self.llm.api_type == LLMType.OPENAI:
|
||||
|
|
|
|||
54
metagpt/configs/embedding_config.py
Normal file
54
metagpt/configs/embedding_config.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class EmbeddingType(Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class EmbeddingConfig(YamlModel):
|
||||
"""Config for Embedding.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
dimensions: "YOUR_MODEL_DIMENSIONS"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
dimensions: Optional[int] = None # output dimension of embedding model
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
def check_api_type(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
6
metagpt/configs/file_parser_config.py
Normal file
6
metagpt/configs/file_parser_config.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class OmniParseConfig(YamlModel):
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
|
|
@ -10,7 +10,7 @@ from typing import Optional
|
|||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.const import CONFIG_ROOT, LLM_API_TIMEOUT, METAGPT_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
|
|
@ -31,6 +31,9 @@ class LLMType(Enum):
|
|||
MOONSHOT = "moonshot"
|
||||
MISTRAL = "mistral"
|
||||
YI = "yi" # lingyiwanwu
|
||||
OPENROUTER = "openrouter"
|
||||
BEDROCK = "bedrock"
|
||||
ARK = "ark"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -72,11 +75,15 @@ class LLMConfig(YamlModel):
|
|||
frequency_penalty: float = 0.0
|
||||
best_of: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
stream: bool = False
|
||||
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
|
||||
stream: bool = True
|
||||
# https://cookbook.openai.com/examples/using_logprobs
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
timeout: int = 600
|
||||
|
||||
# For Amazon Bedrock
|
||||
region_name: str = None
|
||||
|
||||
# For Network
|
||||
proxy: Optional[str] = None
|
||||
|
||||
|
|
@ -87,7 +94,16 @@ class LLMConfig(YamlModel):
|
|||
@classmethod
|
||||
def check_llm_key(cls, v):
|
||||
if v in ["", None, "YOUR_API_KEY"]:
|
||||
raise ValueError("Please set your API key in config2.yaml")
|
||||
repo_config_path = METAGPT_ROOT / "config/config2.yaml"
|
||||
root_config_path = CONFIG_ROOT / "config2.yaml"
|
||||
if root_config_path.exists():
|
||||
raise ValueError(
|
||||
f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \nthe former will overwrite the latter. This may cause unexpected result.\n"
|
||||
)
|
||||
elif repo_config_path.exists():
|
||||
raise ValueError(f"Please set your API key in {repo_config_path}")
|
||||
else:
|
||||
raise ValueError("Please set your API key in config2.yaml")
|
||||
return v
|
||||
|
||||
@field_validator("timeout")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class MermaidConfig(YamlModel):
|
||||
"""Config for Mermaid"""
|
||||
|
||||
engine: Literal["nodejs", "ink", "playwright", "pyppeteer"] = "nodejs"
|
||||
engine: Literal["nodejs", "ink", "playwright", "pyppeteer", "none"] = "nodejs"
|
||||
path: str = "mmdc" # mmdc
|
||||
puppeteer_config: str = ""
|
||||
pyppeteer_path: str = "/usr/bin/google-chrome-stable"
|
||||
|
|
|
|||
112
metagpt/configs/models_config.py
Normal file
112
metagpt/configs/models_config.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
models_config.py
|
||||
|
||||
This module defines the ModelsConfig class for handling configuration of LLM models.
|
||||
|
||||
Attributes:
|
||||
CONFIG_ROOT (Path): Root path for configuration files.
|
||||
METAGPT_ROOT (Path): Root path for MetaGPT files.
|
||||
|
||||
Classes:
|
||||
ModelsConfig (YamlModel): Configuration class for LLM models.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from metagpt.config2 import merge_dict
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class ModelsConfig(YamlModel):
|
||||
"""
|
||||
Configuration class for `models` in `config2.yaml`.
|
||||
|
||||
Attributes:
|
||||
models (Dict[str, LLMConfig]): Dictionary mapping model names or types to LLMConfig objects.
|
||||
|
||||
Methods:
|
||||
update_llm_model(cls, value): Validates and updates LLM model configurations.
|
||||
from_home(cls, path): Loads configuration from ~/.metagpt/config2.yaml.
|
||||
default(cls): Loads default configuration from predefined paths.
|
||||
get(self, name_or_type: str) -> Optional[LLMConfig]: Retrieves LLMConfig by name or API type.
|
||||
"""
|
||||
|
||||
models: Dict[str, LLMConfig] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("models", mode="before")
|
||||
@classmethod
|
||||
def update_llm_model(cls, value):
|
||||
"""
|
||||
Validates and updates LLM model configurations.
|
||||
|
||||
Args:
|
||||
value (Dict[str, Union[LLMConfig, dict]]): Dictionary of LLM configurations.
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[LLMConfig, dict]]: Updated dictionary of LLM configurations.
|
||||
"""
|
||||
for key, config in value.items():
|
||||
if isinstance(config, LLMConfig):
|
||||
config.model = config.model or key
|
||||
elif isinstance(config, dict):
|
||||
config["model"] = config.get("model") or key
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
"""
|
||||
Loads configuration from ~/.metagpt/config2.yaml.
|
||||
|
||||
Args:
|
||||
path (str): Relative path to configuration file.
|
||||
|
||||
Returns:
|
||||
Optional[ModelsConfig]: Loaded ModelsConfig object or None if file doesn't exist.
|
||||
"""
|
||||
pathname = CONFIG_ROOT / path
|
||||
if not pathname.exists():
|
||||
return None
|
||||
return ModelsConfig.from_yaml_file(pathname)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
"""
|
||||
Loads default configuration from predefined paths.
|
||||
|
||||
Returns:
|
||||
ModelsConfig: Default ModelsConfig object.
|
||||
"""
|
||||
default_config_paths: List[Path] = [
|
||||
METAGPT_ROOT / "config/config2.yaml",
|
||||
CONFIG_ROOT / "config2.yaml",
|
||||
]
|
||||
|
||||
dicts = [ModelsConfig.read_yaml(path) for path in default_config_paths]
|
||||
final = merge_dict(dicts)
|
||||
return ModelsConfig(**final)
|
||||
|
||||
def get(self, name_or_type: str) -> Optional[LLMConfig]:
|
||||
"""
|
||||
Retrieves LLMConfig object by name or API type.
|
||||
|
||||
Args:
|
||||
name_or_type (str): Name or API type of the LLM model.
|
||||
|
||||
Returns:
|
||||
Optional[LLMConfig]: LLMConfig object if found, otherwise None.
|
||||
"""
|
||||
if not name_or_type:
|
||||
return None
|
||||
model = self.models.get(name_or_type)
|
||||
if model:
|
||||
return model
|
||||
for m in self.models.values():
|
||||
if m.api_type == name_or_type:
|
||||
return m
|
||||
return None
|
||||
|
|
@ -1,14 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/1 11:59
|
||||
@Author : alexanderwu
|
||||
@File : const.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for
|
||||
common properties in the Message.
|
||||
@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135.
|
||||
@Modified By: mashenquan, 2023/12/5. Add directories for code summarization..
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -51,6 +43,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
|||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
|
@ -78,12 +78,6 @@ class Context(BaseModel):
|
|||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Use a LLM instance"""
|
||||
# self._llm_config = self.config.get_llm_config(name, provider)
|
||||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
|
|
@ -108,3 +102,38 @@ class Context(BaseModel):
|
|||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
"""Serialize the object's attributes into a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing serialized data.
|
||||
"""
|
||||
return {
|
||||
"workdir": str(self.repo.workdir) if self.repo else "",
|
||||
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
||||
"cost_manager": self.cost_manager.model_dump_json(),
|
||||
}
|
||||
|
||||
def deserialize(self, serialized_data: Dict[str, Any]):
|
||||
"""Deserialize the given serialized data and update the object's attributes accordingly.
|
||||
|
||||
Args:
|
||||
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
|
||||
"""
|
||||
if not serialized_data:
|
||||
return
|
||||
workdir = serialized_data.get("workdir")
|
||||
if workdir:
|
||||
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
if src_workspace.exists():
|
||||
self.src_workspace = src_workspace
|
||||
kwargs = serialized_data.get("kwargs")
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
self.kwargs.set(k, v)
|
||||
cost_manager = serialized_data.get("cost_manager")
|
||||
if cost_manager:
|
||||
self.cost_manager.model_validate_json(cost_manager)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ ## Usage
|
|||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
|
||||
# get screenshot from ExtEnv
|
||||
screenshot_path: Path = env.observe(
|
||||
screenshot_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@
|
|||
# @Desc :
|
||||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.android_env.android_env import AndroidEnv
|
||||
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
|
||||
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
|
||||
from metagpt.environment.software_env.software_env import SoftwareEnv
|
||||
|
||||
# from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.environment.werewolf.werewolf_env import WerewolfEnv
|
||||
from metagpt.environment.stanford_town.stanford_town_env import StanfordTownEnv
|
||||
from metagpt.environment.software.software_env import SoftwareEnv
|
||||
|
||||
|
||||
__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.base_env import Environment
|
||||
|
||||
|
||||
class AndroidEnv(Environment, AndroidExtEnv):
|
||||
class AndroidEnv(AndroidExtEnv, Environment):
|
||||
"""in order to use actual `reset`&`observe`, inherited order: AndroidExtEnv, Environment"""
|
||||
|
||||
rows: int = Field(default=0, description="rows of a grid on the screenshot")
|
||||
cols: int = Field(default=0, description="cols of a grid on the screenshot")
|
||||
375
metagpt/environment/android/android_ext_env.py
Normal file
375
metagpt/environment/android/android_ext_env.py
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : The Android external environment to integrate with Android apps
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import clip
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
EnvObsValType,
|
||||
)
|
||||
from metagpt.environment.android.text_icon_localization import (
|
||||
clip_for_icon,
|
||||
crop_for_clip,
|
||||
det,
|
||||
load_model,
|
||||
ocr,
|
||||
)
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import download_model
|
||||
|
||||
|
||||
def load_cv_model(device: str = "cpu") -> any:
|
||||
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
|
||||
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
|
||||
file_path = download_model(file_url, target_folder)
|
||||
groundingdino_model = load_model(file_path, device=device).eval()
|
||||
return ocr_detection, ocr_recognition, groundingdino_model
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
width: int = Field(default=720, description="device screen width")
|
||||
height: int = Field(default=1080, description="device screen height")
|
||||
ocr_detection: any = Field(default=None, description="ocr detection model")
|
||||
ocr_recognition: any = Field(default=None, description="ocr recognition model")
|
||||
groundingdino_model: any = Field(default=None, description="clip groundingdino model")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
device_id = data.get("device_id")
|
||||
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
|
||||
if device_id:
|
||||
devices = self.list_devices()
|
||||
if device_id not in devices:
|
||||
raise RuntimeError(f"device-id: {device_id} not found")
|
||||
(width, height) = self.device_shape
|
||||
self.width = data.get("width", width)
|
||||
self.height = data.get("height", height)
|
||||
self.create_device_path(self.screenshot_dir)
|
||||
self.create_device_path(self.xml_dir)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
return obs, {}
|
||||
|
||||
def _get_obs(self) -> dict[str, EnvObsValType]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
|
||||
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
|
||||
if obs_type == EnvObsType.NONE:
|
||||
pass
|
||||
elif obs_type == EnvObsType.GET_SCREENSHOT:
|
||||
obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir)
|
||||
elif obs_type == EnvObsType.GET_XML:
|
||||
obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir)
|
||||
return obs
|
||||
|
||||
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
res = self._execute_env_action(action)
|
||||
|
||||
obs = {}
|
||||
|
||||
ret = (obs, 1.0, False, False, {"res": res})
|
||||
return ret
|
||||
|
||||
def _execute_env_action(self, action: EnvAction):
|
||||
action_type = action.action_type
|
||||
res = None
|
||||
if action_type == EnvActionType.NONE:
|
||||
pass
|
||||
elif action_type == EnvActionType.SYSTEM_BACK:
|
||||
res = self.system_back()
|
||||
elif action_type == EnvActionType.SYSTEM_TAP:
|
||||
res = self.system_tap(x=action.coord[0], y=action.coord[1])
|
||||
elif action_type == EnvActionType.USER_INPUT:
|
||||
res = self.user_input(input_txt=action.input_txt)
|
||||
elif action_type == EnvActionType.USER_LONGPRESS:
|
||||
res = self.user_longpress(x=action.coord[0], y=action.coord[1])
|
||||
elif action_type == EnvActionType.USER_SWIPE:
|
||||
res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist)
|
||||
elif action_type == EnvActionType.USER_SWIPE_TO:
|
||||
res = self.user_swipe_to(start=action.coord, end=action.tgt_coord)
|
||||
return res
|
||||
|
||||
@property
|
||||
def adb_prefix_si(self):
|
||||
"""adb cmd prefix with `device_id` and `shell input`"""
|
||||
return f"adb -s {self.device_id} shell input "
|
||||
|
||||
@property
|
||||
def adb_prefix_shell(self):
|
||||
"""adb cmd prefix with `device_id` and `shell`"""
|
||||
return f"adb -s {self.device_id} shell "
|
||||
|
||||
@property
|
||||
def adb_prefix(self):
|
||||
"""adb cmd prefix with `device_id`"""
|
||||
return f"adb -s {self.device_id} "
|
||||
|
||||
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
|
||||
adb_cmd = adb_cmd.replace("\\", "/")
|
||||
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
exec_res = ADB_EXEC_FAIL
|
||||
if not res.returncode:
|
||||
exec_res = res.stdout.strip()
|
||||
return exec_res
|
||||
|
||||
def create_device_path(self, folder_path: Path):
|
||||
adb_cmd = f"{self.adb_prefix_shell} mkdir {folder_path} -p"
|
||||
res = self.execute_adb_with_cmd(adb_cmd)
|
||||
if res == ADB_EXEC_FAIL:
|
||||
raise RuntimeError(f"create device path: {folder_path} failed")
|
||||
|
||||
@property
|
||||
def device_shape(self) -> tuple[int, int]:
|
||||
adb_cmd = f"{self.adb_prefix_shell} wm size"
|
||||
shape = (0, 0)
|
||||
shape_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
if shape_res != ADB_EXEC_FAIL:
|
||||
shape = tuple(map(int, shape_res.split(": ")[1].split("x")))
|
||||
return shape
|
||||
|
||||
def list_devices(self):
|
||||
adb_cmd = "adb devices"
|
||||
res = self.execute_adb_with_cmd(adb_cmd)
|
||||
devices = []
|
||||
if res != ADB_EXEC_FAIL:
|
||||
devices = res.split("\n")[1:]
|
||||
devices = [device.split()[0] for device in devices]
|
||||
return devices
|
||||
|
||||
@mark_as_readable
|
||||
def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path:
|
||||
"""
|
||||
ss_name: screenshot file name
|
||||
local_save_dir: local dir to store image from virtual machine
|
||||
"""
|
||||
assert self.screenshot_dir
|
||||
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
res = ADB_EXEC_FAIL
|
||||
if ss_res != ADB_EXEC_FAIL:
|
||||
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
|
||||
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
|
||||
pull_res = self.execute_adb_with_cmd(pull_cmd)
|
||||
time.sleep(0.1)
|
||||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = ss_local_path
|
||||
else:
|
||||
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
|
||||
res = image_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_readable
|
||||
def get_xml(self, xml_name: str, local_save_dir: Path) -> Path:
|
||||
xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml")
|
||||
dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}"
|
||||
xml_res = self.execute_adb_with_cmd(dump_cmd)
|
||||
|
||||
res = ADB_EXEC_FAIL
|
||||
if xml_res != ADB_EXEC_FAIL:
|
||||
xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml")
|
||||
pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}"
|
||||
pull_res = self.execute_adb_with_cmd(pull_cmd)
|
||||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = xml_local_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_writeable
|
||||
def system_back(self) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK"
|
||||
back_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return back_res
|
||||
|
||||
@mark_as_writeable
|
||||
def system_tap(self, x: int, y: int) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} tap {x} {y}"
|
||||
tap_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return tap_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_input(self, input_txt: str) -> str:
|
||||
input_txt = input_txt.replace(" ", "%s").replace("'", "")
|
||||
adb_cmd = f"{self.adb_prefix_si} text {input_txt}"
|
||||
input_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return input_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_longpress(self, x: int, y: int, duration: int = 500) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}"
|
||||
press_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return press_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str:
|
||||
dist_unit = int(self.width / 10)
|
||||
if dist == "long":
|
||||
dist_unit *= 3
|
||||
elif dist == "medium":
|
||||
dist_unit *= 2
|
||||
|
||||
if orient == "up":
|
||||
offset = 0, -2 * dist_unit
|
||||
elif orient == "down":
|
||||
offset = 0, 2 * dist_unit
|
||||
elif orient == "left":
|
||||
offset = -1 * dist_unit, 0
|
||||
elif orient == "right":
|
||||
offset = dist_unit, 0
|
||||
else:
|
||||
return ADB_EXEC_FAIL
|
||||
|
||||
duration = 100 if if_quick else 400
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}"
|
||||
swipe_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return swipe_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
|
||||
swipe_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return swipe_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_exit(self) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME"
|
||||
exit_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return exit_res
|
||||
|
||||
def _ocr_text(self, text: str) -> list:
|
||||
image = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
|
||||
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
|
||||
return output_list
|
||||
|
||||
@mark_as_writeable
|
||||
def user_open_app(self, app_name: str) -> str:
|
||||
ocr_result = self._ocr_text(app_name)
|
||||
in_coordinate, _, x, y, iw, ih = (
|
||||
ocr_result[0],
|
||||
ocr_result[1],
|
||||
ocr_result[2],
|
||||
ocr_result[3],
|
||||
ocr_result[4],
|
||||
ocr_result[5],
|
||||
)
|
||||
if len(in_coordinate) == 0:
|
||||
logger.info(f"No App named {app_name}.")
|
||||
return "no app here"
|
||||
else:
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)
|
||||
|
||||
@mark_as_writeable
|
||||
def user_click_text(self, text: str) -> str:
|
||||
ocr_result = self._ocr_text(text)
|
||||
in_coordinate, out_coordinate, x, y, iw, ih, _ = (
|
||||
ocr_result[0],
|
||||
ocr_result[1],
|
||||
ocr_result[2],
|
||||
ocr_result[3],
|
||||
ocr_result[4],
|
||||
ocr_result[5],
|
||||
ocr_result[6],
|
||||
)
|
||||
if len(out_coordinate) == 0:
|
||||
logger.info(
|
||||
f'Failed to execute action click text ({text}). The text "{text}" is not detected in the screenshot.'
|
||||
)
|
||||
elif len(out_coordinate) == 1:
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
else:
|
||||
logger.info(
|
||||
f'Failed to execute action click text ({text}). There are too many text "{text}" in the screenshot.'
|
||||
)
|
||||
|
||||
@mark_as_writeable
|
||||
def user_stop(self):
|
||||
logger.info("Successful execution of tasks")
|
||||
|
||||
@mark_as_writeable
|
||||
def user_click_icon(self, icon_shape_color: str) -> str:
|
||||
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
image = screenshot_path
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
|
||||
if len(out_coordinate) == 1: # only one icon
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
|
||||
else:
|
||||
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
|
||||
temp_file.mkdir(parents=True, exist_ok=True)
|
||||
hash_table, clip_filter = [], []
|
||||
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
|
||||
if crop_for_clip(image, td, i, temp_file):
|
||||
hash_table.append(td)
|
||||
crop_image = f"{i}.png"
|
||||
clip_filter.append(temp_file.joinpath(crop_image))
|
||||
clip_model, clip_preprocess = clip.load("ViT-B/32") # FIXME: device=device
|
||||
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
|
||||
final_box = hash_table[clip_filter]
|
||||
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
print(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
92
metagpt/environment/android/env_space.py
Normal file
92
metagpt/environment/android/env_space.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.environment.base_env_space import (
|
||||
BaseEnvAction,
|
||||
BaseEnvActionType,
|
||||
BaseEnvObsParams,
|
||||
BaseEnvObsType,
|
||||
)
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
|
||||
SYSTEM_BACK = 1
|
||||
SYSTEM_TAP = 2
|
||||
USER_INPUT = 3
|
||||
USER_LONGPRESS = 4
|
||||
USER_SWIPE = 5
|
||||
USER_SWIPE_TO = 6
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate"
|
||||
)
|
||||
tgt_coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate"
|
||||
)
|
||||
input_txt: str = Field(default="", description="user input text")
|
||||
orient: str = Field(default="up", description="swipe orient")
|
||||
dist: str = Field(default="medium", description="swipe dist")
|
||||
|
||||
@field_validator("coord", "tgt_coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
class EnvObsType(BaseEnvObsType):
|
||||
NONE = 0 # get whole observation from env
|
||||
|
||||
GET_SCREENSHOT = 1
|
||||
GET_XML = 2
|
||||
|
||||
|
||||
class EnvObsParams(BaseEnvObsParams):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
|
||||
ss_name: str = Field(default="", description="screenshot file name")
|
||||
xml_name: str = Field(default="", description="xml file name")
|
||||
local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file")
|
||||
|
||||
|
||||
EnvObsValType = str
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)})
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict:
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"tgt_coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"input_txt": spaces.Text(256),
|
||||
"orient": spaces.Text(16),
|
||||
"dist": spaces.Text(16),
|
||||
}
|
||||
)
|
||||
return space
|
||||
43
metagpt/environment/android/grounding_dino_config.py
Normal file
43
metagpt/environment/android/grounding_dino_config.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
batch_size = 1
|
||||
modelname = "groundingdino"
|
||||
backbone = "swin_T_224_1k"
|
||||
position_embedding = "sine"
|
||||
pe_temperatureH = 20
|
||||
pe_temperatureW = 20
|
||||
return_interm_indices = [1, 2, 3]
|
||||
backbone_freeze_keywords = None
|
||||
enc_layers = 6
|
||||
dec_layers = 6
|
||||
pre_norm = False
|
||||
dim_feedforward = 2048
|
||||
hidden_dim = 256
|
||||
dropout = 0.0
|
||||
nheads = 8
|
||||
num_queries = 900
|
||||
query_dim = 4
|
||||
num_patterns = 0
|
||||
num_feature_levels = 4
|
||||
enc_n_points = 4
|
||||
dec_n_points = 4
|
||||
two_stage_type = "standard"
|
||||
two_stage_bbox_embed_share = False
|
||||
two_stage_class_embed_share = False
|
||||
transformer_activation = "relu"
|
||||
dec_pred_bbox_embed_share = True
|
||||
dn_box_noise_scale = 1.0
|
||||
dn_label_noise_ratio = 0.5
|
||||
dn_label_coef = 1.0
|
||||
dn_bbox_coef = 1.0
|
||||
embed_init_tgt = True
|
||||
dn_labelbook_size = 2000
|
||||
max_text_len = 256
|
||||
text_encoder_type = "bert-base-uncased"
|
||||
use_text_enhancer = True
|
||||
use_fusion_layer = True
|
||||
use_checkpoint = True
|
||||
use_transformer_ckpt = True
|
||||
use_text_cross_attention = True
|
||||
text_dropout = 0.0
|
||||
fusion_dropout = 0.0
|
||||
fusion_droppath = 0.1
|
||||
sub_sentence_present = True
|
||||
368
metagpt/environment/android/text_icon_localization.py
Normal file
368
metagpt/environment/android/text_icon_localization.py
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
# The code in this file was modified by MobileAgent
|
||||
# https://github.com/X-PLUG/MobileAgent.git
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import clip
|
||||
import cv2
|
||||
import groundingdino.datasets.transforms as T
|
||||
import numpy as np
|
||||
import torch
|
||||
from groundingdino.models import build_model
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
||||
from PIL import Image
|
||||
|
||||
################################## text_localization using ocr #######################
|
||||
|
||||
|
||||
def crop_image(img: any, position: any) -> any:
|
||||
def distance(x1, y1, x2, y2):
|
||||
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
|
||||
|
||||
position = position.tolist()
|
||||
for i in range(4):
|
||||
for j in range(i + 1, 4):
|
||||
if position[i][0] > position[j][0]:
|
||||
tmp = position[j]
|
||||
position[j] = position[i]
|
||||
position[i] = tmp
|
||||
if position[0][1] > position[1][1]:
|
||||
tmp = position[0]
|
||||
position[0] = position[1]
|
||||
position[1] = tmp
|
||||
|
||||
if position[2][1] > position[3][1]:
|
||||
tmp = position[2]
|
||||
position[2] = position[3]
|
||||
position[3] = tmp
|
||||
|
||||
x1, y1 = position[0][0], position[0][1]
|
||||
x2, y2 = position[2][0], position[2][1]
|
||||
x3, y3 = position[3][0], position[3][1]
|
||||
x4, y4 = position[1][0], position[1][1]
|
||||
|
||||
corners = np.zeros((4, 2), np.float32)
|
||||
corners[0] = [x1, y1]
|
||||
corners[1] = [x2, y2]
|
||||
corners[2] = [x4, y4]
|
||||
corners[3] = [x3, y3]
|
||||
|
||||
img_width = distance((x1 + x4) / 2, (y1 + y4) / 2, (x2 + x3) / 2, (y2 + y3) / 2)
|
||||
img_height = distance((x1 + x2) / 2, (y1 + y2) / 2, (x4 + x3) / 2, (y4 + y3) / 2)
|
||||
|
||||
corners_trans = np.zeros((4, 2), np.float32)
|
||||
corners_trans[0] = [0, 0]
|
||||
corners_trans[1] = [img_width - 1, 0]
|
||||
corners_trans[2] = [0, img_height - 1]
|
||||
corners_trans[3] = [img_width - 1, img_height - 1]
|
||||
|
||||
transform = cv2.getPerspectiveTransform(corners, corners_trans)
|
||||
dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height)))
|
||||
return dst
|
||||
|
||||
|
||||
def calculate_size(box: any) -> any:
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
|
||||
def order_point(cooperation: any) -> any:
|
||||
arr = np.array(cooperation).reshape([4, 2])
|
||||
sum_ = np.sum(arr, 0)
|
||||
centroid = sum_ / arr.shape[0]
|
||||
theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
|
||||
sort_points = arr[np.argsort(theta)]
|
||||
sort_points = sort_points.reshape([4, -1])
|
||||
if sort_points[0][0] > centroid[0]:
|
||||
sort_points = np.concatenate([sort_points[3:], sort_points[:3]])
|
||||
sort_points = sort_points.reshape([4, 2]).astype("float32")
|
||||
return sort_points
|
||||
|
||||
|
||||
def longest_common_substring_length(str1: str, str2: str) -> int:
|
||||
m = len(str1)
|
||||
n = len(str2)
|
||||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if str1[i - 1] == str2[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1] + 1
|
||||
else:
|
||||
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
||||
|
||||
return dp[m][n]
|
||||
|
||||
|
||||
def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any:
|
||||
text_data = []
|
||||
coordinate = []
|
||||
image = Image.open(image_path)
|
||||
iw, ih = image.size
|
||||
|
||||
image_full = cv2.imread(str(image_path))
|
||||
det_result = ocr_detection(image_full)
|
||||
det_result = det_result["polygons"]
|
||||
for i in range(det_result.shape[0]):
|
||||
pts = order_point(det_result[i])
|
||||
image_crop = crop_image(image_full, pts)
|
||||
result = ocr_recognition(image_crop)["text"][0]
|
||||
|
||||
if result == prompt:
|
||||
box = [int(e) for e in list(pts.reshape(-1))]
|
||||
box = [box[0], box[1], box[4], box[5]]
|
||||
|
||||
if calculate_size(box) > 0.05 * iw * ih:
|
||||
continue
|
||||
|
||||
text_data.append(
|
||||
[
|
||||
int(max(0, box[0] - 10) * x / iw),
|
||||
int(max(0, box[1] - 10) * y / ih),
|
||||
int(min(box[2] + 10, iw) * x / iw),
|
||||
int(min(box[3] + 10, ih) * y / ih),
|
||||
]
|
||||
)
|
||||
coordinate.append(
|
||||
[
|
||||
int(max(0, box[0] - 300) * x / iw),
|
||||
int(max(0, box[1] - 400) * y / ih),
|
||||
int(min(box[2] + 300, iw) * x / iw),
|
||||
int(min(box[3] + 400, ih) * y / ih),
|
||||
]
|
||||
)
|
||||
|
||||
max_length = 0
|
||||
if len(text_data) == 0:
|
||||
for i in range(det_result.shape[0]):
|
||||
pts = order_point(det_result[i])
|
||||
image_crop = crop_image(image_full, pts)
|
||||
result = ocr_recognition(image_crop)["text"][0]
|
||||
|
||||
if len(result) < 0.3 * len(prompt):
|
||||
continue
|
||||
|
||||
if result in prompt:
|
||||
now_length = len(result)
|
||||
else:
|
||||
now_length = longest_common_substring_length(result, prompt)
|
||||
|
||||
if now_length > max_length:
|
||||
max_length = now_length
|
||||
box = [int(e) for e in list(pts.reshape(-1))]
|
||||
box = [box[0], box[1], box[4], box[5]]
|
||||
|
||||
text_data = [
|
||||
[
|
||||
int(max(0, box[0] - 10) * x / iw),
|
||||
int(max(0, box[1] - 10) * y / ih),
|
||||
int(min(box[2] + 10, iw) * x / iw),
|
||||
int(min(box[3] + 10, ih) * y / ih),
|
||||
]
|
||||
]
|
||||
coordinate = [
|
||||
[
|
||||
int(max(0, box[0] - 300) * x / iw),
|
||||
int(max(0, box[1] - 400) * y / ih),
|
||||
int(min(box[2] + 300, iw) * x / iw),
|
||||
int(min(box[3] + 400, ih) * y / ih),
|
||||
]
|
||||
]
|
||||
|
||||
if len(prompt) <= 10:
|
||||
if max_length >= 0.8 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
elif (len(prompt) > 10) and (len(prompt) <= 20):
|
||||
if max_length >= 0.5 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
else:
|
||||
if max_length >= 0.4 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
|
||||
else:
|
||||
return text_data, coordinate
|
||||
|
||||
|
||||
################################## icon_localization using clip #######################
|
||||
|
||||
|
||||
def calculate_iou(box1: list, box2: list) -> float:
|
||||
x_a = max(box1[0], box2[0])
|
||||
y_a = max(box1[1], box2[1])
|
||||
x_b = min(box1[2], box2[2])
|
||||
y_b = min(box1[3], box2[3])
|
||||
|
||||
inter_area = max(0, x_b - x_a) * max(0, y_b - y_a)
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union_area = box1_area + box2_area - inter_area
|
||||
iou = inter_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def in_box(box: list, target: list) -> bool:
|
||||
if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool:
|
||||
image = Image.open(image)
|
||||
w, h = image.size
|
||||
bound = [0, 0, w, h]
|
||||
if in_box(box, bound):
|
||||
cropped_image = image.crop(box)
|
||||
cropped_image.save(temp_file.joinpath(f"{i}.png"))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any:
|
||||
image_features = []
|
||||
for image_file in images:
|
||||
image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
|
||||
image_feature = clip_model.encode_image(image)
|
||||
image_features.append(image_feature)
|
||||
image_features = torch.cat(image_features)
|
||||
|
||||
text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device)
|
||||
text_features = clip_model.encode_text(text)
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0)
|
||||
_, max_pos = torch.max(similarity, dim=0)
|
||||
pos = max_pos.item()
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def transform_image(image_pil: any) -> any:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image, _ = transform(image_pil, None) # 3, h, w
|
||||
return image
|
||||
|
||||
|
||||
def load_model(model_checkpoint_path: Path, device: str) -> any:
|
||||
model_config_path = "grounding_dino_config.py"
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||
print(load_res)
|
||||
_ = model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_grounding_output(
|
||||
model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True
|
||||
) -> any:
|
||||
caption = caption.lower()
|
||||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
caption = caption + "."
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
||||
logits.shape[0]
|
||||
|
||||
logits_filt = logits.clone()
|
||||
boxes_filt = boxes.clone()
|
||||
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
||||
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
||||
logits_filt.shape[0]
|
||||
|
||||
tokenlizer = model.tokenizer
|
||||
tokenized = tokenlizer(caption)
|
||||
|
||||
pred_phrases = []
|
||||
scores = []
|
||||
for logit, box in zip(logits_filt, boxes_filt):
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||
if with_logits:
|
||||
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||
else:
|
||||
pred_phrases.append(pred_phrase)
|
||||
scores.append(logit.max().item())
|
||||
|
||||
return boxes_filt, torch.Tensor(scores), pred_phrases
|
||||
|
||||
|
||||
def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
|
||||
boxes_to_remove = set()
|
||||
|
||||
for i in range(len(boxes_filt)):
|
||||
if calculate_size(boxes_filt[i]) > 0.05 * size[0] * size[1]:
|
||||
boxes_to_remove.add(i)
|
||||
for j in range(len(boxes_filt)):
|
||||
if calculate_size(boxes_filt[j]) > 0.05 * size[0] * size[1]:
|
||||
boxes_to_remove.add(j)
|
||||
if i == j:
|
||||
continue
|
||||
if i in boxes_to_remove or j in boxes_to_remove:
|
||||
continue
|
||||
iou = calculate_iou(boxes_filt[i], boxes_filt[j])
|
||||
if iou >= iou_threshold:
|
||||
boxes_to_remove.add(j)
|
||||
|
||||
boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove]
|
||||
|
||||
return boxes_filt
|
||||
|
||||
|
||||
def det(
|
||||
input_image: any,
|
||||
text_prompt: str,
|
||||
groundingdino_model: any,
|
||||
box_threshold: float = 0.05,
|
||||
text_threshold: float = 0.5,
|
||||
) -> any:
|
||||
image = Image.open(input_image)
|
||||
size = image.size
|
||||
|
||||
image_pil = image.convert("RGB")
|
||||
image = np.array(image_pil)
|
||||
|
||||
transformed_image = transform_image(image_pil)
|
||||
boxes_filt, scores, pred_phrases = get_grounding_output(
|
||||
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
|
||||
)
|
||||
|
||||
H, W = size[1], size[0]
|
||||
for i in range(boxes_filt.size(0)):
|
||||
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
||||
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
||||
boxes_filt[i][2:] += boxes_filt[i][:2]
|
||||
|
||||
boxes_filt = boxes_filt.cpu().int().tolist()
|
||||
filtered_boxes = remove_boxes(boxes_filt, size) # [:9]
|
||||
coordinate = []
|
||||
image_data = []
|
||||
for box in filtered_boxes:
|
||||
image_data.append(
|
||||
[max(0, box[0] - 10), max(0, box[1] - 10), min(box[2] + 10, size[0]), min(box[3] + 10, size[1])]
|
||||
)
|
||||
coordinate.append(
|
||||
[max(0, box[0] - 25), max(0, box[1] - 25), min(box[2] + 25, size[0]), min(box[3] + 25, size[1])]
|
||||
)
|
||||
|
||||
return image_data, coordinate
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : The Android external environment to integrate with Android apps
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
width: int = Field(default=720, description="device screen width")
|
||||
height: int = Field(default=1080, description="device screen height")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if data.get("device_id"):
|
||||
(width, height) = self.device_shape
|
||||
self.width = data.get("width", width)
|
||||
self.height = data.get("height", height)
|
||||
|
||||
@property
|
||||
def adb_prefix_si(self):
|
||||
"""adb cmd prefix with `device_id` and `shell input`"""
|
||||
return f"adb -s {self.device_id} shell input "
|
||||
|
||||
@property
|
||||
def adb_prefix_shell(self):
|
||||
"""adb cmd prefix with `device_id` and `shell`"""
|
||||
return f"adb -s {self.device_id} shell "
|
||||
|
||||
@property
|
||||
def adb_prefix(self):
|
||||
"""adb cmd prefix with `device_id`"""
|
||||
return f"adb -s {self.device_id} "
|
||||
|
||||
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
|
||||
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
exec_res = ADB_EXEC_FAIL
|
||||
if not res.returncode:
|
||||
exec_res = res.stdout.strip()
|
||||
return exec_res
|
||||
|
||||
@property
|
||||
def device_shape(self) -> tuple[int, int]:
|
||||
adb_cmd = f"{self.adb_prefix_shell} wm size"
|
||||
shape = (0, 0)
|
||||
shape_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
if shape_res != ADB_EXEC_FAIL:
|
||||
shape = tuple(map(int, shape_res.split(": ")[1].split("x")))
|
||||
return shape
|
||||
|
||||
def list_devices(self):
|
||||
adb_cmd = "adb devices"
|
||||
res = self.execute_adb_with_cmd(adb_cmd)
|
||||
devices = []
|
||||
if res != ADB_EXEC_FAIL:
|
||||
devices = res.split("\n")[1:]
|
||||
devices = [device.split()[0] for device in devices]
|
||||
return devices
|
||||
|
||||
@mark_as_readable
|
||||
def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path:
|
||||
"""
|
||||
ss_name: screenshot file name
|
||||
local_save_dir: local dir to store image from virtual machine
|
||||
"""
|
||||
assert self.screenshot_dir
|
||||
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
|
||||
res = ADB_EXEC_FAIL
|
||||
if ss_res != ADB_EXEC_FAIL:
|
||||
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
|
||||
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
|
||||
pull_res = self.execute_adb_with_cmd(pull_cmd)
|
||||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = ss_local_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_readable
|
||||
def get_xml(self, xml_name: str, local_save_dir: Path) -> Path:
|
||||
xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml")
|
||||
dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}"
|
||||
xml_res = self.execute_adb_with_cmd(dump_cmd)
|
||||
|
||||
res = ADB_EXEC_FAIL
|
||||
if xml_res != ADB_EXEC_FAIL:
|
||||
xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml")
|
||||
pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}"
|
||||
pull_res = self.execute_adb_with_cmd(pull_cmd)
|
||||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = xml_local_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_writeable
|
||||
def system_back(self) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK"
|
||||
back_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return back_res
|
||||
|
||||
@mark_as_writeable
|
||||
def system_tap(self, x: int, y: int) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} tap {x} {y}"
|
||||
tap_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return tap_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_input(self, input_txt: str) -> str:
|
||||
input_txt = input_txt.replace(" ", "%s").replace("'", "")
|
||||
adb_cmd = f"{self.adb_prefix_si} text {input_txt}"
|
||||
input_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return input_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_longpress(self, x: int, y: int, duration: int = 500) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}"
|
||||
press_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return press_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str:
|
||||
dist_unit = int(self.width / 10)
|
||||
if dist == "long":
|
||||
dist_unit *= 3
|
||||
elif dist == "medium":
|
||||
dist_unit *= 2
|
||||
|
||||
if orient == "up":
|
||||
offset = 0, -2 * dist_unit
|
||||
elif orient == "down":
|
||||
offset = 0, 2 * dist_unit
|
||||
elif orient == "left":
|
||||
offset = -1 * dist_unit, 0
|
||||
elif orient == "right":
|
||||
offset = dist_unit, 0
|
||||
else:
|
||||
return ADB_EXEC_FAIL
|
||||
|
||||
duration = 100 if if_quick else 400
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}"
|
||||
swipe_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return swipe_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
|
||||
swipe_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return swipe_res
|
||||
|
|
@ -18,11 +18,11 @@ class EnvAPIAbstract(BaseModel):
|
|||
class EnvAPIRegistry(BaseModel):
|
||||
"""the registry to store environment w&r api/interface"""
|
||||
|
||||
registry: dict[str, dict[str, Union[dict, Any, str]]] = Field(default=dict(), exclude=True)
|
||||
registry: dict[str, Callable] = Field(default=dict(), exclude=True)
|
||||
|
||||
def get(self, api_name: str):
|
||||
if api_name not in self.registry:
|
||||
raise ValueError
|
||||
raise KeyError(f"api_name: {api_name} not found")
|
||||
return self.registry.get(api_name)
|
||||
|
||||
def __getitem__(self, api_name: str) -> Callable:
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@
|
|||
# @Desc : base env of executing environment
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.context import Context
|
||||
|
|
@ -14,6 +17,7 @@ from metagpt.environment.api.env_api import (
|
|||
ReadAPIRegistry,
|
||||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
|
|
@ -49,6 +53,11 @@ def mark_as_writeable(func):
|
|||
class ExtEnv(BaseModel):
|
||||
"""External Env to integrate actual game environment"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True)
|
||||
observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True)
|
||||
|
||||
def _check_api_exist(self, rw_api: Optional[str] = None):
|
||||
if not rw_api:
|
||||
raise ValueError(f"{rw_api} not exists")
|
||||
|
|
@ -61,39 +70,56 @@ class ExtEnv(BaseModel):
|
|||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self)
|
||||
env_read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(env_read_api)
|
||||
if is_coroutine_func(env_read_api):
|
||||
res = await env_read_api(self)
|
||||
else:
|
||||
res = read_api(self)
|
||||
res = env_read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self, *env_action.args, **env_action.kwargs)
|
||||
env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(env_read_api)
|
||||
if is_coroutine_func(env_read_api):
|
||||
res = await env_read_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
res = env_read_api(self, *env_action.args, **env_action.kwargs)
|
||||
return res
|
||||
|
||||
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
"""execute through particular api of ExtEnv"""
|
||||
res = None
|
||||
if isinstance(env_action, Message):
|
||||
self.publish_message(env_action)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
if is_coroutine_func(write_api):
|
||||
res = await write_api(self, *env_action.args, **env_action.kwargs)
|
||||
env_write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(env_write_api)
|
||||
if is_coroutine_func(env_write_api):
|
||||
res = await env_write_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
res = env_write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Implement this to get init observation"""
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
"""Implement this if you want to get partial observation from the env"""
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
"""Implement this to feed a action and then get new observation from the env"""
|
||||
|
||||
|
||||
class Environment(ExtEnv):
|
||||
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
|
||||
|
|
@ -108,6 +134,20 @@ class Environment(ExtEnv):
|
|||
history: str = "" # For debug
|
||||
context: Context = Field(default_factory=Context, exclude=True)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
|
|
|
|||
33
metagpt/environment/base_env_space.py
Normal file
33
metagpt/environment/base_env_space.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BaseEnvActionType(IntEnum):
|
||||
# # NONE = 0 # no action to run, just get observation
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnvAction(BaseModel):
|
||||
"""env action type and its related params of action functions/apis"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=0, description="action type")
|
||||
|
||||
|
||||
class BaseEnvObsType(IntEnum):
|
||||
# # NONE = 0 # get whole observation from env
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnvObsParams(BaseModel):
|
||||
"""observation params for different EnvObsType to get its observe result"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=0, description="observation type")
|
||||
|
|
@ -13,13 +13,13 @@ from pydantic import ConfigDict, Field
|
|||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.environment.minecraft.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
|
||||
|
||||
|
||||
class MinecraftEnv(Environment, MinecraftExtEnv):
|
||||
class MinecraftEnv(MinecraftExtEnv, Environment):
|
||||
"""MinecraftEnv, including shared memory of cache and information between roles"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
|
@ -282,7 +282,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
position = event["status"]["position"]
|
||||
blocks.append(block)
|
||||
positions.append(position)
|
||||
new_events = self.step(
|
||||
new_events = self._step(
|
||||
f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})",
|
||||
programs=self.programs,
|
||||
)
|
||||
|
|
@ -323,7 +323,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
Exception: If there is an issue retrieving events.
|
||||
"""
|
||||
try:
|
||||
self.reset(
|
||||
self._reset(
|
||||
options={
|
||||
"mode": "soft",
|
||||
"wait_ticks": 20,
|
||||
|
|
@ -332,13 +332,13 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
# difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful"
|
||||
difficulty = "peaceful"
|
||||
|
||||
events = self.step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');")
|
||||
events = self._step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');")
|
||||
self.update_event(events)
|
||||
return events
|
||||
except Exception as e:
|
||||
time.sleep(3) # wait for mineflayer to exit
|
||||
# reset bot status here
|
||||
events = self.reset(
|
||||
events = self._reset(
|
||||
options={
|
||||
"mode": "hard",
|
||||
"wait_ticks": 20,
|
||||
|
|
@ -365,7 +365,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
Exception: If there is an issue retrieving events.
|
||||
"""
|
||||
try:
|
||||
events = self.step(
|
||||
events = self._step(
|
||||
code=self.code,
|
||||
programs=self.programs,
|
||||
)
|
||||
|
|
@ -374,7 +374,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv):
|
|||
except Exception as e:
|
||||
time.sleep(3) # wait for mineflayer to exit
|
||||
# reset bot status here
|
||||
events = self.reset(
|
||||
events = self._reset(
|
||||
options={
|
||||
"mode": "hard",
|
||||
"wait_ticks": 20,
|
||||
|
|
@ -5,20 +5,21 @@
|
|||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_writeable
|
||||
from metagpt.environment.minecraft_env.const import (
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
from metagpt.environment.minecraft.const import (
|
||||
MC_CKPT_DIR,
|
||||
MC_CORE_INVENTORY_ITEMS,
|
||||
MC_CURRICULUM_OB,
|
||||
MC_DEFAULT_WARMUP,
|
||||
METAGPT_ROOT,
|
||||
)
|
||||
from metagpt.environment.minecraft_env.process_monitor import SubprocessMonitor
|
||||
from metagpt.environment.minecraft.process_monitor import SubprocessMonitor
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -38,6 +39,20 @@ class MinecraftExtEnv(ExtEnv):
|
|||
server_paused: bool = Field(default=False)
|
||||
warm_up: dict = Field(default=dict())
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def server(self) -> str:
|
||||
return f"{self.server_host}:{self.server_port}"
|
||||
|
|
@ -48,7 +63,7 @@ class MinecraftExtEnv(ExtEnv):
|
|||
self.mineflayer = SubprocessMonitor(
|
||||
commands=[
|
||||
"node",
|
||||
METAGPT_ROOT.joinpath("metagpt", "environment", "minecraft_env", "mineflayer", "index.js"),
|
||||
METAGPT_ROOT.joinpath("metagpt", "environment", "minecraft", "mineflayer", "index.js"),
|
||||
str(self.server_port),
|
||||
],
|
||||
name="mineflayer",
|
||||
|
|
@ -115,7 +130,7 @@ class MinecraftExtEnv(ExtEnv):
|
|||
return res.json()
|
||||
|
||||
@mark_as_writeable
|
||||
def reset(self, *, seed=None, options=None) -> dict:
|
||||
def _reset(self, *, seed=None, options=None) -> dict:
|
||||
if options is None:
|
||||
options = {}
|
||||
if options.get("inventory", {}) and options.get("mode", "hard") != "hard":
|
||||
|
|
@ -145,7 +160,7 @@ class MinecraftExtEnv(ExtEnv):
|
|||
return json.loads(returned_data)
|
||||
|
||||
@mark_as_writeable
|
||||
def step(self, code: str, programs: str = "") -> dict:
|
||||
def _step(self, code: str, programs: str = "") -> dict:
|
||||
if not self.has_reset:
|
||||
raise RuntimeError("Environment has not been reset yet")
|
||||
self.check_process()
|
||||
105
metagpt/environment/stanford_town/env_space.py
Normal file
105
metagpt/environment/stanford_town/env_space.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.environment.base_env_space import (
|
||||
BaseEnvAction,
|
||||
BaseEnvActionType,
|
||||
BaseEnvObsParams,
|
||||
BaseEnvObsType,
|
||||
)
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
|
||||
ADD_TILE_EVENT = 1 # Add an event triple to a tile
|
||||
RM_TILE_EVENT = 2 # Remove an event triple from a tile
|
||||
TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle
|
||||
RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
"""env action type and its related params of action functions/apis"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
|
||||
)
|
||||
subject: str = Field(default="", description="subject name of first element in event")
|
||||
event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field(
|
||||
default=["", None, None, None], description="tile event"
|
||||
)
|
||||
|
||||
@field_validator("coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
class EnvObsType(BaseEnvObsType):
|
||||
"""get part observation with specific params"""
|
||||
|
||||
NONE = 0 # get whole observation from env
|
||||
|
||||
GET_TITLE = 1 # get the tile detail dictionary with given tile coord
|
||||
TILE_PATH = 2 # get the tile address with given tile coord
|
||||
TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius
|
||||
|
||||
|
||||
class EnvObsParams(BaseEnvObsParams):
|
||||
"""observation params for different EnvObsType"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
|
||||
)
|
||||
level: str = Field(default="", description="different level of title")
|
||||
vision_radius: int = Field(default=0, description="the vision radius of current tile")
|
||||
|
||||
@field_validator("coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]]
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
# it's a
|
||||
space = spaces.Dict(
|
||||
{"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)}
|
||||
)
|
||||
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict:
|
||||
"""The fields defined by the space correspond to the input parameters of the action except `action_type`"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64)
|
||||
), # coord of the tile
|
||||
"subject": spaces.Text(256), # the first element of an tile event
|
||||
"event": spaces.Tuple(
|
||||
(spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256))
|
||||
), # event is a tuple of four str
|
||||
}
|
||||
)
|
||||
return space
|
||||
10
metagpt/environment/stanford_town/stanford_town_env.py
Normal file
10
metagpt/environment/stanford_town/stanford_town_env.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG StanfordTown Env
|
||||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv
|
||||
|
||||
|
||||
class StanfordTownEnv(StanfordTownExtEnv, Environment):
|
||||
pass
|
||||
|
|
@ -5,11 +5,20 @@
|
|||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.stanford_town.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
EnvObsValType,
|
||||
get_action_space,
|
||||
get_observation_space,
|
||||
)
|
||||
from metagpt.utils.common import read_csv_to_list, read_json_file
|
||||
|
||||
|
||||
|
|
@ -197,15 +206,82 @@ class StanfordTownExtEnv(ExtEnv):
|
|||
else:
|
||||
address_tiles[add] = set([(j, i)])
|
||||
values["address_tiles"] = address_tiles
|
||||
|
||||
values["action_space"] = get_action_space((maze_width, maze_height))
|
||||
values["observation_space"] = get_observation_space()
|
||||
return values
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, EnvObsValType], dict[str, Any]]:
|
||||
"""reset env and get the init observation
|
||||
Return results corresponding to `observation, info`
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
return obs, {}
|
||||
|
||||
def _get_obs(self) -> dict[str, EnvObsValType]:
|
||||
"""Get observation"""
|
||||
return {
|
||||
"collision_maze": self.get_collision_maze(),
|
||||
"tiles": self.tiles,
|
||||
"address_tiles": self.get_address_tiles(),
|
||||
}
|
||||
|
||||
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
|
||||
"""Get partial or full observation from the env"""
|
||||
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
|
||||
if obs_type == EnvObsType.NONE:
|
||||
obs = self._get_obs()
|
||||
elif obs_type == EnvObsType.GET_TITLE:
|
||||
obs = self.access_tile(tile=obs_params.coord)
|
||||
elif obs_type == EnvObsType.TILE_PATH:
|
||||
obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level)
|
||||
elif obs_type == EnvObsType.TILE_NBR:
|
||||
obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius)
|
||||
return obs
|
||||
|
||||
def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]:
|
||||
"""Execute action and then return observation
|
||||
Return results corresponding to `observation, reward, terminated, truncated, info`
|
||||
"""
|
||||
terminated = False
|
||||
try:
|
||||
self._execute_env_action(action)
|
||||
except Exception:
|
||||
terminated = True
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
ret = (obs, 1.0, terminated, False, {})
|
||||
return ret
|
||||
|
||||
def _execute_env_action(self, action: EnvAction):
|
||||
action_type = action.action_type
|
||||
if action_type == EnvActionType.NONE:
|
||||
pass
|
||||
elif action_type == EnvActionType.ADD_TILE_EVENT:
|
||||
self.add_event_from_tile(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.RM_TILE_EVENT:
|
||||
self.remove_event_from_tile(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE:
|
||||
self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.RM_TITLE_SUB_EVENT:
|
||||
self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord)
|
||||
|
||||
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
|
||||
"""
|
||||
Turns a pixel coordinate to a tile coordinate.
|
||||
"""
|
||||
x = math.ceil(px_coordinate[0] / self.sq_tile_size)
|
||||
y = math.ceil(px_coordinate[1] / self.sq_tile_size)
|
||||
return (x, y)
|
||||
return x, y
|
||||
|
||||
@mark_as_readable
|
||||
def get_collision_maze(self) -> list:
|
||||
|
|
@ -316,10 +392,6 @@ class StanfordTownExtEnv(ExtEnv):
|
|||
nearby_tiles += [(i, j)]
|
||||
return nearby_tiles
|
||||
|
||||
@mark_as_writeable
|
||||
def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]):
|
||||
self.tiles[pt_y][pt_x]["events"].add(event)
|
||||
|
||||
@mark_as_writeable
|
||||
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
|
||||
"""
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG StanfordTown Env
|
||||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
|
||||
StanfordTownExtEnv,
|
||||
)
|
||||
|
||||
|
||||
class StanfordTownEnv(Environment, StanfordTownExtEnv):
|
||||
pass
|
||||
121
metagpt/environment/werewolf/const.py
Normal file
121
metagpt/environment/werewolf/const.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
|
||||
|
||||
class RoleType(Enum):
|
||||
VILLAGER = "Villager"
|
||||
WEREWOLF = "Werewolf"
|
||||
GUARD = "Guard"
|
||||
SEER = "Seer"
|
||||
WITCH = "Witch"
|
||||
MODERATOR = "Moderator"
|
||||
|
||||
|
||||
class RoleState(Enum):
|
||||
ALIVE = "alive" # the role is alive
|
||||
DEAD = "dead" # killed or poisoned
|
||||
KILLED = "killed" # killed by werewolf or voting
|
||||
POISONED = "poisoned" # killed by poison
|
||||
SAVED = "saved" # saved by antidote
|
||||
PROTECTED = "projected" # projected by guard
|
||||
|
||||
|
||||
class RoleActionRes(Enum):
|
||||
SAVE = "save"
|
||||
PASS = "pass" # ignore current action output
|
||||
|
||||
|
||||
empty_set = set()
|
||||
|
||||
# the ordered rules by the moderator to announce to everyone each step
|
||||
STEP_INSTRUCTIONS = {
|
||||
0: {
|
||||
"content": "It’s dark, everyone close your eyes. I will talk with you/your team secretly at night.",
|
||||
"send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
1: {
|
||||
"content": "Guard, please open your eyes!",
|
||||
"send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
2: {
|
||||
"content": """Guard, now tell me who you protect tonight?
|
||||
You only choose one from the following living options please: {living_players}.
|
||||
Or you can pass. For example: Protect ...""",
|
||||
"send_to": {RoleType.GUARD.value},
|
||||
"restricted_to": {RoleType.MODERATOR.value, RoleType.GUARD.value},
|
||||
},
|
||||
3: {"content": "Guard, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
4: {
|
||||
"content": "Werewolves, please open your eyes!",
|
||||
"send_to": {RoleType.MODERATOR.value},
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
5: {
|
||||
"content": """Werewolves, I secretly tell you that {werewolf_players} are
|
||||
all of the {werewolf_num} werewolves! Keep in mind you are teammates. The rest players are not werewolves.
|
||||
choose one from the following living options please:
|
||||
{living_players}. For example: Kill ...""",
|
||||
"send_to": {RoleType.WEREWOLF.value},
|
||||
"restricted_to": {RoleType.MODERATOR.value, RoleType.WEREWOLF.value},
|
||||
},
|
||||
6: {"content": "Werewolves, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
7: {"content": "Witch, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
8: {
|
||||
"content": """Witch, tonight {player_hunted} has been killed by the werewolves.
|
||||
You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""",
|
||||
"send_to": {RoleType.WITCH.value},
|
||||
"restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value},
|
||||
}, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人
|
||||
9: {
|
||||
"content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players?
|
||||
Choose one from the following living options: {living_players}.
|
||||
If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""",
|
||||
"send_to": {RoleType.WITCH.value},
|
||||
"restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value},
|
||||
}, #
|
||||
10: {"content": "Witch, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
11: {"content": "Seer, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
12: {
|
||||
"content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight?
|
||||
Choose only one from the following living options:{living_players}.""",
|
||||
"send_to": {RoleType.SEER.value},
|
||||
"restricted_to": {RoleType.MODERATOR.value, RoleType.SEER.value},
|
||||
},
|
||||
13: {"content": "Seer, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set},
|
||||
# The 1-st daytime
|
||||
14: {
|
||||
"content": """It's daytime. Everyone woke up except those who had been killed.""",
|
||||
"send_to": {RoleType.MODERATOR.value},
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
15: {
|
||||
"content": "{player_current_dead} was killed last night!",
|
||||
"send_to": {RoleType.MODERATOR.value},
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
16: {
|
||||
"content": """Living players: {living_players}, now freely talk about the current situation based on your observation and
|
||||
reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""",
|
||||
"send_to": {MESSAGE_ROUTE_TO_ALL}, # send to all to speak in daytime
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
17: {
|
||||
"content": """Now vote and tell me who you think is the werewolf. Don’t mention your role.
|
||||
You only choose one from the following living options please:
|
||||
{living_players}. Say ONLY: I vote to eliminate ...""",
|
||||
"send_to": {MESSAGE_ROUTE_TO_ALL},
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
18: {
|
||||
"content": """{player_current_dead} was eliminated.""",
|
||||
"send_to": {RoleType.MODERATOR.value},
|
||||
"restricted_to": empty_set,
|
||||
},
|
||||
}
|
||||
62
metagpt/environment/werewolf/env_space.py
Normal file
62
metagpt/environment/werewolf/env_space.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : werewolf observation/action space and its action definition
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvActionType
|
||||
from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
WOLF_KILL = 1 # wolf kill someone
|
||||
VOTE_KILL = 2 # vote kill someone
|
||||
WITCH_POISON = 3 # witch poison someone
|
||||
WITCH_SAVE = 4 # witch save someone
|
||||
GUARD_PROTECT = 5 # guard protect someone
|
||||
PROGRESS_STEP = 6 # step increment
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
player_name: str = Field(default="", description="the name of the player to do the action")
|
||||
target_player_name: str = Field(default="", description="the name of the player who take the action")
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"game_setup": spaces.Text(256),
|
||||
"step_idx": spaces.Discrete(len(STEP_INSTRUCTIONS)),
|
||||
"living_players": spaces.Tuple(
|
||||
(spaces.Text(16), spaces.Text(16))
|
||||
), # TODO should be tuple of variable length
|
||||
"werewolf_players": spaces.Tuple(
|
||||
(spaces.Text(16), spaces.Text(16))
|
||||
), # TODO should be tuple of variable length
|
||||
"player_hunted": spaces.Text(16),
|
||||
"player_current_dead": spaces.Tuple(
|
||||
(spaces.Text(16), spaces.Text(16))
|
||||
), # TODO should be tuple of variable length
|
||||
"witch_poison_left": spaces.Discrete(2),
|
||||
"witch_antidote_left": spaces.Discrete(2),
|
||||
"winner": spaces.Text(16),
|
||||
"win_reason": spaces.Text(64),
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space() -> spaces.Dict:
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"player_name": spaces.Text(16), # the player to do the action
|
||||
"target_player_name": spaces.Text(16), # the target player who take the action
|
||||
}
|
||||
)
|
||||
return space
|
||||
41
metagpt/environment/werewolf/werewolf_env.py
Normal file
41
metagpt/environment/werewolf/werewolf_env.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG Werewolf Env
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class WerewolfEnv(WerewolfExtEnv, Environment):
|
||||
round_cnt: int = Field(default=0)
|
||||
|
||||
def add_roles(self, roles: Iterable["Role"]):
|
||||
"""增加一批在当前环境的角色
|
||||
Add a batch of characters in the current environment
|
||||
"""
|
||||
for role in roles:
|
||||
self.roles[role.name] = role # use name as key here, due to multi-player can have same profile
|
||||
|
||||
for role in roles: # setup system message with roles
|
||||
role.context = self.context
|
||||
role.set_env(self)
|
||||
|
||||
def publish_message(self, message: Message, add_timestamp: bool = True):
|
||||
"""Post information to the current environment"""
|
||||
if add_timestamp:
|
||||
# Because the content of the message may be repeated, for example, killing the same person in two nights
|
||||
# Therefore, a unique round_cnt prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory.
|
||||
message.content = f"{self.round_cnt} | " + message.content
|
||||
super().publish_message(message)
|
||||
|
||||
async def run(self, k=1):
|
||||
"""Process all Role runs by order"""
|
||||
for _ in range(k):
|
||||
for role in self.roles.values():
|
||||
await role.run()
|
||||
self.round_cnt += 1
|
||||
|
|
@ -4,109 +4,27 @@
|
|||
|
||||
import random
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env_space import BaseEnvObsParams
|
||||
from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS, RoleState, RoleType
|
||||
from metagpt.environment.werewolf.env_space import EnvAction, EnvActionType
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class RoleState(Enum):
|
||||
ALIVE = "alive" # the role is alive
|
||||
KILLED = "killed" # the role is killed by werewolf or voting
|
||||
POISONED = "poisoned" # the role is killed by posion
|
||||
SAVED = "saved" # the role is saved by antidote
|
||||
|
||||
|
||||
# the ordered rules by the moderator to announce to everyone each step
|
||||
STEP_INSTRUCTIONS = {
|
||||
0: {
|
||||
"content": "It’s dark, everyone close your eyes. I will talk with you/your team secretly at night.",
|
||||
"send_to": "Moderator", # for moderator to continuen speaking
|
||||
"restricted_to": "",
|
||||
},
|
||||
1: {
|
||||
"content": "Guard, please open your eyes!",
|
||||
"send_to": "Moderator", # for moderator to continuen speaking
|
||||
"restricted_to": "",
|
||||
},
|
||||
2: {
|
||||
"content": """Guard, now tell me who you protect tonight?
|
||||
You only choose one from the following living options please: {living_players}.
|
||||
Or you can pass. For example: Protect ...""",
|
||||
"send_to": "Guard",
|
||||
"restricted_to": "Moderator,Guard",
|
||||
},
|
||||
3: {"content": "Guard, close your eyes", "send_to": "Moderator", "restricted_to": ""},
|
||||
4: {"content": "Werewolves, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
|
||||
5: {
|
||||
"content": """Werewolves, I secretly tell you that {werewolf_players} are
|
||||
all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves.
|
||||
choose one from the following living options please:
|
||||
{living_players}. For example: Kill ...""",
|
||||
"send_to": "Werewolf",
|
||||
"restricted_to": "Moderator,Werewolf",
|
||||
},
|
||||
6: {"content": "Werewolves, close your eyes", "send_to": "Moderator", "restricted_to": ""},
|
||||
7: {"content": "Witch, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
|
||||
8: {
|
||||
"content": """Witch, tonight {player_hunted} has been killed by the werewolves.
|
||||
You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""",
|
||||
"send_to": "Witch",
|
||||
"restricted_to": "Moderator,Witch",
|
||||
}, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人
|
||||
9: {
|
||||
"content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players?
|
||||
Choose one from the following living options: {living_players}.
|
||||
If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""",
|
||||
"send_to": "Witch",
|
||||
"restricted_to": "Moderator,Witch",
|
||||
}, #
|
||||
10: {"content": "Witch, close your eyes", "send_to": "Moderator", "restricted_to": ""},
|
||||
11: {"content": "Seer, please open your eyes!", "send_to": "Moderator", "restricted_to": ""},
|
||||
12: {
|
||||
"content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight?
|
||||
Choose only one from the following living options:{living_players}.""",
|
||||
"send_to": "Seer",
|
||||
"restricted_to": "Moderator,Seer",
|
||||
},
|
||||
13: {"content": "Seer, close your eyes", "send_to": "Moderator", "restricted_to": ""},
|
||||
# The 1-st daytime
|
||||
14: {
|
||||
"content": """It's daytime. Everyone woke up except those who had been killed.""",
|
||||
"send_to": "Moderator",
|
||||
"restricted_to": "",
|
||||
},
|
||||
15: {"content": "{player_current_dead} was killed last night!", "send_to": "Moderator", "restricted_to": ""},
|
||||
16: {
|
||||
"content": """Living players: {living_players}, now freely talk about the current situation based on your observation and
|
||||
reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""",
|
||||
"send_to": "", # send to all to speak in daytime
|
||||
"restricted_to": "",
|
||||
},
|
||||
17: {
|
||||
"content": """Now vote and tell me who you think is the werewolf. Don’t mention your role.
|
||||
You only choose one from the following living options please:
|
||||
{living_players}. Say ONLY: I vote to eliminate ...""",
|
||||
"send_to": "",
|
||||
"restricted_to": "",
|
||||
},
|
||||
18: {"content": """{player_current_dead} was eliminated.""", "send_to": "Moderator", "restricted_to": ""},
|
||||
}
|
||||
|
||||
|
||||
class WerewolfExtEnv(ExtEnv):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
players_state: dict[str, tuple[str, RoleState]] = Field(
|
||||
default=dict(), description="the player's role type and state by player_name"
|
||||
default_factory=dict, description="the player's role type and state by player_name"
|
||||
)
|
||||
|
||||
round_idx: int = Field(default=0) # the current round
|
||||
step_idx: int = Field(default=0) # the current step of current round
|
||||
eval_step_idx: int = Field(default=0)
|
||||
eval_step_idx: list[int] = Field(default=[])
|
||||
per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS))
|
||||
|
||||
# game global states
|
||||
|
|
@ -114,13 +32,13 @@ class WerewolfExtEnv(ExtEnv):
|
|||
special_role_players: list[str] = Field(default=[])
|
||||
winner: Optional[str] = Field(default=None)
|
||||
win_reason: Optional[str] = Field(default=None)
|
||||
witch_poison_left: int = Field(default=1)
|
||||
witch_antidote_left: int = Field(default=1)
|
||||
witch_poison_left: int = Field(default=1, description="should be 1 or 0")
|
||||
witch_antidote_left: int = Field(default=1, description="should be 1 or 0")
|
||||
|
||||
# game current round states, a round is from closing your eyes to the next time you close your eyes
|
||||
round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result")
|
||||
round_hunts: dict[str, str] = Field(default_factory=dict, description="nighttime wolf hunt result")
|
||||
round_votes: dict[str, str] = Field(
|
||||
default=dict(), description="daytime all players vote result, key=voteer, value=voted one"
|
||||
default_factory=dict, description="daytime all players vote result, key=voter, value=voted one"
|
||||
)
|
||||
player_hunted: Optional[str] = Field(default=None)
|
||||
player_protected: Optional[str] = Field(default=None)
|
||||
|
|
@ -128,6 +46,76 @@ class WerewolfExtEnv(ExtEnv):
|
|||
player_poisoned: Optional[str] = Field(default=None)
|
||||
player_current_dead: list[str] = Field(default=[])
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""currently unused"""
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
"""currently unused"""
|
||||
pass
|
||||
|
||||
def _get_obs(self):
|
||||
return {
|
||||
"game_setup": self.game_setup,
|
||||
"step_idx": self.step_idx,
|
||||
"living_players": self.living_players,
|
||||
"werewolf_players": self.werewolf_players, # currently, lack observation isolation
|
||||
"player_hunted": self.player_hunted,
|
||||
"player_current_dead": self.player_current_dead,
|
||||
"witch_poison_left": self.witch_poison_left,
|
||||
"witch_antidote_left": self.witch_antidote_left,
|
||||
"winner": self.winner,
|
||||
"win_reason": self.win_reason,
|
||||
}
|
||||
|
||||
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
action_type = action.action_type
|
||||
player_name = action.player_name
|
||||
target_player_name = action.target_player_name
|
||||
if action_type == EnvActionType.WOLF_KILL:
|
||||
self.wolf_kill_someone(wolf_name=player_name, player_name=target_player_name)
|
||||
elif action_type == EnvActionType.VOTE_KILL:
|
||||
self.vote_kill_someone(voter_name=player_name, player_name=target_player_name)
|
||||
elif action_type == EnvActionType.WITCH_POISON:
|
||||
self.witch_poison_someone(witch_name=player_name, player_name=target_player_name)
|
||||
elif action_type == EnvActionType.WITCH_SAVE:
|
||||
self.witch_save_someone(witch_name=player_name, player_name=target_player_name)
|
||||
elif action_type == EnvActionType.GUARD_PROTECT:
|
||||
self.guard_protect_someone(guard_name=player_name, player_name=target_player_name)
|
||||
elif action_type == EnvActionType.PROGRESS_STEP:
|
||||
self.progress_step()
|
||||
elif action_type == EnvActionType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"not supported action_type: {action_type}")
|
||||
|
||||
self.update_game_states()
|
||||
terminated = self._check_game_finish()
|
||||
obs = self._get_obs()
|
||||
return obs, 1.0, terminated, False, {}
|
||||
|
||||
def _check_game_finish(self) -> bool:
|
||||
"""return True if game finished else False"""
|
||||
# game's termination condition
|
||||
terminated = False
|
||||
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
|
||||
living_villagers = [p for p in self.villager_players if p in self.living_players]
|
||||
living_special_roles = [p for p in self.special_role_players if p in self.living_players]
|
||||
if not living_werewolf:
|
||||
self.winner = "good guys"
|
||||
self.win_reason = "werewolves all dead"
|
||||
terminated = True
|
||||
elif not living_villagers or not living_special_roles:
|
||||
self.winner = "werewolf"
|
||||
self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead"
|
||||
terminated = True
|
||||
return terminated
|
||||
|
||||
@property
|
||||
def living_players(self) -> list[str]:
|
||||
player_names = []
|
||||
|
|
@ -146,12 +134,12 @@ class WerewolfExtEnv(ExtEnv):
|
|||
|
||||
@property
|
||||
def werewolf_players(self) -> list[str]:
|
||||
player_names = self._role_type_players(role_type="Werewolf")
|
||||
player_names = self._role_type_players(role_type=RoleType.WEREWOLF.value)
|
||||
return player_names
|
||||
|
||||
@property
|
||||
def villager_players(self) -> list[str]:
|
||||
player_names = self._role_type_players(role_type="Villager")
|
||||
player_names = self._role_type_players(role_type=RoleType.VILLAGER.value)
|
||||
return player_names
|
||||
|
||||
def _init_players_state(self, players: list["Role"]):
|
||||
|
|
@ -178,14 +166,14 @@ class WerewolfExtEnv(ExtEnv):
|
|||
"""init players using different roles' num"""
|
||||
role_objs = []
|
||||
for role_obj in role_uniq_objs:
|
||||
if str(role_obj) == "Villager":
|
||||
if RoleType.VILLAGER.value in str(role_obj):
|
||||
role_objs.extend([role_obj] * num_villager)
|
||||
elif str(role_obj) == "Werewolf":
|
||||
elif RoleType.WEREWOLF.value in str(role_obj):
|
||||
role_objs.extend([role_obj] * num_werewolf)
|
||||
else:
|
||||
role_objs.append(role_obj)
|
||||
if shuffle:
|
||||
random.shuffle(len(role_objs))
|
||||
random.shuffle(role_objs)
|
||||
if add_human:
|
||||
assigned_role_idx = random.randint(0, len(role_objs) - 1)
|
||||
assigned_role = role_objs[assigned_role_idx]
|
||||
|
|
@ -218,10 +206,12 @@ class WerewolfExtEnv(ExtEnv):
|
|||
roletype_state = self.players_state[player_name]
|
||||
self.players_state[player_name] = (roletype_state[0], state)
|
||||
|
||||
def _check_valid_role(self, player: "Role", role_type: str) -> bool:
|
||||
return True if role_type in str(player) else False
|
||||
def _check_valid_role(self, player_name: str, role_type: str) -> bool:
|
||||
roletype_state = self.players_state.get(player_name)
|
||||
return True if roletype_state and role_type in roletype_state[0] else False
|
||||
|
||||
def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool:
|
||||
"""to check if can do the operation to the player"""
|
||||
step_idx = self.step_idx % self.per_round_steps
|
||||
if particular_step > 0 and step_idx != particular_step: # step no
|
||||
# particular_step = 18, not daytime vote time, ignore
|
||||
|
|
@ -238,6 +228,10 @@ class WerewolfExtEnv(ExtEnv):
|
|||
self.step_idx += 1
|
||||
return instruction
|
||||
|
||||
@mark_as_writeable
|
||||
def progress_step(self):
|
||||
self.step_idx += 1
|
||||
|
||||
@mark_as_readable
|
||||
def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]:
|
||||
players_state = {
|
||||
|
|
@ -248,57 +242,72 @@ class WerewolfExtEnv(ExtEnv):
|
|||
return players_state
|
||||
|
||||
@mark_as_writeable
|
||||
def vote_kill_someone(self, voteer: "Role", player_name: str = None):
|
||||
def vote_kill_someone(self, voter_name: str, player_name: str = None):
|
||||
"""player vote result at daytime
|
||||
player_name: if it's None, regard as abstaining from voting
|
||||
"""
|
||||
if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no
|
||||
if not self._check_player_continue(voter_name, particular_step=18): # 18=step no
|
||||
return
|
||||
|
||||
self.round_votes[voteer.name] = player_name
|
||||
self.round_votes[voter_name] = player_name
|
||||
# check if all living players finish voting, then get the dead one
|
||||
if list(self.round_votes.keys()) == self.living_players:
|
||||
voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first
|
||||
voted_all = [item for item in voted_all if item]
|
||||
self.player_current_dead = Counter(voted_all).most_common()[0][0]
|
||||
self._update_players_state([self.player_current_dead])
|
||||
self.player_current_dead = [Counter(voted_all).most_common()[0][0]]
|
||||
self._update_players_state(self.player_current_dead)
|
||||
|
||||
@mark_as_writeable
|
||||
def wolf_kill_someone(self, wolf: "Role", player_name: str):
|
||||
if not self._check_valid_role(wolf, "Werewolf"):
|
||||
def wolf_kill_someone(self, wolf_name: str, player_name: str):
|
||||
if not self._check_valid_role(wolf_name, RoleType.WEREWOLF.value):
|
||||
return
|
||||
if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no
|
||||
if not self._check_player_continue(wolf_name, particular_step=6): # 5=step no
|
||||
return
|
||||
|
||||
self.round_hunts[wolf.name] = player_name
|
||||
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
|
||||
self.round_hunts[wolf_name] = player_name
|
||||
# living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
|
||||
# check if all living wolfs finish hunting, then get the hunted one
|
||||
if list(self.round_hunts.keys()) == living_werewolf:
|
||||
hunted_all = list(self.round_hunts.values())
|
||||
self.player_hunted = Counter(hunted_all).most_common()[0][0]
|
||||
# if list(self.round_hunts.keys()) == living_werewolf:
|
||||
# hunted_all = list(self.round_hunts.values())
|
||||
# self.player_hunted = Counter(hunted_all).most_common()[0][0]
|
||||
self.player_hunted = player_name
|
||||
|
||||
@mark_as_writeable
|
||||
def witch_poison_someone(self, witch: "Role", player_name: str = None):
|
||||
if not self._check_valid_role(witch, "Witch"):
|
||||
def _witch_poison_or_save_someone(
|
||||
self, witch_name: str, player_name: str = None, state: RoleState = RoleState.POISONED
|
||||
):
|
||||
if not self._check_valid_role(witch_name, RoleType.WITCH.value):
|
||||
return
|
||||
if not self._check_player_continue(player_name):
|
||||
return
|
||||
|
||||
self._update_players_state([player_name], RoleState.POISONED)
|
||||
self.player_poisoned = player_name
|
||||
assert state in [RoleState.POISONED, RoleState.SAVED]
|
||||
self._update_players_state([player_name], state)
|
||||
if state == RoleState.POISONED:
|
||||
self.player_poisoned = player_name
|
||||
self.witch_poison_left -= 1
|
||||
else:
|
||||
# self.player_protected = player_name
|
||||
self.is_hunted_player_saved = True
|
||||
self.witch_antidote_left -= 1
|
||||
|
||||
@mark_as_writeable
|
||||
def witch_save_someone(self, witch: "Role", player_name: str = None):
|
||||
if not self._check_valid_role(witch, "Witch"):
|
||||
def witch_poison_someone(self, witch_name: str, player_name: str = None):
|
||||
self._witch_poison_or_save_someone(witch_name, player_name, RoleState.POISONED)
|
||||
|
||||
@mark_as_writeable
|
||||
def witch_save_someone(self, witch_name: str, player_name: str = None):
|
||||
self._witch_poison_or_save_someone(witch_name, player_name, RoleState.SAVED)
|
||||
|
||||
@mark_as_writeable
|
||||
def guard_protect_someone(self, guard_name: str, player_name: str = None):
|
||||
if not self._check_valid_role(guard_name, RoleType.GUARD.value):
|
||||
return
|
||||
if not self._check_player_continue(player_name):
|
||||
return
|
||||
|
||||
self._update_players_state([player_name], RoleState.SAVED)
|
||||
self.player_protected = player_name
|
||||
|
||||
@mark_as_writeable
|
||||
def update_game_states(self, memories: list):
|
||||
def update_game_states(self):
|
||||
step_idx = self.step_idx % self.per_round_steps
|
||||
if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx:
|
||||
return
|
||||
|
|
@ -314,22 +323,12 @@ class WerewolfExtEnv(ExtEnv):
|
|||
if self.player_poisoned:
|
||||
self.player_current_dead.append(self.player_poisoned)
|
||||
|
||||
self._update_players_state([self.player_current_dead])
|
||||
self._update_players_state(self.player_current_dead)
|
||||
# reset
|
||||
self.player_hunted = None
|
||||
self.player_protected = None
|
||||
self.is_hunted_player_saved = False
|
||||
self.player_poisoned = None
|
||||
|
||||
# game's termination condition
|
||||
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
|
||||
living_villagers = [p for p in self.villager_players if p in self.living_players]
|
||||
living_special_roles = [p for p in self.special_role_players if p in self.living_players]
|
||||
if not living_werewolf:
|
||||
self.winner = "good guys"
|
||||
self.win_reason = "werewolves all dead"
|
||||
elif not living_villagers or not living_special_roles:
|
||||
self.winner = "werewolf"
|
||||
self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead"
|
||||
if self.winner is not None:
|
||||
self._record_all_experiences() # TODO
|
||||
elif step_idx == 18:
|
||||
# updated use vote_kill_someone
|
||||
pass
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG Werewolf Env
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class WerewolfEnv(Environment, WerewolfExtEnv):
|
||||
timestamp: int = Field(default=0)
|
||||
|
||||
def publish_message(self, message: Message, add_timestamp: bool = True):
|
||||
"""Post information to the current environment"""
|
||||
logger.debug(f"publish_message: {message.dump()}")
|
||||
if add_timestamp:
|
||||
# Because the content of the message may be repeated, for example, killing the same person in two nights
|
||||
# Therefore, a unique timestamp prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory.
|
||||
message.content = f"{self.timestamp} | " + message.content
|
||||
self.memory.add(message)
|
||||
self.history += f"\n{message}"
|
||||
|
||||
async def run(self, k=1):
|
||||
"""Process all Role runs by order"""
|
||||
for _ in range(k):
|
||||
for role in self.roles.values():
|
||||
await role.run()
|
||||
self.timestamp += 1
|
||||
3
metagpt/ext/__init__.py
Normal file
3
metagpt/ext/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
118
metagpt/ext/android_assistant/README.md
Normal file
118
metagpt/ext/android_assistant/README.md
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
# MetaGPT Android Assistant
|
||||
|
||||
The MetaGPT Android Assistant is an intelligent assistance tool driven by a multi-modal large language model based on the advanced MetaGPT framework. It has the ability to self-learn, mastering users' daily usage patterns through learning, and can automatically complete various application operations according to user instructions, achieving comprehensive liberation of users' hands.
|
||||
Next, we will introduce the functions of the MetaGPT Android Assistant and how to use it.
|
||||
|
||||
## Features
|
||||
|
||||
The operation of the MetaGPT Android Assistant mainly includes two stages: learning and automatic execution. Below, we introduce the specific features of the MetaGPT Android Assistant from these two stages.
|
||||
|
||||
### Learning Stage
|
||||
|
||||
By learning from human demonstrations or exploring apps based on human instructions, the MetaGPT Android Assistant can learn the functionality of apps, generate corresponding operation documents for use in the subsequent "automatic execution" stage. Approximately 20 rounds of exploration for any given task objective can significantly improve performance.
|
||||
|
||||
By setting the `stage` to `learn`, you can ask the Android Assistant to enter the learning stage. By setting the `mode` to `auto`, you can instruct the Android Assistant to learn through automatic exploration; by setting the mode to manual, you can instruct the Android Assistant to learn through human manual demonstration. In the usage section, we provide detailed explanations of the script parameters. You can try experimenting with automatic exploration and manual demonstration modes on the "Messenger" app with the following commands:
|
||||
|
||||
```bash
|
||||
cd examples/android_assistant
|
||||
python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "auto or manual" --app-name "Messenger"
|
||||
```
|
||||
|
||||
#### Learning Based on Human Demonstration
|
||||
When asking the Android Assistant to perform self-exploration during the learning stage, you can free your hands. However, when instructing it to learn according to your commands, you need to follow the instructions in the terminal for the Android Assistant to accurately learn your operation methods.
|
||||
A possible example is as follows:
|
||||
|
||||
```bash
|
||||
cd examples/android_assistant
|
||||
python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "manual" --app-name "Messenger"
|
||||
```
|
||||
|
||||
After running this command, you will first see a screenshot of an Android screen that has been marked at various interactive locations, as shown in the figure below:
|
||||
|
||||
<img src="./resources/manual_example.png" width = 30%>
|
||||
|
||||
After remembering the location where you want to operate, a request similar to the one below will be output in the terminal. Reply to it and thereby direct the Android assistant to learn your demonstration action:
|
||||
|
||||
```bash
|
||||
| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11:
|
||||
user_input: 8
|
||||
| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen:
|
||||
tap, text, long_press, swipe, stop
|
||||
user_input: tap
|
||||
```
|
||||
|
||||
### Automatic Execution Stage
|
||||
After the Android Assistant completes the learning stage, you can command it to complete tasks on the phone through text descriptions. By configuring the operation documents from the self-learning stage, the Android Assistant has richer prior knowledge, and its execution capabilities are further enhanced.
|
||||
You can instruct the Android Assistant to send messages in the "Messenger" app with the following command:
|
||||
```bash
|
||||
python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "act" --mode "auto or manual" --app-name "Messenger"
|
||||
```
|
||||
Specifically, by selecting `auto` for `mode`, the Android assistant will employ the operational records compiled through self-exploration. Alternatively, if `manual` is chosen as the `mode`, the Android assistant will leverage the operation manuals accrued from learning via human demonstration.
|
||||
|
||||
## Installation
|
||||
To use the Android Assistant, you first need to meet the following conditions:
|
||||
1. Complete the installation of the MetaGPT environment.
|
||||
2. Install [Android Debug Bridge (ADB)](https://developer.android.com/tools/adb?hl=zh-cn) on your PC, which enables interaction between your PC and Android devices.
|
||||
3. Install Android Studio and within it, install the Android emulator to provide an environment for the Android Assistant to learn and execute. For information on how to install the Android emulator, refer to [Quick Installation of Android Studio & Emulator](https://docs.expo.dev/workflow/android-studio-emulator/).
|
||||
4. (Optional) Connect your Android device to the USB port of your PC, which can also provide an environment for the Android Assistant to learn and execute.
|
||||
|
||||
Note ⚠️: When operating with the Android emulator, the emulator model we use is Medium Phone, which is recommended for first-time users to complete the operation.
|
||||
|
||||
After completing these operations, you can enter the following command to check if ADB is installed successfully and if the Android device is connected:
|
||||
```bash
|
||||
adb devices
|
||||
```
|
||||
|
||||
## Usage
|
||||
The MetaGPT Android Assistant is designed within the MetaGPT framework as a collection of Roles and multiple Actions. You can run it by executing the `run_assistant.py` script. The specific parameter description of this script is as follows:
|
||||
```text
|
||||
Usage: run_assistant.py [OPTIONS] TASK_DESC
|
||||
|
||||
Run a Android Assistant
|
||||
|
||||
Arguments:
|
||||
TASK_DESC the task description you want the android assistant to learn or
|
||||
act [required]
|
||||
|
||||
Options:
|
||||
--n-round INTEGER The max round to do an app operation task.
|
||||
[default: 20]
|
||||
--stage TEXT stage: learn / act [default: learn]
|
||||
--mode TEXT mode: auto / manual , when state=learn
|
||||
[default: auto]
|
||||
--app-name TEXT the name of app you want to run [default:
|
||||
demo]
|
||||
--investment FLOAT Dollar amount to invest in the AI company.
|
||||
[default: 5.0]
|
||||
--refine-doc / --no-refine-doc Refine existing operation docs based on the
|
||||
latest observation if True. [default: no-
|
||||
refine-doc]
|
||||
--min-dist INTEGER The minimum distance between elements to
|
||||
prevent overlapping during the labeling
|
||||
process. [default: 30]
|
||||
--android-screenshot-dir TEXT The path to store screenshots on android
|
||||
device. Make sure it exists. [default:
|
||||
/sdcard/Pictures/Screenshots]
|
||||
--android-xml-dir TEXT The path to store xml files for determining
|
||||
UI elements localtion. Make sure it exists.
|
||||
[default: /sdcard]
|
||||
--device-id TEXT The Android device_id [default:
|
||||
emulator-5554]
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
The MetaGPT Android Assistant has referenced some ideas and code from the [AppAgent](https://github.com/mnotgod96/AppAgent) project. We thank the developers of the Appagent project.
|
||||
|
||||
### Citation
|
||||
|
||||
```bib
|
||||
@misc{yang2023appagent,
|
||||
title={AppAgent: Multimodal Agents as Smartphone Users},
|
||||
author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu},
|
||||
year={2023},
|
||||
eprint={2312.13771},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
113
metagpt/ext/android_assistant/README_CN.md
Normal file
113
metagpt/ext/android_assistant/README_CN.md
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# MetaGPT 安卓助理
|
||||
|
||||
MetaGPT安卓助理是一款依托于先进的MetaGPT框架构建的多模态大语言模型驱动的智能辅助工具。
|
||||
它具备自我学习的能力,能够通过学习掌握用户的日常使用方式,同时能够根据用户的指令自动完成各类应用程序的操作任务,实现了用户双手的全面解放。
|
||||
接下来,我们将介绍MetaGPT安卓助理的功能以及如何使用它。
|
||||
|
||||
## 功能
|
||||
|
||||
MetaGPT 安卓助理的执行主要包含两个阶段,分别为自我学习与自动执行。下面,我们将从这两个阶段介绍MetaGPT 安卓助理的具体功能。
|
||||
|
||||
### 自我学习阶段
|
||||
|
||||
通过学习人类演示或基于人类指令对app进行探索,MetaGPT安卓助理可以对app的功能进行学习,生成相应的操作文档,为后续的“自动执行”阶段使用。对于任何给定的任务目标,进行约20轮的探索可以显著提高性能。
|
||||
|
||||
通过设定`stage`为`learn`可要求安卓助理进入自我学习阶段。通过设定`mode`为`auto`,可要求安卓助理通过自动探索学习,通过设定`mode`为`manual`,可要求安卓助理通过人类手动演示学习。在使用章节,我们对脚本的参数进行了详细的说明。
|
||||
您可以尝试对“Messenger”应用程序进行自动探索和手动演示模式的实验,具体命令如下:
|
||||
|
||||
```bash
|
||||
cd examples/android_assistant
|
||||
python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "auto or manual" --app-name "Messenger"
|
||||
```
|
||||
|
||||
#### 基于人类演示的学习
|
||||
在要求安卓助理在自我学习阶段执行自我探索时,您可以解放您的双手,但在要求他根据您的指令进行学习时,你需要根据终端中的指令进行输入,以便安卓助理能够准确地学习您的操作方式。
|
||||
一个可能的例子如下:
|
||||
|
||||
```bash
|
||||
cd examples/android_assistant
|
||||
python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "manual" --app-name "Messenger"
|
||||
```
|
||||
|
||||
在运行这一指令后,你将首先看到一个在各个可交互的位置进行了标记的安卓屏幕的截图,如下图:
|
||||
|
||||
<img src="./resources/manual_example.png" width = 30%>
|
||||
|
||||
在记住你要操作的位置之后,终端中将会输出与下面类似的要求,回复它,进而指挥安卓助理学习你的演示行为:
|
||||
|
||||
```bash
|
||||
| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11:
|
||||
user_input: 8
|
||||
| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen:
|
||||
tap, text, long_press, swipe, stop
|
||||
user_input: tap
|
||||
```
|
||||
### 自动执行阶段
|
||||
在安卓助理完成了自我学习阶段之后,您可以通过文本描述的方式,指挥安卓助理在手机中完成任务。通过为其配置自我学习阶段的操作文档,安卓助理具备了更丰富的前置知识,执行能力进一步得到提升。
|
||||
你可以通过以下指令,指挥安卓助理在“Messenger”应用中发送信息:
|
||||
```bash
|
||||
python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "act" --mode "auto or manual" --app-name "Messenger"
|
||||
```
|
||||
其中,`mode`选择`auto`,安卓助理将使用自我探索中积累的操作文档;`mode`选择`manual`,安卓助理将使用人类演示学习中积累的操作文档。
|
||||
|
||||
## 安装
|
||||
为了使用安卓助理,你首先需要满足以下条件:
|
||||
1. 完成MetaGPT环境的安装
|
||||
2. 在你的PC上安装[Android Debug Bridge(ADB)](https://developer.android.com/tools/adb?hl=zh-cn),ADB可以使你的PC与安卓设备进行交互。
|
||||
3. 安装Android Studio,在其中安装Android模拟器,以为安卓助手提供学习与执行的环境。关于如何安装Android模拟器,可以参考[快速安装Android Studio & Emulator](https://dev.weixin.qq.com/docs/framework/dev/framework/env/android-simulator.html)。
|
||||
4. (Optional) 将你的安卓设备连接到PC的USB端口上,这同样可以为安卓助手提供学习与执行的环境。
|
||||
|
||||
注意 ⚠️:在使用Android模拟器进行操作时,我们使用的模拟器型号为Medium Phone,建议第一次尝试此类应用的用户使用这一型号完成操作。
|
||||
|
||||
在完成这一系列操作之后,你可以输入以下命令检查ADB是否安装成功,以及安卓设备是否连接
|
||||
```bash
|
||||
adb devices
|
||||
```
|
||||
## 使用
|
||||
MetaGPT 安卓助理在MetaGPT框架中被设计为一个`Role`与多个`Action`的集合,你可以通过运行`run_assistant.py`脚本来运行它。这一脚本具体的参数说明如下:
|
||||
```text
|
||||
用法:run_assistant.py [选项] 任务描述
|
||||
|
||||
运行一个安卓助手
|
||||
|
||||
参数:
|
||||
TASK_DESC 你希望安卓助手学习或执行的任务描述
|
||||
[必需]
|
||||
|
||||
选项:
|
||||
--n-round 整数 执行应用程序操作任务的最大轮数。
|
||||
[默认值:20]
|
||||
--stage 文本 阶段:learn/act [默认值:learn]
|
||||
--mode 文本 模式:auto/manual,当状态=learn时 [默认值:auto]
|
||||
--app-name 文本 你想要运行的应用程序名称 [默认值:
|
||||
演示]
|
||||
--investment 浮点数 投资于人工智能公司的美元金额。
|
||||
[默认值:5.0]
|
||||
--refine-doc / --no-refine-doc 如果为真,则根据最新的观察结果优化现有操作文档。
|
||||
[默认值:--no-refine-doc]
|
||||
--min-dist 整数 在标记过程中防止元素重叠的最小元素间距。
|
||||
[默认值:30]
|
||||
--android-screenshot-dir 文本 在安卓设备上存储截图的路径。确保其存在。
|
||||
[默认值:/sdcard/Pictures/Screenshots]
|
||||
--android-xml-dir 文本 存储用于确定UI元素位置的XML文件的路径。
|
||||
确保其存在。[默认值:/sdcard]
|
||||
--device-id 文本 安卓device_id [默认值:
|
||||
模拟器-5554]
|
||||
--help 显示此信息并退出。
|
||||
```
|
||||
|
||||
## 致谢
|
||||
MetaGPT 安卓助理参考了 [AppAgent](https://github.com/mnotgod96/AppAgent) 项目的部分思路与代码,感谢 Appagent 项目的开发者们。
|
||||
|
||||
### 引用
|
||||
|
||||
```bib
|
||||
@misc{yang2023appagent,
|
||||
title={AppAgent: Multimodal Agents as Smartphone Users},
|
||||
author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu},
|
||||
year={2023},
|
||||
eprint={2312.13771},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
3
metagpt/ext/android_assistant/__init__.py
Normal file
3
metagpt/ext/android_assistant/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
metagpt/ext/android_assistant/actions/__init__.py
Normal file
3
metagpt/ext/android_assistant/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
168
metagpt/ext/android_assistant/actions/manual_record.py
Normal file
168
metagpt/ext/android_assistant/actions/manual_record.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : manual record user interaction in stage=learn & mode=manual, LIKE scripts/step_recorder.py
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.schema import (
|
||||
ActionOp,
|
||||
AndroidActionOutput,
|
||||
RunState,
|
||||
SwipeOp,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.utils import (
|
||||
draw_bbox_multi,
|
||||
elem_list_from_xml_tree,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class ManualRecord(Action):
|
||||
"""do a human operation on the screen with human input"""
|
||||
|
||||
name: str = "ManualRecord"
|
||||
|
||||
useless_list: list[str] = [] # store useless elements uid
|
||||
record_path: Path = ""
|
||||
task_desc_path: Path = ""
|
||||
screenshot_before_path: Path = ""
|
||||
screenshot_after_path: Path = ""
|
||||
xml_path: Path = ""
|
||||
|
||||
async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv):
|
||||
self.record_path = Path(task_dir) / "record.txt"
|
||||
self.task_desc_path = Path(task_dir) / "task_desc.txt"
|
||||
self.screenshot_before_path = Path(task_dir) / "raw_screenshots"
|
||||
self.screenshot_after_path = Path(task_dir) / "labeled_screenshots"
|
||||
self.xml_path = Path(task_dir) / "xml"
|
||||
for path in [self.screenshot_before_path, self.screenshot_after_path, self.xml_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.record_path.write_text("")
|
||||
record_file = open(self.record_path, "w")
|
||||
self.task_desc_path.write_text(task_desc)
|
||||
|
||||
step = 0
|
||||
extra_config = config.extra
|
||||
while True:
|
||||
step += 1
|
||||
screenshot_path: Path = env.observe(
|
||||
EnvObsParams(
|
||||
obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{step}", local_save_dir=self.screenshot_before_path
|
||||
)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{step}", local_save_dir=self.xml_path)
|
||||
)
|
||||
if not screenshot_path.exists() or not xml_path.exists():
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30))
|
||||
|
||||
screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png")
|
||||
labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list)
|
||||
|
||||
cv2.namedWindow("image", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("image", labeled_img)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
user_input = "xxx"
|
||||
logger.info(
|
||||
"Choose one of the following actions you want to perform on the current screen:\n"
|
||||
"tap, text, long_press, swipe, stop"
|
||||
)
|
||||
|
||||
while (
|
||||
user_input.lower() != ActionOp.TAP.value
|
||||
and user_input.lower() != ActionOp.TEXT.value
|
||||
and user_input.lower() != ActionOp.LONG_PRESS.value
|
||||
and user_input.lower() != ActionOp.SWIPE.value
|
||||
and user_input.lower() != ActionOp.STOP.value
|
||||
):
|
||||
user_input = input("user_input: ")
|
||||
|
||||
if user_input.lower() == ActionOp.TAP.value:
|
||||
logger.info(f"Which element do you want to tap? Choose a numeric tag from 1 to {len(elem_list)}:")
|
||||
user_input = "xxx"
|
||||
while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1:
|
||||
user_input = input("user_input: ")
|
||||
tl, br = elem_list[int(user_input) - 1].bbox
|
||||
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
|
||||
action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y))
|
||||
log_str = f"tap({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n"
|
||||
elif user_input.lower() == ActionOp.TEXT.value:
|
||||
logger.info(
|
||||
f"Which element do you want to input the text string? Choose a numeric tag from 1 to "
|
||||
f"{len(elem_list)}:"
|
||||
)
|
||||
input_area = "xxx"
|
||||
while not input_area.isnumeric() or int(input_area) > len(elem_list) or int(input_area) < 1:
|
||||
input_area = input("user_input: ")
|
||||
logger.info("Enter your input text below:")
|
||||
user_input = ""
|
||||
while not user_input:
|
||||
user_input = input("user_input: ")
|
||||
action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=user_input)
|
||||
log_str = f"text({input_area}:sep:'{user_input}'):::{elem_list[int(input_area) - 1].uid}\n"
|
||||
elif user_input.lower() == ActionOp.LONG_PRESS.value:
|
||||
logger.info(
|
||||
f"Which element do you want to long press? Choose a numeric tag from 1 to {len(elem_list)}:"
|
||||
)
|
||||
user_input = "xxx"
|
||||
while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1:
|
||||
user_input = input("user_input: ")
|
||||
tl, br = elem_list[int(user_input) - 1].bbox
|
||||
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
|
||||
action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y))
|
||||
log_str = f"long_press({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n"
|
||||
elif user_input.lower() == ActionOp.SWIPE.value:
|
||||
logger.info(
|
||||
"What is the direction of your swipe? Choose one from the following options:\n"
|
||||
"up, down, left, right"
|
||||
)
|
||||
user_input = ""
|
||||
while (
|
||||
user_input != SwipeOp.UP.value
|
||||
and user_input != SwipeOp.DOWN.value
|
||||
and user_input != SwipeOp.LEFT.value
|
||||
and user_input != SwipeOp.RIGHT.value
|
||||
):
|
||||
user_input = input("user_input: ")
|
||||
swipe_dir = user_input
|
||||
logger.info(f"Which element do you want to swipe? Choose a numeric tag from 1 to {len(elem_list)}:")
|
||||
while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1:
|
||||
user_input = input("user_input: ")
|
||||
tl, br = elem_list[int(user_input) - 1].bbox
|
||||
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
|
||||
|
||||
action = EnvAction(action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=swipe_dir)
|
||||
log_str = f"swipe({int(user_input)}:sep:{swipe_dir}):::{elem_list[int(user_input) - 1].uid}\n"
|
||||
elif user_input.lower() == ActionOp.STOP.value:
|
||||
record_file.write("stop\n")
|
||||
record_file.close()
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
obs, _, _, _, info = env.step(action)
|
||||
action_res = info["res"]
|
||||
if action_res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
record_file.write(log_str)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
return AndroidActionOutput(action_state=RunState.SUCCESS)
|
||||
137
metagpt/ext/android_assistant/actions/parse_record.py
Normal file
137
metagpt/ext/android_assistant/actions/parse_record.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual,
|
||||
# LIKE scripts/document_generation.py
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.ext.android_assistant.actions.parse_record_an import RECORD_PARSE_NODE
|
||||
from metagpt.ext.android_assistant.prompts.operation_prompt import (
|
||||
long_press_doc_template,
|
||||
refine_doc_suffix,
|
||||
swipe_doc_template,
|
||||
tap_doc_template,
|
||||
text_doc_template,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.schema import (
|
||||
ActionOp,
|
||||
AndroidActionOutput,
|
||||
RecordLogItem,
|
||||
RunState,
|
||||
SwipeOp,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
class ParseRecord(Action):
|
||||
name: str = "ParseRecord"
|
||||
record_path: Path = ""
|
||||
task_desc_path: Path = ""
|
||||
screenshot_before_path: Path = ""
|
||||
screenshot_after_path: Path = ""
|
||||
|
||||
async def run(self, task_dir: Path, docs_dir: Path):
|
||||
doc_count = 0
|
||||
self.record_path = Path(task_dir) / "record.txt"
|
||||
self.task_desc_path = Path(task_dir) / "task_desc.txt"
|
||||
self.screenshot_before_path = Path(task_dir) / "raw_screenshots"
|
||||
self.screenshot_after_path = Path(task_dir) / "labeled_screenshots"
|
||||
for path in [self.screenshot_before_path, self.screenshot_after_path]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
task_desc = self.task_desc_path.read_text()
|
||||
extra_config = config.extra
|
||||
|
||||
with open(self.record_path, "r") as record_file:
|
||||
record_step_count = len(record_file.readlines()) - 1
|
||||
record_file.seek(0)
|
||||
for step in range(1, record_step_count + 1):
|
||||
img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png"))
|
||||
img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png"))
|
||||
rec = record_file.readline().strip()
|
||||
action, resource_id = rec.split(":::")
|
||||
action_type = action.split("(")[0]
|
||||
# 构建Prompt
|
||||
action_param = re.findall(r"\((.*?)\)", action)[0]
|
||||
if action_type == ActionOp.TAP.value:
|
||||
prompt_template = tap_doc_template
|
||||
context = prompt_template.format(ui_element=action_param)
|
||||
elif action_type == ActionOp.TEXT.value:
|
||||
input_area, input_text = action_param.split(":sep:")
|
||||
prompt_template = text_doc_template
|
||||
context = prompt_template.format(ui_element=input_area)
|
||||
elif action_type == ActionOp.LONG_PRESS.value:
|
||||
prompt_template = long_press_doc_template
|
||||
context = prompt_template.format(ui_element=action_param)
|
||||
elif action_type == ActionOp.SWIPE.value:
|
||||
swipe_area, swipe_dir = action_param.split(":sep:")
|
||||
if swipe_dir == SwipeOp.UP.value or swipe_dir == SwipeOp.DOWN.value:
|
||||
action_type = ActionOp.VERTICAL_SWIPE.value
|
||||
elif swipe_dir == SwipeOp.LEFT.value or swipe_dir == SwipeOp.RIGHT.value:
|
||||
action_type = ActionOp.HORIZONTAL_SWIPE.value
|
||||
prompt_template = swipe_doc_template
|
||||
context = prompt_template.format(swipe_dir=swipe_dir, ui_element=swipe_area)
|
||||
else:
|
||||
break
|
||||
context = context.format(task_desc=task_desc)
|
||||
|
||||
doc_name = resource_id + ".txt"
|
||||
doc_path = docs_dir.joinpath(doc_name)
|
||||
|
||||
if doc_path.exists():
|
||||
try:
|
||||
doc_content = ast.literal_eval(doc_path.read_text())
|
||||
except Exception as exp:
|
||||
logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}")
|
||||
continue
|
||||
|
||||
if doc_content[action_type]:
|
||||
if extra_config.get("doc_refine", False):
|
||||
refine_context = refine_doc_suffix.format(old_doc=doc_content[action_type])
|
||||
context += refine_context
|
||||
logger.info(
|
||||
f"Documentation for the element {resource_id} already exists. The doc will be "
|
||||
f"refined based on the latest demo."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Documentation for the element {resource_id} already exists. Turn on DOC_REFINE "
|
||||
f"in the config file if needed."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
doc_content = {"tap": "", "text": "", "v_swipe": "", "h_swipe": "", "long_press": ""}
|
||||
|
||||
logger.info(f"Waiting for GPT-4V to generate documentation for the element {resource_id}")
|
||||
node = await RECORD_PARSE_NODE.fill(
|
||||
context=context, llm=self.llm, images=[img_before_base64, img_after_base64]
|
||||
)
|
||||
if "error" in node.content:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
log_path = task_dir.joinpath("log_parse_record.txt")
|
||||
prompt = node.compile(context=context, schema="json", mode="auto")
|
||||
msg = node.content
|
||||
doc_content[action_type] = msg
|
||||
|
||||
with open(log_path, "a") as logfile:
|
||||
log_item = RecordLogItem(
|
||||
step=step,
|
||||
prompt=prompt,
|
||||
image_before=img_before_base64,
|
||||
image_after=img_after_base64,
|
||||
response=node.content,
|
||||
)
|
||||
logfile.write(log_item.model_dump_json() + "\n")
|
||||
with open(doc_path, "w") as outfile:
|
||||
outfile.write(str(doc_content))
|
||||
doc_count += 1
|
||||
logger.info(f"Documentation generated and saved to {doc_path}")
|
||||
|
||||
logger.info(f"Documentation generation phase completed. {doc_count} docs generated.")
|
||||
|
||||
return AndroidActionOutput(action_state=RunState.FINISH)
|
||||
32
metagpt/ext/android_assistant/actions/parse_record_an.py
Normal file
32
metagpt/ext/android_assistant/actions/parse_record_an.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the ActionNode to parse record
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
OBSERVATION = ActionNode(
|
||||
key="Observation",
|
||||
expected_type=str,
|
||||
instruction="Provide a description of your observations of the two images. "
|
||||
"Subsequently, delineate the distinctions between the first image and the second one.",
|
||||
example="",
|
||||
)
|
||||
|
||||
THOUGHT = ActionNode(
|
||||
key="Thought",
|
||||
expected_type=str,
|
||||
instruction="Consider the impact of Action acting on UI elements.",
|
||||
example="",
|
||||
)
|
||||
|
||||
DESCRIPTION = ActionNode(
|
||||
key="Description",
|
||||
expected_type=str,
|
||||
instruction="Describe the functionality of the UI element concisely in one or two sentences Do not include "
|
||||
"the numeric tag in your description",
|
||||
example="",
|
||||
)
|
||||
|
||||
NODES = [OBSERVATION, THOUGHT, DESCRIPTION]
|
||||
|
||||
RECORD_PARSE_NODE = ActionNode.from_children("RecordParse", NODES)
|
||||
204
metagpt/ext/android_assistant/actions/screenshot_parse.py
Normal file
204
metagpt/ext/android_assistant/actions/screenshot_parse.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : LIKE scripts/task_executor.py in stage=act
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.ext.android_assistant.actions.screenshot_parse_an import (
|
||||
SCREENSHOT_PARSE_NODE,
|
||||
)
|
||||
from metagpt.ext.android_assistant.prompts.assistant_prompt import (
|
||||
screenshot_parse_template,
|
||||
screenshot_parse_with_grid_template,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.schema import (
|
||||
AndroidActionOutput,
|
||||
AndroidElement,
|
||||
GridOpParam,
|
||||
LongPressGridOpParam,
|
||||
LongPressOpParam,
|
||||
OpLogItem,
|
||||
RunState,
|
||||
SwipeGridOpParam,
|
||||
SwipeOpParam,
|
||||
TapGridOpParam,
|
||||
TapOpParam,
|
||||
TextOpParam,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.utils import (
|
||||
area_to_xy,
|
||||
draw_bbox_multi,
|
||||
draw_grid,
|
||||
elem_bbox_to_xy,
|
||||
screenshot_parse_extract,
|
||||
traverse_xml_tree,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
class ScreenshotParse(Action):
|
||||
name: str = "ScreenshotParse"
|
||||
|
||||
def _makeup_ui_document(self, elem_list: list[AndroidElement], docs_idr: Path, use_exist_doc: bool = True) -> str:
|
||||
if not use_exist_doc:
|
||||
return ""
|
||||
|
||||
ui_doc = """
|
||||
You also have access to the following documentations that describes the functionalities of UI
|
||||
elements you can interact on the screen. These docs are crucial for you to determine the target of your
|
||||
next action. You should always prioritize these documented elements for interaction: """
|
||||
for i, elem in enumerate(elem_list):
|
||||
doc_path = docs_idr.joinpath(f"{elem.uid}.txt")
|
||||
if not doc_path.exists():
|
||||
continue
|
||||
try:
|
||||
doc_content = ast.literal_eval(doc_path.read_text())
|
||||
except Exception as exp:
|
||||
logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}")
|
||||
continue
|
||||
|
||||
ui_doc += f"Documentation of UI element labeled with the numeric tag '{i + 1}':\n"
|
||||
if doc_content["tap"]:
|
||||
ui_doc += f"This UI element is clickable. {doc_content['tap']}\n\n"
|
||||
if doc_content["text"]:
|
||||
ui_doc += (
|
||||
f"This UI element can receive text input. The text input is used for the following "
|
||||
f"purposes: {doc_content['text']}\n\n"
|
||||
)
|
||||
if doc_content["long_press"]:
|
||||
ui_doc += f"This UI element is long clickable. {doc_content['long_press']}\n\n"
|
||||
if doc_content["v_swipe"]:
|
||||
ui_doc += (
|
||||
f"This element can be swiped directly without tapping. You can swipe vertically on "
|
||||
f"this UI element. {doc_content['v_swipe']}\n\n"
|
||||
)
|
||||
if doc_content["h_swipe"]:
|
||||
ui_doc += (
|
||||
f"This element can be swiped directly without tapping. You can swipe horizontally on "
|
||||
f"this UI element. {doc_content['h_swipe']}\n\n"
|
||||
)
|
||||
return ui_doc
|
||||
|
||||
async def run(
|
||||
self,
|
||||
round_count: int,
|
||||
task_desc: str,
|
||||
last_act: str,
|
||||
task_dir: Path,
|
||||
docs_dir: Path,
|
||||
grid_on: bool,
|
||||
env: AndroidEnv,
|
||||
):
|
||||
extra_config = config.extra
|
||||
for path in [task_dir, docs_dir]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
screenshot_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir)
|
||||
)
|
||||
if not screenshot_path.exists() or not xml_path.exists():
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
clickable_list = []
|
||||
focusable_list = []
|
||||
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
|
||||
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
|
||||
elem_list: list[AndroidElement] = clickable_list.copy()
|
||||
for elem in focusable_list:
|
||||
bbox = elem.bbox
|
||||
center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
close = False
|
||||
for e in clickable_list:
|
||||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
if dist <= extra_config.get("min_dist", 30):
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
elem_list.append(elem)
|
||||
|
||||
screenshot_labeled_path = task_dir.joinpath(f"{round_count}_labeled.png")
|
||||
draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list)
|
||||
img_base64 = encode_image(screenshot_labeled_path)
|
||||
|
||||
parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template
|
||||
|
||||
if grid_on:
|
||||
env.rows, env.cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png")
|
||||
|
||||
ui_doc = self._makeup_ui_document(elem_list, docs_dir)
|
||||
context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act)
|
||||
node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64])
|
||||
|
||||
if "error" in node.content:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
prompt = node.compile(context=context, schema="json", mode="auto")
|
||||
OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_labeled_path), response=node.content)
|
||||
|
||||
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on)
|
||||
if op_param.param_state == RunState.FINISH:
|
||||
logger.info(f"op_param: {op_param}")
|
||||
return AndroidActionOutput(action_state=RunState.FINISH)
|
||||
if op_param.param_state == RunState.FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
last_act = op_param.last_act
|
||||
if isinstance(op_param, TapOpParam):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y))
|
||||
elif isinstance(op_param, TextOpParam):
|
||||
action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str)
|
||||
elif isinstance(op_param, LongPressOpParam):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y))
|
||||
elif isinstance(op_param, SwipeOpParam):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(
|
||||
action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist
|
||||
)
|
||||
elif isinstance(op_param, GridOpParam):
|
||||
grid_on = True
|
||||
elif isinstance(op_param, TapGridOpParam) or isinstance(op_param, LongPressGridOpParam):
|
||||
x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols)
|
||||
if isinstance(op_param, TapGridOpParam):
|
||||
action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y))
|
||||
else:
|
||||
# LongPressGridOpParam
|
||||
action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y))
|
||||
elif isinstance(op_param, SwipeGridOpParam):
|
||||
start_x, start_y = area_to_xy(
|
||||
op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols
|
||||
)
|
||||
end_x, end_y = area_to_xy(
|
||||
op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols
|
||||
)
|
||||
action = EnvAction(
|
||||
action_type=EnvActionType.USER_SWIPE_TO, coord=(start_x, start_y), tgt_coord=(end_x, end_y)
|
||||
)
|
||||
|
||||
if not grid_on:
|
||||
obs, _, _, _, info = env.step(action)
|
||||
action_res = info["res"]
|
||||
if action_res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
if op_param.act_name != "grid":
|
||||
grid_on = False
|
||||
|
||||
return AndroidActionOutput(data={"grid_on": grid_on, "last_act": last_act})
|
||||
48
metagpt/ext/android_assistant/actions/screenshot_parse_an.py
Normal file
48
metagpt/ext/android_assistant/actions/screenshot_parse_an.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the ActionNode to parse screenshot
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
OBSERVATION = ActionNode(
|
||||
key="Observation", expected_type=str, instruction="Describe what you observe in the image", example=""
|
||||
)
|
||||
|
||||
THOUGHT = ActionNode(
|
||||
key="Thought",
|
||||
expected_type=str,
|
||||
instruction="To complete the given task, what is the next step I should do",
|
||||
example="",
|
||||
)
|
||||
|
||||
ACTION = ActionNode(
|
||||
key="Action",
|
||||
expected_type=str,
|
||||
instruction="The function call with the correct parameters to proceed with the task. If you believe the task is "
|
||||
"completed or there is nothing to be done, you should output FINISH. You cannot output anything else "
|
||||
"except a function call or FINISH in this field.",
|
||||
example="",
|
||||
)
|
||||
|
||||
SUMMARY = ActionNode(
|
||||
key="Summary",
|
||||
expected_type=str,
|
||||
instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include "
|
||||
"the numeric tag in your summary",
|
||||
example="",
|
||||
)
|
||||
|
||||
SUMMARY_GRID = ActionNode(
|
||||
key="Summary",
|
||||
expected_type=str,
|
||||
instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include "
|
||||
"the grid area number in your summary",
|
||||
example="",
|
||||
)
|
||||
|
||||
NODES = [OBSERVATION, THOUGHT, ACTION, SUMMARY]
|
||||
|
||||
NODES_GRID = [OBSERVATION, THOUGHT, ACTION, SUMMARY_GRID]
|
||||
|
||||
SCREENSHOT_PARSE_NODE = ActionNode.from_children("ScreenshotParse", NODES)
|
||||
SCREENSHOT_PARSE_GRID_NODE = ActionNode.from_children("ScreenshotParseGrid", NODES_GRID)
|
||||
231
metagpt/ext/android_assistant/actions/self_learn_and_reflect.py
Normal file
231
metagpt/ext/android_assistant/actions/self_learn_and_reflect.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.ext.android_assistant.actions.screenshot_parse_an import (
|
||||
SCREENSHOT_PARSE_NODE,
|
||||
)
|
||||
from metagpt.ext.android_assistant.actions.self_learn_reflect_an import (
|
||||
SELF_LEARN_REFLECT_NODE,
|
||||
)
|
||||
from metagpt.ext.android_assistant.prompts.assistant_prompt import (
|
||||
screenshot_parse_self_explore_reflect_template as reflect_template,
|
||||
)
|
||||
from metagpt.ext.android_assistant.prompts.assistant_prompt import (
|
||||
screenshot_parse_self_explore_template,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.schema import (
|
||||
ActionOp,
|
||||
AndroidActionOutput,
|
||||
AndroidElement,
|
||||
Decision,
|
||||
DocContent,
|
||||
LongPressOpParam,
|
||||
OpLogItem,
|
||||
ReflectLogItem,
|
||||
RunState,
|
||||
SwipeOp,
|
||||
SwipeOpParam,
|
||||
TapOpParam,
|
||||
TextOpParam,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.utils import (
|
||||
draw_bbox_multi,
|
||||
elem_bbox_to_xy,
|
||||
elem_list_from_xml_tree,
|
||||
reflect_parse_extarct,
|
||||
screenshot_parse_extract,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
class SelfLearnAndReflect(Action):
|
||||
name: str = "SelfLearnAndReflect"
|
||||
|
||||
useless_list: list[str] = [] # store useless elements uid
|
||||
|
||||
screenshot_before_path: str = ""
|
||||
screenshot_before_base64: str = ""
|
||||
elem_list: list[AndroidElement] = []
|
||||
swipe_orient: str = "up"
|
||||
act_name: str = ""
|
||||
ui_area: int = -1
|
||||
|
||||
async def run(
|
||||
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
|
||||
) -> AndroidActionOutput:
|
||||
for path in [task_dir, docs_dir]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env)
|
||||
if resp.action_state != RunState.SUCCESS:
|
||||
return resp
|
||||
|
||||
resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env)
|
||||
return resp
|
||||
|
||||
async def run_self_learn(
|
||||
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
|
||||
) -> AndroidActionOutput:
|
||||
extra_config = config.extra
|
||||
screenshot_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir)
|
||||
)
|
||||
if not screenshot_path.exists() or not xml_path.exists():
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30))
|
||||
|
||||
screenshot_before_labeled_path = task_dir.joinpath(f"{round_count}_before_labeled.png")
|
||||
draw_bbox_multi(screenshot_path, screenshot_before_labeled_path, elem_list)
|
||||
img_base64 = encode_image(screenshot_before_labeled_path)
|
||||
self.screenshot_before_base64 = img_base64
|
||||
self.screenshot_before_path = screenshot_before_labeled_path
|
||||
|
||||
self_explore_template = screenshot_parse_self_explore_template
|
||||
context = self_explore_template.format(task_description=task_desc, last_act=last_act)
|
||||
|
||||
node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64])
|
||||
logger.debug(f"fill result:{node}")
|
||||
if "error" in node.content:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
prompt = node.compile(context=context, schema="json", mode="auto")
|
||||
# Modify WindowsPath to Str
|
||||
OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content)
|
||||
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False)
|
||||
# TODO Modify Op_param. When op_param.action is FINISH, how to solve this ?
|
||||
if op_param.param_state == RunState.FINISH:
|
||||
return AndroidActionOutput(action_state=RunState.FINISH)
|
||||
if op_param.param_state == RunState.FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
if isinstance(op_param, TapOpParam):
|
||||
self.ui_area = op_param.area
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y))
|
||||
elif isinstance(op_param, TextOpParam):
|
||||
action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str)
|
||||
elif isinstance(op_param, LongPressOpParam):
|
||||
self.ui_area = op_param.area
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y))
|
||||
elif isinstance(op_param, SwipeOpParam):
|
||||
self.ui_area = op_param.area
|
||||
self.swipe_orient = op_param.swipe_orient
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
action = EnvAction(
|
||||
action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist
|
||||
)
|
||||
|
||||
obs, _, _, _, info = env.step(action)
|
||||
action_res = info["res"]
|
||||
if action_res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
self.elem_list = elem_list
|
||||
self.act_name = op_param.act_name
|
||||
return AndroidActionOutput()
|
||||
|
||||
async def run_reflect(
|
||||
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
|
||||
) -> AndroidActionOutput:
|
||||
screenshot_path: Path = env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_after", local_save_dir=task_dir)
|
||||
)
|
||||
if not screenshot_path.exists():
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
screenshot_after_labeled_path = task_dir.joinpath(f"{round_count}_after_labeled.png")
|
||||
draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list)
|
||||
img_base64 = encode_image(screenshot_after_labeled_path)
|
||||
if self.act_name == ActionOp.TAP.value:
|
||||
action = "tapping"
|
||||
elif self.act_name == ActionOp.LONG_PRESS.value:
|
||||
action = "long pressing"
|
||||
elif self.act_name == ActionOp.SWIPE.value:
|
||||
action = "swiping"
|
||||
if self.swipe_orient == SwipeOp.UP.value or self.swipe_orient == SwipeOp.DOWN.value:
|
||||
action = "v_swipe"
|
||||
elif self.swipe_orient == SwipeOp.LEFT.value or self.swipe_orient == SwipeOp.RIGHT.value:
|
||||
action = "h_swipe"
|
||||
else:
|
||||
# TODO Test for assignment, This error is eupiped with the next.
|
||||
logger.warning(f"Current action name parse failed, it's `{self.act_name}`")
|
||||
action = None
|
||||
context = reflect_template.format(
|
||||
action=action, ui_element=str(self.ui_area), task_desc=task_desc, last_act=last_act
|
||||
)
|
||||
node = await SELF_LEARN_REFLECT_NODE.fill(
|
||||
context=context, llm=self.llm, images=[self.screenshot_before_base64, img_base64]
|
||||
)
|
||||
|
||||
if "error" in node.content:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
prompt = node.compile(context=context, schema="json", mode="auto")
|
||||
ReflectLogItem(
|
||||
step=round_count,
|
||||
prompt=prompt,
|
||||
image_before=str(self.screenshot_before_path),
|
||||
image_after=str(screenshot_after_labeled_path),
|
||||
response=node.content,
|
||||
)
|
||||
|
||||
op_param = reflect_parse_extarct(node.instruct_content.model_dump())
|
||||
if op_param.param_state == RunState.FINISH:
|
||||
return AndroidActionOutput(action_state=RunState.FINISH)
|
||||
if op_param.param_state == RunState.FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
logger.info(
|
||||
f"reflect_parse_extarct decision: {op_param.decision}, "
|
||||
f"elem_list size: {len(self.elem_list)}, ui_area: {self.ui_area}"
|
||||
)
|
||||
# TODO here will cause `IndexError: list index out of range`.
|
||||
# Maybe you should clink back to the desktop in the simulator
|
||||
resource_id = self.elem_list[int(self.ui_area) - 1].uid
|
||||
if op_param.decision == Decision.INEFFECTIVE.value:
|
||||
self.useless_list.append(resource_id)
|
||||
last_act = "NONE" # TODO global
|
||||
elif op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value, Decision.SUCCESS.value]:
|
||||
if op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value]:
|
||||
self.useless_list.append(resource_id)
|
||||
last_act = "NONE"
|
||||
if op_param.decision == Decision.BACK.value:
|
||||
action = EnvAction(action_type=EnvActionType.SYSTEM_BACK)
|
||||
obs, _, _, _, info = env.step(action)
|
||||
if info["res"] == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
doc = op_param.documentation
|
||||
doc_path = docs_dir.joinpath(f"{resource_id}.txt")
|
||||
if doc_path.exists():
|
||||
try:
|
||||
doc_content = ast.literal_eval(doc_path.read_text())
|
||||
except Exception as exp:
|
||||
logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}")
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
if doc_content[self.act_name]:
|
||||
logger.info(f"Documentation for the element {resource_id} already exists.")
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
else:
|
||||
doc_content = DocContent()
|
||||
setattr(doc_content, self.act_name, doc)
|
||||
doc_path.write_text(str(doc_content))
|
||||
return AndroidActionOutput(data={"last_act": last_act})
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the ActionNode to parse Reflection
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
DECISION = ActionNode(
|
||||
key="Decision", expected_type=str, instruction="explain why you made this decision", example="BACK"
|
||||
)
|
||||
|
||||
|
||||
THOUGHT = ActionNode(key="Thought", expected_type=str, instruction="explain why you made this decision", example="")
|
||||
|
||||
|
||||
DOCUMENTATION = ActionNode(
|
||||
key="Documentation", expected_type=str, instruction="describe the function of the UI element", example=""
|
||||
)
|
||||
|
||||
|
||||
NODES = [DECISION, THOUGHT, DOCUMENTATION]
|
||||
SELF_LEARN_REFLECT_NODE = ActionNode.from_children("SelfLearnReflect", NODES)
|
||||
3
metagpt/ext/android_assistant/prompts/__init__.py
Normal file
3
metagpt/ext/android_assistant/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
168
metagpt/ext/android_assistant/prompts/assistant_prompt.py
Normal file
168
metagpt/ext/android_assistant/prompts/assistant_prompt.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the prompt templates of assistant learning and acting
|
||||
|
||||
screenshot_parse_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given a
|
||||
smartphone screenshot. The interactive UI elements on the screenshot are labeled with numeric tags starting from 1. The
|
||||
numeric tag of each interactive element is located in the center of the element.
|
||||
|
||||
You can call the following functions to control the smartphone:
|
||||
|
||||
1. tap(element: int)
|
||||
This function is used to tap an UI element shown on the smartphone screen.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen.
|
||||
A simple use case can be tap(5), which taps the UI element labeled with the number 5.
|
||||
|
||||
2. text(text_input: str)
|
||||
This function is used to insert text input in an input field/box. text_input is the string you want to insert and must
|
||||
be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string
|
||||
"Hello, world!" into the input area on the smartphone screen. This function is usually callable when you see a keyboard
|
||||
showing in the lower half of the screen.
|
||||
|
||||
3. long_press(element: int)
|
||||
This function is used to long press an UI element shown on the smartphone screen.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen.
|
||||
A simple use case can be long_press(5), which long presses the UI element labeled with the number 5.
|
||||
|
||||
4. swipe(element: int, direction: str, dist: str)
|
||||
This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that
|
||||
represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation
|
||||
marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should
|
||||
choose the appropriate distance option according to your need.
|
||||
A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a
|
||||
medium distance.
|
||||
|
||||
5. grid()
|
||||
You should call this function when you find the element you want to interact with is not labeled with a numeric tag and
|
||||
other elements with numeric tags cannot help with the task. The function will bring up a grid overlay to divide the
|
||||
smartphone screen into small areas and this will give you more freedom to choose any part of the screen to tap, long
|
||||
press, or swipe.
|
||||
{ui_document}
|
||||
The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as
|
||||
follows: {last_act}
|
||||
Now, given the documentation and the following labeled screenshot, you need to think and call the function needed to
|
||||
proceed with the task. Your output should include three parts in the given format:
|
||||
|
||||
You can only take one action at a time, so please directly call the function."""
|
||||
|
||||
screenshot_parse_with_grid_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given
|
||||
a smartphone screenshot overlaid by a grid. The grid divides the screenshot into small square areas. Each area is
|
||||
labeled with an integer in the top-left corner.
|
||||
|
||||
You can call the following functions to control the smartphone:
|
||||
|
||||
1. tap(area: int, subarea: str)
|
||||
This function is used to tap a grid area shown on the smartphone screen. "area" is the integer label assigned to a grid
|
||||
area shown on the smartphone screen. "subarea" is a string representing the exact location to tap within the grid area.
|
||||
It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom, and
|
||||
bottom-right.
|
||||
A simple use case can be tap(5, "center"), which taps the exact center of the grid area labeled with the number 5.
|
||||
|
||||
2. long_press(area: int, subarea: str)
|
||||
This function is used to long press a grid area shown on the smartphone screen. "area" is the integer label assigned to
|
||||
a grid area shown on the smartphone screen. "subarea" is a string representing the exact location to long press within
|
||||
the grid area. It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom,
|
||||
and bottom-right.
|
||||
A simple use case can be long_press(7, "top-left"), which long presses the top left part of the grid area labeled with
|
||||
the number 7.
|
||||
|
||||
3. swipe(start_area: int, start_subarea: str, end_area: int, end_subarea: str)
|
||||
This function is used to perform a swipe action on the smartphone screen, especially when you want to interact with a
|
||||
scroll view or a slide bar. "start_area" is the integer label assigned to the grid area which marks the starting
|
||||
location of the swipe. "start_subarea" is a string representing the exact location to begin the swipe within the grid
|
||||
area. "end_area" is the integer label assigned to the grid area which marks the ending location of the swipe.
|
||||
"end_subarea" is a string representing the exact location to end the swipe within the grid area.
|
||||
The two subarea parameters can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left,
|
||||
bottom, and bottom-right.
|
||||
A simple use case can be swipe(21, "center", 25, "right"), which performs a swipe starting from the center of grid area
|
||||
21 to the right part of grid area 25.
|
||||
|
||||
The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as
|
||||
follows: {last_act}
|
||||
Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task.
|
||||
Your output should include three parts in the given format:
|
||||
|
||||
You can only take one action at a time, so please directly call the function."""
|
||||
|
||||
screenshot_parse_self_explore_template = """You are an agent that is trained to complete certain tasks on a smartphone. You will be
|
||||
given a screenshot of a smartphone app. The interactive UI elements on the screenshot are labeled with numeric tags
|
||||
starting from 1.
|
||||
|
||||
You can call the following functions to interact with those labeled elements to control the smartphone:
|
||||
|
||||
1. tap(element: int)
|
||||
This function is used to tap an UI element shown on the smartphone screen.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen.
|
||||
A simple use case can be tap(5), which taps the UI element labeled with the number 5.
|
||||
|
||||
2. text(text_input: str)
|
||||
This function is used to insert text input in an input field/box. text_input is the string you want to insert and must
|
||||
be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string
|
||||
"Hello, world!" into the input area on the smartphone screen. This function is only callable when you see a keyboard
|
||||
showing in the lower half of the screen.
|
||||
|
||||
3. long_press(element: int)
|
||||
This function is used to long press an UI element shown on the smartphone screen.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen.
|
||||
A simple use case can be long_press(5), which long presses the UI element labeled with the number 5.
|
||||
|
||||
4. swipe(element: int, direction: str, dist: str)
|
||||
This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar.
|
||||
"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that
|
||||
represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation
|
||||
marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should
|
||||
choose the appropriate distance option according to your need.
|
||||
A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a
|
||||
medium distance.
|
||||
|
||||
The task you need to complete is to {task_description}. Your past actions to proceed with this task are summarized as
|
||||
follows: {last_act}
|
||||
Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task.
|
||||
Your output should include three parts in the given format:
|
||||
|
||||
You can only take one action at a time, so please directly call the function."""
|
||||
|
||||
screenshot_parse_self_explore_reflect_template = """I will give you screenshots of a mobile app before and after {action} the UI
|
||||
element labeled with the number '{ui_element}' on the first screenshot. The numeric tag of each element is located at
|
||||
the center of the element. The action of {action} this UI element was described as follows:
|
||||
{last_act}
|
||||
The action was also an attempt to proceed with a larger task, which is to {task_desc}. Your job is to carefully analyze
|
||||
the difference between the two screenshots to determine if the action is in accord with the description above and at
|
||||
the same time effectively moved the task forward. Your output should be determined based on the following situations:
|
||||
1. BACK
|
||||
If you think the action navigated you to a page where you cannot proceed with the given task, you should go back to the
|
||||
previous interface. At the same time, describe the functionality of the UI element concisely in one or two sentences by
|
||||
observing the difference between the two screenshots. Notice that your description of the UI element should focus on
|
||||
the general function. Never include the numeric tag of the UI element in your description. You can use pronouns such as
|
||||
"the UI element" to refer to the element. Your output should be in the following format:
|
||||
Decision: BACK
|
||||
Thought: <explain why you think the last action is wrong and you should go back to the previous interface>
|
||||
Documentation: <describe the function of the UI element>
|
||||
2. INEFFECTIVE
|
||||
If you find the action changed nothing on the screen (screenshots before and after the action are identical), you
|
||||
should continue to interact with other elements on the screen. Notice that if you find the location of the cursor
|
||||
changed between the two screenshots, then they are not identical. Your output should be in the following format:
|
||||
Decision: INEFFECTIVE
|
||||
Thought: <explain why you made this decision>
|
||||
Documentation: <None>
|
||||
3. CONTINUE
|
||||
If you find the action changed something on the screen but does not reflect the action description above and did not
|
||||
move the given task forward, you should continue to interact with other elements on the screen. At the same time,
|
||||
describe the functionality of the UI element concisely in one or two sentences by observing the difference between the
|
||||
two screenshots. Notice that your description of the UI element should focus on the general function. Never include the
|
||||
numeric tag of the UI element in your description. You can use pronouns such as "the UI element" to refer to the
|
||||
element. Your output should be in the following format:
|
||||
Decision: CONTINUE
|
||||
Thought: <explain why you think the action does not reflect the action description above and did not move the given
|
||||
task forward>
|
||||
Documentation: <describe the function of the UI element>
|
||||
4. SUCCESS
|
||||
If you think the action successfully moved the task forward (even though it did not completed the task), you should
|
||||
describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI
|
||||
element should focus on the general function. Never include the numeric tag of the UI element in your description. You
|
||||
can use pronouns such as "the UI element" to refer to the element. Your output should be in the following format:
|
||||
Decision: SUCCESS
|
||||
Thought: <explain why you think the action successfully moved the task forward>
|
||||
Documentation: <describe the function of the UI element>
|
||||
"""
|
||||
45
metagpt/ext/android_assistant/prompts/operation_prompt.py
Normal file
45
metagpt/ext/android_assistant/prompts/operation_prompt.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the prompt templates of phone operation
|
||||
|
||||
tap_doc_template = """I will give you the screenshot of a mobile app before and after tapping the UI element labeled
|
||||
with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element.
|
||||
Tapping this UI element is a necessary part of proceeding with a larger task, which is to <task_desc>. Your task is to
|
||||
describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI
|
||||
element should focus on the general function. For example, if the UI element is used to navigate to the chat window
|
||||
with John, your description should not include the name of the specific person. Just say: "Tapping this area will
|
||||
navigate the user to the chat window". Never include the numeric tag of the UI element in your description. You can use
|
||||
pronouns such as "the UI element" to refer to the element."""
|
||||
|
||||
text_doc_template = """I will give you the screenshot of a mobile app before and after typing in the input area labeled
|
||||
with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element.
|
||||
Typing in this UI element is a necessary part of proceeding with a larger task, which is to <task_desc>. Your task is
|
||||
to describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the
|
||||
UI element should focus on the general function. For example, if the change of the screenshot shows that the user typed
|
||||
"How are you?" in the chat box, you do not need to mention the actual text. Just say: "This input area is used for the
|
||||
user to type a message to send to the chat window.". Never include the numeric tag of the UI element in your
|
||||
description. You can use pronouns such as "the UI element" to refer to the element."""
|
||||
|
||||
long_press_doc_template = """I will give you the screenshot of a mobile app before and after long pressing the UI
|
||||
element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of
|
||||
the element. Long pressing this UI element is a necessary part of proceeding with a larger task, which is to
|
||||
<task_desc>. Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice
|
||||
that your description of the UI element should focus on the general function. For example, if long pressing the UI
|
||||
element redirects the user to the chat window with John, your description should not include the name of the specific
|
||||
person. Just say: "Long pressing this area will redirect the user to the chat window". Never include the numeric tag of
|
||||
the UI element in your description. You can use pronouns such as "the UI element" to refer to the element."""
|
||||
|
||||
swipe_doc_template = """I will give you the screenshot of a mobile app before and after swiping <swipe_dir> the UI
|
||||
element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of
|
||||
the element. Swiping this UI element is a necessary part of proceeding with a larger task, which is to <task_desc>.
|
||||
Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice that your
|
||||
description of the UI element should be as general as possible. For example, if swiping the UI element increases the
|
||||
contrast ratio of an image of a building, your description should be just like this: "Swiping this area enables the
|
||||
user to tune a specific parameter of the image". Never include the numeric tag of the UI element in your description.
|
||||
You can use pronouns such as "the UI element" to refer to the element."""
|
||||
|
||||
refine_doc_suffix = """\nA documentation of this UI element generated from previous demos is shown below. Your
|
||||
generated description should be based on this previous doc and optimize it. Notice that it is possible that your
|
||||
understanding of the function of the UI element derived from the given screenshots conflicts with the previous doc,
|
||||
because the function of a UI element can be flexible. In this case, your generated description should combine both.
|
||||
Old documentation of this UI element: {old_doc}"""
|
||||
3
metagpt/ext/android_assistant/roles/__init__.py
Normal file
3
metagpt/ext/android_assistant/roles/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
146
metagpt/ext/android_assistant/roles/android_assistant.py
Normal file
146
metagpt/ext/android_assistant/roles/android_assistant.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : android assistant to learn from app operations and operate apps
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.ext.android_assistant.actions.manual_record import ManualRecord
|
||||
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
|
||||
from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse
|
||||
from metagpt.ext.android_assistant.actions.self_learn_and_reflect import (
|
||||
SelfLearnAndReflect,
|
||||
)
|
||||
from metagpt.ext.android_assistant.utils.schema import AndroidActionOutput, RunState
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class AndroidAssistant(Role):
|
||||
name: str = "Nick"
|
||||
profile: str = "AndroidAssistant"
|
||||
goal: str = "operate the mobile phone's apps with self-learn"
|
||||
|
||||
task_desc: str = ""
|
||||
round_count: int = 0
|
||||
last_act: str = "None"
|
||||
output_root_dir: Optional[Path] = Field(default=None)
|
||||
task_dir: Optional[Path] = Field(default=None)
|
||||
docs_dir: Optional[Path] = Field(default=None)
|
||||
grid_on: bool = Field(default=False)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
self._watch([UserRequirement, AndroidActionOutput])
|
||||
extra_config = config.extra
|
||||
self.task_desc = extra_config.get("task_desc", "Just explore any app in this phone!")
|
||||
app_name = extra_config.get("app_name", "demo")
|
||||
data_dir = self.output_root_dir.absolute().joinpath("output") or EXAMPLE_PATH.joinpath(
|
||||
"android_assistant/output"
|
||||
)
|
||||
cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
"""Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app,
|
||||
run the learn first and then do the act stage or learn it during the action.
|
||||
"""
|
||||
stage = extra_config.get("stage")
|
||||
mode = extra_config.get("mode")
|
||||
if stage == "learn" and mode == "manual":
|
||||
# choose ManualRecord and then run ParseRecord
|
||||
# Remember, only run each action only one time, no need to run n_round.
|
||||
self.set_actions([ManualRecord, ParseRecord])
|
||||
self.task_dir = data_dir.joinpath(app_name, f"manual_learn_{cur_datetime}")
|
||||
self.docs_dir = data_dir.joinpath(app_name, "manual_docs")
|
||||
elif stage == "learn" and mode == "auto":
|
||||
# choose SelfLearnAndReflect to run
|
||||
self.set_actions([SelfLearnAndReflect])
|
||||
self.task_dir = data_dir.joinpath(app_name, f"auto_learn_{cur_datetime}")
|
||||
self.docs_dir = data_dir.joinpath(app_name, "auto_docs")
|
||||
elif stage == "act":
|
||||
# choose ScreenshotParse to run
|
||||
self.set_actions([ScreenshotParse])
|
||||
self.task_dir = data_dir.joinpath(app_name, f"act_{cur_datetime}")
|
||||
if mode == "manual":
|
||||
self.docs_dir = data_dir.joinpath(app_name, "manual_docs")
|
||||
else:
|
||||
self.docs_dir = data_dir.joinpath(app_name, "auto_docs")
|
||||
else:
|
||||
raise ValueError(f"invalid stage: {stage}, mode: {mode}")
|
||||
|
||||
self._check_dir()
|
||||
|
||||
self._set_react_mode(RoleReactMode.BY_ORDER)
|
||||
|
||||
def _check_dir(self):
|
||||
self.task_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def react(self) -> Message:
|
||||
self.round_count += 1
|
||||
result = await super().react()
|
||||
logger.debug(f"react result {result}")
|
||||
return result
|
||||
|
||||
async def _observe(self, ignore_memory=True) -> int:
|
||||
"""ignore old memory to make it run multi rounds inside a role"""
|
||||
newest_msgs = self.rc.memory.get(k=1)
|
||||
newest_msg = newest_msgs[0] if newest_msgs else None
|
||||
if newest_msg and (RunState.SUCCESS.value.upper() not in newest_msg.content):
|
||||
ignore_memory = False
|
||||
state_val = newest_msg.content.split(".")[-1] # RoundCount: 1, action_state: RunState.SUCCESS
|
||||
logger.warning(f"Latest action_state is {state_val}, will run in the remainder rounds without `react`")
|
||||
return await super()._observe(ignore_memory)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
todo = self.rc.todo
|
||||
if isinstance(todo, ManualRecord):
|
||||
resp = await todo.run(task_dir=self.task_dir, task_desc=self.task_desc, env=self.rc.env)
|
||||
elif isinstance(todo, ParseRecord):
|
||||
resp = await todo.run(
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
)
|
||||
elif isinstance(todo, SelfLearnAndReflect):
|
||||
resp = await todo.run(
|
||||
round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
env=self.rc.env,
|
||||
)
|
||||
if resp.action_state == RunState.SUCCESS:
|
||||
self.last_act = resp.data.get("last_act")
|
||||
elif isinstance(todo, ScreenshotParse):
|
||||
resp = await todo.run(
|
||||
round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
grid_on=self.grid_on,
|
||||
env=self.rc.env,
|
||||
)
|
||||
if resp.action_state == RunState.SUCCESS:
|
||||
logger.info(f"grid_on: {resp.data.get('grid_on')}")
|
||||
self.grid_on = resp.data.get("grid_on", False)
|
||||
self.last_act = resp.data.get("last_act", "None")
|
||||
msg = Message(
|
||||
content=f"RoundCount: {self.round_count}, action_state: {resp.action_state}",
|
||||
role=self.profile,
|
||||
cause_by=type(resp),
|
||||
send_from=self.name,
|
||||
send_to=self.name,
|
||||
)
|
||||
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
3
metagpt/ext/android_assistant/utils/__init__.py
Normal file
3
metagpt/ext/android_assistant/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
158
metagpt/ext/android_assistant/utils/schema.py
Normal file
158
metagpt/ext/android_assistant/utils/schema.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ActionOp(Enum):
|
||||
TAP = "tap"
|
||||
LONG_PRESS = "long_press"
|
||||
TEXT = "text"
|
||||
SWIPE = "swipe"
|
||||
VERTICAL_SWIPE = "v_swipe"
|
||||
HORIZONTAL_SWIPE = "h_swipe"
|
||||
GRID = "grid"
|
||||
STOP = "stop"
|
||||
|
||||
|
||||
class SwipeOp(Enum):
|
||||
UP = "up"
|
||||
DOWN = "down"
|
||||
LEFT = "left"
|
||||
RIGHT = "right"
|
||||
|
||||
|
||||
class Decision(Enum):
|
||||
BACK = "BACK"
|
||||
INEFFECTIVE = "INEFFECTIVE"
|
||||
CONTINUE = "CONTINUE"
|
||||
SUCCESS = "SUCCESS"
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
|
||||
class AndroidElement(BaseModel):
|
||||
"""UI Element"""
|
||||
|
||||
uid: str = Field(default="")
|
||||
bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={})
|
||||
attrib: str = Field(default="")
|
||||
|
||||
|
||||
class OpLogItem(BaseModel):
|
||||
"""log content for self-learn or task act"""
|
||||
|
||||
step: int = Field(default=0)
|
||||
prompt: str = Field(default="")
|
||||
image: str = Field(default="")
|
||||
response: str = Field(default="")
|
||||
|
||||
|
||||
class ReflectLogItem(BaseModel):
|
||||
"""log content for self-learn-reflect"""
|
||||
|
||||
step: int = Field(default=0)
|
||||
prompt: str = Field(default="")
|
||||
image_before: str = Field(default="")
|
||||
image_after: str = Field(default="")
|
||||
response: str = Field(default="")
|
||||
|
||||
|
||||
class RecordLogItem(BaseModel):
|
||||
"""log content for record parse, same as ReflectLogItem"""
|
||||
|
||||
step: int = Field(default=0)
|
||||
prompt: str = Field(default="")
|
||||
image_before: str = Field(default="")
|
||||
image_after: str = Field(default="")
|
||||
response: str = Field(default="")
|
||||
|
||||
|
||||
class DocContent(BaseModel):
|
||||
tap: str = Field(default="")
|
||||
text: str = Field(default="")
|
||||
v_swipe: str = Field(default="")
|
||||
h_swipe: str = Field(default="")
|
||||
long_press: str = Field(default="")
|
||||
|
||||
|
||||
# start =================== define different Action Op and its params =============
|
||||
class RunState(Enum):
|
||||
"""run state"""
|
||||
|
||||
SUCCESS = "success"
|
||||
FINISH = "finish"
|
||||
FAIL = "fail"
|
||||
|
||||
|
||||
class BaseOpParam(BaseModel):
|
||||
act_name: str = Field(default="", validate_default=True)
|
||||
last_act: str = Field(default="None")
|
||||
param_state: RunState = Field(default=RunState.SUCCESS, description="return state when extract params")
|
||||
|
||||
|
||||
class TapOpParam(BaseOpParam):
|
||||
area: int = Field(default=-1)
|
||||
|
||||
|
||||
class TextOpParam(BaseOpParam):
|
||||
input_str: str = Field(default="")
|
||||
|
||||
|
||||
class LongPressOpParam(BaseOpParam):
|
||||
area: int = Field(default=-1)
|
||||
|
||||
|
||||
# Modify This SwipeOp to SwipeOpParam, Need better name
|
||||
class SwipeOpParam(BaseOpParam):
|
||||
area: int = Field(default=-1)
|
||||
swipe_orient: str = Field(default="up")
|
||||
dist: str = Field(default="")
|
||||
|
||||
|
||||
class GridOpParam(BaseOpParam):
|
||||
act_name: str = Field(default="")
|
||||
|
||||
|
||||
class BaseGridOpParam(BaseOpParam):
|
||||
@field_validator("act_name", mode="before")
|
||||
@classmethod
|
||||
def check_act_name(cls, act_name: str) -> str:
|
||||
return f"{act_name}_grid"
|
||||
|
||||
|
||||
class TapGridOpParam(BaseGridOpParam):
|
||||
area: int = Field(default=-1)
|
||||
subarea: str = Field(default="")
|
||||
|
||||
|
||||
class LongPressGridOpParam(BaseGridOpParam):
|
||||
area: int = Field(default=-1)
|
||||
subarea: str = Field(default="")
|
||||
|
||||
|
||||
class SwipeGridOpParam(BaseGridOpParam):
|
||||
start_area: int = Field(default=-1)
|
||||
start_subarea: str = Field(default="")
|
||||
end_area: int = Field(default=-1)
|
||||
end_subarea: str = Field(default="")
|
||||
|
||||
|
||||
# end =================== define different Action Op and its params =============
|
||||
|
||||
|
||||
class ReflectOp(BaseModel):
|
||||
decision: str = ""
|
||||
thought: str = ""
|
||||
documentation: str = ""
|
||||
param_state: RunState = RunState.SUCCESS
|
||||
|
||||
|
||||
class AndroidActionOutput(BaseModel):
|
||||
data: dict = Field(default=dict())
|
||||
action_state: RunState = Field(default=RunState.SUCCESS)
|
||||
329
metagpt/ext/android_assistant/utils/utils.py
Normal file
329
metagpt/ext/android_assistant/utils/utils.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from xml.etree.ElementTree import Element, iterparse
|
||||
|
||||
import cv2
|
||||
import pyshine as ps
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.ext.android_assistant.utils.schema import (
|
||||
ActionOp,
|
||||
AndroidElement,
|
||||
BaseGridOpParam,
|
||||
BaseOpParam,
|
||||
Decision,
|
||||
GridOpParam,
|
||||
LongPressGridOpParam,
|
||||
LongPressOpParam,
|
||||
ReflectOp,
|
||||
RunState,
|
||||
SwipeGridOpParam,
|
||||
SwipeOpParam,
|
||||
TapGridOpParam,
|
||||
TapOpParam,
|
||||
TextOpParam,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def get_id_from_element(elem: Element) -> str:
|
||||
bounds = elem.attrib["bounds"][1:-1].split("][")
|
||||
x1, y1 = map(int, bounds[0].split(","))
|
||||
x2, y2 = map(int, bounds[1].split(","))
|
||||
elem_w, elem_h = x2 - x1, y2 - y1
|
||||
if "resource-id" in elem.attrib and elem.attrib["resource-id"]:
|
||||
elem_id = elem.attrib["resource-id"].replace(":", ".").replace("/", "_")
|
||||
else:
|
||||
elem_id = f"{elem.attrib['class']}_{elem_w}_{elem_h}"
|
||||
if "content-desc" in elem.attrib and elem.attrib["content-desc"] and len(elem.attrib["content-desc"]) < 20:
|
||||
content_desc = elem.attrib["content-desc"].replace("/", "_").replace(" ", "").replace(":", "_")
|
||||
elem_id += f"_{content_desc}"
|
||||
return elem_id
|
||||
|
||||
|
||||
def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False):
|
||||
path = []
|
||||
extra_config = config.extra
|
||||
for event, elem in iterparse(str(xml_path), ["start", "end"]):
|
||||
if event == "start":
|
||||
path.append(elem)
|
||||
if attrib in elem.attrib and elem.attrib[attrib] == "true":
|
||||
parent_prefix = ""
|
||||
if len(path) > 1:
|
||||
parent_prefix = get_id_from_element(path[-2])
|
||||
bounds = elem.attrib["bounds"][1:-1].split("][")
|
||||
x1, y1 = map(int, bounds[0].split(","))
|
||||
x2, y2 = map(int, bounds[1].split(","))
|
||||
center = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
elem_id = get_id_from_element(elem)
|
||||
if parent_prefix:
|
||||
elem_id = parent_prefix + "_" + elem_id
|
||||
if add_index:
|
||||
elem_id += f"_{elem.attrib['index']}"
|
||||
close = False
|
||||
for e in elem_list:
|
||||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
if dist <= extra_config.get("min_dist", 30):
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
elem_list.append(AndroidElement(uid=elem_id, bbox=((x1, y1), (x2, y2)), attrib=attrib))
|
||||
|
||||
if event == "end":
|
||||
path.pop()
|
||||
|
||||
|
||||
def elem_list_from_xml_tree(xml_path: Path, useless_list: list[str], min_dist: int) -> list[AndroidElement]:
|
||||
clickable_list = []
|
||||
focusable_list = []
|
||||
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
|
||||
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
|
||||
elem_list = []
|
||||
for elem in clickable_list:
|
||||
if elem.uid in useless_list:
|
||||
continue
|
||||
elem_list.append(elem)
|
||||
for elem in focusable_list:
|
||||
if elem.uid in useless_list:
|
||||
continue
|
||||
bbox = elem.bbox
|
||||
center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
close = False
|
||||
for e in clickable_list:
|
||||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
if dist <= min_dist:
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
elem_list.append(elem)
|
||||
return elem_list
|
||||
|
||||
|
||||
def draw_bbox_multi(
|
||||
img_path: Path,
|
||||
output_path: Path,
|
||||
elem_list: list[AndroidElement],
|
||||
record_mode: bool = False,
|
||||
dark_mode: bool = False,
|
||||
):
|
||||
imgcv = cv2.imread(str(img_path))
|
||||
count = 1
|
||||
for elem in elem_list:
|
||||
try:
|
||||
top_left = elem.bbox[0]
|
||||
bottom_right = elem.bbox[1]
|
||||
left, top = top_left[0], top_left[1]
|
||||
right, bottom = bottom_right[0], bottom_right[1]
|
||||
label = str(count)
|
||||
if record_mode:
|
||||
if elem.attrib == "clickable":
|
||||
color = (250, 0, 0)
|
||||
elif elem.attrib == "focusable":
|
||||
color = (0, 0, 250)
|
||||
else:
|
||||
color = (0, 250, 0)
|
||||
imgcv = ps.putBText(
|
||||
imgcv,
|
||||
label,
|
||||
text_offset_x=(left + right) // 2 + 10,
|
||||
text_offset_y=(top + bottom) // 2 + 10,
|
||||
vspace=10,
|
||||
hspace=10,
|
||||
font_scale=1,
|
||||
thickness=2,
|
||||
background_RGB=color,
|
||||
text_RGB=(255, 250, 250),
|
||||
alpha=0.5,
|
||||
)
|
||||
else:
|
||||
text_color = (10, 10, 10) if dark_mode else (255, 250, 250)
|
||||
bg_color = (255, 250, 250) if dark_mode else (10, 10, 10)
|
||||
imgcv = ps.putBText(
|
||||
imgcv,
|
||||
label,
|
||||
text_offset_x=(left + right) // 2 + 10,
|
||||
text_offset_y=(top + bottom) // 2 + 10,
|
||||
vspace=10,
|
||||
hspace=10,
|
||||
font_scale=1,
|
||||
thickness=2,
|
||||
background_RGB=bg_color,
|
||||
text_RGB=text_color,
|
||||
alpha=0.5,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: An exception occurs while labeling the image\n{e}")
|
||||
count += 1
|
||||
cv2.imwrite(str(output_path), imgcv)
|
||||
return imgcv
|
||||
|
||||
|
||||
def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]:
|
||||
def get_unit_len(n):
|
||||
for i in range(1, n + 1):
|
||||
if n % i == 0 and 120 <= i <= 180:
|
||||
return i
|
||||
return -1
|
||||
|
||||
image = cv2.imread(str(img_path))
|
||||
height, width, _ = image.shape
|
||||
color = (255, 116, 113)
|
||||
unit_height = get_unit_len(height)
|
||||
if unit_height < 0:
|
||||
unit_height = 120
|
||||
unit_width = get_unit_len(width)
|
||||
if unit_width < 0:
|
||||
unit_width = 120
|
||||
thick = int(unit_width // 50)
|
||||
rows = height // unit_height
|
||||
cols = width // unit_width
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
label = i * cols + j + 1
|
||||
left = int(j * unit_width)
|
||||
top = int(i * unit_height)
|
||||
right = int((j + 1) * unit_width)
|
||||
bottom = int((i + 1) * unit_height)
|
||||
cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2)
|
||||
cv2.putText(
|
||||
image,
|
||||
str(label),
|
||||
(left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3),
|
||||
0,
|
||||
int(0.01 * unit_width),
|
||||
(0, 0, 0),
|
||||
thick,
|
||||
)
|
||||
cv2.putText(
|
||||
image,
|
||||
str(label),
|
||||
(left + int(unit_width * 0.05), top + int(unit_height * 0.3)),
|
||||
0,
|
||||
int(0.01 * unit_width),
|
||||
color,
|
||||
thick,
|
||||
)
|
||||
cv2.imwrite(str(output_path), image)
|
||||
return rows, cols
|
||||
|
||||
|
||||
def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]:
|
||||
area -= 1
|
||||
row, col = area // cols, area % cols
|
||||
x_0, y_0 = col * (width // cols), row * (height // rows)
|
||||
if subarea == "top-left":
|
||||
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 4
|
||||
elif subarea == "top":
|
||||
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 4
|
||||
elif subarea == "top-right":
|
||||
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 4
|
||||
elif subarea == "left":
|
||||
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 2
|
||||
elif subarea == "right":
|
||||
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 2
|
||||
elif subarea == "bottom-left":
|
||||
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) * 3 // 4
|
||||
elif subarea == "bottom":
|
||||
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) * 3 // 4
|
||||
elif subarea == "bottom-right":
|
||||
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) * 3 // 4
|
||||
else:
|
||||
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2
|
||||
return x, y
|
||||
|
||||
|
||||
def elem_bbox_to_xy(bbox: tuple[tuple[int, int], tuple[int, int]]) -> tuple[int, int]:
|
||||
tl, br = bbox
|
||||
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
|
||||
return x, y
|
||||
|
||||
|
||||
def reflect_parse_extarct(parsed_json: dict) -> ReflectOp:
|
||||
decision = parsed_json.get("Decision")
|
||||
if decision not in Decision.values():
|
||||
op = ReflectOp(param_state=RunState.FAIL)
|
||||
else:
|
||||
op = ReflectOp(
|
||||
decision=parsed_json.get("Decision"),
|
||||
thought=parsed_json.get("Thought"),
|
||||
documentation=parsed_json.get("Documentation"),
|
||||
)
|
||||
return op
|
||||
|
||||
|
||||
def screenshot_parse_extract(
|
||||
parsed_json: dict, grid_on: bool = False
|
||||
) -> Union[BaseOpParam, BaseGridOpParam, GridOpParam]:
|
||||
act = parsed_json.get("Action")
|
||||
last_act = parsed_json.get("Summary")
|
||||
act_name = act.split("(")[0]
|
||||
|
||||
if RunState.FINISH.value.upper() in act:
|
||||
return BaseOpParam(param_state=RunState.FINISH)
|
||||
|
||||
if grid_on:
|
||||
return screenshot_parse_extract_with_grid(act_name, act, last_act)
|
||||
else:
|
||||
return screenshot_parse_extract_without_grid(act_name, act, last_act)
|
||||
|
||||
|
||||
def op_params_clean(params: list[str]) -> list[Union[int, str]]:
|
||||
param_values = []
|
||||
for param_value in params:
|
||||
if '"' in param_value or "'" in param_value: # remove `"`
|
||||
param_values.append(param_value.strip()[1:-1])
|
||||
else:
|
||||
param_values.append(int(param_value))
|
||||
return param_values
|
||||
|
||||
|
||||
def screenshot_parse_extract_without_grid(act_name: str, act: str, last_act: str) -> Union[BaseOpParam, GridOpParam]:
|
||||
if act_name == ActionOp.TAP.value:
|
||||
area = int(re.findall(r"tap\((.*?)\)", act)[0])
|
||||
op = TapOpParam(act_name=act_name, area=area, last_act=last_act)
|
||||
elif act_name == ActionOp.TEXT.value:
|
||||
input_str = re.findall(r"text\((.*?)\)", act)[0][1:-1]
|
||||
op = TextOpParam(act_name=act_name, input_str=input_str, last_act=last_act)
|
||||
elif act_name == ActionOp.LONG_PRESS.value:
|
||||
area = int(re.findall(r"long_press\((.*?)\)", act)[0])
|
||||
op = LongPressOpParam(act_name=act_name, area=area, last_act=last_act)
|
||||
elif act_name == ActionOp.SWIPE.value:
|
||||
params = re.findall(r"swipe\((.*?)\)", act)[0].split(",")
|
||||
params = op_params_clean(params) # area, swipe_orient, dist
|
||||
op = SwipeOpParam(act_name=act_name, area=params[0], swipe_orient=params[1], dist=params[2], last_act=last_act)
|
||||
elif act_name == ActionOp.GRID.value:
|
||||
op = GridOpParam(act_name=act_name)
|
||||
else:
|
||||
op = BaseOpParam(param_state=RunState.FAIL)
|
||||
return op
|
||||
|
||||
|
||||
def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) -> Union[BaseGridOpParam, GridOpParam]:
|
||||
if act_name == ActionOp.TAP.value:
|
||||
params = re.findall(r"tap\((.*?)\)", act)[0].split(",")
|
||||
params = op_params_clean(params)
|
||||
op = TapGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act)
|
||||
elif act_name == ActionOp.LONG_PRESS.value:
|
||||
params = re.findall(r"long_press\((.*?)\)", act)[0].split(",")
|
||||
params = op_params_clean(params)
|
||||
op = LongPressGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act)
|
||||
elif act_name == ActionOp.SWIPE.value:
|
||||
params = re.findall(r"swipe\((.*?)\)", act)[0].split(",")
|
||||
params = op_params_clean(params)
|
||||
op = SwipeGridOpParam(
|
||||
act_name=act_name, start_area=params[0], start_subarea=params[1], end_area=params[2], end_subarea=params[3]
|
||||
)
|
||||
elif act_name == ActionOp.GRID.value:
|
||||
op = GridOpParam(act_name=act_name)
|
||||
else:
|
||||
op = BaseGridOpParam(param_state=RunState.FAIL)
|
||||
return op
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue