From 9cbc34662aa37adbd9f8dc9528dc78f0d3fce1cd Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 26 Jan 2024 15:11:43 +0800 Subject: [PATCH] add gpt4-v --- examples/llm_hello_world.py | 13 ++++++++++ metagpt/actions/action_node.py | 14 ++++++----- metagpt/provider/base_llm.py | 30 ++++++++++++++++++++--- metagpt/provider/openai_api.py | 2 +- metagpt/utils/common.py | 6 +++++ metagpt/utils/token_counter.py | 6 +++++ tests/metagpt/actions/test_action_node.py | 18 ++++++++++++++ tests/mock/mock_llm.py | 8 +++--- 8 files changed, 84 insertions(+), 13 deletions(-) diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 219a303c8..4baeaa01e 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -6,9 +6,11 @@ @File : llm_hello_world.py """ import asyncio +from pathlib import Path from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.utils.common import encode_image async def main(): @@ -27,6 +29,17 @@ async def main(): if hasattr(llm, "completion"): logger.info(llm.completion(hello_msg)) + # check llm-vision capacity if it supports + invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + try: + res = await llm.aask(msg="if this is a invoice, just return True else return False", + images=[img_base64]) + assert "true" in res.lower() + except Exception as exp: + pass + + if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index ca41c76a5..dfe8b0aae 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -354,12 +354,13 @@ class ActionNode: prompt: str, output_class_name: str, output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format timeout=3, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) logger.debug(f"llm raw output:\n{content}") output_class = self.create_model_class(output_class_name, output_data_mapping) @@ -388,13 +389,13 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=3, exclude=None): + async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None): prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + content, scontent = await self._aask_v1(prompt, class_name, mapping, images=images, schema=schema, timeout=timeout) self.content = content self.instruct_content = scontent else: @@ -403,7 +404,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -419,6 +420,7 @@ class ActionNode: :param strgy: simple/complex - simple: run only once - complex: run each node + :param images: the list of image url or base64 for gpt4-v :param timeout: Timeout for llm invocation. :param exclude: The keys of ActionNode to exclude. :return: self @@ -429,14 +431,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): if exclude and i.key in exclude: continue - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.model_dump()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 5fe9d1c3a..f9e9cddc9 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -34,8 +34,31 @@ class BaseLLM(ABC): def __init__(self, config: LLMConfig): pass - def _user_msg(self, msg: str) -> dict[str, str]: - return {"role": "user", "content": msg} + def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: + if images: + # as gpt-4v, chat with image + return self._user_msg_with_imgs(msg, images) + else: + return {"role": "user", "content": msg} + + def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]): + """ + images: can be list of http(s) url or base64 + """ + if isinstance(images, str): + images = [images] + content = [ + {"type": "text", "text": msg} + ] + for image in images: + # image url or image base64 + url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}" + # it can with multiple-image inputs + content.append({ + "type": "image_url", + "image_url": url + }) + return {"role": "user", "content": content} def _assistant_msg(self, msg: str) -> dict[str, str]: return {"role": "assistant", "content": msg} @@ -54,6 +77,7 @@ class BaseLLM(ABC): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -65,7 +89,7 @@ class BaseLLM(ABC): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) logger.debug(message) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index d6944eae6..2ec78317a 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -99,7 +99,7 @@ class OpenAILLM(BaseLLM): "messages": messages, "max_tokens": self._get_max_tokens(messages), "n": 1, - "stop": None, + # "stop": None, # default it's None and gpt4-v can't have this one "temperature": 0.3, "model": self.model, "timeout": max(self.config.timeout, timeout), diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 74024fdd6..1cc482852 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -23,6 +23,7 @@ import re import sys import traceback import typing +import base64 from pathlib import Path from typing import Any, List, Tuple, Union @@ -591,3 +592,8 @@ def list_files(root: str | Path) -> List[Path]: except Exception as e: logger.error(f"Error: {e}") return files + + +def encode_image(image_path: Path, encoding: str = "utf-8") -> str: + with open(str(image_path), "rb") as image_file: + return base64.b64encode(image_file.read()).decode(encoding) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 885eb37d7..1f6622d28 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -26,6 +26,8 @@ TOKEN_COSTS = { "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator + "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens "glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens @@ -48,6 +50,8 @@ TOKEN_MAX = { "gpt-4-32k-0314": 32768, "gpt-4-0613": 8192, "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4-1106-vision-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, "gemini-pro": 32768, @@ -73,6 +77,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-0613", "gpt-4-32k-0613", "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4-1106-vision-preview" }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 53de9cc75..ccda665a1 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -8,6 +8,8 @@ from typing import List, Tuple import pytest +import base64 +from pathlib import Path from pydantic import ValidationError from metagpt.actions import Action @@ -17,6 +19,7 @@ from metagpt.llm import LLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.team import Team +from metagpt.utils.common import encode_image @pytest.mark.asyncio @@ -241,6 +244,21 @@ def test_create_model_class_with_mapping(): assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] +@pytest.mark.asyncio +async def test_action_node_with_image(): + invoice = ActionNode( + key="invoice", + expected_type=bool, + instruction="if it's a invoice file, return True else False", + example="False" + ) + + invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) + assert node.instruct_content.invoice + + if __name__ == "__main__": test_create_model_class() test_create_model_class_with_mapping() diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index bef380c83..f093d9ce1 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from metagpt.config2 import config from metagpt.logs import log_llm_stream, logger @@ -35,6 +35,7 @@ class MockLLM(OpenAILLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ): @@ -47,7 +48,7 @@ class MockLLM(OpenAILLM): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp @@ -66,6 +67,7 @@ class MockLLM(OpenAILLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -73,7 +75,7 @@ class MockLLM(OpenAILLM): if system_msgs: joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#" msg_key = joined_system_msg + msg_key - rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream) + rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp async def aask_batch(self, msgs: list, timeout=3) -> str: