fix-task-setting-and-st-load (#46)
Browse files- fix: load in new st, task setting (5a13d0f29dd4f13b4a0d82f530acd43d189d44fc)
- README.md +1 -1
- custom_st.py +18 -5
- modeling_jina_embeddings_v4.py +23 -20
README.md
CHANGED
|
@@ -155,7 +155,7 @@ from transformers import AutoModel
|
|
| 155 |
import torch
|
| 156 |
|
| 157 |
# Initialize the model
|
| 158 |
-
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True)
|
| 159 |
|
| 160 |
model.to("cuda")
|
| 161 |
|
|
|
|
| 155 |
import torch
|
| 156 |
|
| 157 |
# Initialize the model
|
| 158 |
+
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True, torch_dtype=torch.float16)
|
| 159 |
|
| 160 |
model.to("cuda")
|
| 161 |
|
custom_st.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from io import BytesIO
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Any, Dict, List, Literal, Optional, Union
|
|
@@ -104,7 +106,10 @@ class Transformer(nn.Module):
|
|
| 104 |
return encoding
|
| 105 |
|
| 106 |
def forward(
|
| 107 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 108 |
) -> Dict[str, torch.Tensor]:
|
| 109 |
self.model.eval()
|
| 110 |
|
|
@@ -138,8 +143,10 @@ class Transformer(nn.Module):
|
|
| 138 |
**text_batch, task_label=task
|
| 139 |
).single_vec_emb
|
| 140 |
if truncate_dim:
|
| 141 |
-
text_embeddings = text_embeddings[:, :
|
| 142 |
-
text_embeddings = torch.nn.functional.normalize(
|
|
|
|
|
|
|
| 143 |
for i, embedding in enumerate(text_embeddings):
|
| 144 |
all_embeddings.append((text_indices[i], embedding))
|
| 145 |
|
|
@@ -156,8 +163,10 @@ class Transformer(nn.Module):
|
|
| 156 |
**image_batch, task_label=task
|
| 157 |
).single_vec_emb
|
| 158 |
if truncate_dim:
|
| 159 |
-
img_embeddings = img_embeddings[:, :
|
| 160 |
-
img_embeddings = torch.nn.functional.normalize(
|
|
|
|
|
|
|
| 161 |
|
| 162 |
for i, embedding in enumerate(img_embeddings):
|
| 163 |
all_embeddings.append((image_indices[i], embedding))
|
|
@@ -170,3 +179,7 @@ class Transformer(nn.Module):
|
|
| 170 |
features["sentence_embedding"] = combined_embeddings
|
| 171 |
|
| 172 |
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
from io import BytesIO
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
|
| 106 |
return encoding
|
| 107 |
|
| 108 |
def forward(
|
| 109 |
+
self,
|
| 110 |
+
features: Dict[str, torch.Tensor],
|
| 111 |
+
task: Optional[str] = None,
|
| 112 |
+
truncate_dim: Optional[int] = None,
|
| 113 |
) -> Dict[str, torch.Tensor]:
|
| 114 |
self.model.eval()
|
| 115 |
|
|
|
|
| 143 |
**text_batch, task_label=task
|
| 144 |
).single_vec_emb
|
| 145 |
if truncate_dim:
|
| 146 |
+
text_embeddings = text_embeddings[:, :truncate_dim]
|
| 147 |
+
text_embeddings = torch.nn.functional.normalize(
|
| 148 |
+
text_embeddings, p=2, dim=-1
|
| 149 |
+
)
|
| 150 |
for i, embedding in enumerate(text_embeddings):
|
| 151 |
all_embeddings.append((text_indices[i], embedding))
|
| 152 |
|
|
|
|
| 163 |
**image_batch, task_label=task
|
| 164 |
).single_vec_emb
|
| 165 |
if truncate_dim:
|
| 166 |
+
img_embeddings = img_embeddings[:, :truncate_dim]
|
| 167 |
+
img_embeddings = torch.nn.functional.normalize(
|
| 168 |
+
img_embeddings, p=2, dim=-1
|
| 169 |
+
)
|
| 170 |
|
| 171 |
for i, embedding in enumerate(img_embeddings):
|
| 172 |
all_embeddings.append((image_indices[i], embedding))
|
|
|
|
| 179 |
features["sentence_embedding"] = combined_embeddings
|
| 180 |
|
| 181 |
return features
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def load(cls, input_path: str) -> "Transformer":
|
| 185 |
+
return cls(model_name_or_path=input_path)
|
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -242,7 +242,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 242 |
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
|
| 243 |
dim=1, keepdim=True
|
| 244 |
)
|
| 245 |
-
|
| 246 |
else: # got query text
|
| 247 |
pooled_output = torch.sum(
|
| 248 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
|
@@ -332,7 +331,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 332 |
collate_fn=processor_fn,
|
| 333 |
)
|
| 334 |
if return_multivector and len(data) > 1:
|
| 335 |
-
assert
|
|
|
|
|
|
|
| 336 |
results = []
|
| 337 |
self.eval()
|
| 338 |
for batch in tqdm(dataloader, desc=desc):
|
|
@@ -346,10 +347,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 346 |
embeddings = embeddings.single_vec_emb
|
| 347 |
if truncate_dim is not None:
|
| 348 |
embeddings = embeddings[:, :truncate_dim]
|
| 349 |
-
embeddings = torch.nn.functional.normalize(
|
|
|
|
|
|
|
| 350 |
else:
|
| 351 |
embeddings = embeddings.multi_vec_emb
|
| 352 |
-
|
| 353 |
if return_multivector and not return_numpy:
|
| 354 |
valid_tokens = batch["attention_mask"].bool()
|
| 355 |
embeddings = [
|
|
@@ -436,7 +439,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 436 |
List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
|
| 437 |
"""
|
| 438 |
prompt_name = prompt_name or "query"
|
| 439 |
-
encode_kwargs = self._validate_encoding_params(
|
|
|
|
|
|
|
| 440 |
|
| 441 |
task = self._validate_task(task)
|
| 442 |
|
|
@@ -451,9 +456,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 451 |
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
| 452 |
if return_multivector and return_list and len(texts) > 1:
|
| 453 |
if return_numpy:
|
| 454 |
-
print(
|
|
|
|
|
|
|
| 455 |
return_numpy = False
|
| 456 |
-
|
| 457 |
if isinstance(texts, str):
|
| 458 |
texts = [texts]
|
| 459 |
|
|
@@ -468,7 +475,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 468 |
**encode_kwargs,
|
| 469 |
)
|
| 470 |
|
| 471 |
-
return embeddings if return_list else embeddings[0]
|
| 472 |
|
| 473 |
def _load_images_if_needed(
|
| 474 |
self, images: List[Union[str, Image.Image]]
|
|
@@ -515,19 +522,21 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 515 |
)
|
| 516 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
| 517 |
task = self._validate_task(task)
|
| 518 |
-
|
| 519 |
return_list = isinstance(images, list)
|
| 520 |
|
| 521 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
| 522 |
if return_multivector and return_list and len(images) > 1:
|
| 523 |
if return_numpy:
|
| 524 |
-
print(
|
|
|
|
|
|
|
| 525 |
return_numpy = False
|
| 526 |
|
| 527 |
# Convert single image to list
|
| 528 |
if isinstance(images, (str, Image.Image)):
|
| 529 |
images = [images]
|
| 530 |
-
|
| 531 |
images = self._load_images_if_needed(images)
|
| 532 |
embeddings = self._process_batches(
|
| 533 |
data=images,
|
|
@@ -588,18 +597,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 588 |
config=lora_config,
|
| 589 |
)
|
| 590 |
|
| 591 |
-
|
| 592 |
-
def task(self):
|
| 593 |
return self.model.task
|
| 594 |
|
| 595 |
-
|
| 596 |
-
def task(self, value):
|
| 597 |
self.model.task = value
|
| 598 |
|
| 599 |
-
peft_model.task = property(
|
| 600 |
-
peft_model.__class__.task = property(
|
| 601 |
-
lambda self: self.model.task,
|
| 602 |
-
lambda self, value: setattr(self.model, "task", value),
|
| 603 |
-
)
|
| 604 |
|
| 605 |
return peft_model
|
|
|
|
| 242 |
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
|
| 243 |
dim=1, keepdim=True
|
| 244 |
)
|
|
|
|
| 245 |
else: # got query text
|
| 246 |
pooled_output = torch.sum(
|
| 247 |
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
|
|
|
| 331 |
collate_fn=processor_fn,
|
| 332 |
)
|
| 333 |
if return_multivector and len(data) > 1:
|
| 334 |
+
assert (
|
| 335 |
+
not return_numpy
|
| 336 |
+
), "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
|
| 337 |
results = []
|
| 338 |
self.eval()
|
| 339 |
for batch in tqdm(dataloader, desc=desc):
|
|
|
|
| 347 |
embeddings = embeddings.single_vec_emb
|
| 348 |
if truncate_dim is not None:
|
| 349 |
embeddings = embeddings[:, :truncate_dim]
|
| 350 |
+
embeddings = torch.nn.functional.normalize(
|
| 351 |
+
embeddings, p=2, dim=-1
|
| 352 |
+
)
|
| 353 |
else:
|
| 354 |
embeddings = embeddings.multi_vec_emb
|
| 355 |
+
|
| 356 |
if return_multivector and not return_numpy:
|
| 357 |
valid_tokens = batch["attention_mask"].bool()
|
| 358 |
embeddings = [
|
|
|
|
| 439 |
List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
|
| 440 |
"""
|
| 441 |
prompt_name = prompt_name or "query"
|
| 442 |
+
encode_kwargs = self._validate_encoding_params(
|
| 443 |
+
truncate_dim=truncate_dim, prompt_name=prompt_name
|
| 444 |
+
)
|
| 445 |
|
| 446 |
task = self._validate_task(task)
|
| 447 |
|
|
|
|
| 456 |
# If return_multivector is True and encoding multiple texts, ignore return_numpy
|
| 457 |
if return_multivector and return_list and len(texts) > 1:
|
| 458 |
if return_numpy:
|
| 459 |
+
print(
|
| 460 |
+
"Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`"
|
| 461 |
+
)
|
| 462 |
return_numpy = False
|
| 463 |
+
|
| 464 |
if isinstance(texts, str):
|
| 465 |
texts = [texts]
|
| 466 |
|
|
|
|
| 475 |
**encode_kwargs,
|
| 476 |
)
|
| 477 |
|
| 478 |
+
return embeddings if return_list else embeddings[0]
|
| 479 |
|
| 480 |
def _load_images_if_needed(
|
| 481 |
self, images: List[Union[str, Image.Image]]
|
|
|
|
| 522 |
)
|
| 523 |
encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
|
| 524 |
task = self._validate_task(task)
|
| 525 |
+
|
| 526 |
return_list = isinstance(images, list)
|
| 527 |
|
| 528 |
# If return_multivector is True and encoding multiple images, ignore return_numpy
|
| 529 |
if return_multivector and return_list and len(images) > 1:
|
| 530 |
if return_numpy:
|
| 531 |
+
print(
|
| 532 |
+
"Warning: `return_numpy` is ignored when `return_multivector=True` and `len(images) > 1`"
|
| 533 |
+
)
|
| 534 |
return_numpy = False
|
| 535 |
|
| 536 |
# Convert single image to list
|
| 537 |
if isinstance(images, (str, Image.Image)):
|
| 538 |
images = [images]
|
| 539 |
+
|
| 540 |
images = self._load_images_if_needed(images)
|
| 541 |
embeddings = self._process_batches(
|
| 542 |
data=images,
|
|
|
|
| 597 |
config=lora_config,
|
| 598 |
)
|
| 599 |
|
| 600 |
+
def task_getter(self):
|
|
|
|
| 601 |
return self.model.task
|
| 602 |
|
| 603 |
+
def task_setter(self, value):
|
|
|
|
| 604 |
self.model.task = value
|
| 605 |
|
| 606 |
+
peft_model.__class__.task = property(task_getter, task_setter)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
return peft_model
|