add ollama support

This commit is contained in:
better629 2023-12-22 02:20:43 +08:00
parent 7e0a2fabc7
commit 4b0cb0084a
10 changed files with 284 additions and 36 deletions

View file

@ -48,6 +48,10 @@ RPM: 10
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
#### if use self-host open llm model by ollama
# OLLAMA_API_BASE: http://127.0.0.1:11434/api
# OLLAMA_API_MODEL: llama2
#### for Search
## Supported values: serpapi/google/serper/ddg

View file

@ -42,6 +42,7 @@ class LLMProviderEnum(Enum):
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
OLLAMA = "ollama"
class Config(metaclass=Singleton):
@ -78,7 +79,8 @@ class Config(metaclass=Singleton):
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
(self.gemini_api_key, LLMProviderEnum.GEMINI),
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
]:
if self._is_valid_llm_key(k):
# logger.debug(f"Use LLMProvider: {v.value}")
@ -103,6 +105,8 @@ class Config(metaclass=Singleton):
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
self.gemini_api_key = self._get("GEMINI_API_KEY")
self.ollama_api_base = self._get("OLLAMA_API_BASE")
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
_ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")

View file

@ -102,3 +102,5 @@ CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
YAPI_URL = "http://yapi.deepwisdomai.com/"
LLM_API_TIMEOUT = 300

View file

@ -8,8 +8,9 @@
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]

View file

