doof-ferb commited on
Commit
2918f0f
1 Parent(s): 828e0c2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -0
README.md CHANGED
@@ -64,3 +64,38 @@ manually evaluate WER on test set - vietnamese part:
64
  | this LoRA | 14.7% | 14.7% | 9.4% |
65
 
66
  all training + evaluation scripts are on my repo: https://github.com/phineas-pta/fine-tune-whisper-vi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  | this LoRA | 14.7% | 14.7% | 9.4% |
65
 
66
  all training + evaluation scripts are on my repo: https://github.com/phineas-pta/fine-tune-whisper-vi
67
+
68
+ usage example:
69
+ ```python
70
+ # pip install peft accelerate bitsandbytes
71
+ import torch
72
+ import torchaudio
73
+ from peft import PeftModel, PeftConfig
74
+ from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer
75
+
76
+ PEFT_MODEL_ID = "doof-ferb/whisper-large-peft-lora-vi"
77
+ BASE_MODEL_ID = PeftConfig.from_pretrained(PEFT_MODEL_ID).base_model_name_or_path
78
+
79
+ FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(BASE_MODEL_ID)
80
+ TOKENIZER = WhisperTokenizer.from_pretrained(BASE_MODEL_ID)
81
+
82
+ MODEL = PeftModel.from_pretrained(
83
+ WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16).to("cuda:0"),
84
+ PEFT_MODEL_ID
85
+ ).merge_and_unload(progressbar=True)
86
+
87
+ DECODER_ID = torch.tensor(
88
+ TOKENIZER.convert_tokens_to_ids(["<|startoftranscript|>", "<|vi|>", "<|transcribe|>", "<|notimestamps|>"]),
89
+ device=MODEL.device
90
+ ).unsqueeze(dim=0)
91
+
92
+ waveform, sampling_rate = torchaudio.load("audio.mp3")
93
+ if waveform.size(0) > 1: # convert dual to mono channel
94
+ waveform = waveform.mean(dim=0, keepdim=True)
95
+
96
+ inputs = FEATURE_EXTRACTOR(waveform, sampling_rate=sampling_rate, return_tensors="pt").to(MODEL.device)
97
+ with torch.inference_mode(), torch.autocast(device_type="cuda"): # required by PEFT
98
+ predicted_ids = MODEL.generate(input_features=inputs.input_features, decoder_input_ids=DECODER_ID)
99
+
100
+ TOKENIZER.batch_decode(predicted_ids, skip_special_tokens=True)[0]
101
+ ```