monsoon-nlp's picture
Update README.md
618de05 verified
|
raw
history blame
2.09 kB
metadata
library_name: transformers
license: apache-2.0
base_model: monsoon-nlp/dna-blockdiff-2

DNA and Block Diffusion

Using the Block Diffusion architecture and AgroNT's six-nucleotide-length tokens.

Took dna-blockdiff-2 weights, trained on Papaya genome for one epoch.

Training loss was up and down, but validation curve (on human genome) was consistently improving

Loading model

from transformers import AutoModelForMaskedLM
m = AutoModelForMaskedLM.from_pretrained(
    "monsoon-nlp/dna-blockdiff-papaya",
    trust_remote_code=True,
)

Perplexity of a sequence

cd bd3lms && python -u main.py \
    loader.eval_batch_size=1 \
    model=small \
    algo=bd3lm \
    algo.T=5000 \
    algo.backbone=hf_dit \
    data=instadeep \
    model.length=256 \
    block_size=4 \
    wandb=null \
    mode=ppl_eval \
    eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
    model.attn_backend=sdpa \
    sampling.nucleus_p=0.9 \
    sampling.kv_cache=true \
    sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
    data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"

Generating text

cd bd3lms && python -u main.py \
    loader.eval_batch_size=1 \
    model=small \
    algo=bd3lm \
    algo.T=5000 \
    algo.backbone=hf_dit \
    data=instadeep \
    model.length=256 \
    block_size=4 \
    wandb=null \
    mode=sample_eval \
    eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
    model.attn_backend=sdpa \
    sampling.nucleus_p=0.9 \
    sampling.kv_cache=true \
    sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
    data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"

Currently this generates <cls> N N N N N... but could be improved by guiding decoding