|
|
--- |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
base_model: monsoon-nlp/dna-blockdiff-2 |
|
|
--- |
|
|
|
|
|
# DNA and Block Diffusion |
|
|
|
|
|
Using the [Block Diffusion](https://github.com/kuleshov-group/bd3lms) architecture and |
|
|
[AgroNT](https://huggingface.co/InstaDeepAI/agro-nucleotide-transformer-1b)'s six-nucleotide-length tokens. |
|
|
|
|
|
Took [dna-blockdiff-2](https://huggingface.co/monsoon-nlp/dna-blockdiff-2) weights, |
|
|
trained on [Papaya genome](https://huggingface.co/datasets/monsoon-nlp/wheat-bees) for one epoch. |
|
|
|
|
|
Training loss was up and down, but validation curve (on [human genome](https://huggingface.co/datasets/dnagpt/human_genome_GCF_009914755.1)) was consistently improving |
|
|
|
|
|
### Loading model |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForMaskedLM |
|
|
m = AutoModelForMaskedLM.from_pretrained( |
|
|
"monsoon-nlp/dna-blockdiff-papaya", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
``` |
|
|
|
|
|
### Perplexity of a sequence |
|
|
|
|
|
``` |
|
|
cd bd3lms && python -u main.py \ |
|
|
loader.eval_batch_size=1 \ |
|
|
model=small \ |
|
|
algo=bd3lm \ |
|
|
algo.T=5000 \ |
|
|
algo.backbone=hf_dit \ |
|
|
data=instadeep \ |
|
|
model.length=256 \ |
|
|
block_size=4 \ |
|
|
wandb=null \ |
|
|
mode=ppl_eval \ |
|
|
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \ |
|
|
model.attn_backend=sdpa \ |
|
|
sampling.nucleus_p=0.9 \ |
|
|
sampling.kv_cache=true \ |
|
|
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \ |
|
|
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya" |
|
|
``` |
|
|
|
|
|
### Generating text |
|
|
|
|
|
```bash |
|
|
cd bd3lms && python -u main.py \ |
|
|
loader.eval_batch_size=1 \ |
|
|
model=small \ |
|
|
algo=bd3lm \ |
|
|
algo.T=5000 \ |
|
|
algo.backbone=hf_dit \ |
|
|
data=instadeep \ |
|
|
model.length=256 \ |
|
|
block_size=4 \ |
|
|
wandb=null \ |
|
|
mode=sample_eval \ |
|
|
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \ |
|
|
model.attn_backend=sdpa \ |
|
|
sampling.nucleus_p=0.9 \ |
|
|
sampling.kv_cache=true \ |
|
|
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \ |
|
|
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya" |
|
|
``` |
|
|
|
|
|
Currently this generates `<cls> N N N N N...` but could be improved by guiding decoding |