merve HF Staff commited on
Commit
6862a3f
·
verified ·
1 Parent(s): c10d94a

Upload modular_isaac.py

Browse files

Hello and congrats for the release!
This PR makes this model load with no additional dependency, you can also add a small inference notebook I made by replacing the username to Perceptron: https://colab.research.google.com/drive/1BHl2ZT8cYZ0HlP_q4HllFuCXWIBX_R_2?usp=sharing

if you add the "notebook.ipynb" repo to it's one-click open in the repository, making it easier for people to try out your model as well!

Files changed (1) hide show
  1. modular_isaac.py +943 -21
modular_isaac.py CHANGED
@@ -1,7 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from collections import defaultdict
4
- from typing import Any, Union, TypedDict
5
 
6
  import math
7
  import numpy as np
@@ -33,22 +33,944 @@ from transformers.models.siglip2.modeling_siglip2 import (
33
  Siglip2MLP,
34
  )
35
  from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
36
- from perceptron.tensorstream import (
37
- Event,
38
- Stream,
39
- TensorStream,
40
- TextType,
41
- VisionType,
42
- create_stream,
43
- group_streams,
44
- )
45
- from perceptron.tensorstream.ops import (
46
- compute_mrope_pos_tensor,
47
- modality_mask,
48
- reconstruct_tensor_stream_from_compact_dict,
49
- slice as ts_slice,
50
- tensor_stream_token_view,
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
@@ -474,7 +1396,7 @@ class Siglip2SequenceVisionTransformer(nn.Module):
474
  # Configuration
475
  # ============================================================================
476
 
477
- MAX_PIXELS = 60_000_000 # 60megapixel ceiling ≈ 8200 × 7300 px
478
 
479
  # Vision preprocessing constants
480
  VISION_MEAN = (0.5, 0.5, 0.5)
@@ -491,13 +1413,13 @@ def _make_writeable(arr: np.ndarray) -> np.ndarray:
491
  if arr.flags.writeable:
492
  return arr
493
 
494
- # First, try the cheap path — inplace flag toggle (works for mmap'd arrays
495
  # and some shared memory buffers):
496
  try:
497
  arr.setflags(write=True)
498
  return arr # success: no data copy
499
  except ValueError:
500
- # Buffer is inherently readonly (e.g. backed by PyAV / PIL): make copy
501
  return arr.copy()
502
 
503
 
@@ -1623,4 +2545,4 @@ __all__ = [
1623
  "IsaacModel",
1624
  "IsaacForConditionalGeneration",
1625
  "IsaacProcessor",
1626
- ]
 
1
  from __future__ import annotations
2
 
3
  from collections import defaultdict
4
+ from typing import Any, NewType, Union, TypedDict
5
 
6
  import math
7
  import numpy as np
 
33
  Siglip2MLP,
34
  )
35
  from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
36
+
37
+ import itertools
38
+ from collections.abc import Callable, Iterable
39
+
40
+
41
+ import heapq
42
+ from collections.abc import Callable, Iterable
43
+ from dataclasses import dataclass, field, fields, replace
44
+ from enum import Enum
45
+
46
+ from torch.profiler import record_function
47
+
48
+
49
+ class ModalityType(Enum):
50
+ """
51
+ Base class for modality-type enumerations.
52
+ Each derived class (VisionType, TextType) holds
53
+ an integer value that identifies a specific modality.
54
+ Example usage:
55
+ If you have an object `my_event` of class `Event`,
56
+ you might write:
57
+ if my_event.type == VisionType.image:
58
+ # process an image frame
59
+ The methods below implement ordering and hashing
60
+ based on the integer `.value` of each enum member.
61
+ """
62
+
63
+ @property
64
+ def modality(self):
65
+ return self.__class__
66
+
67
+ def __lt__(self, other):
68
+ if isinstance(other, ModalityType):
69
+ return self.value < other.value
70
+ raise NotImplementedError()
71
+
72
+ def __eq__(self, other):
73
+ if isinstance(other, ModalityType):
74
+ return self.value == other.value
75
+ raise NotImplementedError()
76
+
77
+ def __hash__(self):
78
+ return hash(self.value)
79
+
80
+
81
+ # NOTE: modality types need to be unique
82
+ class VisionType(ModalityType):
83
+ """
84
+ Enum for vision modalities such as key video frames.
85
+ Typically used in video processing or image sequences.
86
+ Members:
87
+ image: A single image frame.
88
+ """
89
+
90
+ image = 0
91
+
92
+
93
+ class TextType(ModalityType):
94
+ """
95
+ Enum for text tokens and padding.
96
+ Members:
97
+ text: Actual textual tokens.
98
+ padding: Padding tokens used in sequence batching.
99
+ """
100
+
101
+ text = 1
102
+ padding = 2
103
+
104
+
105
+ # maps idx -> type
106
+ ALL_TYPES = [
107
+ tp
108
+ for types in [
109
+ list(VisionType),
110
+ list(TextType),
111
+ ]
112
+ for tp in types
113
+ ]
114
+
115
+
116
+ # @dataclass
117
+ @dataclass(slots=True)
118
+ class Event:
119
+ """
120
+ Represents a single data occurrence (with a specific type, time interval, and data payload).
121
+ Attributes:
122
+ data (Any): The actual data payload (e.g. a torch.Tensor, a string, etc.).
123
+ type (ModalityType): The modality type of the data (e.g., VisionType.image).
124
+ time (Tuple[float, float]): (start_time, end_time) indicating when this Event occurs.
125
+ role (Optional[str]): The role associated with this event (e.g., "user", "agent", "system").
126
+ If None, the event is always included in loss calculation.
127
+ Example usage:
128
+ evt = Event(data=torch.zeros((1, 224, 224, 3)), # e.g. a single image frame
129
+ type=VisionType.image,
130
+ time=(0.0, 0.04),
131
+ role="user")
132
+ """
133
+
134
+ # Descriptors
135
+ data: Any
136
+ time: tuple[float, float]
137
+ type: ModalityType
138
+ role: str | None = None
139
+
140
+ # Structure
141
+ dims_virtual: list[int] | None = None # virtual/processed dimensions (e.g., pixel-shuffled)
142
+ dims_real: list[int] | None = None # real/actual tensor dimensions
143
+ idx_range: tuple[int, int] | None = None
144
+
145
+ # Misc Tags (data source, shard idx, etc.)
146
+ tags: dict = field(default_factory=dict)
147
+
148
+ def dims(self, virtual: bool = True) -> list[int] | None:
149
+ """
150
+ Get the dimensions of this event.
151
+ Args:
152
+ virtual: If True (default), return virtual/processed dimensions (e.g., pixel-shuffled).
153
+ If False, return real/actual tensor dimensions.
154
+ Returns:
155
+ Dimensions list or None if not measured.
156
+ """
157
+ if virtual:
158
+ return self.dims_virtual
159
+ else:
160
+ return self.dims_real
161
+
162
+ @property
163
+ def is_measured(self):
164
+ return self.dims_virtual is not None
165
+
166
+ def slice_tokens(self, start: int | None = None, end: int | None = None):
167
+ """
168
+ Converts into a partial event where the only valid data is between start and end indices of the flattened data
169
+ """
170
+ assert self.is_measured
171
+ assert start is not None and end is not None
172
+ assert self.idx_range[0] <= start <= end <= self.idx_range[1]
173
+ self.idx_range = (start or 0, end or math.prod(self.dims()))
174
+
175
+ def num_tokens(self, partial=True, virtual=True) -> int:
176
+ if not virtual:
177
+ assert partial is False and isinstance(self.data, torch.Tensor)
178
+ return math.prod(self.dims(virtual=False))
179
+ return self.idx_range[1] - self.idx_range[0] if partial else math.prod(self.dims())
180
+
181
+ def shallow_copy(self) -> Event:
182
+ return replace(self)
183
+
184
+ def __hash__(self) -> int:
185
+ """Hash Event based on structure, excluding data."""
186
+
187
+ def make_hashable(obj):
188
+ """Convert any object to hashable form."""
189
+ if obj is None:
190
+ return None
191
+ elif isinstance(obj, str | int | float | bool | tuple):
192
+ return obj
193
+ elif isinstance(obj, list):
194
+ return tuple(make_hashable(item) for item in obj) if obj else None
195
+ elif isinstance(obj, dict):
196
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) if obj else None
197
+ elif hasattr(obj, "value"): # Enum types
198
+ return obj.value
199
+ else:
200
+ return str(obj) # Fallback for other types
201
+
202
+ hash_values = []
203
+ for fld in fields(self):
204
+ if fld.name == "data":
205
+ continue # Skip tensor data
206
+
207
+ value = getattr(self, fld.name)
208
+ hash_values.append(make_hashable(value))
209
+
210
+ return hash(tuple(hash_values))
211
+
212
+ def __eq__(self, other) -> bool:
213
+ """
214
+ Compares two Event objects for strict equality,
215
+ allowing for float tolerances in torch.Tensors (via torch.allclose).
216
+ """
217
+ if not isinstance(other, Event):
218
+ return False
219
+
220
+ for fld in fields(self):
221
+ self_value = getattr(self, fld.name)
222
+ other_value = getattr(other, fld.name)
223
+
224
+ if fld.name == "data":
225
+ # Special handling for tensor data with float tolerance
226
+ if isinstance(self_value, torch.Tensor) and isinstance(other_value, torch.Tensor):
227
+ if not torch.allclose(self_value, other_value):
228
+ return False
229
+ else:
230
+ if self_value != other_value:
231
+ return False
232
+ elif fld.name == "role":
233
+ # Special handling for role: both must be None or both must be set and equal
234
+ if (self_value is None) != (other_value is None):
235
+ return False
236
+ if self_value is not None and self_value != other_value:
237
+ return False
238
+ else:
239
+ # Standard equality for all other fields
240
+ if self_value != other_value:
241
+ return False
242
+
243
+ return True
244
+
245
+
246
+ @dataclass
247
+ class Stream:
248
+ """
249
+ Represents an ordered sequence of Event objects, each with
250
+ a specific ModalityType and a time range.
251
+ Attributes:
252
+ events (List[Event]): The list of Event objects in the stream.
253
+ priority (List[ModalityType]): A list of modality types that define
254
+ how we might want to reorder or prioritize events if scheduling is needed.
255
+ Example usage:
256
+ # Create two events of different types
257
+ evt1 = Event(torch.zeros((1, 224, 224, 3)), VisionType.image, (0.0, 0.04))
258
+ evt2 = Event(torch.randint(0, 1000, (16, 1)), TextType.text, (0.0, 0.32))
259
+ # Make a stream with a given priority
260
+ s = Stream(events=[evt1, evt2],
261
+ priority=[VisionType.image, TextType.text])
262
+ print(s)
263
+ """
264
+
265
+ events: list[Event]
266
+ priority: list[ModalityType] # priority of stream ordering
267
+
268
+ def __len__(self):
269
+ """Returns the number of Event objects in this Stream."""
270
+ return len(self.events)
271
+
272
+ def __getitem__(self, key: int) -> Stream | Event:
273
+ return self.events[key]
274
+
275
+ def __iter__(self):
276
+ """
277
+ Yields each Event in the Stream, enabling iteration like:
278
+ for event in my_stream:
279
+ ...
280
+ """
281
+ yield from self.events
282
+
283
+ # --- after ------------------------------------------------------------
284
+ @record_function("Stream.map")
285
+ def map(
286
+ self,
287
+ func: Callable[[Event], dict[str, Any]],
288
+ *,
289
+ copy_unchanged: bool = False, # opt-in if you really need isolation
290
+ ) -> Stream:
291
+ """
292
+ Apply *func* to every event and return a new Stream.
293
+ *func* must return a **dict of fields that actually change**.
294
+ We create **one shallow copy** only when something changes;
295
+ unchanged events are reused directly, which is inexpensive and
296
+ keeps autograd graphs intact.
297
+ """
298
+ mapped: list[Event] = []
299
+ for ev in self.events:
300
+ delta = func(ev)
301
+ if not delta: # fast-path: nothing changes
302
+ mapped.append(ev if not copy_unchanged else ev.shallow_copy())
303
+ continue
304
+
305
+ new_ev = ev.shallow_copy() # ⚡ no tensor clone
306
+ for k, v in delta.items():
307
+ setattr(new_ev, k, v)
308
+ mapped.append(new_ev)
309
+
310
+ return create_stream(mapped, priority=self.priority, schedule=False)
311
+
312
+ @record_function("Stream.compact")
313
+ def compact(self) -> torch.Tensor:
314
+ assert all([(isinstance(ev.data, torch.Tensor) and ev.is_measured) for ev in self.events]), (
315
+ "Stream.compact only works for streams with events that have measured tensor data"
316
+ )
317
+ return torch.cat([ev.data for ev in self.events]).contiguous()
318
+
319
+ @record_function("Stream.map_compact")
320
+ def map_compact(self, event_tf: Callable[[Event], list[Any]]) -> torch.Tensor:
321
+ mapped_list = []
322
+ for event in self:
323
+ mapped_list.extend(event_tf(event))
324
+ tensor = torch.tensor(
325
+ mapped_list,
326
+ dtype=torch.long,
327
+ device=next(
328
+ (ev.data.device for ev in self.events if isinstance(ev.data, torch.Tensor)),
329
+ "cpu",
330
+ ),
331
+ ).contiguous()
332
+ return tensor
333
+
334
+ def flatten(self) -> Stream:
335
+ return self.map(lambda ev: {"data": ev.data.reshape(-1, ev.data.shape[-1])})
336
+
337
+ def shallow_copy(self) -> Stream:
338
+ events_copy = [ev.shallow_copy() for ev in self.events]
339
+ return create_stream(events=events_copy, priority=self.priority, schedule=False)
340
+
341
+ def __hash__(self) -> int:
342
+ """Hash Stream based on structure."""
343
+ return hash(
344
+ (
345
+ tuple(p.value for p in self.priority), # Convert enums to values
346
+ tuple(hash(event) for event in self.events), # Use Event.__hash__
347
+ )
348
+ )
349
+
350
+ def __eq__(self, other) -> bool:
351
+ """Compare Streams structurally."""
352
+ if not isinstance(other, Stream):
353
+ return False
354
+
355
+ return (
356
+ self.priority == other.priority
357
+ and len(self.events) == len(other.events)
358
+ and all(e1 == e2 for e1, e2 in zip(self.events, other.events, strict=False))
359
+ )
360
+
361
+
362
+ # TODO: implement all types of cool indexing which can happen since TensorStream assuems Event.data = Tensor
363
+ @dataclass
364
+ class TensorStream:
365
+ streams: list[Stream]
366
+ _device: torch.device | None = None
367
+
368
+ def __post_init__(self):
369
+ for stream in self.streams:
370
+ for event in stream.events:
371
+ assert isinstance(event.data, torch.Tensor)
372
+ if self._device is None:
373
+ self._device = torch.device(event.data.device)
374
+
375
+ # TODO: implement non-strict compaction modes
376
+ @record_function("TensorStream.compact")
377
+ def compact(self, mode="strict") -> torch.Tensor:
378
+ compact_tensor_stream = torch.stack([stream.compact() for stream in self.streams]).contiguous()
379
+ return compact_tensor_stream
380
+
381
+ @record_function("TensorStream.map")
382
+ def map(self, event_tf: Callable[[Event], dict[str, Any]]) -> TensorStream:
383
+ mapped_streams = [stream.map(event_tf) for stream in self.streams]
384
+ return TensorStream(mapped_streams)
385
+
386
+ @record_function("TensorStream.map_compact")
387
+ def map_compact(self, event_tf: Callable[[Event], list[Any]]) -> torch.Tensor:
388
+ mapped_list = []
389
+ for stream in self.streams:
390
+ for event in stream:
391
+ mapped_list.extend(event_tf(event))
392
+ B, T = self.shape
393
+ tensor = torch.tensor(mapped_list, dtype=torch.long, device=self.device).reshape(B, T)
394
+ return tensor
395
+
396
+ def flat_stream(self) -> Stream:
397
+ if not self.streams:
398
+ return create_stream([], priority=[], schedule=False)
399
+ return create_stream(
400
+ [event for stream in self.streams for event in stream], priority=self.streams[0].priority, schedule=False
401
+ )
402
+
403
+ @property
404
+ def device(self):
405
+ return self._device
406
+
407
+ @property
408
+ def shape(self):
409
+ seq_lens = [sum([ev.num_tokens() for ev in stream]) for stream in self.streams]
410
+ assert all([sl == seq_lens[0] for sl in seq_lens]), (
411
+ f"each stream must have same token count to have a shape: {seq_lens}"
412
+ )
413
+ return (len(seq_lens), seq_lens[0])
414
+
415
+ @record_function("TensorStream.to")
416
+ def to(
417
+ self,
418
+ device: torch.device | str,
419
+ dtype: torch.dtype | None = None,
420
+ non_blocking: bool = True,
421
+ ) -> TensorStream:
422
+ """
423
+ Move **all** `Event.data` tensors to *device*.
424
+ We send each tensor individually instead of the
425
+ flatten → unflatten round-trip:
426
+ * one async H2D copy per tensor (still overlapped when
427
+ `pin_memory=True` is set on the DataLoader),
428
+ * no extra host-side concat, no extra device allocation,
429
+ * `requires_grad` flags are preserved.
430
+ NOTE: textual modalities are always cast to `torch.long`;
431
+ everything else keeps its original
432
+ dtype unless an explicit *dtype* argument is supplied.
433
+ """
434
+ target_device = torch.device(device)
435
+
436
+ for stream in self.streams:
437
+ for ev in stream:
438
+ # ------------------------------------------------------------------
439
+ # Decide the dtype for *this* event.
440
+ # ------------------------------------------------------------------
441
+ if ev.type in list(TextType):
442
+ tgt_dtype = torch.long
443
+ else:
444
+ tgt_dtype = dtype or ev.data.dtype
445
+
446
+ # ------------------------------------------------------------------
447
+ # Perform the device / dtype move.
448
+ # ------------------------------------------------------------------
449
+ # We clone no tensor here; torch will reuse storage
450
+ # if `dtype` and `device` are unchanged.
451
+ moved = ev.data.to(
452
+ device=target_device,
453
+ dtype=tgt_dtype,
454
+ non_blocking=non_blocking,
455
+ )
456
+
457
+ # Preserve autograd leaf & grad-enabled state.
458
+ moved.requires_grad_(ev.data.requires_grad)
459
+
460
+ ev.data = moved
461
+
462
+ # Remember where the whole TensorStream lives now.
463
+ self._device = target_device
464
+ return self
465
+
466
+ @record_function("TensorStream.pin_memory")
467
+ def pin_memory(self, non_blocking: bool = True) -> TensorStream:
468
+ """
469
+ Page-lock (aka *pin*) all **CPU** tensors contained in this
470
+ `TensorStream`. Pinned tensors make subsequent asynchronous
471
+ H2D copies (e.g. inside `TensorStream.to("cuda")`) faster and,
472
+ when used together with a `DataLoader(pin_memory=True)`,
473
+ enable overlap of host-to-device transfers with GPU execution.
474
+ The call is a no-op for tensors that are already on a CUDA /
475
+ MPS / other non-CPU device.
476
+ Parameters
477
+ ----------
478
+ non_blocking : bool, default = True
479
+ Forwarded to `Tensor.pin_memory()`; should almost always
480
+ stay *True* so later `to(device, non_blocking=True)` calls
481
+ can overlap.
482
+ Returns
483
+ -------
484
+ self : TensorStream
485
+ The same object (mutated in-place) to allow call chaining.
486
+ """
487
+ for stream in self.streams:
488
+ for ev in stream:
489
+ if ev.data.device.type == "cpu":
490
+ # `pin_memory()` clones only when needed
491
+ pinned = ev.data.pin_memory() # noqa: F841
492
+ # NB: pin_memory() preserves dtype/shape/grad/etc.
493
+ if not non_blocking:
494
+ # ensure the pinning work is done now
495
+ torch.cuda.current_stream().synchronize() # safe on CPU too
496
+ ev.data = pinned
497
+ # `_device` **stays** the same (still CPU) – no change needed
498
+ return self
499
+
500
+ def __hash__(self) -> int:
501
+ """Hash TensorStream based on structure."""
502
+ return hash(
503
+ (
504
+ tuple(hash(stream) for stream in self.streams), # Use Stream.__hash__
505
+ str(self._device) if self._device else None,
506
+ self.shape,
507
+ )
508
+ )
509
+
510
+ def __eq__(self, other) -> bool:
511
+ """Compare TensorStreams structurally."""
512
+ if not isinstance(other, TensorStream):
513
+ return False
514
+
515
+ return (
516
+ self._device == other._device
517
+ and self.shape == other.shape
518
+ and len(self.streams) == len(other.streams)
519
+ and all(s1 == s2 for s1, s2 in zip(self.streams, other.streams, strict=False))
520
+ )
521
+
522
+
523
+ def collate_tensor_stream(
524
+ tensor_streams: list[TensorStream],
525
+ ) -> TensorStream:
526
+ return TensorStream([stream for ts in tensor_streams for stream in ts.streams])
527
+
528
+
529
+ def _schedule_stream(stream: Stream) -> Stream:
530
+ """
531
+ Internal function that reorders (schedules) the events in a Stream
532
+ based on the stream's priority.
533
+ By default, this calls schedule_events(...) and reorders the events accordingly.
534
+ The new ordering is assigned in-place to stream.events.
535
+ Example usage (indirect):
536
+ new_stream = _schedule_stream(old_stream)
537
+ """
538
+ scheduled_inds = schedule_events(stream, priority=stream.priority)
539
+ stream.events = [stream.events[i] for i in scheduled_inds]
540
+ return stream
541
+
542
+
543
+ def create_stream(events: list[Event], priority: list[ModalityType], schedule: bool = True) -> Stream:
544
+ """
545
+ Creates a new Stream with the given events and priority.
546
+ If 'schedule' is True, the events are reordered by calling _schedule_stream.
547
+ Example usage:
548
+ evt1 = Event(torch.zeros(10), TextType.text, (0.0, 1.0))
549
+ evt2 = Event(torch.ones(10), TextType.text, (1.0, 2.0))
550
+ my_stream = create_stream(events=[evt1, evt2],
551
+ priority=[TextType.text],
552
+ schedule=False)
553
+ print(my_stream)
554
+ """
555
+ stream = Stream(events, priority)
556
+ if schedule:
557
+ stream = _schedule_stream(stream)
558
+ return stream
559
+
560
+
561
+ def merge_streams(streams: Iterable[Stream]) -> Stream:
562
+ """
563
+ Merges multiple Stream objects into one.
564
+ The priority of the merged stream is chosen from the longest priority list among the inputs.
565
+ Stream priorities must be consistent with the chosen priority.
566
+ All events are concatenated, and a new Stream is created (and scheduled).
567
+ Example usage:
568
+ merged = merge_streams([stream1, stream2])
569
+ """
570
+ chosen_priority = max([stream.priority for stream in streams], key=len)
571
+ assert all(
572
+ [str(stream.priority) in str([p for p in chosen_priority if p in stream.priority]) for stream in streams]
573
+ ), "One or more streams has a priority order that doesn't match the merged stream"
574
+ merged_event_list = [ev for stream in streams for ev in stream.events]
575
+ merged_stream = create_stream(merged_event_list, chosen_priority) # non-root stream creation
576
+ return merged_stream
577
+
578
+
579
+ EventDescriptor = NewType("EventDescriptor", Any)
580
+
581
+
582
+ # NOTE: actually not used now but thought it *might* be useful
583
+ def get_stream_descriptor(
584
+ stream: Stream, measure_fn: Callable[[Event], EventDescriptor] = lambda ev: ev.type
585
+ ) -> set[Any]:
586
+ """
587
+ Create a set of descriptors for each Event in a Stream based on measure_fn.
588
+ measure_fn maps an Event to a descriptive key.
589
+ For example, if events have different data shapes, one might use:
590
+ measure_fn = lambda ev: ev.data.shape
591
+ i.e.
592
+ stream of VisionTypes with tensors of shapes [(1, 3, 3), (1, 3, 3), (1, 4, 4)]
593
+ get_stream_descriptor(stream, measure_fn=lambda t: t.shape) = {(1, 3, 3), (1, 4, 4)}
594
+ now we can pass this into group_streams which will split out vision sub-streams which can be bundled
595
+ Returns:
596
+ A set of descriptors representing the Events in the stream.
597
+ Example usage:
598
+ descriptor = get_stream_descriptor(my_stream, lambda ev: ev.type)
599
+ """
600
+ stream_descriptor = set()
601
+ for ev in stream.events:
602
+ ev_measurement = measure_fn(ev)
603
+ stream_descriptor.add(ev_measurement)
604
+ return stream_descriptor
605
+
606
+
607
+ def group_streams(
608
+ stream: Stream, group_fn: Callable[[Event], EventDescriptor], schedule=True
609
+ ) -> dict[EventDescriptor, Stream]:
610
+ """
611
+ Splits a single Stream into multiple sub-Streams, grouped by the output of group_fn(event).
612
+ For example, group_fn could be:
613
+ - lambda ev: ev.type
614
+ - lambda ev: ev.type.modality
615
+ - lambda ev: (ev.type.modality, ev.data.shape)
616
+ Returns:
617
+ A dictionary mapping each group key to a Stream of events belonging to that group.
618
+ If 'schedule' is True, each sub-Stream is scheduled via create_stream(..., schedule=True).
619
+ Example usage:
620
+ substreams = group_streams(my_stream, lambda ev: ev.type)
621
+ """
622
+ split_streams: defaultdict[EventDescriptor, list[Event]] = defaultdict(list)
623
+ for ev in stream:
624
+ group = group_fn(ev)
625
+ split_streams[group].append(ev)
626
+ for g, events in split_streams.items():
627
+ split_streams[g] = create_stream(events, stream.priority, schedule=schedule)
628
+ return dict(split_streams)
629
+
630
+
631
+ # Define Category for clarity
632
+ Category = NewType("Category", Any)
633
+
634
+
635
+ def schedule_events(stream: Stream, priority: list[Category]) -> list[int]:
636
+ """
637
+ Schedule events based on their start time and priority using a topological sort algorithm.
638
+ The priority list defines the ordering of categories.
639
+ This function:
640
+ 1. Pairs each event with its original index.
641
+ 2. Sorts events by start time.
642
+ 3. Builds a dependency graph based on overlapping events.
643
+ 4. Uses a heap to perform a deterministic topological sort with tie-breakers.
644
+ Raises:
645
+ ValueError: If a cycle is detected in the events (i.e., no valid ordering exists).
646
+ Returns:
647
+ List[int]: A list of original indices representing the scheduled order of events.
648
+ """
649
+ priority_index: dict[Category, int] = {category: idx for idx, category in enumerate(priority)}
650
+
651
+ # Pair each event metadata with its original index
652
+ events = []
653
+ for i, event in enumerate(stream.events):
654
+ events.append(
655
+ (
656
+ i,
657
+ event.time[0],
658
+ event.time[1],
659
+ event.type,
660
+ )
661
+ )
662
+
663
+ sorted_events = sorted(events, key=lambda e: e[1]) # sort by start time
664
+ num_events = len(sorted_events)
665
+
666
+ # Build dependency graph
667
+ graph = defaultdict(set)
668
+ indegree = {i: 0 for i in range(num_events)}
669
+
670
+ for i in range(num_events):
671
+ idx_i, start_i, end_i, category_i = sorted_events[i]
672
+ prio_i = priority_index[category_i]
673
+ for j in range(i + 1, num_events):
674
+ idx_j, start_j, end_j, category_j = sorted_events[j]
675
+ if start_j >= end_i:
676
+ break
677
+ if end_i > start_j and end_j > start_i:
678
+ prio_j = priority_index[category_j]
679
+ if prio_i < prio_j:
680
+ graph[i].add(j)
681
+ indegree[j] += 1
682
+ elif prio_i > prio_j:
683
+ graph[j].add(i)
684
+ indegree[i] += 1
685
+
686
+ # Use heap for deterministic tie-breakers: (start_time, priority, original_index)
687
+ heap = [
688
+ (
689
+ sorted_events[i][1],
690
+ priority_index[sorted_events[i][3]],
691
+ sorted_events[i][0],
692
+ i,
693
+ )
694
+ for i in range(num_events)
695
+ if indegree[i] == 0
696
+ ]
697
+ heapq.heapify(heap)
698
+ resolved_order = []
699
+
700
+ while heap:
701
+ _, _, _, u = heapq.heappop(heap)
702
+ resolved_order.append(u)
703
+ for v in graph[u]:
704
+ indegree[v] -= 1
705
+ if indegree[v] == 0:
706
+ heapq.heappush(
707
+ heap,
708
+ (
709
+ sorted_events[v][1],
710
+ priority_index[sorted_events[v][3]],
711
+ sorted_events[v][0],
712
+ v,
713
+ ),
714
+ )
715
+
716
+ if len(resolved_order) != num_events:
717
+ raise ValueError("Cycle detected in events, cannot resolve order")
718
+
719
+ return [sorted_events[i][0] for i in resolved_order]
720
+
721
+ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Tensor:
722
+ """
723
+ Create a (batch, T, n_pos_dims) position tensor in one sweep.
724
+ The first dim is the running “time” index, the rest are spatial (or 1-fillers).
725
+
726
+ Args:
727
+ ts : TensorStream
728
+ n_pos_dims : total coordinate dimensions (default 3)
729
+
730
+ Returns:
731
+ torch.LongTensor - shape (batch_size, seq_len, n_pos_dims)
732
+ """
733
+
734
+ # Manually iterate through streams and events like map_compact does,
735
+ # but maintain cumulative time offset for each stream
736
+ all_coords = []
737
+ for stream in ts.streams: # one Stream == one batch sample
738
+ cumulative_offset = 0 # running time index for this stream
739
+
740
+ for event in stream:
741
+ # --- build coordinate grid for THIS event using itertools (no tensor ops) ---
742
+ dims = (event.dims() or [1]) + [1] * (n_pos_dims - len(event.dims() or []))
743
+
744
+ # Create ranges for each dimension (similar to old _finalize implementation)
745
+ first_dim = range(cumulative_offset, cumulative_offset + dims[0])
746
+ cumulative_offset += dims[0] # advance time for the next event
747
+ other_dims = [range(d) for d in dims[1:]]
748
+
749
+ # Use itertools.product to create all coordinate combinations
750
+ full_coords = list(itertools.product(first_dim, *other_dims))
751
+
752
+ # Slice if the event is partial
753
+ s, e = event.idx_range
754
+ coords = full_coords[s:e]
755
+
756
+ # Extend the flattened coordinate list
757
+ all_coords.extend(coords)
758
+
759
+ # Convert to tensor and reshape to (B, T, n_pos_dims)
760
+ B, T = ts.shape
761
+ return torch.tensor(all_coords, dtype=torch.long, device=ts.device).reshape(B, T, n_pos_dims)
762
+
763
+
764
+ # ──────────────────────────────────────────────────────────────────────────
765
+ # Generic event-labelling helper
766
+ # ──────────────────────────────────────────────────────────────────────────
767
+ def event_mask(
768
+ ts: TensorStream,
769
+ tag_fn: Callable[[Event], int | None],
770
+ default: int = -1,
771
+ ) -> torch.Tensor:
772
+ """
773
+ Build a (batch, seq_len) LongTensor whose value for every *token*
774
+ is given by `tag_fn(event)`, falling back to `default` when the
775
+ function returns None.
776
+
777
+ The work is done in a single pass via `map → compact`.
778
+ """
779
+
780
+ def to_label(ev: Event) -> Any:
781
+ label = tag_fn(ev)
782
+ if label is None:
783
+ label = default
784
+ return [label] * ev.num_tokens()
785
+
786
+ return ts.map_compact(to_label).squeeze(-1)
787
+
788
+
789
+ def event_mask_by_key(
790
+ ts: TensorStream,
791
+ key: str,
792
+ tag_index: dict[str, int],
793
+ default: int = -1,
794
+ ) -> torch.Tensor:
795
+ """
796
+ Faster call-site syntax when you just want to look up
797
+ `event.tags[key]` and map it through `tag_index`.
798
+ """
799
+ return event_mask(
800
+ ts,
801
+ lambda ev: tag_index.get(ev.tags.get(key)) if key in ev.tags else None,
802
+ default=default,
803
+ )
804
+
805
+
806
+ def modality_mask(ts: TensorStream) -> torch.Tensor:
807
+ return event_mask(ts, lambda ev: ev.type.value)
808
+
809
+
810
+ ROLE_TO_IDX = {
811
+ None: -1,
812
+ "": -1,
813
+ "agent": 0,
814
+ "user": 1,
815
+ "system": 2,
816
+ # … add more if you like
817
+ }
818
+
819
+
820
+ def role_mask(ts: TensorStream) -> torch.Tensor:
821
+ return event_mask(ts, lambda ev: ROLE_TO_IDX.get(ev.role, -1))
822
+
823
+
824
+ def tensor_stream_token_view(ts: TensorStream) -> torch.Tensor:
825
+ """
826
+ Return a (B, T) token view by summing across the last dim of every
827
+ event and flattening over the selected token range.
828
+ """
829
+
830
+ def to_token_view(ev: Event) -> list[int]:
831
+ # collapse all but the last dim, cast to long
832
+ flat = ev.data.sum(dim=-1).long().reshape(-1)
833
+ if ev.idx_range is not None:
834
+ s, e = ev.idx_range
835
+ return flat[s:e].tolist()
836
+ else:
837
+ return flat.tolist()
838
+
839
+ return ts.map_compact(to_token_view) # shape (B, T)
840
+
841
+
842
+ def reconstruct_tensor_stream_from_compact_dict(
843
+ ts: TensorStream, compact_dict: dict[ModalityType, torch.Tensor]
844
+ ) -> TensorStream:
845
+ streams = []
846
+ for stream in ts.streams:
847
+ event_list = []
848
+ for event in stream:
849
+ new_event = event.shallow_copy()
850
+ new_event.data = compact_dict[event.type][event.idx_range[0] : event.idx_range[1]]
851
+ compact_dict[event.type] = compact_dict[event.type][event.num_tokens(partial=False) :]
852
+ event_list.append(new_event)
853
+ streams.append(Stream(event_list, priority=stream.priority))
854
+ return TensorStream(streams)
855
+
856
+
857
+ def set_data(
858
+ tensor_stream: TensorStream,
859
+ stream_types: Iterable[ModalityType],
860
+ roles: Iterable[str] = ROLE_TO_IDX.keys(),
861
+ ) -> tuple[torch.Tensor, torch.Tensor]:
862
+ """
863
+ Gathers data from a TensorStream according to the given stream types
864
+ and returns (data, mask) where 'data' has valid entries for
865
+ each requested stream type and 'mask' indicates which elements
866
+ in 'data' are valid.
867
+
868
+ NOTE: Currently assumes stream_types are text-based types, but can be extended.
869
+
870
+ Args:
871
+ tensor_stream (TensorStream):
872
+ The input TensorStream which contains data for multiple modalities.
873
+ stream_types (Iterable[ModalityType]):
874
+ A list or iterable of modality types (e.g., TextType, VisionType, etc.)
875
+ to retrieve from the TensorStream.
876
+ exclude_non_agent_roles (bool, optional):
877
+ If True, only include tokens with role="agent" or role=None in the loss calculation.
878
+ Defaults to False.
879
+
880
+ Returns:
881
+ Tuple[torch.Tensor, torch.Tensor]:
882
+ - data: A tensor of the same shape as the internal metadata shape,
883
+ containing valid entries from the given stream types.
884
+ - mask: A boolean tensor of the same shape, where True indicates
885
+ the corresponding element in 'data' is valid/used.
886
+ """
887
+ # Retrieve indexing and shape metadata
888
+ st_tensor = modality_mask(tensor_stream) # (B, T) modality-ids
889
+ roles_tensor = role_mask(tensor_stream) # (B, T) role-ids
890
+
891
+ # Create output data placeholders on the same device
892
+ data = torch.zeros_like(st_tensor).to(tensor_stream.device)
893
+ set_data_mask = torch.zeros_like(st_tensor).bool().to(tensor_stream.device).bool()
894
+ per_modality_stream = group_streams(tensor_stream.flat_stream(), group_fn=lambda ev: ev.type, schedule=False)
895
+ per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()}
896
+
897
+ # Fill 'data' and 'set_data_mask' for each requested stream type
898
+ for st in stream_types:
899
+ data_mask = st_tensor == st.value
900
+ partial_mask = (
901
+ per_modality_stream[st]
902
+ .map_compact(
903
+ lambda ev: [int(ev.idx_range[0] <= i < ev.idx_range[1]) for i in range(ev.num_tokens(partial=False))]
904
+ )
905
+ .bool()
906
+ )
907
+ data[data_mask] = per_modality_compact_stream[st].reshape(-1)[partial_mask]
908
+
909
+ roles_mask = torch.zeros_like(st_tensor).bool().to(tensor_stream.device).bool()
910
+ for role in roles:
911
+ roles_mask |= roles_tensor == ROLE_TO_IDX[role]
912
+ data_mask = data_mask & roles_mask
913
+ set_data_mask[data_mask] = True
914
+
915
+ return data, set_data_mask
916
+
917
+
918
+ def ts_slice(tensor_stream: TensorStream, start: int, end: int) -> TensorStream:
919
+ """
920
+ Return a new TensorStream that contains *only* the tokens in the
921
+ half-open interval ``[start, end)`` (0-based, inclusive-exclusive).
922
+ """
923
+ B, T = tensor_stream.shape
924
+ assert 0 <= start <= end <= T, f"slice [{start}, {end}) is out of bounds for sequence length {T}"
925
+
926
+ sliced_streams: list[Stream] = []
927
+
928
+ for stream in tensor_stream.streams:
929
+ # current position in tensor stream token dims
930
+ curr_global_index = 0
931
+ new_events: list[Event] = []
932
+
933
+ # iterate over each of the events in the stream only selecting
934
+ # the events that fall within the range
935
+ for ev in stream:
936
+ ev_len = ev.num_tokens()
937
+
938
+ # ev_start, ev_end are the start and end indicies of the
939
+ # event within the tensor stream token dim
940
+ global_ev_start, global_ev_end = curr_global_index, curr_global_index + ev_len
941
+
942
+ if global_ev_end <= start:
943
+ # The event occurs before the start skip it and move the cursor
944
+ # forward
945
+ curr_global_index = global_ev_end
946
+ continue
947
+ if global_ev_start >= end:
948
+ # event occurs after the end we can exit
949
+ break
950
+
951
+ # only consider the part of the event that falls within the range
952
+ keep_from = max(0, start - global_ev_start)
953
+ keep_to = min(ev_len, end - global_ev_start)
954
+ part = ev.shallow_copy()
955
+
956
+ if keep_from == 0 and keep_to == ev_len:
957
+ # Event lies wholly inside the slice
958
+ new_events.append(part)
959
+ else:
960
+ # Partial overlap → trim.
961
+ assert ev.is_measured
962
+
963
+ # update the local event ranges for the slices
964
+ sliced_event_start = part.idx_range[0] + keep_from
965
+ sliced_event_end = part.idx_range[0] + keep_to
966
+ part.slice_tokens(sliced_event_start, sliced_event_end)
967
+ new_events.append(part)
968
+
969
+ curr_global_index = global_ev_end
970
+
971
+ sliced_streams.append(create_stream(new_events, stream.priority, schedule=False))
972
+
973
+ return TensorStream(sliced_streams)
974
 
975
 
976
  class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig):
 
1396
  # Configuration
1397
  # ============================================================================
1398
 
1399
+ MAX_PIXELS = 60_000_000 # 60-megapixel ceiling ≈ 8200 × 7300 px
1400
 
1401
  # Vision preprocessing constants
1402
  VISION_MEAN = (0.5, 0.5, 0.5)
 
1413
  if arr.flags.writeable:
1414
  return arr
1415
 
1416
+ # First, try the cheap path — in-place flag toggle (works for mmap'd arrays
1417
  # and some shared memory buffers):
1418
  try:
1419
  arr.setflags(write=True)
1420
  return arr # success: no data copy
1421
  except ValueError:
1422
+ # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy
1423
  return arr.copy()
1424
 
1425
 
 
2545
  "IsaacModel",
2546
  "IsaacForConditionalGeneration",
2547
  "IsaacProcessor",
2548
+ ]