refactor: rename vector_type to output_format
Browse files- modeling_jina_embeddings_v4.py +16 -16
modeling_jina_embeddings_v4.py
CHANGED
@@ -324,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
324 |
task_label: Union[str, List[str]],
|
325 |
processor_fn: Callable,
|
326 |
desc: str,
|
327 |
-
|
328 |
return_numpy: bool = False,
|
329 |
batch_size: int = 32,
|
330 |
truncate_dim: Optional[int] = None,
|
@@ -344,8 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
345 |
):
|
346 |
embeddings = self(**batch, task_label=task_label)
|
347 |
-
|
348 |
-
if
|
349 |
embeddings = embeddings.single_vec_emb
|
350 |
if truncate_dim is not None:
|
351 |
embeddings = embeddings[:, :truncate_dim]
|
@@ -362,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
362 |
|
363 |
def _validate_encoding_params(
|
364 |
self,
|
365 |
-
|
366 |
truncate_dim: Optional[int] = None,
|
367 |
prompt_name: Optional[str] = None,
|
368 |
) -> Dict[str, Any]:
|
@@ -379,16 +379,16 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
379 |
else PREFIX_DICT["query"]
|
380 |
)
|
381 |
|
382 |
-
|
383 |
-
if isinstance(
|
384 |
-
encode_kwargs["
|
385 |
else:
|
386 |
try:
|
387 |
-
|
388 |
-
encode_kwargs["
|
389 |
except ValueError:
|
390 |
raise ValueError(
|
391 |
-
f"Invalid
|
392 |
)
|
393 |
|
394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
@@ -422,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
422 |
task: Optional[str] = None,
|
423 |
max_length: int = 8192,
|
424 |
batch_size: int = 8,
|
425 |
-
|
426 |
return_numpy: bool = False,
|
427 |
truncate_dim: Optional[int] = None,
|
428 |
prompt_name: Optional[str] = None,
|
@@ -434,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
434 |
texts: text or list of text strings to encode
|
435 |
max_length: Maximum token length for text processing
|
436 |
batch_size: Number of texts to process at once
|
437 |
-
|
438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
@@ -444,7 +444,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
444 |
"""
|
445 |
prompt_name = prompt_name or "query"
|
446 |
encode_kwargs = self._validate_encoding_params(
|
447 |
-
|
448 |
)
|
449 |
|
450 |
task = self._validate_task(task)
|
@@ -489,7 +489,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
489 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
490 |
task: Optional[str] = None,
|
491 |
batch_size: int = 8,
|
492 |
-
|
493 |
return_numpy: bool = False,
|
494 |
truncate_dim: Optional[int] = None,
|
495 |
max_pixels: Optional[int] = None,
|
@@ -500,7 +500,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
500 |
Args:
|
501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
502 |
batch_size: Number of images to process at once
|
503 |
-
|
504 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
506 |
max_pixels: Maximum number of pixels to process per image
|
@@ -513,7 +513,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
513 |
self.processor.image_processor.max_pixels = (
|
514 |
max_pixels # change during encoding
|
515 |
)
|
516 |
-
encode_kwargs = self._validate_encoding_params(
|
517 |
task = self._validate_task(task)
|
518 |
|
519 |
# Convert single image to list
|
|
|
324 |
task_label: Union[str, List[str]],
|
325 |
processor_fn: Callable,
|
326 |
desc: str,
|
327 |
+
output_format: Union[str, VectorType] = VectorType.single,
|
328 |
return_numpy: bool = False,
|
329 |
batch_size: int = 32,
|
330 |
truncate_dim: Optional[int] = None,
|
|
|
344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
345 |
):
|
346 |
embeddings = self(**batch, task_label=task_label)
|
347 |
+
output_format_str = output_format.value if isinstance(output_format, VectorType) else output_format
|
348 |
+
if output_format_str == VectorType.single.value:
|
349 |
embeddings = embeddings.single_vec_emb
|
350 |
if truncate_dim is not None:
|
351 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
362 |
|
363 |
def _validate_encoding_params(
|
364 |
self,
|
365 |
+
output_format: Optional[Union[str, VectorType]] = None,
|
366 |
truncate_dim: Optional[int] = None,
|
367 |
prompt_name: Optional[str] = None,
|
368 |
) -> Dict[str, Any]:
|
|
|
379 |
else PREFIX_DICT["query"]
|
380 |
)
|
381 |
|
382 |
+
output_format = output_format or VectorType.single
|
383 |
+
if isinstance(output_format, VectorType):
|
384 |
+
encode_kwargs["output_format"] = output_format.value
|
385 |
else:
|
386 |
try:
|
387 |
+
output_format_enum = VectorType(output_format)
|
388 |
+
encode_kwargs["output_format"] = output_format_enum.value
|
389 |
except ValueError:
|
390 |
raise ValueError(
|
391 |
+
f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorType]}."
|
392 |
)
|
393 |
|
394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
|
|
422 |
task: Optional[str] = None,
|
423 |
max_length: int = 8192,
|
424 |
batch_size: int = 8,
|
425 |
+
output_format: Optional[Union[str, VectorType]] = None,
|
426 |
return_numpy: bool = False,
|
427 |
truncate_dim: Optional[int] = None,
|
428 |
prompt_name: Optional[str] = None,
|
|
|
434 |
texts: text or list of text strings to encode
|
435 |
max_length: Maximum token length for text processing
|
436 |
batch_size: Number of texts to process at once
|
437 |
+
output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
444 |
"""
|
445 |
prompt_name = prompt_name or "query"
|
446 |
encode_kwargs = self._validate_encoding_params(
|
447 |
+
output_format, truncate_dim, prompt_name
|
448 |
)
|
449 |
|
450 |
task = self._validate_task(task)
|
|
|
489 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
490 |
task: Optional[str] = None,
|
491 |
batch_size: int = 8,
|
492 |
+
output_format: Optional[Union[str, VectorType]] = None,
|
493 |
return_numpy: bool = False,
|
494 |
truncate_dim: Optional[int] = None,
|
495 |
max_pixels: Optional[int] = None,
|
|
|
500 |
Args:
|
501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
502 |
batch_size: Number of images to process at once
|
503 |
+
output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
504 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
506 |
max_pixels: Maximum number of pixels to process per image
|
|
|
513 |
self.processor.image_processor.max_pixels = (
|
514 |
max_pixels # change during encoding
|
515 |
)
|
516 |
+
encode_kwargs = self._validate_encoding_params(output_format, truncate_dim)
|
517 |
task = self._validate_task(task)
|
518 |
|
519 |
# Convert single image to list
|