diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 01035bc9..57958bc0 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -96,20 +96,20 @@ class Processor(LlmService): api_kwargs = self._build_kwargs(model_name, effective_temperature) - resp = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - **api_kwargs, + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] + + resp = self.variant.create_completion( + self.openai, model_name, messages, **api_kwargs, ) inputtokens = resp.usage.prompt_tokens @@ -176,28 +176,24 @@ class Processor(LlmService): try: api_kwargs = self._build_kwargs(model_name, effective_temperature) - response = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - stream=True, - stream_options={"include_usage": True}, - **api_kwargs, - ) + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] total_input_tokens = 0 total_output_tokens = 0 - for chunk in response: + async for chunk in self.variant.create_completion_stream( + self.openai, model_name, messages, **api_kwargs, + ): if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( text=chunk.choices[0].delta.content, diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py index 7e650e37..87de725d 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -62,6 +62,20 @@ class Variant: """Extract thinking content from a streaming delta.""" return getattr(delta, "reasoning_content", None) + def create_completion(self, client, model, messages, **kwargs): + """Call the completions API. Override for non-standard SDKs.""" + return client.chat.completions.create( + model=model, messages=messages, **kwargs, + ) + + async def create_completion_stream(self, client, model, messages, **kwargs): + """Call the streaming completions API. Override for non-standard SDKs.""" + for chunk in client.chat.completions.create( + model=model, messages=messages, stream=True, + stream_options={"include_usage": True}, **kwargs, + ): + yield chunk + class OpenAIVariant(Variant): """Standard OpenAI API (GPT-4o, o1, o3, etc.).""" @@ -96,30 +110,8 @@ class DeepSeekVariant(Variant): return {} -class QwenVariant(Variant): - """Qwen / Alibaba Cloud API.""" - - name = "qwen" - token_param = "max_completion_tokens" - temperature_with_thinking = True - - def completion_kwargs(self, max_output, temperature, thinking): - enabled = thinking != "off" - kwargs = { - self.token_param: max_output, - "temperature": temperature, - "extra_body": { - "enable_thinking": enabled, - }, - } - return kwargs - - def thinking_kwargs(self, effort): - return {} - - class DashScopeVariant(Variant): - """Alibaba Cloud DashScope API (Qwen models via DashScope).""" + """Alibaba Cloud DashScope API (Qwen models).""" name = "dashscope" token_param = "max_completion_tokens" @@ -127,17 +119,24 @@ class DashScopeVariant(Variant): def completion_kwargs(self, max_output, temperature, thinking): enabled = thinking != "off" - kwargs = { + return { self.token_param: max_output, "temperature": temperature, - "enable_thinking": enabled, + "extra_body": { + "enable_thinking": enabled, + }, } - return kwargs def thinking_kwargs(self, effort): return {} +class QwenVariant(DashScopeVariant): + """Qwen — alias for DashScope.""" + + name = "qwen" + + class MistralVariant(Variant): """Mistral API (Mistral Large, etc.)."""