Col-InternVL2-4B / README.md
puar-playground's picture
Update README.md
767fd09 verified
|
raw
history blame
5.69 kB
---
base_model: OpenGVLab/InternVL2-4B
library_name: peft
---
# Model Details
- **Developed by:** Jian Chen
- **Model type:** MLLM-based encoder
- **Finetuned from model:** [OpenGVLab/InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B)
## Model Sources [optional]
<!-- Provide the basic links for the model. -->
- **Repository:** [SV-RAG](https://github.com/puar-playground/SV-RAG)
- **Paper [optional]:** [SV-RAG: LoRA-Contextualizing Adaptation of Large Multimodal Models for Long Document Understanding](https://arxiv.org/abs/2411.01106)
## Uses
A demo script is provided in the [GitHub](https://github.com/puar-playground/SV-RAG/blob/main/test_retrieval.py)
Alternatively, this code provides a more detailed breakdown of the computation. The [`colpali_engine`](https://github.com/puar-playground/SV-RAG/tree/main/colpali_engine) used is customized and is available in the GitHub.
```
from colpali_engine.models import ColInternvl2_4b, ColInternProcessor
class ColInternVL2Retriever(BaseRetriever):
"""Retriever class using ColInternVL2 for multimodal retrieval."""
def __init__(self, model_name="puar-playground/Col-InternVL2-4B", device="cuda" if torch.cuda.is_available() else "cpu"):
"""
Initializes the ColInternVL2 model.
Args:
model_name (str): The model identifier.
device (str): Device to run the model on ('cuda' or 'cpu').
"""
os.system('pip install transformers==4.47.1')
self.multimodel = True
self.device = device
self.model = ColInternvl2_4b.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device).eval()
self.processor = ColInternProcessor('OpenGVLab/InternVL2-4B')
def process_text(self, query_list: List[str], batch_size: int = 4):
"""
Processes a list of text queries into embeddings using ColPhi in batches.
Args:
query_list (List[str]): List of query texts.
batch_size (int): Number of queries processed per batch.
Returns:
torch.Tensor: Concatenated embeddings for all queries.
"""
all_embeddings = []
for i in range(0, len(query_list), batch_size):
batch_queries = query_list[i : i + batch_size]
# Convert queries to model-compatible format
batch_inputs = self.processor.process_queries(batch_queries).to(self.model.device)
with torch.no_grad():
batch_embeddings = self.model(**batch_inputs)
all_embeddings.append(batch_embeddings.to("cpu"))
# Concatenate all batch outputs into a single tensor
all_embeddings = self.pad_and_cat_tensors(all_embeddings)
return all_embeddings
@staticmethod
def pad_and_cat_tensors(tensor_list):
# Find the maximum length of the second dimension (x_i) across all tensors
max_x = max(tensor.size(1) for tensor in tensor_list)
# Pad tensors to have the same size in the second dimension
padded_tensors = []
for tensor in tensor_list:
padding_size = max_x - tensor.size(1)
# Pad with zeros on the right in the second dimension
padded_tensor = torch.nn.functional.pad(tensor, (0, 0, 0, padding_size))
padded_tensors.append(padded_tensor)
# Concatenate the padded tensors along the first dimension
result_tensor = torch.cat(padded_tensors, dim=0)
return result_tensor
def process_image(self, image_dir_list: List[str]):
"""Processes images into embeddings using ColInternVL2."""
def process_images_in_batches(processor, img_dir_list, model, batch_size=2):
all_embeddings = []
# Split img_dir_list into batches
for img_dir in img_dir_list:
img = Image.open(img_dir)
# Process the batch of images
batch_features = processor.process_images(img)
# Extract the tensor from the BatchFeature object
batch_images = {k: v.to(model.device) for k, v in batch_features.items()}
# Assuming the model expects a specific input (e.g., 'pixel_values')
embeddings = model(**batch_images)
# Move embeddings to CPU and append to the list
embeddings = embeddings.to("cpu")
all_embeddings.append(embeddings)
# Concatenate all processed batches into a single tensor
all_embeddings = self.pad_and_cat_tensors(all_embeddings)
return all_embeddings
# Forward pass
with torch.no_grad():
# image_embeddings = model(**batch_images)
image_embeddings = process_images_in_batches(self.processor, image_dir_list, self.model)
return image_embeddings
def compute_similarity(self, text_embeddings, image_embeddings):
""" Computes cosine similarity between text and image embeddings. """
scores = self.processor.score_multi_vector(text_embeddings, image_embeddings)
return scores
def retrieve(self, query_list: str, image_list: List[str]):
text_embeddings = self.process_text(query_list)
image_embeddings = self.process_image(image_list)
similarity_score = self.compute_similarity(text_embeddings, image_embeddings)
values, top_indices = torch.tensor(similarity_score).sort(descending=True)
return values, top_indices
```
## Citation [optional]