mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
add gpt-4v support for aask and AN.fill
This commit is contained in:
parent
9e49e2252d
commit
310687258e
8 changed files with 113 additions and 14 deletions
|
|
@ -6,9 +6,11 @@
|
|||
@File : llm_hello_world.py
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -27,6 +29,15 @@ async def main():
|
|||
if hasattr(llm, "completion"):
|
||||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
# check llm-vision capacity if it supports
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
try:
|
||||
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
|
||||
assert "true" in res.lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -370,12 +370,13 @@ class ActionNode:
|
|||
prompt: str,
|
||||
output_class_name: str,
|
||||
output_data_mapping: dict,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=3,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
|
|
@ -404,13 +405,15 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, timeout=3, exclude=None):
|
||||
async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode, exclude=exclude)
|
||||
class_name = f"{self.key}_AN"
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
|
||||
content, scontent = await self._aask_v1(
|
||||
prompt, class_name, mapping, images=images, schema=schema, timeout=timeout
|
||||
)
|
||||
self.content = content
|
||||
self.instruct_content = scontent
|
||||
else:
|
||||
|
|
@ -419,7 +422,17 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=3, exclude=[]):
|
||||
async def fill(
|
||||
self,
|
||||
context,
|
||||
llm,
|
||||
schema="json",
|
||||
mode="auto",
|
||||
strgy="simple",
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
exclude=[],
|
||||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
@ -435,6 +448,7 @@ class ActionNode:
|
|||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
:param images: the list of image url or base64 for gpt4-v
|
||||
:param timeout: Timeout for llm invocation.
|
||||
:param exclude: The keys of ActionNode to exclude.
|
||||
:return: self
|
||||
|
|
@ -445,14 +459,14 @@ class ActionNode:
|
|||
schema = self.schema
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
|
||||
elif strgy == "complex":
|
||||
# 这里隐式假设了拥有children
|
||||
tmp = {}
|
||||
for _, i in self.children.items():
|
||||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.model_dump())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,26 @@ class BaseLLM(ABC):
|
|||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "content": msg}
|
||||
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, Union[str, dict]]:
|
||||
if images:
|
||||
# as gpt-4v, chat with image
|
||||
return self._user_msg_with_imgs(msg, images)
|
||||
else:
|
||||
return {"role": "user", "content": msg}
|
||||
|
||||
def _user_msg_with_imgs(self, msg: str, images: Optional[Union[str, list[str]]]):
|
||||
"""
|
||||
images: can be list of http(s) url or base64
|
||||
"""
|
||||
if isinstance(images, str):
|
||||
images = [images]
|
||||
content = [{"type": "text", "text": msg}]
|
||||
for image in images:
|
||||
# image url or image base64
|
||||
url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}"
|
||||
# it can with multiple-image inputs
|
||||
content.append({"type": "image_url", "image_url": url})
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "assistant", "content": msg}
|
||||
|
|
@ -54,6 +72,7 @@ class BaseLLM(ABC):
|
|||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
|
|
@ -65,7 +84,7 @@ class BaseLLM(ABC):
|
|||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class OpenAILLM(BaseLLM):
|
|||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import base64
|
||||
import contextlib
|
||||
import csv
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -465,6 +467,29 @@ def write_json_file(json_file: str, data: list, encoding=None):
|
|||
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
|
||||
|
||||
|
||||
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
||||
"""
|
||||
Reads in a csv file to a list of list. If header is True, it returns a
|
||||
tuple with (header row, all rows)
|
||||
ARGS:
|
||||
curr_file: path to the current csv file.
|
||||
RETURNS:
|
||||
List of list where the component lists are the rows of the file.
|
||||
"""
|
||||
logger.debug(f"start read csv: {curr_file}")
|
||||
analysis_list = []
|
||||
with open(curr_file) as f_analysis_file:
|
||||
data_reader = csv.reader(f_analysis_file, delimiter=",")
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
if not header:
|
||||
return analysis_list
|
||||
else:
|
||||
return analysis_list[0], analysis_list[1:]
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
module = importlib.import_module(module_name)
|
||||
a_class = getattr(module, class_name)
|
||||
|
|
@ -573,3 +598,8 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return files
|
||||
|
||||
|
||||
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
|
||||
with open(str(image_path), "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode(encoding)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ TOKEN_COSTS = {
|
|||
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
|
|
@ -54,6 +55,7 @@ TOKEN_MAX = {
|
|||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768,
|
||||
|
|
@ -82,6 +84,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-1106-vision-preview",
|
||||
}:
|
||||
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
|
||||
|
|
@ -112,7 +115,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
content = value
|
||||
if isinstance(value, list):
|
||||
# for gpt-4v
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("type") in ["text"]:
|
||||
content = item.get("text", "")
|
||||
num_tokens += len(encoding.encode(content))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action_node.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
|
@ -17,6 +18,7 @@ from metagpt.llm import LLM
|
|||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.team import Team
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -241,6 +243,18 @@ def test_create_model_class_with_mapping():
|
|||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_with_image():
|
||||
invoice = ActionNode(
|
||||
key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False"
|
||||
)
|
||||
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
|
||||
assert node.instruct_content.invoice
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
|
|
@ -35,6 +35,7 @@ class MockLLM(OpenAILLM):
|
|||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
):
|
||||
|
|
@ -47,7 +48,7 @@ class MockLLM(OpenAILLM):
|
|||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
|
|
@ -66,6 +67,7 @@ class MockLLM(OpenAILLM):
|
|||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
|
|
@ -73,7 +75,7 @@ class MockLLM(OpenAILLM):
|
|||
if system_msgs:
|
||||
joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#"
|
||||
msg_key = joined_system_msg + msg_key
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, timeout, stream)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream)
|
||||
return rsp
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue