nan commited on
Commit
96925c4
·
1 Parent(s): 085e2ed

refactor: rename vector_type to output_format

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +16 -16
modeling_jina_embeddings_v4.py CHANGED
@@ -324,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
324
  task_label: Union[str, List[str]],
325
  processor_fn: Callable,
326
  desc: str,
327
- vector_type: Union[str, VectorType] = VectorType.single,
328
  return_numpy: bool = False,
329
  batch_size: int = 32,
330
  truncate_dim: Optional[int] = None,
@@ -344,8 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
344
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
345
  ):
346
  embeddings = self(**batch, task_label=task_label)
347
- vector_type_str = vector_type.value if isinstance(vector_type, VectorType) else vector_type
348
- if vector_type_str == VectorType.single.value:
349
  embeddings = embeddings.single_vec_emb
350
  if truncate_dim is not None:
351
  embeddings = embeddings[:, :truncate_dim]
@@ -362,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
362
 
363
  def _validate_encoding_params(
364
  self,
365
- vector_type: Optional[Union[str, VectorType]] = None,
366
  truncate_dim: Optional[int] = None,
367
  prompt_name: Optional[str] = None,
368
  ) -> Dict[str, Any]:
@@ -379,16 +379,16 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
379
  else PREFIX_DICT["query"]
380
  )
381
 
382
- vector_type = vector_type or VectorType.single
383
- if isinstance(vector_type, VectorType):
384
- encode_kwargs["vector_type"] = vector_type.value
385
  else:
386
  try:
387
- vector_type_enum = VectorType(vector_type)
388
- encode_kwargs["vector_type"] = vector_type_enum.value
389
  except ValueError:
390
  raise ValueError(
391
- f"Invalid vector_type: {vector_type}. Must be one of {[v.value for v in VectorType]}."
392
  )
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
@@ -422,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
425
- vector_type: Optional[Union[str, VectorType]] = None,
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
@@ -434,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
434
  texts: text or list of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
- vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
438
  return_numpy: Whether to return numpy arrays instead of torch tensors
439
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
440
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -444,7 +444,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
444
  """
445
  prompt_name = prompt_name or "query"
446
  encode_kwargs = self._validate_encoding_params(
447
- vector_type, truncate_dim, prompt_name
448
  )
449
 
450
  task = self._validate_task(task)
@@ -489,7 +489,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
489
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
- vector_type: Optional[Union[str, VectorType]] = None,
493
  return_numpy: bool = False,
494
  truncate_dim: Optional[int] = None,
495
  max_pixels: Optional[int] = None,
@@ -500,7 +500,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
500
  Args:
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
- vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image
@@ -513,7 +513,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
513
  self.processor.image_processor.max_pixels = (
514
  max_pixels # change during encoding
515
  )
516
- encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
517
  task = self._validate_task(task)
518
 
519
  # Convert single image to list
 
324
  task_label: Union[str, List[str]],
325
  processor_fn: Callable,
326
  desc: str,
327
+ output_format: Union[str, VectorType] = VectorType.single,
328
  return_numpy: bool = False,
329
  batch_size: int = 32,
330
  truncate_dim: Optional[int] = None,
 
344
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
345
  ):
346
  embeddings = self(**batch, task_label=task_label)
347
+ output_format_str = output_format.value if isinstance(output_format, VectorType) else output_format
348
+ if output_format_str == VectorType.single.value:
349
  embeddings = embeddings.single_vec_emb
350
  if truncate_dim is not None:
351
  embeddings = embeddings[:, :truncate_dim]
 
362
 
363
  def _validate_encoding_params(
364
  self,
365
+ output_format: Optional[Union[str, VectorType]] = None,
366
  truncate_dim: Optional[int] = None,
367
  prompt_name: Optional[str] = None,
368
  ) -> Dict[str, Any]:
 
379
  else PREFIX_DICT["query"]
380
  )
381
 
382
+ output_format = output_format or VectorType.single
383
+ if isinstance(output_format, VectorType):
384
+ encode_kwargs["output_format"] = output_format.value
385
  else:
386
  try:
387
+ output_format_enum = VectorType(output_format)
388
+ encode_kwargs["output_format"] = output_format_enum.value
389
  except ValueError:
390
  raise ValueError(
391
+ f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorType]}."
392
  )
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
 
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
425
+ output_format: Optional[Union[str, VectorType]] = None,
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
 
434
  texts: text or list of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
+ output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
438
  return_numpy: Whether to return numpy arrays instead of torch tensors
439
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
440
  prompt_name: Type of text being encoded ('query' or 'passage')
 
444
  """
445
  prompt_name = prompt_name or "query"
446
  encode_kwargs = self._validate_encoding_params(
447
+ output_format, truncate_dim, prompt_name
448
  )
449
 
450
  task = self._validate_task(task)
 
489
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
+ output_format: Optional[Union[str, VectorType]] = None,
493
  return_numpy: bool = False,
494
  truncate_dim: Optional[int] = None,
495
  max_pixels: Optional[int] = None,
 
500
  Args:
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
+ output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image
 
513
  self.processor.image_processor.max_pixels = (
514
  max_pixels # change during encoding
515
  )
516
+ encode_kwargs = self._validate_encoding_params(output_format, truncate_dim)
517
  task = self._validate_task(task)
518
 
519
  # Convert single image to list