mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
add gpt4-v
This commit is contained in:
parent
fe0d27dde1
commit
471871b827
4 changed files with 51 additions and 0 deletions
|
|
@ -6,9 +6,11 @@
|
|||
@File : llm_hello_world.py
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -38,6 +40,17 @@ async def main():
|
|||
if hasattr(llm, "completion"):
|
||||
logger.info(llm.completion(hello_msg))
|
||||
|
||||
# check llm-vision capacity if it supports
|
||||
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
try:
|
||||
res = await llm.aask(msg="if this is a invoice, just return True else return False",
|
||||
images=[img_base64])
|
||||
assert "true" in res.lower()
|
||||
except Exception as exp:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -60,12 +60,25 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
if isinstance(images, str):
|
||||
images = [images]
|
||||
<<<<<<< HEAD
|
||||
content = [{"type": "text", "text": msg}]
|
||||
=======
|
||||
content = [
|
||||
{"type": "text", "text": msg}
|
||||
]
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
for image in images:
|
||||
# image url or image base64
|
||||
url = image if image.startswith("http") else f"data:image/jpeg;base64,{image}"
|
||||
# it can with multiple-image inputs
|
||||
<<<<<<< HEAD
|
||||
content.append({"type": "image_url", "image_url": url})
|
||||
=======
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": url
|
||||
})
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
@ -131,7 +144,11 @@ class BaseLLM(ABC):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
<<<<<<< HEAD
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
=======
|
||||
timeout=3,
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
|
|
@ -142,10 +159,14 @@ class BaseLLM(ABC):
|
|||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
<<<<<<< HEAD
|
||||
if isinstance(msg, str):
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
else:
|
||||
message.extend(msg)
|
||||
=======
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -109,9 +109,15 @@ class OpenAILLM(BaseLLM):
|
|||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
<<<<<<< HEAD
|
||||
# "n": 1, # Some services do not provide this parameter, such as mistral
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": self.config.temperature,
|
||||
=======
|
||||
"n": 1,
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": 0.3,
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
"model": self.model,
|
||||
"timeout": self.get_timeout(timeout),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,12 @@ import platform
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
<<<<<<< HEAD
|
||||
from io import BytesIO
|
||||
=======
|
||||
import typing
|
||||
import base64
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Literal, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
|
@ -744,6 +749,7 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
|
@ -861,3 +867,8 @@ def get_markdown_codeblock_type(filename: str) -> str:
|
|||
"application/sql": "sql",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
=======
|
||||
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
|
||||
with open(str(image_path), "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode(encoding)
|
||||
>>>>>>> 9cbc3466 (add gpt4-v)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue