attach images to message

This commit is contained in:
garylin2099 2024-08-14 20:12:17 +08:00
parent dce5502c07
commit e9984f2bf8
6 changed files with 83 additions and 33 deletions

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from metagpt.actions import (
UserRequirement,
WriteDesign,
@ -6,12 +8,12 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT
from metagpt.const import AGENT, IMAGES
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
from metagpt.schema import Message, SerializationMixin
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.common import any_to_str, any_to_str_set, extract_and_encode_images
class MGXEnv(Environment, SerializationMixin):
@ -27,6 +29,8 @@ class MGXEnv(Environment, SerializationMixin):
def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool:
"""let the team leader take over message publishing"""
message = self.attach_images(message) # for multi-modal message
tl = self.get_role("Mike") # TeamLeader's name is Mike
if user_defined_recipient:
@ -119,9 +123,16 @@ class MGXEnv(Environment, SerializationMixin):
converted_msg.role = "assistant"
sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from
converted_msg.content = (
f"[Message] from {sent_from if sent_from else 'User'} to {converted_msg.send_to}: {converted_msg.content}"
f"[Message] from {sent_from or 'User'} to {converted_msg.send_to}: {converted_msg.content}"
)
return converted_msg
def attach_images(self, message: Message) -> Message:
if message.role == "user":
images = extract_and_encode_images(message.content)
if images:
message.add_metadata(IMAGES, images)
return message
def __repr__(self):
return "MGXEnv()"

View file

@ -24,8 +24,9 @@ from tenacity import (
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.const import IMAGES, LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.constant import MULTI_MODAL_MODELS
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
@ -50,7 +51,7 @@ class BaseLLM(ABC):
pass
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]:
if images:
if images and self.support_image_input():
# as gpt-4v, chat with image
return self._user_msg_with_imgs(msg, images)
else:
@ -76,6 +77,9 @@ class BaseLLM(ABC):
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}
def support_image_input(self) -> bool:
return any([m in self.config.model for m in MULTI_MODAL_MODELS])
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
@ -91,7 +95,9 @@ class BaseLLM(ABC):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
images = msg.metadata.get(IMAGES)
processed_msg = self._user_msg(msg=msg.content, images=images) if images else msg.to_dict()
processed_messages.append(processed_msg)
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import inspect
import json
import os
import re
import traceback
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
@ -13,6 +12,7 @@ from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
from metagpt.actions.di.run_command import RunCommand
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
from metagpt.const import IMAGES
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
@ -35,13 +35,7 @@ from metagpt.tools.libs.browser import Browser
from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import (
CodeParser,
any_to_str,
encode_image,
extract_image_paths,
is_support_image_input,
)
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
@ -219,15 +213,14 @@ class RoleZero(Role):
return memory
def parse_images(self, memory: list[Message]) -> list[Message]:
if not is_support_image_input(self.llm.model):
if not self.llm.support_image_input():
return memory
for i, msg in enumerate(memory):
if msg.role == "user" and isinstance(msg.content, str) and extract_image_paths(msg.content):
images = []
for path in extract_image_paths(msg.content):
if os.path.exists(path):
images.append(encode_image(path))
memory[i] = self.llm._user_msg_with_imgs(msg.content, images=images)
for msg in memory:
if IMAGES in msg.metadata or msg.role != "user":
continue
images = extract_and_encode_images(msg.content)
if images:
msg.add_metadata(IMAGES, images)
return memory
async def _act(self) -> Message:

View file

@ -840,12 +840,6 @@ def decode_image(img_url_or_b64: str) -> Image:
return img
def is_support_image_input(model_name: str) -> bool:
# model name can be gpt-4o-2024-08-06
support_models = ["gpt-4o", "gpt-4o-mini"] # FIXME: hard code for now
return any([m in model_name for m in support_models])
def extract_image_paths(content: str) -> bool:
# We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx"
pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff)"
@ -853,6 +847,14 @@ def extract_image_paths(content: str) -> bool:
return image_paths
def extract_and_encode_images(content: str) -> list[str]:
images = []
for path in extract_image_paths(content):
if os.path.exists(path):
images.append(encode_image(path))
return images
def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(

View file

@ -10,8 +10,9 @@ import pytest
from metagpt.configs.compress_msg_config import CompressType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import IMAGES
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from metagpt.schema import AIMessage, Message, UserMessage
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
@ -163,3 +164,41 @@ def test_compress_messages_long_no_sys_msg(compress_type):
print(compressed)
assert compressed
assert len(compressed[0]["content"]) < len(messages[0]["content"])
def test_format_msg(mocker):
base_llm = MockBaseLLM()
messages = [UserMessage(content="req"), AIMessage(content="rsp")]
formatted_msgs = base_llm.format_msg(messages)
assert formatted_msgs == [{"role": "user", "content": "req"}, {"role": "assistant", "content": "rsp"}]
def test_format_msg_w_images(mocker):
base_llm = MockBaseLLM()
base_llm.config.model = "gpt-4o"
msg_w_images = UserMessage(content="req1")
msg_w_images.add_metadata(IMAGES, ["base64 string 1", "base64 string 2"])
msg_w_empty_images = UserMessage(content="req2")
msg_w_empty_images.add_metadata(IMAGES, [])
messages = [
msg_w_images, # should be converted
AIMessage(content="rsp"),
msg_w_empty_images, # should not be converted
]
formatted_msgs = base_llm.format_msg(messages)
assert formatted_msgs == [
{
"role": "user",
"content": [
{"type": "text", "text": "req1"},
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 1"}},
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,base64 string 2"}},
],
},
{"role": "assistant", "content": "rsp"},
{"role": "user", "content": "req2"},
]
if name == "__main__":
pytest.main([__file__, "-s"])

View file

@ -29,9 +29,9 @@ from metagpt.utils.common import (
awrite,
check_cmd_exists,
concat_namespace,
extract_and_encode_images,
extract_image_paths,
import_class_inst,
is_support_image_input,
parse_recipient,
print_members,
read_file_block,
@ -231,9 +231,8 @@ def test_extract_image_paths():
assert not extract_image_paths(content)
def test_is_support_image_input():
assert is_support_image_input("gpt-4o-2024-08-06")
assert not is_support_image_input("deepseek-coder")
def test_extract_and_encode_images():
assert not extract_and_encode_images("a non-existing.jpg")
if __name__ == "__main__":