mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-10 00:02:38 +02:00
fix ollama to add llava support
This commit is contained in:
parent
8b209d4e17
commit
3fefc48efa
3 changed files with 121 additions and 101 deletions
|
|
@ -320,49 +320,6 @@ class APIRequestor:
|
|||
resp, got_stream = self._interpret_response(result, stream)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
headers,
|
||||
files,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
params=...,
|
||||
headers=...,
|
||||
files=...,
|
||||
stream: Literal[False] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
|
||||
) -> Tuple[OpenAIResponse, bool, str]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def arequest(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -3,25 +3,24 @@
|
|||
# @Desc : General Async API for http-based LLM model
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
|
||||
from typing import AsyncGenerator, Iterator, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.general_api_base import APIRequestor
|
||||
from metagpt.provider.general_api_base import APIRequestor, OpenAIResponse
|
||||
|
||||
|
||||
def parse_stream_helper(line: bytes) -> Union[bytes, None]:
|
||||
def parse_stream_helper(line: bytes) -> Optional[bytes]:
|
||||
if line and line.startswith(b"data:"):
|
||||
if line.startswith(b"data: "):
|
||||
# SSE event may be valid when it contain whitespace
|
||||
# SSE event may be valid when it contains whitespace
|
||||
line = line[len(b"data: ") :]
|
||||
else:
|
||||
line = line[len(b"data:") :]
|
||||
if line.strip() == b"[DONE]":
|
||||
# return here will cause GeneratorExit exception in urllib3
|
||||
# and it will close http connection with TCP Reset
|
||||
# Returning None to indicate end of stream
|
||||
return None
|
||||
else:
|
||||
return line
|
||||
|
|
@ -37,7 +36,7 @@ def parse_stream(rbody: Iterator[bytes]) -> Iterator[bytes]:
|
|||
|
||||
class GeneralAPIRequestor(APIRequestor):
|
||||
"""
|
||||
usage
|
||||
Usage example:
|
||||
# full_url = "{base_url}{url}"
|
||||
requester = GeneralAPIRequestor(base_url=base_url)
|
||||
result, _, api_key = await requester.arequest(
|
||||
|
|
@ -50,26 +49,47 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
)
|
||||
"""
|
||||
|
||||
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders, stream: bool) -> bytes:
|
||||
# just do nothing to meet the APIRequestor process and return the raw data
|
||||
# due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases.
|
||||
def _interpret_response_line(self, rbody: bytes, rcode: int, rheaders: dict, stream: bool) -> OpenAIResponse:
|
||||
"""
|
||||
Process and return the response data wrapped in OpenAIResponse.
|
||||
|
||||
return rbody
|
||||
Args:
|
||||
rbody (bytes): The response body.
|
||||
rcode (int): The response status code.
|
||||
rheaders (dict): The response headers.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
OpenAIResponse: The response data wrapped in OpenAIResponse.
|
||||
"""
|
||||
return OpenAIResponse(rbody, rheaders)
|
||||
|
||||
def _interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
|
||||
"""
|
||||
Interpret a synchronous response.
|
||||
|
||||
Args:
|
||||
result (requests.Response): The response object.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
|
||||
"""
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
|
||||
return (
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
), True
|
||||
(
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
),
|
||||
True,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
result.content, # let the caller to decode the msg
|
||||
result.content, # let the caller decode the msg
|
||||
result.status_code,
|
||||
result.headers,
|
||||
stream=False,
|
||||
|
|
@ -79,26 +99,39 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
|
||||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
|
||||
"""
|
||||
Interpret an asynchronous response.
|
||||
|
||||
Args:
|
||||
result (aiohttp.ClientResponse): The response object.
|
||||
stream (bool): Whether the response is a stream.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: A tuple containing the response content and a boolean indicating if it is a stream.
|
||||
"""
|
||||
content_type = result.headers.get("Content-Type", "")
|
||||
if stream and (
|
||||
"text/event-stream" in content_type or "application/x-ndjson" in content_type or content_type == ""
|
||||
):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
), True
|
||||
(
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
),
|
||||
True,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
await result.read()
|
||||
response_content = await result.read()
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise TimeoutError("Request timed out") from e
|
||||
except aiohttp.ClientError as exp:
|
||||
logger.warning(f"response: {result.content}, exp: {exp}")
|
||||
logger.warning(f"response: {result}, exp: {exp}")
|
||||
response_content = b""
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
await result.read(), # let the caller to decode the msg
|
||||
response_content, # let the caller decode the msg
|
||||
result.status,
|
||||
result.headers,
|
||||
stream=False,
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator, Tuple
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor, OpenAIResponse
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
|
@ -34,65 +35,68 @@ class OllamaLLM(BaseLLM):
|
|||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
return {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
assert assist_msg.get("role", None) == "assistant"
|
||||
return assist_msg.get("content")
|
||||
if assist_msg.get("role", None) == "assistant": # chat
|
||||
return assist_msg.get("content")
|
||||
else: # llava
|
||||
return resp["response"]
|
||||
|
||||
def get_usage(self, resp: dict) -> dict:
|
||||
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
|
||||
|
||||
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
|
||||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
def _decode_and_load(self, openai_resp: OpenAIResponse, encoding: str = "utf-8") -> dict:
|
||||
return json.loads(openai_resp.data.decode(encoding))
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
messages, fixed_suffix_url = self._apply_llava(messages)
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
params=self._const_kwargs(messages),
|
||||
url=fixed_suffix_url,
|
||||
headers=self._get_headers(),
|
||||
params=messages,
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
if isinstance(resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(resp)
|
||||
elif isinstance(resp, OpenAIResponse):
|
||||
return self._processing_openai_response(resp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
headers = (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
}
|
||||
)
|
||||
messages, fixed_suffix_url = self._apply_llava(messages, stream=True)
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
headers=headers,
|
||||
url=fixed_suffix_url,
|
||||
headers=self._get_headers(),
|
||||
stream=True,
|
||||
params=self._const_kwargs(messages, stream=True),
|
||||
params=messages,
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
if isinstance(stream_resp, AsyncGenerator):
|
||||
return await self._processing_openai_response_async_generator(stream_resp)
|
||||
elif isinstance(stream_resp, OpenAIResponse):
|
||||
return self._processing_openai_response(stream_resp)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _processing_openai_response(self, openai_resp: OpenAIResponse):
|
||||
resp = self._decode_and_load(openai_resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for raw_chunk in stream_resp:
|
||||
async for raw_chunk in ag_openai_resp:
|
||||
chunk = self._decode_and_load(raw_chunk)
|
||||
|
||||
if not chunk.get("done", False):
|
||||
|
|
@ -107,3 +111,29 @@ class OllamaLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
def _get_headers(self):
|
||||
return (
|
||||
None
|
||||
if not self.config.api_key or self.config.api_key == "sk-"
|
||||
else {"Authorization": f"Bearer {self.config.api_key}"}
|
||||
)
|
||||
|
||||
def _apply_llava(self, messages: list[dict], stream: bool = False) -> Tuple[dict, str]:
|
||||
llava = False
|
||||
if isinstance(messages[0]["content"], str):
|
||||
return self._const_kwargs(messages, stream), self.suffix_url
|
||||
|
||||
if any(len(msg["content"]) >= 2 for msg in messages):
|
||||
assert all(len(msg["content"]) >= 2 for msg in messages), "input should have the same api type"
|
||||
llava = True
|
||||
if not llava:
|
||||
return self._const_kwargs(messages, stream), self.suffix_url
|
||||
|
||||
assert len(messages) <= 1, "not support batch massages in llava calling images"
|
||||
contents = messages[0]["content"]
|
||||
return {
|
||||
"model": self.model,
|
||||
"prompt": contents[0]["text"],
|
||||
"images": [i["image_url"]["url"][23:] for i in contents[1:]],
|
||||
}, "/generate"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue