mirror of
https://github.com/katanemo/plano.git
synced 2026-05-04 13:23:00 +02:00
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
This commit is contained in:
parent
3bd2ffe9fb
commit
8e54ac20d8
13 changed files with 480 additions and 43 deletions
|
|
@ -11,15 +11,14 @@ class ArchGuardHanlder:
|
|||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
self.hardware_config = model_dict["hardware_config"]
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def guard_predict(self, input_text):
|
||||
def guard_predict(self, input_text, max_length=512):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
input_text, truncation=True, max_length=512, return_tensors="pt"
|
||||
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue