Safetensors

Essence 3B V1

A version of SmolLM3-3B-Base, the 'encoder', is finetuned to turn a text into a set of 'embedding tokens' which can be reconstituted back into the original text with another version, the 'decoder'.

We use LoRA at rank 64 on QKVO along with trainable LayerNorms and, for the encoder, LoRA on all MLP layers as well as trainable token embeddings.

During training this model system saw, in addition to the base model's token count, 138 million tokens over the course of 9000 training steps.

The model was trained to encode text into 4, 8, 16, 32, or 64 embedding tokens, and exhibits some limited generalization to other embedding lengths.

Simple Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch import nn
import torch
from huggingface_hub import hf_hub_download

device = torch.device("cuda:0")
dtype = torch.bfloat16
base_model_id = "HuggingFaceTB/SmolLM3-3B-Base"
compressor_id = "midwestern-simulation/essence-3b-v1"

# === MODEL LOADING ===

tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side='left')
encoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)
decoder = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"":device}, torch_dtype=dtype)

encoder = PeftModel.from_pretrained(encoder, compressor_id, subfolder="encoder")
decoder = PeftModel.from_pretrained(decoder, compressor_id, subfolder="decoder")

projector = nn.Linear(2048, 2048).to(device).to(dtype)
projector.load_state_dict(torch.load(hf_hub_download(repo_id=compressor_id, filename="projector.pt")))


# === MODEL INFERENCE ===

text = "mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow"
n_embed_tokens = 4 # for best performance, can be 4, 8, 16, 32, or 64

encoder_input = text.strip() + f"\n[[/END DOCUMENT]]\n[[START SUMMARY ntoks={n_embed_tokens}]]" + "<|im_end|>" * n_embed_tokens

tokenized = tokenizer(encoder_input, return_tensors='pt', add_special_tokens=False)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
encoding = encoder.model.model(**tokenized).last_hidden_state[:, -n_embed_tokens:, :]
encoding = projector(encoding)

tokenized_prefix = tokenizer("\n[[/END SUMMARY]]\n[[START DOCUMENT]]\n", return_tensors="pt", add_special_tokens=False)
prefix_embeds = decoder.model.model.embed_tokens(tokenized_prefix['input_ids'].to(device))
inputs_embeds = torch.cat([encoding, prefix_embeds], 1)
output = decoder.generate(
    inputs_embeds=inputs_embeds,
    temperature=0.7,
    max_new_tokens=1024,
    do_sample=True,
    top_k=128,
    min_new_tokens=8,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)
print(tokenizer.decode(output[0]))
# mary had a little lamb, little lamb, little lamb, mary had a little lamb whose fleece was white as snow
# [[/END DOCUMENT]]<|end_of_text|>
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for midwestern-simulation/essence-3b-v1

Finetuned
(17)
this model

Dataset used to train midwestern-simulation/essence-3b-v1