|
from typing import Dict, Any |
|
import io, base64, torch, torchvision |
|
from PIL import Image |
|
from torchvision import transforms as T |
|
from safetensors.torch import load_file |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = "."): |
|
self.labels = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash'] |
|
self.model = torchvision.models.efficientnet_v2_s(weights=None) |
|
nf = self.model.classifier[1].in_features |
|
self.model.classifier = torch.nn.Sequential( |
|
torch.nn.Linear(nf, 256), |
|
torch.nn.ReLU(inplace=True), |
|
torch.nn.Dropout(0.5), |
|
torch.nn.Linear(256, len(self.labels)), |
|
) |
|
state = load_file(f"{path}/model.safetensors", device="cpu") |
|
self.model.load_state_dict(state) |
|
self.model.eval() |
|
|
|
self.preprocess = T.Compose([ |
|
T.Resize((224, 224)), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485,0.456,0.406], |
|
std =[0.229,0.224,0.225]) |
|
]) |
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
img_bytes = data["inputs"] |
|
if isinstance(img_bytes, str): |
|
img_bytes = base64.b64decode(img_bytes) |
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
x = self.preprocess(img).unsqueeze(0) |
|
with torch.no_grad(): |
|
probs = self.model(x).softmax(1)[0] |
|
topk = probs.topk(5) |
|
return [{"label": self.labels[i], "score": float(topk.values[j])} |
|
for j,i in enumerate(topk.indices)] |
|
|