@ -1,3 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : refs to openai 0.x sdk
import asyncio
import json
import os
@ -43,8 +47,8 @@ MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
_thread_context = threading.local()
OPENAI_LOG = os.environ.get("OPENAI_LOG")
OPENAI_LOG = "debug"
LLM_LOG = os.environ.get("LLM_LOG")
LLM_LOG = "debug"
class ApiType(Enum):
@ -74,8 +78,8 @@ api_key_to_header = (
def _console_log_level():
if OPENAI_LOG in ["debug", "info"]:
return OPENAI_LOG
if LLM_LOG in ["debug", "info"]:
return LLM_LOG
else:
return None
@ -140,7 +144,7 @@ class OpenAIResponse:
@property
def organization(self) -> Optional[str]:
return self._headers.get("OpenAI-Organization")
return self._headers.get("LLM-Organization")
@property
def response_ms(self) -> Optional[int]:
@ -478,7 +482,7 @@ class APIRequestor:
error_data["message"] += "\n\n" + error_data["internal_message"]
log_info(
"OpenAI API error received",
"LLM API error received",
error_code=error_data.get("code"),
error_type=error_data.get("type"),
error_message=error_data.get("message"),
@ -516,7 +520,7 @@ class APIRequestor:
)
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,)
user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,)
uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node")
ua = {
@ -530,17 +534,17 @@ class APIRequestor:
}
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"X-LLM-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
}
headers.update(api_key_to_header(self.api_type, self.api_key))
if self.organization:
headers["OpenAI-Organization"] = self.organization
headers["LLM-Organization"] = self.organization
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
headers["LLM-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
headers.update(extra)
@ -592,15 +596,14 @@ class APIRequestor:
headers["Content-Type"] = "application/json"
else:
raise openai.APIConnectionError(
"Unrecognized HTTP method %r. This may indicate a bug in the "
"OpenAI bindings. Please contact us through our help center at help.openai.com for "
"assistance." % (method,)
message=f"Unrecognized HTTP method {method}. This may indicate a bug in the LLM bindings.",
request=None,
)
headers = self.request_headers(method, headers, request_id)
log_debug("Request to OpenAI API", method=method, path=abs_url)
log_debug("Post details", data=data, api_version=self.api_version)
# log_debug("Request to LLM API", method=method, path=abs_url)
# log_debug("Post details", data=data, api_version=self.api_version)
return abs_url, headers, data
@ -639,14 +642,14 @@ class APIRequestor:
except requests.exceptions.Timeout as e:
raise openai.APITimeoutError("Request timed out: {}".format(e)) from e
except requests.exceptions.RequestException as e:
raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
log_debug(
"OpenAI API response",
path=abs_url,
response_code=result.status_code,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
raise openai.APIConnectionError(message="Error communicating with LLM: {}".format(e), request=None) from e
# log_debug(
# "LLM API response",
# path=abs_url,
# response_code=result.status_code,
# processing_ms=result.headers.get("LLM-Processing-Ms"),
# request_id=result.headers.get("X-Request-Id"),
# )
return result
async def arequest_raw(
@ -685,18 +688,18 @@ class APIRequestor:
}
try:
result = await session.request(**request_kwargs)
log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
# log_info(
# "LLM API response",
# path=abs_url,
# response_code=result.status,
# processing_ms=result.headers.get("LLM-Processing-Ms"),
# request_id=result.headers.get("X-Request-Id"),
# )
return result
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
raise openai.APITimeoutError("Request timed out") from e
except aiohttp.ClientError as e:
raise openai.APIConnectionError("Error communicating with OpenAI") from e
raise openai.APIConnectionError(message="Error communicating with LLM", request=None) from e
def _interpret_response(
self, result: requests.Response, stream: bool

View file

@ -3,14 +3,38 @@
# @Desc : General Async API for http-based LLM model
import asyncio
from typing import AsyncGenerator, Tuple, Union
from typing import AsyncGenerator, Generator, Iterator, Optional, Tuple, Union
import aiohttp
import requests
from metagpt.logs import logger
from metagpt.provider.general_api_base import APIRequestor
def parse_stream_helper(line: bytes) -> Optional[str]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
else:
return line.decode("utf-8")
return None
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = parse_stream_helper(line)
if _line is not None:
yield _line
class GeneralAPIRequestor(APIRequestor):
"""
usage
@ -32,10 +56,34 @@ class GeneralAPIRequestor(APIRequestor):
return rbody
def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[str, Iterator[Generator]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())
), True
else:
return (
self._interpret_response_line(
result.content, # let the caller to decode the msg
result.status_code,
result.headers,
stream=False,
),
False,
)
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]:
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")
):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
async for line in result.content

View file

@ -0,0 +1,151 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
import json
from requests import ConnectionError
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
class OllamaCostManager(CostManager):
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(
f"Max budget: ${CONFIG.max_budget:.3f} | "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OLLAMA)
class OllamaGPTAPI(BaseGPTAPI):
"""
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
"""
def __init__(self):
self.__init_ollama(CONFIG)
self.client = GeneralAPIRequestor(base_url=CONFIG.ollama_api_base)
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = OllamaCostManager()
def __init_ollama(self, config: CONFIG):
assert config.ollama_api_base
self.model = config.ollama_api_model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})
assert assist_msg.get("role", None) == "assistant"
return assist_msg.get("content")
def get_usage(self, resp: dict) -> dict:
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
def _decode_and_load(self, chunk: bytes, encoding: str = "utf-8") -> dict:
chunk = chunk.decode(encoding)
return json.loads(chunk)
def completion(self, messages: list[dict]) -> dict:
resp, _, _ = self.client.request(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
request_timeout=LLM_API_TIMEOUT,
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
request_timeout=LLM_API_TIMEOUT,
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
stream=True,
params=self._const_kwargs(messages, stream=True),
request_timeout=LLM_API_TIMEOUT,
)
collected_content = []
usage = {}
async for raw_chunk in stream_resp:
chunk = self._decode_and_load(raw_chunk)
if not chunk.get("done", False):
content = self.get_choice_text(chunk)
collected_content.append(content)
print(content, end="")
else:
# stream finished
usage = self.get_usage(chunk)
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -196,6 +196,8 @@ def repair_invalid_json(output: str, error: str) -> str:
new_line = f'"{line}'
elif '",' in line:
new_line = line[:-2] + "',"
else:
new_line = line
arr[line_no] = new_line
output = "\n".join(arr)

View file

@ -9,7 +9,7 @@ import pytest
from metagpt.provider.google_gemini_api import GeminiGPTAPI
messages = [{"role": "user", "content": "who are you"}]
messages = [{"role": "user", "parts": "who are you"}]
@dataclass

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of ollama api
import pytest
from metagpt.provider.ollama_api import OllamaGPTAPI
messages = [{"role": "user", "content": "who are you"}]
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
def mock_llm_ask(self, messages: list[dict]) -> dict:
return default_resp
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
resp = OllamaGPTAPI().completion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
return default_resp
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
resp = await OllamaGPTAPI().acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]