mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
attach images to message
This commit is contained in:
parent
dce5502c07
commit
e9984f2bf8
6 changed files with 83 additions and 33 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from metagpt.actions import (
|
||||
UserRequirement,
|
||||
WriteDesign,
|
||||
|
|
@ -6,12 +8,12 @@ from metagpt.actions import (
|
|||
WriteTest,
|
||||
)
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import AGENT
|
||||
from metagpt.const import AGENT, IMAGES
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.logs import get_human_input
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
|
||||
from metagpt.schema import Message, SerializationMixin
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, extract_and_encode_images
|
||||
|
||||
|
||||
class MGXEnv(Environment, SerializationMixin):
|
||||
|
|
@ -27,6 +29,8 @@ class MGXEnv(Environment, SerializationMixin):
|
|||
|
||||
def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool:
|
||||
"""let the team leader take over message publishing"""
|
||||
message = self.attach_images(message) # for multi-modal message
|
||||
|
||||
tl = self.get_role("Mike") # TeamLeader's name is Mike
|
||||
|
||||
if user_defined_recipient:
|
||||
|
|
@ -119,9 +123,16 @@ class MGXEnv(Environment, SerializationMixin):
|
|||
converted_msg.role = "assistant"
|
||||
sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from
|
||||
converted_msg.content = (
|
||||
f"[Message] from {sent_from if sent_from else 'User'} to {converted_msg.send_to}: {converted_msg.content}"
|
||||
f"[Message] from {sent_from or 'User'} to {converted_msg.send_to}: {converted_msg.content}"
|
||||
)
|
||||
return converted_msg
|
||||
|
||||
def attach_images(self, message: Message) -> Message:
|
||||
if message.role == "user":
|
||||
images = extract_and_encode_images(message.content)
|
||||
if images:
|
||||
message.add_metadata(IMAGES, images)
|
||||
return message
|
||||
|
||||
def __repr__(self):
|
||||
return "MGXEnv()"
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@ from tenacity import (
|
|||
|
||||
from metagpt.configs.compress_msg_config import CompressType
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
|
||||
from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.constant import MULTI_MODAL_MODELS
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
|
@ -50,7 +51,7 @@ class BaseLLM(ABC):
|
|||
pass
|
||||
|
||||
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]:
|
||||
if images:
|
||||
if images and self.support_image_input():
|
||||
# as gpt-4v, chat with image
|
||||
return self._user_msg_with_imgs(msg, images)
|
||||
else:
|
||||
|
|
@ -76,6 +77,9 @@ class BaseLLM(ABC):
|
|||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "system", "content": msg}
|
||||
|
||||
def support_image_input(self) -> bool:
|
||||
return any([m in self.config.model for m in MULTI_MODAL_MODELS])
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -91,7 +95,9 @@ class BaseLLM(ABC):
|
|||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
images = msg.metadata.get(IMAGES)
|
||||
processed_msg = self._user_msg(msg=msg.content, images=images) if images else msg.to_dict()
|
||||
processed_messages.append(processed_msg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
|
@ -13,6 +12,7 @@ from metagpt.actions import Action, UserRequirement
|
|||
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
|
||||
from metagpt.const import IMAGES
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers import RoleZeroSerializer
|
||||
|
|
@ -35,13 +35,7 @@ from metagpt.tools.libs.browser import Browser
|
|||
from metagpt.tools.libs.editor import Editor
|
||||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
any_to_str,
|
||||
encode_image,
|
||||
extract_image_paths,
|
||||
is_support_image_input,
|
||||
)
|
||||
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
repair_escape_error,
|
||||
|
|
@ -219,15 +213,14 @@ class RoleZero(Role):
|
|||
return memory
|
||||
|
||||
def parse_images(self, memory: list[Message]) -> list[Message]:
|
||||
if not is_support_image_input(self.llm.model):
|
||||
if not self.llm.support_image_input():
|
||||
return memory
|
||||
for i, msg in enumerate(memory):
|
||||
if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content):
|
||||
images = []
|
||||
for path in extract_image_paths(msg.content):
|
||||
if os.path.exists(path):
|
||||
images.append(encode_image(path))
|
||||
memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images)
|
||||
for msg in memory:
|
||||
if IMAGES in msg.metadata or msg.role != "user":
|
||||
continue
|
||||
images = extract_and_encode_images(msg.content)
|
||||
if images:
|
||||
msg.add_metadata(IMAGES, images)
|
||||
return memory
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
|
|||
|
|
@ -840,12 +840,6 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
return img
|
||||
|
||||
|
||||
def is_support_image_input(model_name: str) -> bool:
|
||||
# model name can be gpt-4o-2024-08-06
|
||||
support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now
|
||||
return any([m in model_name for m in support_models])
|
||||
|
||||
|
||||
def extract_image_paths(content: str) -> bool:
|
||||
# We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx"
|
||||
pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)"
|
||||
|
|
@ -853,6 +847,14 @@ def extract_image_paths(content: str) -> bool:
|
|||
return image_paths
|
||||
|
||||
|
||||
def extract_and_encode_images(content: str) -> list[str]:
|
||||
images = []
|
||||
for path in extract_image_paths(content):
|
||||
if os.path.exists(path):
|
||||
images.append(encode_image(path))
|
||||
return images
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ import pytest
|
|||
|
||||
from metagpt.configs.compress_msg_config import CompressType
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import IMAGES
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import AIMessage, Message, UserMessage
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
|
|
@ -163,3 +164,41 @@ def test_compress_messages_long_no_sys_msg(compress_type):
|
|||
print(compressed)
|
||||
assert compressed
|
||||
assert len(compressed[0]["content"]) < len(messages[0]["content"])
|
||||
|
||||
|
||||
def test_format_msg(mocker):
|
||||
base_llm = MockBaseLLM()
|
||||
messages = [UserMessage(content="req"), AIMessage(content="rsp")]
|
||||
formatted_msgs = base_llm.format_msg(messages)
|
||||
assert formatted_msgs == [{"role": "user", "content": "req"}, {"role": "assistant", "content": "rsp"}]
|
||||
|
||||
|
||||
def test_format_msg_w_images(mocker):
|
||||
base_llm = MockBaseLLM()
|
||||
base_llm.config.model = "gpt-4o"
|
||||
msg_w_images = UserMessage(content="req1")
|
||||
msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"])
|
||||
msg_w_empty_images = UserMessage(content="req2")
|
||||
msg_w_empty_images.add_metadata(IMAGES, [])
|
||||
messages = [
|
||||
msg_w_images, # should be converted
|
||||
AIMessage(content="rsp"),
|
||||
msg_w_empty_images, # should not be converted
|
||||
]
|
||||
formatted_msgs = base_llm.format_msg(messages)
|
||||
assert formatted_msgs == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "req1"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 1"}},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 2"}},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "rsp"},
|
||||
{"role": "user", "content": "req2"},
|
||||
]
|
||||
|
||||
|
||||
if name == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -29,9 +29,9 @@ from metagpt.utils.common import (
|
|||
awrite,
|
||||
check_cmd_exists,
|
||||
concat_namespace,
|
||||
extract_and_encode_images,
|
||||
extract_image_paths,
|
||||
import_class_inst,
|
||||
is_support_image_input,
|
||||
parse_recipient,
|
||||
print_members,
|
||||
read_file_block,
|
||||
|
|
@ -231,9 +231,8 @@ def test_extract_image_paths():
|
|||
assert not extract_image_paths(content)
|
||||
|
||||
|
||||
def test_is_support_image_input():
|
||||
assert is_support_image_input("gpt-4o-2024-08-06")
|
||||
assert not is_support_image_input("deepseek-coder")
|
||||
def test_extract_and_encode_images():
|
||||
assert not extract_and_encode_images("a non-existing.jpg")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue