Sentence Similarity
sentence-transformers
Safetensors
code
feature-extraction
dense
Generated from Trainer
dataset_size:143054
loss:MultipleNegativesRankingLoss
Eval Results

SentenceTransformer based on NeuML/pubmedbert-base-embeddings

This is a sentence-transformers model finetuned from NeuML/pubmedbert-base-embeddings on the cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation and geo_70k_multiplets_natural_language_annotation datasets. It maps sentences & paragraphs to a 1024-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.

Model Details

Model Description

Model Sources

Full Model Architecture

SentenceTransformer(
  (0): MMContextEncoder(
    (text_encoder): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (text_adapter): AdapterModule(
      (net): Sequential(
        (0): Linear(in_features=768, out_features=512, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=512, out_features=1024, bias=True)
        (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (pooling): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  )
)

Usage

Direct Usage (Sentence Transformers)

First install the Sentence Transformers library:

pip install -U sentence-transformers

Then you can load this model and run inference.

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("jo-mengr/mmcontext-pubmedbert-v2")
# Run inference
sentences = [
    'census_1b9d8702-5af8-4142-85ed-020eb06ec4f6_20229',
    "This measurement was conducted with 10x 3' v3. Terminally differentiated CD8+ T cells from the lung tissue of a male individual in his sixties.",
    "This measurement was conducted with 10x 5' v2. Sample contains regulatory T cells (Tregs), specifically T cells, from a female individual in her eighth decade, isolated from a thoracic lymph node.",
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.4047, 0.4202],
#         [0.4047, 1.0000, 0.9129],
#         [0.4202, 0.9129, 1.0000]])

Evaluation

Metrics

Triplet

  • Datasets: cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation_cell_sentence_1 and geo_70k_multiplets_natural_language_annotation_cell_sentence_1
  • Evaluated with TripletEvaluator
Metric cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation_cell_sentence_1 geo_70k_multiplets_natural_language_annotation_cell_sentence_1
cosine_accuracy 0.5292 0.7119

Training Details

Training Datasets

cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation

  • Dataset: cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation at b141493
  • Size: 81,143 training samples
  • Columns: anchor, positive, negative_1, and negative_2
  • Approximate statistics based on the first 1000 samples:
    anchor positive negative_1 negative_2
    type string string string string
    details
    • min: 45 characters
    • mean: 47.72 characters
    • max: 49 characters
    • min: 92 characters
    • mean: 216.13 characters
    • max: 900 characters
    • min: 101 characters
    • mean: 215.14 characters
    • max: 870 characters
    • min: 45 characters
    • mean: 47.75 characters
    • max: 49 characters
  • Samples:
    anchor positive negative_1 negative_2
    census_218acb0f-9f2f-4f76-b90b-15a4b7c7f629_26009 This measurement was conducted with 10x 3' v2. A proliferating lymphocyte cell sample, obtained from a 34-year-old female Asian individual, derived from peripheral blood mononuclear cells. This measurement was conducted with 10x 3' v2. Sample is a 25-year-old female with European ethnicity, having CD8-positive, alpha-beta T cell type. This cell type exhibits elevated expression of type 1 interferon-stimulated genes (ISGs) in monocytes, reduction of naïve CD4+ T cells correlating with monocyte ISG expression, and expansion of repertoire-restricted cytotoxic GZMH+ CD8+ T cells. census_218acb0f-9f2f-4f76-b90b-15a4b7c7f629_14165
    census_1b9d8702-5af8-4142-85ed-020eb06ec4f6_6333 This measurement was conducted with 10x 5' v1. Sample is a cell from the omentum tissue, specifically an effector memory CD4-positive, alpha-beta T cell, from a female in her sixth decade. This measurement was conducted with 10x 5' v2. Conventional dendritic cell from the jejunal epithelium of a female in her eighth decade. census_1b9d8702-5af8-4142-85ed-020eb06ec4f6_2714
    census_adda0684-f8ea-4403-b393-2a25607077c4_271 This measurement was conducted with 10x 3' v3. Neuron cell type from a 29-year-old male, specifically from the thalamic complex, specifically the thalamus (THM) - posterior nuclear complex of thalamus (PoN) - medial geniculate nuclei (MG). This measurement was conducted with 10x 3' v3. Neuron from the thalamic complex (thalamus, posterior nuclear complex of thalamus, medial geniculate nuclei) of a 42-year-old male, identified as a midbrain-derived inhibitory neuron. census_adda0684-f8ea-4403-b393-2a25607077c4_425
  • Loss: MultipleNegativesRankingLoss with these parameters:
    {
        "scale": 20.0,
        "similarity_fct": "cos_sim"
    }
    

geo_70k_multiplets_natural_language_annotation

  • Dataset: geo_70k_multiplets_natural_language_annotation at 4c62cd1
  • Size: 61,911 training samples
  • Columns: anchor, positive, negative_1, and negative_2
  • Approximate statistics based on the first 1000 samples:
    anchor positive negative_1 negative_2
    type string string string string
    details
    • min: 9 characters
    • mean: 9.29 characters
    • max: 10 characters
    • min: 83 characters
    • mean: 189.5 characters
    • max: 698 characters
    • min: 100 characters
    • mean: 165.46 characters
    • max: 465 characters
    • min: 9 characters
    • mean: 9.1 characters
    • max: 10 characters
  • Samples:
    anchor positive negative_1 negative_2
    SRX083304 This measurement was conducted with Illumina HiSeq 2000. 5-day HeLa cell line with ELAVL1/HuR siRNA1 knockdown, 120 hours post-transfection. This measurement was conducted with Illumina HiSeq 2000. BJ fibroblast cells in a proliferative stage, with polyA RNA subtype. SRX105303
    SRX105302 This measurement was conducted with Illumina HiSeq 2000. BJ fibroblast cells in a proliferative stage, with polyA RNA subtype. This measurement was conducted with Illumina HiSeq 2000. 5-day HeLa cell line with ELAVL1/HuR siRNA1 knockdown, 120 hours post-transfection. SRX105303
    SRX105303 This measurement was conducted with Illumina HiSeq 2000. BJ fibroblast cells at a confluent growth stage, with polyA RNA subtype. This measurement was conducted with Illumina HiSeq 2000. 5-day HeLa cell line with ELAVL1/HuR siRNA1 knockdown, 120 hours post-transfection. SRX105302
  • Loss: MultipleNegativesRankingLoss with these parameters:
    {
        "scale": 20.0,
        "similarity_fct": "cos_sim"
    }
    

Evaluation Datasets

cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation

  • Dataset: cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation at b141493
  • Size: 9,011 evaluation samples
  • Columns: anchor, positive, negative_1, and negative_2
  • Approximate statistics based on the first 1000 samples:
    anchor positive negative_1 negative_2
    type string string string string
    details
    • min: 45 characters
    • mean: 47.73 characters
    • max: 49 characters
    • min: 99 characters
    • mean: 209.99 characters
    • max: 941 characters
    • min: 102 characters
    • mean: 213.87 characters
    • max: 981 characters
    • min: 45 characters
    • mean: 47.73 characters
    • max: 49 characters
  • Samples:
    anchor positive negative_1 negative_2
    census_0b4a15a7-4e9e-4555-9733-2423e5c66469_490 This measurement was conducted with 10x 3' v3. Cell sample from the cortex of kidney, taken from a 43-year-old male of European ethnicity with a reported history of kidney cancer. The cell type is identified as a kidney collecting duct intercalated cell. This measurement was conducted with 10x 3' v3. Kidney collecting duct intercalated cell from a 43-year old European male with kidney cancer, taken from the cortex of kidney and cryopreserved for further analysis. census_0b4a15a7-4e9e-4555-9733-2423e5c66469_9
    census_4976b234-9028-4b4b-8a2f-8ac59d636219_269 This measurement was conducted with 10x 3' v3. Neuron cell type from a 29-year-old male cerebellum, specifically from the Cerebellar Vermis - CBV region, with European self-reported ethnicity, analyzed at the nucleus level. This measurement was conducted with 10x 3' v3. Endothelial cells derived from the cerebellum (specifically, cerebellar vermis) of a 42-year-old male, classified under the vascular supercluster term. census_4976b234-9028-4b4b-8a2f-8ac59d636219_923
    census_44882825-0da1-4547-b721-2c6105d4a9d1_10258 This measurement was conducted with 10x 5' v1. Cell sample from the tonsil of a 9-year-old female with recurrent tonsillitis, characterized as a centroblast B cell with IGLC2, IGLV7-43, IGLJ3 immunoglobulin genes expressed. This measurement was conducted with 10x 5' v1. Centroblast cells derived from a 3-year-old male human tonsil sample, with obstructive sleep apnea and recurrent tonsillitis, undergoing affinity maturation and differentiation into memory or plasma cells. census_44882825-0da1-4547-b721-2c6105d4a9d1_9654
  • Loss: MultipleNegativesRankingLoss with these parameters:
    {
        "scale": 20.0,
        "similarity_fct": "cos_sim"
    }
    

geo_70k_multiplets_natural_language_annotation

  • Dataset: geo_70k_multiplets_natural_language_annotation at 4c62cd1
  • Size: 6,901 evaluation samples
  • Columns: anchor, positive, negative_1, and negative_2
  • Approximate statistics based on the first 1000 samples:
    anchor positive negative_1 negative_2
    type string string string string
    details
    • min: 9 characters
    • mean: 10.35 characters
    • max: 11 characters
    • min: 78 characters
    • mean: 191.46 characters
    • max: 983 characters
    • min: 90 characters
    • mean: 217.63 characters
    • max: 702 characters
    • min: 10 characters
    • mean: 10.04 characters
    • max: 11 characters
  • Samples:
    anchor positive negative_1 negative_2
    SRX2244363 This measurement was conducted with Illumina HiSeq 2000. 15-year-old male HepG2 immortalized cell line with hepatocellular carcinoma, transiently expressing shRNA targeting PKM2 for RNA-seq study. This measurement was conducted with Illumina HiSeq 2000. 15-year-old male patient with hepatocellular carcinoma; HNRNPC knocked down via shRNA in HepG2 (immortalized cell line) for RNA-seq analysis. SRX5457055
    SRX3136447 This measurement was conducted with Illumina HiSeq 2000. 16-year-old female's T cells from a control group, stimulated with ag85 at timepoint 0, and primary cells. This measurement was conducted with Illumina HiSeq 2000. 17-year-old male's monocytes stimulated with mTb, taken at 180 days post-stimulation, as part of the control group in a study. SRX3137689
    SRX2734845 This measurement was conducted with Illumina HiSeq 2500. UM-UC18 bladder cancer cell line, a type of urinary bladder cancer cell line, cultured for study of bladder disease, cancer cell proliferation, and neoplasm. This measurement was conducted with NextSeq 500. HeLa cells with PARP knockdown treatment. SRX3130770
  • Loss: MultipleNegativesRankingLoss with these parameters:
    {
        "scale": 20.0,
        "similarity_fct": "cos_sim"
    }
    

Training Hyperparameters

Non-Default Hyperparameters

  • eval_strategy: steps
  • per_device_train_batch_size: 256
  • per_device_eval_batch_size: 256
  • learning_rate: 0.05
  • num_train_epochs: 4
  • warmup_ratio: 0.1
  • bf16: True
  • gradient_checkpointing: True

All Hyperparameters

Click to expand
  • overwrite_output_dir: False
  • do_predict: False
  • eval_strategy: steps
  • prediction_loss_only: True
  • per_device_train_batch_size: 256
  • per_device_eval_batch_size: 256
  • per_gpu_train_batch_size: None
  • per_gpu_eval_batch_size: None
  • gradient_accumulation_steps: 1
  • eval_accumulation_steps: None
  • torch_empty_cache_steps: None
  • learning_rate: 0.05
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • max_grad_norm: 1.0
  • num_train_epochs: 4
  • max_steps: -1
  • lr_scheduler_type: linear
  • lr_scheduler_kwargs: {}
  • warmup_ratio: 0.1
  • warmup_steps: 0
  • log_level: passive
  • log_level_replica: warning
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • save_safetensors: True
  • save_on_each_node: False
  • save_only_model: False
  • restore_callback_states_from_checkpoint: False
  • no_cuda: False
  • use_cpu: False
  • use_mps_device: False
  • seed: 42
  • data_seed: None
  • jit_mode_eval: False
  • use_ipex: False
  • bf16: True
  • fp16: False
  • fp16_opt_level: O1
  • half_precision_backend: auto
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: None
  • local_rank: 0
  • ddp_backend: None
  • tpu_num_cores: None
  • tpu_metrics_debug: False
  • debug: []
  • dataloader_drop_last: False
  • dataloader_num_workers: 0
  • dataloader_prefetch_factor: None
  • past_index: -1
  • disable_tqdm: False
  • remove_unused_columns: True
  • label_names: None
  • load_best_model_at_end: False
  • ignore_data_skip: False
  • fsdp: []
  • fsdp_min_num_params: 0
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • fsdp_transformer_layer_cls_to_wrap: None
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • deepspeed: None
  • label_smoothing_factor: 0.0
  • optim: adamw_torch
  • optim_args: None
  • adafactor: False
  • group_by_length: False
  • length_column_name: length
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: False
  • skip_memory_metrics: True
  • use_legacy_prediction_loop: False
  • push_to_hub: False
  • resume_from_checkpoint: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_private_repo: None
  • hub_always_push: False
  • hub_revision: None
  • gradient_checkpointing: True
  • gradient_checkpointing_kwargs: None
  • include_inputs_for_metrics: False
  • include_for_metrics: []
  • eval_do_concat_batches: True
  • fp16_backend: auto
  • push_to_hub_model_id: None
  • push_to_hub_organization: None
  • mp_parameters:
  • auto_find_batch_size: False
  • full_determinism: False
  • torchdynamo: None
  • ray_scope: last
  • ddp_timeout: 1800
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • include_tokens_per_second: False
  • include_num_input_tokens_seen: False
  • neftune_noise_alpha: None
  • optim_target_modules: None
  • batch_eval_metrics: False
  • eval_on_start: False
  • use_liger_kernel: False
  • liger_kernel_config: None
  • eval_use_gather_object: False
  • average_tokens_across_devices: False
  • prompts: None
  • batch_sampler: batch_sampler
  • multi_dataset_batch_sampler: proportional
  • router_mapping: {}
  • learning_rate_mapping: {}

Training Logs

Epoch Step Training Loss cellxgene pseudo bulk 100k multiplets natural language annotation loss geo 70k multiplets natural language annotation loss cellxgene_pseudo_bulk_100k_multiplets_natural_language_annotation_cell_sentence_1_cosine_accuracy geo_70k_multiplets_natural_language_annotation_cell_sentence_1_cosine_accuracy
0.1789 100 5.4042 12.7258 18.5172 0.5005 0.4127
0.3578 200 4.4018 21.3994 27.5368 0.5012 0.4662
0.5367 300 4.274 13.9052 15.7111 0.5054 0.5134
0.7156 400 3.9977 17.2145 18.8384 0.5060 0.6522
0.8945 500 3.8001 18.0511 20.1693 0.5058 0.3982
1.0733 600 3.7527 15.4862 20.6695 0.5064 0.4026
1.2522 700 3.7414 15.5879 13.7452 0.5089 0.3620
1.4311 800 3.4425 14.7486 11.9465 0.5069 0.5199
1.6100 900 3.3452 13.9171 14.1143 0.5113 0.4123
1.7889 1000 3.2576 15.2234 12.7155 0.5143 0.4120
1.9678 1100 3.322 14.6208 10.3553 0.5262 0.4456
2.1467 1200 3.1823 12.7034 12.2282 0.5236 0.5434
2.3256 1300 3.1449 11.2867 9.9116 0.5292 0.6111
2.5045 1400 3.0859 10.7462 9.3380 0.5349 0.6641
2.6834 1500 3.0582 12.2004 8.9558 0.5298 0.6866
2.8623 1600 2.9614 11.8808 8.8887 0.5317 0.6728
3.0411 1700 3.008 12.4199 8.4042 0.5290 0.6911
3.2200 1800 2.9739 10.9099 8.9717 0.5379 0.6656
3.3989 1900 2.9152 11.6201 8.5289 0.5314 0.6954
3.5778 2000 2.9668 11.9039 8.4831 0.5318 0.6860
3.7567 2100 3.0303 11.2059 8.9941 0.5368 0.6696
3.9356 2200 2.936 12.0965 7.9045 0.5292 0.7119

Framework Versions

  • Python: 3.11.6
  • Sentence Transformers: 5.0.0
  • Transformers: 4.55.0.dev0
  • PyTorch: 2.5.1+cu121
  • Accelerate: 1.9.0
  • Datasets: 2.19.1
  • Tokenizers: 0.21.4

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}

MultipleNegativesRankingLoss

@misc{henderson2017efficient,
    title={Efficient Natural Language Response Suggestion for Smart Reply},
    author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
    year={2017},
    eprint={1705.00652},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for jo-mengr/mmcontext-pubmedbert-v2

Evaluation results

  • Cosine Accuracy on cellxgene pseudo bulk 100k multiplets natural language annotation cell sentence 1
    self-reported
    0.529
  • Cosine Accuracy on geo 70k multiplets natural language annotation cell sentence 1
    self-reported
    0.712