nan commited on
Commit
669c42a
·
1 Parent(s): 96925c4

feat: rename the VectorType

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +15 -15
modeling_jina_embeddings_v4.py CHANGED
@@ -30,9 +30,9 @@ class PromptType(str, Enum):
30
  passage = "passage"
31
 
32
 
33
- class VectorType(str, Enum):
34
- single = "single"
35
- multi = "multi"
36
 
37
 
38
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
@@ -324,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
324
  task_label: Union[str, List[str]],
325
  processor_fn: Callable,
326
  desc: str,
327
- output_format: Union[str, VectorType] = VectorType.single,
328
  return_numpy: bool = False,
329
  batch_size: int = 32,
330
  truncate_dim: Optional[int] = None,
@@ -344,8 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
344
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
345
  ):
346
  embeddings = self(**batch, task_label=task_label)
347
- output_format_str = output_format.value if isinstance(output_format, VectorType) else output_format
348
- if output_format_str == VectorType.single.value:
349
  embeddings = embeddings.single_vec_emb
350
  if truncate_dim is not None:
351
  embeddings = embeddings[:, :truncate_dim]
@@ -362,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
362
 
363
  def _validate_encoding_params(
364
  self,
365
- output_format: Optional[Union[str, VectorType]] = None,
366
  truncate_dim: Optional[int] = None,
367
  prompt_name: Optional[str] = None,
368
  ) -> Dict[str, Any]:
@@ -379,16 +379,16 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
379
  else PREFIX_DICT["query"]
380
  )
381
 
382
- output_format = output_format or VectorType.single
383
- if isinstance(output_format, VectorType):
384
  encode_kwargs["output_format"] = output_format.value
385
  else:
386
  try:
387
- output_format_enum = VectorType(output_format)
388
  encode_kwargs["output_format"] = output_format_enum.value
389
  except ValueError:
390
  raise ValueError(
391
- f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorType]}."
392
  )
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
@@ -422,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
425
- output_format: Optional[Union[str, VectorType]] = None,
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
@@ -434,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
434
  texts: text or list of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
- output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
438
  return_numpy: Whether to return numpy arrays instead of torch tensors
439
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
440
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -489,7 +489,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
489
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
- output_format: Optional[Union[str, VectorType]] = None,
493
  return_numpy: bool = False,
494
  truncate_dim: Optional[int] = None,
495
  max_pixels: Optional[int] = None,
@@ -500,7 +500,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
500
  Args:
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
- output_format: Type of embedding vector to generate (VectorType.single or VectorType.multi)
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image
 
30
  passage = "passage"
31
 
32
 
33
+ class VectorOutputFormat(str, Enum):
34
+ SINGLE = "single"
35
+ MULTIPLE = "multiple"
36
 
37
 
38
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
324
  task_label: Union[str, List[str]],
325
  processor_fn: Callable,
326
  desc: str,
327
+ output_format: Union[str, VectorOutputFormat] = VectorOutputFormat.SINGLE,
328
  return_numpy: bool = False,
329
  batch_size: int = 32,
330
  truncate_dim: Optional[int] = None,
 
344
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
345
  ):
346
  embeddings = self(**batch, task_label=task_label)
347
+ output_format_str = output_format.value if isinstance(output_format, VectorOutputFormat) else output_format
348
+ if output_format_str == VectorOutputFormat.SINGLE.value:
349
  embeddings = embeddings.single_vec_emb
350
  if truncate_dim is not None:
351
  embeddings = embeddings[:, :truncate_dim]
 
362
 
363
  def _validate_encoding_params(
364
  self,
365
+ output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
366
  truncate_dim: Optional[int] = None,
367
  prompt_name: Optional[str] = None,
368
  ) -> Dict[str, Any]:
 
379
  else PREFIX_DICT["query"]
380
  )
381
 
382
+ output_format = output_format or VectorOutputFormat.SINGLE
383
+ if isinstance(output_format, VectorOutputFormat):
384
  encode_kwargs["output_format"] = output_format.value
385
  else:
386
  try:
387
+ output_format_enum = VectorOutputFormat(output_format)
388
  encode_kwargs["output_format"] = output_format_enum.value
389
  except ValueError:
390
  raise ValueError(
391
+ f"Invalid output_format: {output_format}. Must be one of {[v.value for v in VectorOutputFormat]}."
392
  )
393
 
394
  truncate_dim = truncate_dim or self.config.truncate_dim
 
422
  task: Optional[str] = None,
423
  max_length: int = 8192,
424
  batch_size: int = 8,
425
+ output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
426
  return_numpy: bool = False,
427
  truncate_dim: Optional[int] = None,
428
  prompt_name: Optional[str] = None,
 
434
  texts: text or list of text strings to encode
435
  max_length: Maximum token length for text processing
436
  batch_size: Number of texts to process at once
437
+ output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
438
  return_numpy: Whether to return numpy arrays instead of torch tensors
439
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
440
  prompt_name: Type of text being encoded ('query' or 'passage')
 
489
  images: Union[str, Image.Image, List[Union[str, Image.Image]]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
+ output_format: Optional[Union[str, VectorOutputFormat]] = VectorOutputFormat.SINGLE,
493
  return_numpy: bool = False,
494
  truncate_dim: Optional[int] = None,
495
  max_pixels: Optional[int] = None,
 
500
  Args:
501
  images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
502
  batch_size: Number of images to process at once
503
+ output_format: Type of embedding vector to generate (VectorOutputFormat.SINGLE or VectorOutputFormat.MULTIPLE)
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
505
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
506
  max_pixels: Maximum number of pixels to process per image