from typing import Dict, Any import io, base64, torch, torchvision from PIL import Image from torchvision import transforms as T from safetensors.torch import load_file class EndpointHandler: def __init__(self, path: str = "."): self.labels = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash'] self.model = torchvision.models.efficientnet_v2_s(weights=None) nf = self.model.classifier[1].in_features self.model.classifier = torch.nn.Sequential( torch.nn.Linear(nf, 256), torch.nn.ReLU(inplace=True), torch.nn.Dropout(0.5), torch.nn.Linear(256, len(self.labels)), ) state = load_file(f"{path}/model.safetensors", device="cpu") self.model.load_state_dict(state) self.model.eval() self.preprocess = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std =[0.229,0.224,0.225]) ]) def __call__(self, data: Dict[str, Any]): img_bytes = data["inputs"] if isinstance(img_bytes, str): # base64 img_bytes = base64.b64decode(img_bytes) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") x = self.preprocess(img).unsqueeze(0) with torch.no_grad(): probs = self.model(x).softmax(1)[0] topk = probs.topk(5) return [{"label": self.labels[i], "score": float(topk.values[j])} for j,i in enumerate(topk.indices)]