Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune script #11

Open
JinyuanSun opened this issue Feb 28, 2024 · 9 comments
Open

Finetune script #11

JinyuanSun opened this issue Feb 28, 2024 · 9 comments

Comments

@JinyuanSun
Copy link

Could you provide a script/notebook demo to show how to finetune this model?

@philippbayer
Copy link

Here's how I just did it, would be curious to see if there's anything model-specific I can use to make the training go faster. I ran into a bunch of issues, see comments.

I had to use a single letter for the padding (N instead of [PAD]), I had to split my fastas into pieces of 500bp or the A100 would run out of memory.

I'm using mitogenome pieces of fastas of fish; not prokaryotes but of prokaryotic origin :)

import torch
print(torch.cuda.is_available())
from transformers import AutoConfig, AutoModelForCausalLM, TrainingArguments, Trainer
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from datasets import Dataset, load_dataset, DatasetDict
from Bio import SeqIO

# they used the 1-8k base model as a start for finetuning, so let's do the same for now
model_name = 'togethercomputer/evo-1-8k-base'

model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=model_config,
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained("togethercomputer/evo-1-8k-base",
        trust_remote_code=True)

## make a Dataset from my local data

with open('all_fastas.csv', 'w') as out_fh:
    out_fh.write('ID,Seq\n')
    for x in SeqIO.parse('12S.v0.10.16S.v0.4.Mitogenomes.v0.1.fasta', 'fasta'):
        full_seq = str(x.seq)
        # splitting the sequences in pieces of 500bp so the A100 doesn't run out of memory
        full_seq_split = [full_seq[i:i+500] for i in range(0, len(full_seq), 500)]
        for counter, i in enumerate(full_seq_split):
            out_fh.write(f'{x.id}_{counter},{i}\n')

train_ds = load_dataset('csv', data_files = 'all_fastas.csv')

def preprocess_function(text):
    return tokenizer(text['Seq'])

tokenized_ds = train_ds.map(
    preprocess_function,
    batched=True,
    num_proc=4,
)

train_testvalid = tokenized_ds['train'].train_test_split(test_size=0.2)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

tokenizer.pad_token = "N" # crashes without a padding token (saying I need a padding token)
# but Evo runs ord() on the padding token and ord() can only use single characters. Normally people use [PAD], let's just use N

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    # runs out of memory with the default batch_size
    per_device_train_batch_size = 4
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_testvalid["train"],
    eval_dataset=train_testvalid["test"],
    data_collator=data_collator,
)

trainer.train()

print(trainer.evaluate())

trainer.save_model('./finetuned.model')

Load the model via from_pretrained('./finetuned.model'). It's still finetuning and a bit slow; will take about 50 hours for 304,354 pieces of DNA of length 500 bp that come from 64,573 mitogenomes and 12S/16S pieces (total size: 140Mb)

@JinyuanSun
Copy link
Author

@philippbayer Nice job! I also found there are no special token in this tokenizer, but the paper says they used EOS tokens to split individual DNA sequences.

@JinyuanSun
Copy link
Author

@philippbayer I tried to use bf16, seems to work. Another way might be freezing some layers. I am using Human cds sequences to finetune the last layer. Not sure will the Hyena works like Transformers that finetune some layers can work. But this allow me to finetune evo on one RTX 4090 and largely increased the sequence length to 3000 nt.

import transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from transformers import AutoConfig, TrainingArguments, Trainer
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

model_name = 'togethercomputer/evo-1-8k-base'

model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model_config.use_cache = False

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=model_config,
    trust_remote_code=True,
    device_map={"":0},
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = "X"

# frezze most parameters
for p in model.parameters():
    p.requires_grad = False

for p in model.backbone.blocks[-1].parameters():
    p.requires_grad = True

from datasets import load_dataset

dataset = load_dataset("gonzalobenegas/human-genome-cds")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

def preprocess_function(sample):
    return tokenizer(sample['seq'], padding="longest", truncation=True, max_length=3000)

tokenized_ds = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=12,
)

training_args = TrainingArguments(
    output_dir="./evo_results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=4,
    warmup_steps=10,
    max_steps=100, # only a demo
    logging_steps=10,
    eval_steps=100,
    bf16=True
    # fp16=True, # This didn't work.
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    data_collator=data_collator,
    
)

trainer.train()

@philippbayer
Copy link

philippbayer commented Mar 1, 2024

Thanks!!! The freezing is great, it cut my finetuning town down from ~50 hours to ~14 hours which makes sense considering I'm retraining the entire thing :) It also lets me increase my per_device_train_batch_size and per_device_eval_batch_size to 128 instead of 4 for a few more hours shaved off, now total train time is around 10 hours.

I did ran out of space over night so don't forget to set your save_total_limit = 3

fp16 is very hardware-dependent, I just always turn it on as it promises faster training and less memory usage https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/ might just be that your GPU doesn't support it.

With the freezing of most parameters the 'full' sequences which include entire mitogenomes still make my A100 run out of memory. I haven't run any experiments on what the 'longest' possible length is. I can use 5000 bp pieces without crashing.

@Zymrael
Copy link
Collaborator

Zymrael commented Mar 1, 2024

This is all great. In the future (if there is sufficient interest) we are also going to support finetuning of Evo on the Together API, which hopefully will make it a lot easier to perform full model finetunes at full context (131k and beyond).

There is also planned support of the architecture on other open frameworks for LLM finetuning.

@philippbayer
Copy link

I'd definitely be interested in that!!

@cclough
Copy link

cclough commented Mar 6, 2024

This is all great. In the future (if there is sufficient interest) we are also going to support finetuning of Evo on the Together API, which hopefully will make it a lot easier to perform full model finetunes at full context (131k and beyond).

There is also planned support of the architecture on other open frameworks for LLM finetuning.

interested!

@philippbayer
Copy link

Interestingly, freezing everything but the last layer as above in @JinyuanSun's script doesn't work well with my data:

image

As you can see, the training loss is identical.

Doing it 'my' way works better, but takes about three times the time (so far....)

image

I'm sure there's some optimal middle way where some, not all layers are finetuned.

@leonwyang
Copy link

Has anyone explored FSDP finetuning on multiple GPUs? I got error "ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32" Seems due to the fact FSDP requires the tensor dtype being uniform but stripedhyena is a mixed precision model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants