diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 219a303c8..dfc2603aa 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -6,9 +6,11 @@ @File : llm_hello_world.py """ import asyncio +from pathlib import Path from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.utils.common import encode_image async def main(): @@ -27,6 +29,15 @@ async def main(): if hasattr(llm, "completion"): logger.info(llm.completion(hello_msg)) + # check llm-vision capacity if it supports + invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + try: + res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64]) + assert "true" in res.lower() + except Exception: + pass + if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 162ab90eb..bd2f0d11f 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -370,12 +370,13 @@ class ActionNode: prompt: str, output_class_name: str, output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format timeout=3, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, timeout=timeout) + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) logger.debug(f"llm raw output:\n{content}") output_class = self.create_model_class(output_class_name, output_data_mapping) @@ -404,13 +405,15 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=3, exclude=None): + async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None): prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) + content, scontent = await self._aask_v1( + prompt, class_name, mapping, images=images, schema=schema, timeout=timeout + ) self.content = content self.instruct_content = scontent else: @@ -419,7 +422,17 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]): + async def fill( + self, + context, + llm, + schema="json", + mode="auto", + strgy="simple", + images: Optional[Union[str, list[str]]] = None, + timeout=3, + exclude=[], + ): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -435,6 +448,7 @@ class ActionNode: :param strgy: simple/complex - simple: run only once - complex: run each node + :param images: the list of image url or base64 for gpt4-v :param timeout: Timeout for llm invocation. :param exclude: The keys of ActionNode to exclude. :return: self @@ -445,14 +459,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): if exclude and i.key in exclude: continue - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) + child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.model_dump()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 5fe9d1c3a..7c5892018 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -34,8 +34,26 @@ class BaseLLM(ABC): def __init__(self, config: LLMConfig): pass - def _user_msg(self, msg: str) -> dict[str, str]: - return {"role": "user", "content": msg} + def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]: + if images: + # as gpt-4v, chat with image + return self._user_msg_with_imgs(msg, images) + else: + return {"role": "user", "content": msg} + + def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]): + """ + images: can be list of http(s) url or base64 + """ + if isinstance(images, str): + images = [images] + content = [{"type": "text", "text": msg}] + for image in images: + # image url or image base64 + url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}" + # it can with multiple-image inputs + content.append({"type": "image_url", "image_url": url}) + return {"role": "user", "content": content} def _assistant_msg(self, msg: str) -> dict[str, str]: return {"role": "assistant", "content": msg} @@ -54,6 +72,7 @@ class BaseLLM(ABC): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -65,7 +84,7 @@ class BaseLLM(ABC): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) logger.debug(message) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index d6944eae6..2ec78317a 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -99,7 +99,7 @@ class OpenAILLM(BaseLLM): "messages": messages, "max_tokens": self._get_max_tokens(messages), "n": 1, - "stop": None, + # "stop": None, # default it's None and gpt4-v can't have this one "temperature": 0.3, "model": self.model, "timeout": max(self.config.timeout, timeout), diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index f1bd1a8e5..73017cf77 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -12,7 +12,9 @@ from __future__ import annotations import ast +import base64 import contextlib +import csv import importlib import inspect import json @@ -465,6 +467,29 @@ def write_json_file(json_file: str, data: list, encoding=None): json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python) +def read_csv_to_list(curr_file: str, header=False, strip_trail=True): + """ + Reads in a csv file to a list of list. If header is True, it returns a + tuple with (header row, all rows) + ARGS: + curr_file: path to the current csv file. + RETURNS: + List of list where the component lists are the rows of the file. + """ + logger.debug(f"start read csv: {curr_file}") + analysis_list = [] + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + if strip_trail: + row = [i.strip() for i in row] + analysis_list += [row] + if not header: + return analysis_list + else: + return analysis_list[0], analysis_list[1:] + + def import_class(class_name: str, module_name: str) -> type: module = importlib.import_module(module_name) a_class = getattr(module, class_name) @@ -573,3 +598,8 @@ def list_files(root: str | Path) -> List[Path]: except Exception as e: logger.error(f"Error: {e}") return files + + +def encode_image(image_path: Path, encoding: str = "utf-8") -> str: + with open(str(image_path), "rb") as image_file: + return base64.b64encode(image_file.read()).decode(encoding) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 94506e373..a0fb3b70d 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -29,6 +29,7 @@ TOKEN_COSTS = { "gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens @@ -54,6 +55,7 @@ TOKEN_MAX = { "gpt-4-turbo-preview": 128000, "gpt-4-0125-preview": 128000, "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, @@ -82,6 +84,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-turbo-preview", "gpt-4-0125-preview", "gpt-4-1106-preview", + "gpt-4-vision-preview", "gpt-4-1106-vision-preview", }: tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> @@ -112,7 +115,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): for message in messages: num_tokens += tokens_per_message for key, value in message.items(): - num_tokens += len(encoding.encode(value)) + content = value + if isinstance(value, list): + # for gpt-4v + for item in value: + if isinstance(item, dict) and item.get("type") in ["text"]: + content = item.get("text", "") + num_tokens += len(encoding.encode(content)) if key == "name": num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 53de9cc75..8aee071d4 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : test_action_node.py """ +from pathlib import Path from typing import List, Tuple import pytest @@ -17,6 +18,7 @@ from metagpt.llm import LLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.team import Team +from metagpt.utils.common import encode_image @pytest.mark.asyncio @@ -241,6 +243,18 @@ def test_create_model_class_with_mapping(): assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] +@pytest.mark.asyncio +async def test_action_node_with_image(): + invoice = ActionNode( + key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False" + ) + + invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") + img_base64 = encode_image(invoice_path) + node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) + assert node.instruct_content.invoice + + if __name__ == "__main__": test_create_model_class() test_create_model_class_with_mapping() diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index bef380c83..f093d9ce1 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from metagpt.config2 import config from metagpt.logs import log_llm_stream, logger @@ -35,6 +35,7 @@ class MockLLM(OpenAILLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ): @@ -47,7 +48,7 @@ class MockLLM(OpenAILLM): message = [] if format_msgs: message.extend(format_msgs) - message.append(self._user_msg(msg)) + message.append(self._user_msg(msg, images=images)) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp @@ -66,6 +67,7 @@ class MockLLM(OpenAILLM): msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: @@ -73,7 +75,7 @@ class MockLLM(OpenAILLM): if system_msgs: joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#" msg_key = joined_system_msg + msg_key - rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream) + rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp async def aask_batch(self, msgs: list, timeout=3) -> str: