Update game/game.py
Browse files- game/game.py +29 -2
game/game.py
CHANGED
@@ -173,6 +173,28 @@ def load_exported_model(model_dir, model_type):
|
|
173 |
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
174 |
|
175 |
# Create a wrapper class to mimic the PyTorch model interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
class ONNXModelWrapper:
|
177 |
"""
|
178 |
- predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
|
@@ -196,8 +218,13 @@ def load_exported_model(model_dir, model_type):
|
|
196 |
self.idx_to_tag = {}
|
197 |
|
198 |
self.tag_to_category = self.metadata.get("tag_to_category", {})
|
199 |
-
if not self.tag_to_category
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
# Compatibility shim for scan_handler: model.dataset.get_tag_info(...)
|
203 |
self.dataset = self
|
|
|
173 |
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
174 |
|
175 |
# Create a wrapper class to mimic the PyTorch model interface
|
176 |
+
class ONNXModelWrapper:
|
177 |
+
"""
|
178 |
+
- predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
|
179 |
+
- dataset.get_tag_info(idx) -> (tag, category)
|
180 |
+
- Keeps your signature compatible, ignores 'threshold' in the wrapper.
|
181 |
+
"""
|
182 |
+
def __init__(self, session, metadata: dict):
|
183 |
+
self.session = session
|
184 |
+
self.metadata = metadata or {}
|
185 |
+
dataset_info = self.metadata.get("dataset_info", {})
|
186 |
+
self.total_tags = dataset_info.get("total_tags", 0)
|
187 |
+
|
188 |
+
# idx <-> tag mapping
|
189 |
+
tag_mapping = dataset_info.get("tag_mapping", {})
|
190 |
+
if "idx_to_tag" in tag_mapping:
|
191 |
+
self.idx_to_tag = {int(k): v for k, v in tag_mapping["idx_to_tag"].items()}
|
192 |
+
elif "tag_to_idx" in tag_mapping:
|
193 |
+
t2i = tag_mapping["tag_to_idx"]
|
194 |
+
self.idx_to_tag = {v: k for k, v in t2i.items()}
|
195 |
+
else:
|
196 |
+
self.idx_to_tag = {}
|
197 |
+
|
198 |
class ONNXModelWrapper:
|
199 |
"""
|
200 |
- predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
|
|
|
218 |
self.idx_to_tag = {}
|
219 |
|
220 |
self.tag_to_category = self.metadata.get("tag_to_category", {})
|
221 |
+
if not self.tag_to_category:
|
222 |
+
# Try to get from dataset_info.tag_mapping.tag_to_category (correct path)
|
223 |
+
if "tag_mapping" in dataset_info and "tag_to_category" in dataset_info["tag_mapping"]:
|
224 |
+
self.tag_to_category = dataset_info["tag_mapping"]["tag_to_category"]
|
225 |
+
# Fallback to direct path in case structure varies
|
226 |
+
elif "tag_to_category" in dataset_info:
|
227 |
+
self.tag_to_category = dataset_info["tag_to_category"]
|
228 |
|
229 |
# Compatibility shim for scan_handler: model.dataset.get_tag_info(...)
|
230 |
self.dataset = self
|