Text Generation
Transformers
Safetensors
English
ddllama
conversational
custom_code
xuan-luo commited on
Commit
ae99ccd
·
verified ·
1 Parent(s): 6943ab5

Upload sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft.py +149 -0
sft.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+ import datasets
5
+ from datasets import load_dataset
6
+ import torch
7
+ import transformers
8
+ from trl import SFTTrainer
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
10
+ from typing import Dict, List
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ """
15
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --gradient_clipping=1.0 --multi_gpu --num_processes=8 --num_machines=1 --mixed_precision=bf16 --zero_stage=3 sft.py
16
+ """
17
+ ###################
18
+ # Hyper-parameters
19
+ ###################
20
+
21
+ training_config = {
22
+ "bf16": True,
23
+ "do_eval": False,
24
+ "learning_rate": 1e-04,
25
+ "log_level": "info",
26
+ "logging_steps": 20,
27
+ "logging_strategy": "steps",
28
+ "lr_scheduler_type": "cosine",
29
+ "num_train_epochs": 1,
30
+ "max_steps": -1,
31
+ "output_dir": "./ckpts",
32
+ "overwrite_output_dir": True,
33
+ "per_device_eval_batch_size": 8,
34
+ "per_device_train_batch_size": 8,
35
+ "remove_unused_columns": True,
36
+ "save_steps": 1000,
37
+ "save_total_limit": 1,
38
+ "seed": 0,
39
+ "gradient_checkpointing": True,
40
+ "gradient_checkpointing_kwargs":{"use_reentrant": False},
41
+ "gradient_accumulation_steps": 1,
42
+ "warmup_ratio": 0.03,
43
+ }
44
+ train_conf = TrainingArguments(**training_config)
45
+
46
+
47
+ ###############
48
+ # Setup logging
49
+ ###############
50
+ logging.basicConfig(
51
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
52
+ datefmt="%Y-%m-%d %H:%M:%S",
53
+ handlers=[logging.StreamHandler(sys.stdout)],
54
+ )
55
+ log_level = train_conf.get_process_log_level()
56
+ logger.setLevel(log_level)
57
+ datasets.utils.logging.set_verbosity(log_level)
58
+ transformers.utils.logging.set_verbosity(log_level)
59
+ transformers.utils.logging.enable_default_handler()
60
+ transformers.utils.logging.enable_explicit_format()
61
+
62
+ # Log on each process a small summary
63
+ logger.warning(
64
+ f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
65
+ + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
66
+ )
67
+ logger.info(f"Training/evaluation parameters {train_conf}")
68
+
69
+
70
+ ################
71
+ # Model Loading
72
+ ################
73
+
74
+ checkpoint_path = "./"
75
+ model_kwargs = dict(
76
+ use_cache=False,
77
+ trust_remote_code=True,
78
+ attn_implementation="flash_attention_2",
79
+ torch_dtype=torch.bfloat16,
80
+ device_map=None
81
+ )
82
+ model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
83
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
84
+ tokenizer.model_max_length = 2048
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
87
+ tokenizer.padding_side = 'right'
88
+
89
+
90
+ ##################
91
+ # Data Processing
92
+ ##################
93
+ def apply_chat_template(
94
+ example,
95
+ tokenizer,
96
+ ):
97
+ messages = example["messages"]
98
+ example["text"] = tokenizer.apply_chat_template(
99
+ messages, tokenize=False, add_generation_prompt=False)
100
+ return example
101
+
102
+ raw_dataset = load_dataset("allenai/tulu-v2-sft-mixture")
103
+ train_dataset = raw_dataset["train"]
104
+ column_names = list(train_dataset.features)
105
+
106
+ processed_dataset = train_dataset.map(
107
+ apply_chat_template,
108
+ fn_kwargs={"tokenizer": tokenizer},
109
+ num_proc=64,
110
+ remove_columns=column_names,
111
+ desc="Applying chat template to train_sft",
112
+ )
113
+
114
+
115
+ ###########
116
+ # Freeze Transformer
117
+ ###########
118
+ for param in model.parameters():
119
+ param.requires_grad = False
120
+
121
+ for name, param in model.named_parameters():
122
+ if 'router' in name.lower():
123
+ param.requires_grad = True
124
+
125
+ ###########
126
+ # Training
127
+ ###########
128
+ trainer = SFTTrainer(
129
+ model=model,
130
+ args=train_conf,
131
+ peft_config=None,
132
+ train_dataset=processed_dataset,
133
+ eval_dataset=None,
134
+ max_seq_length=2048,
135
+ dataset_text_field="text",
136
+ tokenizer=tokenizer,
137
+ packing=False
138
+ )
139
+
140
+ train_result = trainer.train()
141
+ metrics = train_result.metrics
142
+ trainer.log_metrics("train", metrics)
143
+ trainer.save_metrics("train", metrics)
144
+ trainer.save_state()
145
+
146
+ # ############
147
+ # # Save model
148
+ # ############
149
+ trainer.save_model(train_conf.output_dir)