File size: 1,595 Bytes
8c3c7d8
 
 
 
35f28e0
8c3c7d8
 
35f28e0
8c3c7d8
35f28e0
8c3c7d8
 
 
 
 
35f28e0
8c3c7d8
35f28e0
d373c9b
8c3c7d8
35f28e0
 
8c3c7d8
 
35f28e0
 
8c3c7d8
 
 
35f28e0
 
8c3c7d8
35f28e0
 
8c3c7d8
 
 
35f28e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Dict, Any
import io, base64, torch, torchvision
from PIL import Image
from torchvision import transforms as T
from safetensors.torch import load_file

class EndpointHandler:
    def __init__(self, path: str = "."):
        self.labels = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash']
        self.model  = torchvision.models.efficientnet_v2_s(weights=None)
        nf = self.model.classifier[1].in_features
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Linear(nf, 256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(256, len(self.labels)),
        )
        state = load_file(f"{path}/model.safetensors", device="cpu")
        self.model.load_state_dict(state)
        self.model.eval()

        self.preprocess = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406],
                        std =[0.229,0.224,0.225])
        ])

    def __call__(self, data: Dict[str, Any]):
        img_bytes = data["inputs"]
        if isinstance(img_bytes, str):         # base64
            img_bytes = base64.b64decode(img_bytes)
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        x   = self.preprocess(img).unsqueeze(0)
        with torch.no_grad():
            probs = self.model(x).softmax(1)[0]
        topk = probs.topk(5)
        return [{"label": self.labels[i], "score": float(topk.values[j])}
                for j,i in enumerate(topk.indices)]