attilaultzindur's picture
model & handler update
35f28e0
from typing import Dict, Any
import io, base64, torch, torchvision
from PIL import Image
from torchvision import transforms as T
from safetensors.torch import load_file
class EndpointHandler:
def __init__(self, path: str = "."):
self.labels = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash']
self.model = torchvision.models.efficientnet_v2_s(weights=None)
nf = self.model.classifier[1].in_features
self.model.classifier = torch.nn.Sequential(
torch.nn.Linear(nf, 256),
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(0.5),
torch.nn.Linear(256, len(self.labels)),
)
state = load_file(f"{path}/model.safetensors", device="cpu")
self.model.load_state_dict(state)
self.model.eval()
self.preprocess = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406],
std =[0.229,0.224,0.225])
])
def __call__(self, data: Dict[str, Any]):
img_bytes = data["inputs"]
if isinstance(img_bytes, str): # base64
img_bytes = base64.b64decode(img_bytes)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
x = self.preprocess(img).unsqueeze(0)
with torch.no_grad():
probs = self.model(x).softmax(1)[0]
topk = probs.topk(5)
return [{"label": self.labels[i], "score": float(topk.values[j])}
for j,i in enumerate(topk.indices)]