from typing import Dict, Any import io, base64, torch, torchvision from safetensors.torch import load_file from PIL import Image from torchvision import transforms as T class EndpointHandler: def __init__(self, path='.'): # path = repo root in container self.labels = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash'] self.model = torchvision.models.efficientnet_v2_s(weights=None) nf = self.model.classifier[1].in_features self.model.classifier = torch.nn.Sequential( torch.nn.Linear(nf, 256), torch.nn.ReLU(inplace=True), torch.nn.Dropout(0.5), torch.nn.Linear(256, len(self.labels)) ) state = load_file(str(pth), device="cpu") self.model.load_state_dict(state) self.model.eval() self.trans = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) def __call__(self, data: Dict[str, Any]): img_bytes = data['inputs'] if isinstance(img_bytes, str): # base64 img_bytes = base64.b64decode(img_bytes) img = Image.open(io.BytesIO(img_bytes)).convert('RGB') x = self.trans(img).unsqueeze(0) with torch.no_grad(): probs = self.model(x).softmax(1)[0] topk = probs.topk(5) return [{'label': self.labels[i], 'score': float(topk.values[j])} for j, i in enumerate(topk.indices)]