liuyang commited on
Commit
6d56dd1
·
1 Parent(s): 99ff812
Files changed (1) hide show
  1. app.py +42 -30
app.py CHANGED
@@ -30,37 +30,44 @@ import tempfile
30
  import spaces
31
  from faster_whisper import WhisperModel
32
  from faster_whisper.vad import VadOptions
33
- from pyannote.audio import Pipeline
34
  import requests
35
  import base64
36
 
37
- # Create global Whisper model
38
- print("Loading Whisper model...")
39
- model = WhisperModel(
40
- "large-v3-turbo",
41
- device="cuda",
42
- compute_type="float16",
43
- )
44
- print("Whisper model loaded successfully")
45
 
46
- # Create global diarization pipeline
47
- diarization_pipe = None
48
- try:
49
- print("Loading diarization model...")
50
- diarization_pipe = Pipeline.from_pretrained(
51
- "pyannote/speaker-diarization-3.1",
52
- use_auth_token=os.getenv("HF_TOKEN"),
53
- torch_dtype=torch.float16,
54
- ).to(torch.device("cuda"))
55
- print("Diarization model loaded successfully")
56
- except Exception as e:
57
- print(f"Could not load diarization model: {e}")
58
- diarization_pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
60
  class WhisperTranscriber:
61
  def __init__(self):
62
- self.model = model # Use global Whisper model
63
- self.diarization_model = diarization_pipe # Use global diarization pipeline
64
 
65
  def convert_audio_format(self, audio_path):
66
  """Convert audio to 16kHz mono WAV format"""
@@ -109,9 +116,11 @@ class WhisperTranscriber:
109
 
110
  return audio_segments
111
 
112
- @spaces.GPU
113
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
114
  """Transcribe multiple audio segments using faster_whisper"""
 
 
115
  print(f"Transcribing {len(audio_segments)} audio segments...")
116
  start_time = time.time()
117
 
@@ -121,7 +130,7 @@ class WhisperTranscriber:
121
  beam_size=5,
122
  vad_filter=True,
123
  vad_parameters=VadOptions(
124
- max_speech_duration_s=self.model.feature_extractor.chunk_length,
125
  min_speech_duration_ms=100,
126
  speech_pad_ms=100,
127
  threshold=0.25,
@@ -140,7 +149,7 @@ class WhisperTranscriber:
140
  print(f"Processing segment {i+1}/{len(audio_segments)}")
141
 
142
  # Transcribe this segment
143
- segments, transcript_info = self.model.transcribe(segment["audio_path"], **options)
144
  segments = list(segments)
145
 
146
  # Get detected language from first segment
@@ -181,9 +190,12 @@ class WhisperTranscriber:
181
 
182
  return results, detected_language
183
 
 
184
  def perform_diarization(self, audio_path, num_speakers=None):
185
  """Perform speaker diarization"""
186
- if self.diarization_model is None:
 
 
187
  print("Diarization model not available, creating single speaker segment")
188
  # Load audio to get duration
189
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -201,7 +213,7 @@ class WhisperTranscriber:
201
  waveform, sample_rate = torchaudio.load(audio_path)
202
 
203
  # Perform diarization
204
- diarization = self.diarization_model(
205
  {"waveform": waveform, "sample_rate": sample_rate},
206
  num_speakers=num_speakers,
207
  )
@@ -266,7 +278,7 @@ class WhisperTranscriber:
266
 
267
  return grouped_segments
268
 
269
- @spaces.GPU
270
  def process_audio(self, audio_file, num_speakers=None, language=None,
271
  translate=False, prompt=None, group_segments=True):
272
  """Main processing function - diarization first, then transcription"""
 
30
  import spaces
31
  from faster_whisper import WhisperModel
32
  from faster_whisper.vad import VadOptions
 
33
  import requests
34
  import base64
35
 
36
+ # Lazy global holder ----------------------------------------------------------
37
+ _whisper = None
38
+ _diarizer = None
 
 
 
 
 
39
 
40
+ @spaces.GPU # GPU is guaranteed to exist *inside* this function
41
+ def _load_models():
42
+ global _whisper, _diarizer
43
+ if _whisper is None:
44
+ print("Loading Whisper model...")
45
+ _whisper = WhisperModel(
46
+ "large-v3-turbo",
47
+ device="cuda",
48
+ compute_type="float16",
49
+ )
50
+ print("Whisper model loaded successfully")
51
+ if _diarizer is None:
52
+ print("Loading diarization model...")
53
+ try:
54
+ from pyannote.audio import Pipeline
55
+ _diarizer = Pipeline.from_pretrained(
56
+ "pyannote/speaker-diarization-3.1",
57
+ use_auth_token=os.getenv("HF_TOKEN"),
58
+ torch_dtype=torch.float16,
59
+ ).to(torch.device("cuda"))
60
+ print("Diarization model loaded successfully")
61
+ except Exception as e:
62
+ print(f"Could not load diarization model: {e}")
63
+ _diarizer = None
64
+ return _whisper, _diarizer
65
 
66
+ # -----------------------------------------------------------------------------
67
  class WhisperTranscriber:
68
  def __init__(self):
69
+ # do **not** create the models here!
70
+ pass
71
 
72
  def convert_audio_format(self, audio_path):
73
  """Convert audio to 16kHz mono WAV format"""
 
116
 
117
  return audio_segments
118
 
119
+ @spaces.GPU # each call gets a GPU slice
120
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
121
  """Transcribe multiple audio segments using faster_whisper"""
122
+ whisper, diarizer = _load_models() # models live on the GPU
123
+
124
  print(f"Transcribing {len(audio_segments)} audio segments...")
125
  start_time = time.time()
126
 
 
130
  beam_size=5,
131
  vad_filter=True,
132
  vad_parameters=VadOptions(
133
+ max_speech_duration_s=whisper.feature_extractor.chunk_length,
134
  min_speech_duration_ms=100,
135
  speech_pad_ms=100,
136
  threshold=0.25,
 
149
  print(f"Processing segment {i+1}/{len(audio_segments)}")
150
 
151
  # Transcribe this segment
152
+ segments, transcript_info = whisper.transcribe(segment["audio_path"], **options)
153
  segments = list(segments)
154
 
155
  # Get detected language from first segment
 
190
 
191
  return results, detected_language
192
 
193
+ @spaces.GPU # each call gets a GPU slice
194
  def perform_diarization(self, audio_path, num_speakers=None):
195
  """Perform speaker diarization"""
196
+ whisper, diarizer = _load_models() # models live on the GPU
197
+
198
+ if diarizer is None:
199
  print("Diarization model not available, creating single speaker segment")
200
  # Load audio to get duration
201
  waveform, sample_rate = torchaudio.load(audio_path)
 
213
  waveform, sample_rate = torchaudio.load(audio_path)
214
 
215
  # Perform diarization
216
+ diarization = diarizer(
217
  {"waveform": waveform, "sample_rate": sample_rate},
218
  num_speakers=num_speakers,
219
  )
 
278
 
279
  return grouped_segments
280
 
281
+ @spaces.GPU # each call gets a GPU slice
282
  def process_audio(self, audio_file, num_speakers=None, language=None,
283
  translate=False, prompt=None, group_segments=True):
284
  """Main processing function - diarization first, then transcription"""