mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
feat: add API variant profiles and thinking support to OpenAI processor (#1007)
Add a --variant flag (openai, deepseek, qwen, mistral, llama) that encapsulates provider-specific API differences: output token parameter names, thinking/reasoning toggles, temperature rules, and thinking output extraction. Add --thinking flag (off, low, medium, high) to control reasoning effort.
This commit is contained in:
parent
01cc8dbc64
commit
f20b50cfb2
2 changed files with 233 additions and 16 deletions
|
|
@ -10,6 +10,7 @@ import logging
|
|||
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
from . variants import get_variant, DEFAULT_VARIANT, VARIANTS
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -21,6 +22,7 @@ default_temperature = 0.0
|
|||
default_max_output = 4096
|
||||
default_api_key = os.getenv("OPENAI_TOKEN")
|
||||
default_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
default_thinking = "off"
|
||||
|
||||
if default_base_url is None or default_base_url == "":
|
||||
default_base_url = "https://api.openai.com/v1"
|
||||
|
|
@ -28,16 +30,21 @@ if default_base_url is None or default_base_url == "":
|
|||
class Processor(LlmService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
thinking = params.get("thinking", default_thinking)
|
||||
variant_name = params.get("variant", DEFAULT_VARIANT)
|
||||
|
||||
if not api_key:
|
||||
api_key = "not-set"
|
||||
|
||||
self.variant = get_variant(variant_name)
|
||||
self.thinking = thinking
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
|
|
@ -56,13 +63,28 @@ class Processor(LlmService):
|
|||
else:
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
logger.info("OpenAI LLM service initialized")
|
||||
logger.info(
|
||||
f"OpenAI LLM service initialized "
|
||||
f"(variant={self.variant.name}, thinking={self.thinking})"
|
||||
)
|
||||
|
||||
def _build_kwargs(self, model_name, temperature):
|
||||
"""Build API call kwargs using the active variant."""
|
||||
return self.variant.completion_kwargs(
|
||||
max_output=self.max_output,
|
||||
temperature=temperature,
|
||||
thinking=self.thinking,
|
||||
)
|
||||
|
||||
def _extract_content(self, message):
|
||||
"""Extract visible content from a response message."""
|
||||
if hasattr(self.variant, "extract_content"):
|
||||
return self.variant.extract_content(message)
|
||||
return message.content
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
|
|
@ -72,6 +94,8 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
|
|
@ -85,18 +109,23 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
**api_kwargs,
|
||||
)
|
||||
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
logger.debug(f"LLM response: {resp.choices[0].message.content}")
|
||||
|
||||
content = self._extract_content(resp.choices[0].message)
|
||||
thinking = self.variant.extract_thinking(resp.choices[0].message)
|
||||
|
||||
logger.debug(f"LLM response: {content}")
|
||||
if thinking:
|
||||
logger.debug(f"LLM thinking: {thinking[:200]}...")
|
||||
logger.info(f"Input Tokens: {inputtokens}")
|
||||
logger.info(f"Output Tokens: {outputtokens}")
|
||||
|
||||
resp = LlmResult(
|
||||
text = resp.choices[0].message.content,
|
||||
text = content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = model_name
|
||||
|
|
@ -136,9 +165,7 @@ class Processor(LlmService):
|
|||
Stream content generation from OpenAI.
|
||||
Yields LlmChunk objects with is_final=True on the last chunk.
|
||||
"""
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
|
|
@ -147,6 +174,8 @@ class Processor(LlmService):
|
|||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
response = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
|
|
@ -160,16 +189,14 @@ class Processor(LlmService):
|
|||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
stream_options={"include_usage": True},
|
||||
**api_kwargs,
|
||||
)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
|
|
@ -254,6 +281,20 @@ class Processor(LlmService):
|
|||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--thinking',
|
||||
choices=["off", "low", "medium", "high"],
|
||||
default=default_thinking,
|
||||
help=f'Thinking/reasoning effort level (default: {default_thinking})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--variant',
|
||||
choices=sorted(VARIANTS.keys()),
|
||||
default=DEFAULT_VARIANT,
|
||||
help=f'API variant (default: {DEFAULT_VARIANT})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,176 @@
|
|||
"""
|
||||
OpenAI API variant profiles.
|
||||
|
||||
Different providers expose OpenAI-compatible APIs with subtle differences
|
||||
in parameter names, thinking/reasoning support, and temperature handling.
|
||||
Each variant encapsulates those quirks so the processor doesn't need
|
||||
provider-specific conditionals.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Variant:
|
||||
"""Base variant — defines the interface all variants implement."""
|
||||
|
||||
name = None
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
"""Build provider-specific kwargs for chat.completions.create().
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_output : int
|
||||
Configured max output tokens.
|
||||
temperature : float
|
||||
Configured temperature.
|
||||
thinking : str
|
||||
Thinking effort level: "off", "low", "medium", "high".
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Extra kwargs to spread into the API call.
|
||||
"""
|
||||
kwargs = {self.token_param: max_output}
|
||||
|
||||
if thinking != "off":
|
||||
kwargs.update(self.thinking_kwargs(thinking))
|
||||
if not self.temperature_with_thinking:
|
||||
kwargs["temperature"] = 1.0
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
"""Return kwargs to enable thinking at the given effort level."""
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
"""Extract thinking/reasoning content from a response message."""
|
||||
return getattr(message, "reasoning_content", None)
|
||||
|
||||
def extract_thinking_stream(self, delta):
|
||||
"""Extract thinking content from a streaming delta."""
|
||||
return getattr(delta, "reasoning_content", None)
|
||||
|
||||
|
||||
class OpenAIVariant(Variant):
|
||||
"""Standard OpenAI API (GPT-4o, o1, o3, etc.)."""
|
||||
|
||||
name = "openai"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class DeepSeekVariant(Variant):
|
||||
"""DeepSeek API (R1, V3, etc.)."""
|
||||
|
||||
name = "deepseek"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = "enabled" if thinking != "off" else "disabled"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"thinking": {"type": enabled},
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class QwenVariant(Variant):
|
||||
"""Qwen / Alibaba Cloud API."""
|
||||
|
||||
name = "qwen"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = thinking != "off"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"enable_thinking": enabled,
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class MistralVariant(Variant):
|
||||
"""Mistral API (Mistral Large, etc.)."""
|
||||
|
||||
name = "mistral"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class LlamaVariant(Variant):
|
||||
"""Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.).
|
||||
|
||||
Thinking is typically always-on or always-off depending on the model.
|
||||
When present, thinking appears inline as <think>...</think> tags.
|
||||
"""
|
||||
|
||||
name = "llama"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
content = message.content or ""
|
||||
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
def extract_content(self, message):
|
||||
"""Strip think tags from visible content."""
|
||||
content = message.content or ""
|
||||
return re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
VARIANTS = {
|
||||
"openai": OpenAIVariant,
|
||||
"deepseek": DeepSeekVariant,
|
||||
"qwen": QwenVariant,
|
||||
"mistral": MistralVariant,
|
||||
"llama": LlamaVariant,
|
||||
}
|
||||
|
||||
DEFAULT_VARIANT = "openai"
|
||||
|
||||
|
||||
def get_variant(name):
|
||||
"""Look up a variant by name, raising ValueError if unknown."""
|
||||
cls = VARIANTS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown variant {name!r}. "
|
||||
f"Available: {', '.join(sorted(VARIANTS))}"
|
||||
)
|
||||
return cls()
|
||||
Loading…
Add table
Add a link
Reference in a new issue