Camais03 commited on
Commit
3e5d145
·
verified ·
1 Parent(s): 6393fb5

Update game/game.py

Browse files
Files changed (1) hide show
  1. game/game.py +29 -2
game/game.py CHANGED
@@ -173,6 +173,28 @@ def load_exported_model(model_dir, model_type):
173
  session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
174
 
175
  # Create a wrapper class to mimic the PyTorch model interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  class ONNXModelWrapper:
177
  """
178
  - predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
@@ -196,8 +218,13 @@ def load_exported_model(model_dir, model_type):
196
  self.idx_to_tag = {}
197
 
198
  self.tag_to_category = self.metadata.get("tag_to_category", {})
199
- if not self.tag_to_category and "tag_to_category" in dataset_info:
200
- self.tag_to_category = dataset_info["tag_to_category"]
 
 
 
 
 
201
 
202
  # Compatibility shim for scan_handler: model.dataset.get_tag_info(...)
203
  self.dataset = self
 
173
  session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
174
 
175
  # Create a wrapper class to mimic the PyTorch model interface
176
+ class ONNXModelWrapper:
177
+ """
178
+ - predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
179
+ - dataset.get_tag_info(idx) -> (tag, category)
180
+ - Keeps your signature compatible, ignores 'threshold' in the wrapper.
181
+ """
182
+ def __init__(self, session, metadata: dict):
183
+ self.session = session
184
+ self.metadata = metadata or {}
185
+ dataset_info = self.metadata.get("dataset_info", {})
186
+ self.total_tags = dataset_info.get("total_tags", 0)
187
+
188
+ # idx <-> tag mapping
189
+ tag_mapping = dataset_info.get("tag_mapping", {})
190
+ if "idx_to_tag" in tag_mapping:
191
+ self.idx_to_tag = {int(k): v for k, v in tag_mapping["idx_to_tag"].items()}
192
+ elif "tag_to_idx" in tag_mapping:
193
+ t2i = tag_mapping["tag_to_idx"]
194
+ self.idx_to_tag = {v: k for k, v in t2i.items()}
195
+ else:
196
+ self.idx_to_tag = {}
197
+
198
  class ONNXModelWrapper:
199
  """
200
  - predict(image_path=..., threshold=...) -> {"refined_probabilities": np.ndarray[N, C]}
 
218
  self.idx_to_tag = {}
219
 
220
  self.tag_to_category = self.metadata.get("tag_to_category", {})
221
+ if not self.tag_to_category:
222
+ # Try to get from dataset_info.tag_mapping.tag_to_category (correct path)
223
+ if "tag_mapping" in dataset_info and "tag_to_category" in dataset_info["tag_mapping"]:
224
+ self.tag_to_category = dataset_info["tag_mapping"]["tag_to_category"]
225
+ # Fallback to direct path in case structure varies
226
+ elif "tag_to_category" in dataset_info:
227
+ self.tag_to_category = dataset_info["tag_to_category"]
228
 
229
  # Compatibility shim for scan_handler: model.dataset.get_tag_info(...)
230
  self.dataset = self