mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add ollama support
This commit is contained in:
parent
7e0a2fabc7
commit
4b0cb0084a
10 changed files with 284 additions and 36 deletions
|
|
@ -48,6 +48,10 @@ RPM: 10
|
|||
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
|
||||
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
|
||||
|
||||
#### if use self-host open llm model by ollama
|
||||
# OLLAMA_API_BASE: http://127.0.0.1:11434/api
|
||||
# OLLAMA_API_MODEL: llama2
|
||||
|
||||
#### for Search
|
||||
|
||||
## Supported values: serpapi/google/serper/ddg
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class LLMProviderEnum(Enum):
|
|||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
|
|
@ -78,7 +79,8 @@ class Config(metaclass=Singleton):
|
|||
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
|
||||
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
|
||||
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
|
||||
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
|
||||
(self.gemini_api_key, LLMProviderEnum.GEMINI),
|
||||
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
|
||||
]:
|
||||
if self._is_valid_llm_key(k):
|
||||
# logger.debug(f"Use LLMProvider: {v.value}")
|
||||
|
|
@ -103,6 +105,8 @@ class Config(metaclass=Singleton):
|
|||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
self.gemini_api_key = self._get("GEMINI_API_KEY")
|
||||
self.ollama_api_base = self._get("OLLAMA_API_BASE")
|
||||
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
|
|
|
|||
|
|
@ -102,3 +102,5 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
|
|||
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
||||
LLM_API_TIMEOUT = 300
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@
|
|||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]
|
||||
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : refs to openai 0.x sdk
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
|
@ -43,8 +47,8 @@ MAX_CONNECTION_RETRIES = 2
|
|||
# Has one attribute per thread, 'session'.
|
||||
_thread_context = threading.local()
|
||||
|
||||
OPENAI_LOG = os.environ.get("OPENAI_LOG")
|
||||
OPENAI_LOG = "debug"
|
||||
LLM_LOG = os.environ.get("LLM_LOG")
|
||||
LLM_LOG = "debug"
|
||||
|
||||
|
||||
class ApiType(Enum):
|
||||
|
|
@ -74,8 +78,8 @@ api_key_to_header = (
|
|||
|
||||
|
||||
def _console_log_level():
|
||||
if OPENAI_LOG in ["debug", "info"]:
|
||||
return OPENAI_LOG
|
||||
if LLM_LOG in ["debug", "info"]:
|
||||
return LLM_LOG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
@ -140,7 +144,7 @@ class OpenAIResponse:
|
|||
|
||||
@property
|
||||
def organization(self) -> Optional[str]:
|
||||
return self._headers.get("OpenAI-Organization")
|
||||
return self._headers.get("LLM-Organization")
|
||||
|
||||
@property
|
||||
def response_ms(self) -> Optional[int]:
|
||||
|
|
@ -478,7 +482,7 @@ class APIRequestor:
|
|||
error_data["message"] += "\n\n" + error_data["internal_message"]
|
||||
|
||||
log_info(
|
||||
"OpenAI API error received",
|
||||
"LLM API error received",
|
||||
error_code=error_data.get("code"),
|
||||
error_type=error_data.get("type"),
|
||||
error_message=error_data.get("message"),
|
||||
|
|
@ -516,7 +520,7 @@ class APIRequestor:
|
|||
)
|
||||
|
||||
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
|
||||
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
|
||||
user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,)
|
||||
|
||||
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
|
||||
ua = {
|
||||
|
|
@ -530,17 +534,17 @@ class APIRequestor:
|
|||
}
|
||||
|
||||
headers = {
|
||||
"X-OpenAI-Client-User-Agent": json.dumps(ua),
|
||||
"X-LLM-Client-User-Agent": json.dumps(ua),
|
||||
"User-Agent": user_agent,
|
||||
}
|
||||
|
||||
headers.update(api_key_to_header(self.api_type, self.api_key))
|
||||
|
||||
if self.organization:
|
||||
headers["OpenAI-Organization"] = self.organization
|
||||
headers["LLM-Organization"] = self.organization
|
||||
|
||||
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
|
||||
headers["OpenAI-Version"] = self.api_version
|
||||
headers["LLM-Version"] = self.api_version
|
||||
if request_id is not None:
|
||||
headers["X-Request-Id"] = request_id
|
||||
headers.update(extra)
|
||||
|
|
@ -592,15 +596,14 @@ class APIRequestor:
|
|||
headers["Content-Type"] = "application/json"
|
||||
else:
|
||||
raise openai.APIConnectionError(
|
||||
"Unrecognized HTTP method %r. This may indicate a bug in the "
|
||||
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
|
||||
"assistance." % (method,)
|
||||
message=f"Unrecognized HTTP method {method}. This may indicate a bug in the LLM bindings.",
|
||||
request=None,
|
||||
)
|
||||
|
||||
headers = self.request_headers(method, headers, request_id)
|
||||
|
||||
log_debug("Request to OpenAI API", method=method, path=abs_url)
|
||||
log_debug("Post details", data=data, api_version=self.api_version)
|
||||
# log_debug("Request to LLM API", method=method, path=abs_url)
|
||||
# log_debug("Post details", data=data, api_version=self.api_version)
|
||||
|
||||
return abs_url, headers, data
|
||||
|
||||
|
|
@ -639,14 +642,14 @@ class APIRequestor:
|
|||
except requests.exceptions.Timeout as e:
|
||||
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
|
||||
log_debug(
|
||||
"OpenAI API response",
|
||||
path=abs_url,
|
||||
response_code=result.status_code,
|
||||
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
|
||||
request_id=result.headers.get("X-Request-Id"),
|
||||
)
|
||||
raise openai.APIConnectionError(message="Error communicating with LLM: {}".format(e), request=None) from e
|
||||
# log_debug(
|
||||
# "LLM API response",
|
||||
# path=abs_url,
|
||||
# response_code=result.status_code,
|
||||
# processing_ms=result.headers.get("LLM-Processing-Ms"),
|
||||
# request_id=result.headers.get("X-Request-Id"),
|
||||
# )
|
||||
return result
|
||||
|
||||
async def arequest_raw(
|
||||
|
|
@ -685,18 +688,18 @@ class APIRequestor:
|
|||
}
|
||||
try:
|
||||
result = await session.request(**request_kwargs)
|
||||
log_info(
|
||||
"OpenAI API response",
|
||||
path=abs_url,
|
||||
response_code=result.status,
|
||||
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
|
||||
request_id=result.headers.get("X-Request-Id"),
|
||||
)
|
||||
# log_info(
|
||||
# "LLM API response",
|
||||
# path=abs_url,
|
||||
# response_code=result.status,
|
||||
# processing_ms=result.headers.get("LLM-Processing-Ms"),
|
||||
# request_id=result.headers.get("X-Request-Id"),
|
||||
# )
|
||||
return result
|
||||
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
|
||||
raise openai.APITimeoutError("Request timed out") from e
|
||||
except aiohttp.ClientError as e:
|
||||
raise openai.APIConnectionError("Error communicating with OpenAI") from e
|
||||
raise openai.APIConnectionError(message="Error communicating with LLM", request=None) from e
|
||||
|
||||
def _interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
|
|
|
|||
|
|
@ -3,14 +3,38 @@
|
|||
# @Desc : General Async API for http-based LLM model
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Tuple, Union
|
||||
from typing import AsyncGenerator, Generator, Iterator, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.general_api_base import APIRequestor
|
||||
|
||||
|
||||
def parse_stream_helper(line: bytes) -> Optional[str]:
|
||||
if line and line.startswith(b"data:"):
|
||||
if line.startswith(b"data: "):
|
||||
# SSE event may be valid when it contain whitespace
|
||||
line = line[len(b"data: ") :]
|
||||
else:
|
||||
line = line[len(b"data:") :]
|
||||
if line.strip() == b"[DONE]":
|
||||
# return here will cause GeneratorExit exception in urllib3
|
||||
# and it will close http connection with TCP Reset
|
||||
return None
|
||||
else:
|
||||
return line.decode("utf-8")
|
||||
return None
|
||||
|
||||
|
||||
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||
for line in rbody:
|
||||
_line = parse_stream_helper(line)
|
||||
if _line is not None:
|
||||
yield _line
|
||||
|
||||
|
||||
class GeneralAPIRequestor(APIRequestor):
|
||||
"""
|
||||
usage
|
||||
|
|
@ -32,10 +56,34 @@ class GeneralAPIRequestor(APIRequestor):
|
|||
|
||||
return rbody
|
||||
|
||||
def _interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[str, Iterator[Generator]], bool]:
|
||||
"""Returns the response(s) and a bool indicating whether it is a stream."""
|
||||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
|
||||
return (
|
||||
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
|
||||
for line in parse_stream(result.iter_lines())
|
||||
), True
|
||||
else:
|
||||
return (
|
||||
self._interpret_response_line(
|
||||
result.content, # let the caller to decode the msg
|
||||
result.status_code,
|
||||
result.headers,
|
||||
stream=False,
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
async def _interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
|
||||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
|
||||
if stream and (
|
||||
"text/event-stream" in result.headers.get("Content-Type", "")
|
||||
or "application/x-ndjson" in result.headers.get("Content-Type", "")
|
||||
):
|
||||
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
|
||||
return (
|
||||
self._interpret_response_line(line, result.status, result.headers, stream=True)
|
||||
async for line in result.content
|
||||
|
|
|
|||
151
metagpt/provider/ollama_api.py
Normal file
151
metagpt/provider/ollama_api.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
|
||||
|
||||
import json
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
|
||||
|
||||
class OllamaCostManager(CostManager):
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
class OllamaGPTAPI(BaseGPTAPI):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_ollama(CONFIG)
|
||||
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
|
||||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = OllamaCostManager()
|
||||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
assert assist_msg.get("role", None) == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def get_usage(self, resp: dict) -> dict:
|
||||
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
|
||||
|
||||
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
|
||||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = self.client.request(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
stream=True,
|
||||
params=self._const_kwargs(messages, stream=True),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for raw_chunk in stream_resp:
|
||||
chunk = self._decode_and_load(raw_chunk)
|
||||
|
||||
if not chunk.get("done", False):
|
||||
content = self.get_choice_text(chunk)
|
||||
collected_content.append(content)
|
||||
print(content, end="")
|
||||
else:
|
||||
# stream finished
|
||||
usage = self.get_usage(chunk)
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -196,6 +196,8 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
new_line = f'"{line}'
|
||||
elif '",' in line:
|
||||
new_line = line[:-2] + "',"
|
||||
else:
|
||||
new_line = line
|
||||
|
||||
arr[line_no] = new_line
|
||||
output = "\n".join(arr)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
messages = [{"role": "user", "parts": "who are you"}]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
33
tests/metagpt/provider/test_ollama_api.py
Normal file
33
tests/metagpt/provider/test_ollama_api.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ollama api
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
|
||||
|
||||
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
|
||||
resp = OllamaGPTAPI().completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await OllamaGPTAPI().acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue