vertigoq3 commited on
Commit
8c3fc6d
·
verified ·
1 Parent(s): 4383ca8

Add inference endpoint handler

Browse files
Files changed (1) hide show
  1. handler.py +85 -0
handler.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Handler para el Inference Endpoint del clasificador de emails
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ import pickle
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ class EndpointHandler:
12
+ def __init__(self):
13
+ self.model = None
14
+ self.tokenizer = None
15
+ self.encoder = None
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.load_model()
18
+
19
+ def load_model(self):
20
+ """Cargar el modelo"""
21
+ try:
22
+ # Cargar modelo y tokenizer
23
+ self.model = AutoModelForSequenceClassification.from_pretrained("vertigoq3/email-classifier-bert")
24
+ self.tokenizer = AutoTokenizer.from_pretrained("vertigoq3/email-classifier-bert")
25
+
26
+ # Mover al dispositivo
27
+ self.model.to(self.device)
28
+ self.model.eval()
29
+
30
+ # Cargar encoder
31
+ encoder_path = hf_hub_download(
32
+ repo_id="vertigoq3/email-classifier-bert",
33
+ filename="label_encoder.pkl"
34
+ )
35
+
36
+ with open(encoder_path, "rb") as f:
37
+ self.encoder = pickle.load(f)
38
+
39
+ except Exception as e:
40
+ print(f"Error al cargar modelo: {e}")
41
+ raise
42
+
43
+ def __call__(self, inputs):
44
+ """Procesar una solicitud de inferencia"""
45
+ try:
46
+ if isinstance(inputs, str):
47
+ text = inputs
48
+ elif isinstance(inputs, dict) and "inputs" in inputs:
49
+ text = inputs["inputs"]
50
+ else:
51
+ text = str(inputs)
52
+
53
+ # Tokenizar
54
+ tokenized = self.tokenizer(
55
+ text,
56
+ return_tensors="pt",
57
+ truncation=True,
58
+ padding=True,
59
+ max_length=512
60
+ )
61
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
62
+
63
+ # Clasificar
64
+ with torch.no_grad():
65
+ outputs = self.model(**tokenized)
66
+ logits = outputs.logits
67
+ probabilities = torch.softmax(logits, dim=-1)
68
+ predicted_class_id = torch.argmax(probabilities, dim=-1).item()
69
+ predicted_class = self.encoder.inverse_transform([predicted_class_id])[0]
70
+ confidence = float(probabilities[0][predicted_class_id])
71
+
72
+ return {
73
+ "predicted_class": predicted_class,
74
+ "confidence": confidence,
75
+ "all_probabilities": {
76
+ self.encoder.classes_[i]: float(probabilities[0][i])
77
+ for i in range(len(self.encoder.classes_))
78
+ }
79
+ }
80
+
81
+ except Exception as e:
82
+ return {"error": str(e)}
83
+
84
+ # Crear instancia global
85
+ handler = EndpointHandler()