feat: rename the VectorType
Browse files- modeling_jina_embeddings_v4.py +15 -15
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -30,9 +30,9 @@ class PromptType(str, Enum):
|
|
| 30 |
passage = "passage"
|
| 31 |
|
| 32 |
|
| 33 |
-
class
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
|
| 38 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
|
@@ -324,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 324 |
task_label: Union[str, List[str]],
|
| 325 |
processor_fn: Callable,
|
| 326 |
desc: str,
|
| 327 |
-
output_format: Union[str,
|
| 328 |
return_numpy: bool = False,
|
| 329 |
batch_size: int = 32,
|
| 330 |
truncate_dim: Optional[int] = None,
|
|
@@ -344,8 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 345 |
):
|
| 346 |
embeddings = self(**batch, task_label=task_label)
|
| 347 |
-
output_format_str = output_format.value if isinstance(output_format,
|
| 348 |
-
if output_format_str ==
|
| 349 |
embeddings = embeddings.single_vec_emb
|
| 350 |
if truncate_dim is not None:
|
| 351 |
embeddings = embeddings[:, :truncate_dim]
|
|
@@ -362,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 362 |
|
| 363 |
def _validate_encoding_params(
|
| 364 |
self,
|
| 365 |
-
output_format: Optional[Union[str,
|
| 366 |
truncate_dim: Optional[int] = None,
|
| 367 |
prompt_name: Optional[str] = None,
|
| 368 |
) -> Dict[str, Any]:
|
|
@@ -379,16 +379,16 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 379 |
else PREFIX_DICT["query"]
|
| 380 |
)
|
| 381 |
|
| 382 |
-
output_format = output_format or
|
| 383 |
-
if isinstance(output_format,
|
| 384 |
encode_kwargs["output_format"] = output_format.value
|
| 385 |
else:
|
| 386 |
try:
|
| 387 |
-
output_format_enum =
|
| 388 |
encode_kwargs["output_format"] = output_format_enum.value
|
| 389 |
except ValueError:
|
| 390 |
raise ValueError(
|
| 391 |
-
f"Invalid output_format: {output_format}. Must be one of {[v.value for v in
|
| 392 |
)
|
| 393 |
|
| 394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
|
@@ -422,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 422 |
task: Optional[str] = None,
|
| 423 |
max_length: int = 8192,
|
| 424 |
batch_size: int = 8,
|
| 425 |
-
output_format: Optional[Union[str,
|
| 426 |
return_numpy: bool = False,
|
| 427 |
truncate_dim: Optional[int] = None,
|
| 428 |
prompt_name: Optional[str] = None,
|
|
@@ -434,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 434 |
texts: text or list of text strings to encode
|
| 435 |
max_length: Maximum token length for text processing
|
| 436 |
batch_size: Number of texts to process at once
|
| 437 |
-
output_format: Type of embedding vector to generate (
|
| 438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
@@ -489,7 +489,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 489 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
| 490 |
task: Optional[str] = None,
|
| 491 |
batch_size: int = 8,
|
| 492 |
-
output_format: Optional[Union[str,
|
| 493 |
return_numpy: bool = False,
|
| 494 |
truncate_dim: Optional[int] = None,
|
| 495 |
max_pixels: Optional[int] = None,
|
|
@@ -500,7 +500,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 500 |
Args:
|
| 501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 502 |
batch_size: Number of images to process at once
|
| 503 |
-
output_format: Type of embedding vector to generate (
|
| 504 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 506 |
max_pixels: Maximum number of pixels to process per image
|
|
|
|
| 30 |
passage = "passage"
|
| 31 |
|
| 32 |
|
| 33 |
+
class VectorOutputFormat(str, Enum):
|
| 34 |
+
SINGLE = "single"
|
| 35 |
+
MULTIPLE = "multiple"
|
| 36 |
|
| 37 |
|
| 38 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
|
|
|
| 324 |
task_label: Union[str, List[str]],
|
| 325 |
processor_fn: Callable,
|
| 326 |
desc: str,
|
| 327 |
+
output_format: Union[str, VectorOutputFormat] = VectorOutputFormat.SINGLE,
|
| 328 |
return_numpy: bool = False,
|
| 329 |
batch_size: int = 32,
|
| 330 |
truncate_dim: Optional[int] = None,
|
|
|
|
| 344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 345 |
):
|
| 346 |
embeddings = self(**batch, task_label=task_label)
|
| 347 |
+
output_format_str = output_format.value if isinstance(output_format, VectorOutputFormat) else output_format
|
| 348 |
+
if output_format_str == VectorOutputFormat.SINGLE.value:
|
| 349 |
embeddings = embeddings.single_vec_emb
|
| 350 |
if truncate_dim is not None:
|
| 351 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
|
| 362 |
|
| 363 |
def _validate_encoding_params(
|
| 364 |
self,
|
| 365 |
+
output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
|
| 366 |
truncate_dim: Optional[int] = None,
|
| 367 |
prompt_name: Optional[str] = None,
|
| 368 |
) -> Dict[str, Any]:
|
|
|
|
| 379 |
else PREFIX_DICT["query"]
|
| 380 |
)
|
| 381 |
|
| 382 |
+
output_format = output_format or VectorOutputFormat.SINGLE
|
| 383 |
+
if isinstance(output_format, VectorOutputFormat):
|
| 384 |
encode_kwargs["output_format"] = output_format.value
|
| 385 |
else:
|
| 386 |
try:
|
| 387 |
+
output_format_enum = VectorOutputFormat(output_format)
|
| 388 |
encode_kwargs["output_format"] = output_format_enum.value
|
| 389 |
except ValueError:
|
| 390 |
raise ValueError(
|
| 391 |
+
f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorOutputFormat]}."
|
| 392 |
)
|
| 393 |
|
| 394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
|
|
|
| 422 |
task: Optional[str] = None,
|
| 423 |
max_length: int = 8192,
|
| 424 |
batch_size: int = 8,
|
| 425 |
+
output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
|
| 426 |
return_numpy: bool = False,
|
| 427 |
truncate_dim: Optional[int] = None,
|
| 428 |
prompt_name: Optional[str] = None,
|
|
|
|
| 434 |
texts: text or list of text strings to encode
|
| 435 |
max_length: Maximum token length for text processing
|
| 436 |
batch_size: Number of texts to process at once
|
| 437 |
+
output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
|
| 438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
|
| 489 |
images: Union[str, Image.Image, List[Union[str, Image.Image]]],
|
| 490 |
task: Optional[str] = None,
|
| 491 |
batch_size: int = 8,
|
| 492 |
+
output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
|
| 493 |
return_numpy: bool = False,
|
| 494 |
truncate_dim: Optional[int] = None,
|
| 495 |
max_pixels: Optional[int] = None,
|
|
|
|
| 500 |
Args:
|
| 501 |
images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
|
| 502 |
batch_size: Number of images to process at once
|
| 503 |
+
output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
|
| 504 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 505 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 506 |
max_pixels: Maximum number of pixels to process per image
|