File size: 8,445 Bytes
83d6ace ae10436 83d6ace |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
#!/usr/bin/env python
# finetune_whisper.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
import torch
from datasets import load_dataset, Audio
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
import ipdb
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
# → Choose device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Configuration
LANGUAGE_WHISPER = "chinese" # Whisper config for another language since it does not support Cebuano
MODEL_CHECKPOINT = "openai/whisper-large-v3"
OUTPUT_DIR = f"./whisper-wenetspeech-S"
TRAIN_SPLIT = "train"
VALID_SPLIT = "validation"
TEST_SPLIT = "test"
# 2. Load FLEURS Dataset (audio at 16 kHz)
raw_datasets_train = load_dataset("pengyizhou/wenetspeech-subset-S", streaming=True)
raw_datasets_valid = load_dataset("wenet-e2e/wenetspeech", "DEV_fixed", split="validation", streaming=True)
raw_datasets_testnet = load_dataset("wenet-e2e/wenetspeech", "TEST_NET", split="test", streaming=True)
raw_datasets_testmeeting = load_dataset("wenet-e2e/wenetspeech", "TEST_MEETING", split="test", streaming=True)
# Cast “audio” column to 16 kHz
# for split in ["train", "validation", "test"]:
# raw_datasets[split] = raw_datasets[split].cast_column("audio", Audio(sampling_rate=16_000))
# 3. Load Whisper Processor & Model
processor = WhisperProcessor.from_pretrained(MODEL_CHECKPOINT, language=LANGUAGE_WHISPER)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
model.to(device)
# 4. Preprocessing Function
# - Extract log‐Mel features from audio
# - Tokenize the target transcription
def preprocess_batch(batch):
# batch["audio"]["array"] is a list of NumPy arrays @ 16 kHz
audio_arrays = [example["array"] for example in batch["audio"]]
# 4a. Feature extraction (log‐Mel + normalization)
inputs = processor.feature_extractor(
audio_arrays,
sampling_rate=16_000,
return_tensors="pt"
)
# 4b. Tokenize (labels) using the Whisper tokenizer
# We prefix with target language ID (e.g. "<|my_mm|>") if necessary;
# but for FLEURS, the default Whisper language‐ID tokens should suffice.
labels = processor.tokenizer(
batch["text"],
return_tensors="pt",
padding="longest",
)
# ipdb.set_trace()
# rename for trainer:
inputs["input_features"] = inputs.pop("input_features")
inputs["labels"] = labels.input_ids
return inputs
# 5. Apply preprocessing to train/validation/test
# - Remove all non‐audio columns after mapping
train_dataset = raw_datasets_train["train"].map(
preprocess_batch,
remove_columns=raw_datasets_train["train"].column_names,
batched=True,
batch_size=16, # adjust batch_size to your memory
)
# ipdb.set_trace()
eval_dataset = raw_datasets_valid.map(
preprocess_batch,
remove_columns=raw_datasets_valid.column_names,
batched=True,
batch_size=8,
)
testnet_dataset = raw_datasets_testnet.map(
preprocess_batch,
remove_columns=raw_datasets_testnet.column_names,
batched=True,
batch_size=8,
)
testmeet_dataset = raw_datasets_testmeeting.map(
preprocess_batch,
remove_columns=raw_datasets_testmeeting.column_names,
batched=True,
batch_size=8,
)
# 6. Data Collator
# This will pad input_features and labels to the maximum length in the batch,
# and replace padding token ID in labels by -100 to ignore them in loss computation.
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# 7. Metrics: WER & CER (using Hugging Face Evaluate)
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
"""
pred.predictions: raw token IDs from generate()
pred.label_ids: token IDs used as labels
"""
# 7a. decode predictions → strings
pred_ids = pred.predictions
# ensure we skip special tokens
pred_str = processor.batch_decode(pred_ids,
skip_special_tokens=True)
# 7b. decode references → strings, replacing -100 with padding_token_id
label_ids = pred.label_ids
# replace -100 with pad_token_id so that the tokenizer does not crash
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
ref_str = processor.batch_decode(label_ids, skip_special_tokens=True)
# lowercase & strip
pred_str = [s.lower().strip() for s in pred_str]
ref_str = [s.lower().strip() for s in ref_str]
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
return { "wer": wer_score, "cer": cer_score }
"""
# 8. Training Arguments
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=4, # reduce if you OOM; or increase if large GPU
per_device_eval_batch_size=4,
gradient_accumulation_steps=2, # to simulate a larger batch
evaluation_strategy="steps",
eval_steps=500, # evaluate every 500 steps
logging_steps=250,
save_steps=1000,
num_train_epochs=3,
learning_rate=1e-5,
warmup_steps=500,
fp16=True, # use mixed precision if supported
predict_with_generate=True, # for computing WER/CER we need generate()
save_total_limit=2,
push_to_hub=False,
)
"""
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=30,
gradient_accumulation_steps=1,
learning_rate=2e-5,
warmup_steps=500,
max_steps=6000,
gradient_checkpointing=True,
fp16=True,
eval_strategy="steps",
per_device_eval_batch_size=30,
predict_with_generate=True,
generation_max_length=200,
save_steps=1500,
eval_steps=500,
logging_steps=10,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="cer",
greater_is_better=False,
push_to_hub=True
)
# 9. Initialize Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=processor.feature_extractor, # feature_extractor + tokenizer packed in processor
compute_metrics=compute_metrics,
)
# 10. Fine-tune
if __name__ == "__main__":
# 10a. Train
trainer.train()
# 10b. Evaluate on TEST split
print("\n***** Evaluating on TEST split *****")
test_metrics = trainer.predict(testnet_dataset, metric_key_prefix="test")
print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%")
print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%")
test_metrics = trainer.predict(testmeet_dataset, metric_key_prefix="test")
print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%")
print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%")
|