|
from sentence_transformers import models |
|
import torch |
|
from transformers import AutoTokenizer |
|
from optimum.onnxruntime import ORTModelForFeatureExtraction |
|
import numpy as np |
|
import os |
|
import onnxruntime |
|
|
|
|
|
model_dir = "embeddinggemma-300m" |
|
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
onnx_model = ORTModelForFeatureExtraction.from_pretrained( |
|
model_dir, |
|
file_name="model.onnx" |
|
).to(device) |
|
|
|
class ONNXTransformer: |
|
def __init__(self, onnx_model, tokenizer, max_seq_length=2048): |
|
self.onnx_model = onnx_model |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_seq_length |
|
def encode(self, sentences): |
|
inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length) |
|
input_ids = inputs['input_ids'] |
|
sequence_length = input_ids.shape[1] |
|
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length) |
|
inputs['position_ids'] = position_ids.to(input_ids.device) |
|
with torch.no_grad(): |
|
outputs = self.onnx_model(**inputs) |
|
return outputs.last_hidden_state |
|
|
|
modules = [] |
|
onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048) |
|
modules.append(onnx_transformer) |
|
for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]: |
|
module_path = os.path.join(model_dir, f"{idx}_{name}") |
|
if name == "Pooling": |
|
modules.append(models.Pooling(module_path)) |
|
elif name == "Dense": |
|
|
|
dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx") |
|
modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"])) |
|
elif name == "Normalize": |
|
modules.append(models.Normalize()) |
|
|
|
class ONNXSentenceTransformer: |
|
def __init__(self, modules): |
|
self.modules = modules |
|
def encode(self, sentences): |
|
features = self.modules[0].encode(sentences) |
|
for module in self.modules[1:]: |
|
if isinstance(module, models.Pooling): |
|
features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding'] |
|
elif isinstance(module, onnxruntime.InferenceSession): |
|
|
|
if isinstance(features, torch.Tensor): |
|
features = features.cpu().detach().numpy() |
|
outputs = [] |
|
for vec in features: |
|
ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)} |
|
out = module.run(None, ort_inputs)[0] |
|
outputs.append(out.squeeze(0)) |
|
features = np.stack(outputs, axis=0) |
|
elif isinstance(module, models.Normalize): |
|
|
|
if not isinstance(features, torch.Tensor): |
|
features = torch.from_numpy(features) |
|
features = module({'sentence_embedding': features})['sentence_embedding'] |
|
if isinstance(features, torch.Tensor): |
|
return features.cpu().detach().numpy() |
|
return features |
|
|
|
onnx_st = ONNXSentenceTransformer(modules) |
|
|
|
def cosine_similarity(a, b): |
|
a = a.flatten() |
|
b = b.flatten() |
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
if __name__ == "__main__": |
|
words = ["apple", "banana", "car"] |
|
embeddings = onnx_st.encode(words) |
|
print(embeddings) |
|
for idx, embedding in enumerate(embeddings): |
|
print(f"Embedding {idx+1}: {embedding.shape}") |
|
|
|
print("\nCosine similarities:") |
|
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}") |
|
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}") |
|
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}") |
|
|