Adds basic metering infrastructure (#68)

* Basic metering module structure
* Token counting working for Bedrock
* Price calc using price list
* Added more models to pricelist
* Added Ollama token counts
----
Authored-by: JackColquitt <daniel@kalntera.ai>
This commit is contained in:
cybermaggedon 2024-09-28 20:48:20 +01:00 committed by GitHub
parent 8085bb0118
commit 2a49365482
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 176 additions and 8 deletions

View file

@ -0,0 +1,3 @@
from . counter import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . counter import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,71 @@
"""
Simple token counter for each LLM response.
"""
from prometheus_client import Histogram, Info
from . pricelist import price_list
from .. schema import TextCompletionResponse, Error
from .. schema import text_completion_response_queue
from .. log_level import LogLevel
from .. base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_response_queue
default_subscriber = module
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": TextCompletionResponse,
}
)
def get_prices(self, prices, modelname):
for model in prices["price_list"]:
if model["model_name"] == modelname:
return model["input_price"], model["output_price"]
return None, None # Return None if model is not found
def handle(self, msg):
v = msg.value()
modelname = v.model
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling response {id}...", flush=True)
num_in = v.in_token
num_out = v.out_token
model_input_price, model_output_price = self.get_prices(price_list, modelname)
cost_in = num_in * model_input_price
cost_out = num_out * model_output_price
cost_per_call = cost_in + cost_out
print(f"Input Tokens: {num_in}", flush=True)
print(f"Output Tokens: {num_out}", flush=True)
print(f"Cost for call: ${cost_per_call:.6f}", flush=True)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,49 @@
price_list = {
"price_list": [
{
"model_name": "mistral.mistral-large-2407-v1:0",
"input_price": 0.000004,
"output_price": 0.000012
},
{
"model_name": "meta.llama3-1-405b-instruct-v1:0",
"input_price": 0.00000532,
"output_price": 0.000016
},
{
"model_name": "mistral.mixtral-8x7b-instruct-v0:1",
"input_price": 0.00000045,
"output_price": 0.0000007
},
{
"model_name": "meta.llama3-1-70b-instruct-v1:0",
"input_price": 0.00000099,
"output_price": 0.00000099
},
{
"model_name": "meta.llama3-1-8b-instruct-v1:0",
"input_price": 0.00000022,
"output_price": 0.00000022
},
{
"model_name": "anthropic.claude-3-haiku-20240307-v1:0",
"input_price": 0.00000025,
"output_price": 0.00000125
},
{
"model_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"input_price": 0.000003,
"output_price": 0.000015
},
{
"model_name": "cohere.command-r-plus-v1:0",
"input_price": 0.0000030,
"output_price": 0.0000150
},
{
"model_name": "ollama",
"input_price": 0,
"output_price": 0
}
]
}

View file

@ -209,14 +209,23 @@ class Processor(ConsumerProducer):
# Use Mistral as default
else:
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
outputtext = response_body['outputs'][0]['text']
metadata = response['ResponseMetadata']['HTTPHeaders']
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
print(outputtext, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(
error=None,
response=outputtext
response=outputtext,
in_token=inputtokens,
out_token=outputtokens,
model=str(self.model),
)
self.send(r, properties={"id": id})
@ -236,6 +245,9 @@ class Processor(ConsumerProducer):
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
@ -254,6 +266,9 @@ class Processor(ConsumerProducer):
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.consumer.acknowledge(msg)

View file

@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service.
Input is prompt, output is response.
"""
from langchain_community.llms import Ollama
from ollama import Client
from prometheus_client import Histogram, Info
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
@ -67,7 +67,8 @@ class Processor(ConsumerProducer):
"ollama": ollama,
})
self.llm = Ollama(base_url=ollama, model=model)
self.model = model
self.llm = Client(host=ollama)
def handle(self, msg):
@ -83,11 +84,16 @@ class Processor(ConsumerProducer):
try:
with __class__.text_completion_metric.time():
response = self.llm.invoke(prompt)
response = self.llm.generate(self.model, prompt)
response_text = response['response']
print("Send response...", flush=True)
print(response_text, flush=True)
r = TextCompletionResponse(response=response, error=None)
inputtokens = int(response['prompt_eval_count'])
outputtokens = int(response['eval_count'])
r = TextCompletionResponse(response=response_text, error=None, in_token=inputtokens, out_token=outputtokens, model="ollama")
self.send(r, properties={"id": id})
@ -105,6 +111,9 @@ class Processor(ConsumerProducer):
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
@ -123,6 +132,9 @@ class Processor(ConsumerProducer):
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})

View file

@ -1,5 +1,5 @@
from pulsar.schema import Record, String, Array, Double
from pulsar.schema import Record, String, Array, Double, Integer
from . topic import topic
from . types import Error
@ -14,6 +14,9 @@ class TextCompletionRequest(Record):
class TextCompletionResponse(Record):
error = Error()
response = String()
in_token = Integer()
out_token = Integer()
model = String()
text_completion_request_queue = topic(
'text-completion', kind='non-persistent', namespace='request'