File size: 9,593 Bytes
1f9f72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# load the requirements
import torch
import os
from transformers import (
    WhisperFeatureExtractor, 
    WhisperTokenizer, WhisperProcessor, 
    Seq2SeqTrainingArguments, 
    WhisperForConditionalGeneration, 
    TrainerCallback, 
    Seq2SeqTrainer,
)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from torch.utils.data import IterableDataset
import evaluate
from datasets import load_dataset, Audio
from dataclasses import dataclass
import pandas as pd
import subprocess
import datetime
import csv

# define the model id
model_id = "openai/insert_model_id"

# specify the output file path of the wrong predictions
output_file_path = "path/to/your/output/wrong_predictions.csv"

# specify the output file path of the computational resources data
output_file_path_gpu = "path/to/your/output/efficiency_data.csv"

# load and define the feature extractor and the tokenizer 
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

tokenizer = WhisperTokenizer.from_pretrained(model_id, language = "English", task = "transcribe")

# load audio dataset
audio_dataset_train = load_dataset("audiofolder", data_dir = "/path/to/dataset/train")
audio_dataset_test = load_dataset("audiofolder", data_dir = "/path/to/dataset/test")

# load the processor
processor = WhisperProcessor.from_pretrained(model_id, language = "English", task = "transcribe")

# preprocess the data
audio_dataset_train = audio_dataset_train.cast_column("audio", Audio(sampling_rate=16000))
audio_dataset_test = audio_dataset_test.cast_column("audio", Audio(sampling_rate=16000))

do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()

def prepare_dataset(batch):

    audio = batch["audio"] 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    transcription = batch["transcription"]
    if do_lower_case:
        transcription = transcription.lower()
    if do_remove_punctuation:
        transcription = normalizer(transcription).strip()
    batch["labels"] = processor.tokenizer(transcription).input_ids
    return batch

# apply 'prepare dataset' function to each sample in the dataset
vectorized_audio_dataset_train = audio_dataset_train.map(
    prepare_dataset,
    remove_columns=list(next(iter(audio_dataset_train.values())).features)).with_format("torch")
vectorized_audio_dataset_test = audio_dataset_test.map(
    prepare_dataset,
    remove_columns=list(next(iter(audio_dataset_test.values())).features)).with_format("torch")

# shuffle the audioset, shard selects the whole dataset, seed and contigiuguos=TRUE ensure the reproducibility of the shuffling order 
vectorized_audio_dataset_train["train"] = vectorized_audio_dataset_train["train"].shuffle(
    seed=0,
    load_from_cache_file=False).shard(
    num_shards=1, index=0, contiguous=True)

# training and evaluation

# define a data collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: any

    def __call__(self, features):
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
    
# evaluation matrix WER
metric = evaluate.load("wer")
do_normalize_eval = True

# store filenames, predictions and references
predicted_words_list = []
target_words_list = []
filenames = []

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    if do_normalize_eval:
        pred_str = [normalizer(pred) for pred in pred_str]
        label_str = [normalizer(label) for label in label_str]

        # filtering step to only evaluate the samples that correspond to non-zero references:
        pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    
    # append wrong predictions and references to the respective lists, if it is a wrong prediction
    for pred_word, target_word, filename in zip(pred_str, label_str, audio_dataset_test["train"]["audio"]):
        if pred_word.strip() != "" and pred_word != target_word:
            predicted_words_list.append(pred_word)
            target_words_list.append(target_word)
            filenames.append(os.path.basename(str(filename)))

    print(f"WER: {wer}")
    return {"wer": wer}

# load a pre-trained checkpoint
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(torch.device(0))

# disable the use of forced ids, suppressing tokens and the cache 
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False

# freeze the encoder
for param in model.get_encoder().parameters():
    param.requires_grad = False

# define the training parameters
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    save_total_limit=2,
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,
    eval_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=100,
    max_steps=1000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=25,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

# trainer callback to reinitialise and reshuffle the datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if not isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.shuffle()


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vectorized_audio_dataset_train["train"],
    eval_dataset=vectorized_audio_dataset_test["train"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    callbacks=[ShuffleCallback()],
)

model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

# log start and endtime of the training
start_time = datetime.datetime.now()

# launch training
trainer.train()

end_time = datetime.datetime.now()

# determine the maximum length among the lists
max_length = max(len(filenames), len(predicted_words_list), len(target_words_list))

# fill in missing values with empty strings to ensure equal lengths
filenames += [""] * (max_length - len(filenames))
predicted_words_list += [""] * (max_length - len(predicted_words_list))
target_words_list += [""] * (max_length - len(target_words_list))

# save the wrong predictions
df_wrong_predictions = pd.DataFrame({
    "File Name": filenames,
    "Predictions": predicted_words_list,
    "References": target_words_list
})

pred_words_split = [pred.split() for pred in predicted_words_list]
target_words_split = [target.split() for target in target_words_list]
filtered_pred_words = [" ".join([word for word in pred if word != target_word]) for pred, target_word in zip(pred_words_split, target_words_split)]
filtered_target_words = [" ".join([word for word in target if word != pred_word]) for target, pred_word in zip(target_words_split, pred_words_split)]

# update the DataFrame with the filtered files
df_wrong_predictions["Predictions"] = filtered_pred_words
df_wrong_predictions["References"] = filtered_target_words
df_wrong_predictions = df_wrong_predictions[df_wrong_predictions["Predictions"] != df_wrong_predictions["References"]]

# save the DataFrame as a CSV file
df_wrong_predictions.to_csv(output_file_path, index=False)

# get training speed
duration = end_time - start_time
duration_hours = duration.total_seconds() / 3600  # Convert duration to hours

# get the GPU infos
def get_gpu_info():
    try:
        output = subprocess.check_output(["nvidia-smi", "--query-gpu=index,name,memory.used", "--format=csv,noheader,nounits"])
        gpu_info = [line.strip().split(", ") for line in output.decode("utf-8").split("\n") if line.strip()]
        return gpu_info
    except Exception as e:
        return []
    
gpu_info = get_gpu_info()
if gpu_info:
    gpu_name = gpu_info[0][1]
    gpu_memory_used = int(gpu_info[0][2])

with open(output_file_path_gpu, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Training Duration (hours)", "GPU Name", "GPU Memory Used (MB)"])
    writer.writerow([duration_hours, gpu_name, gpu_memory_used])