Spaces:
Running
Running
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- inference-cli.py +77 -116
- inference-cli.toml +1 -1
inference-cli.py
CHANGED
|
@@ -93,17 +93,6 @@ wave_path = Path(output_dir)/"out.wav"
|
|
| 93 |
spectrogram_path = Path(output_dir)/"out.png"
|
| 94 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
| 95 |
|
| 96 |
-
SPLIT_WORDS = [
|
| 97 |
-
"but", "however", "nevertheless", "yet", "still",
|
| 98 |
-
"therefore", "thus", "hence", "consequently",
|
| 99 |
-
"moreover", "furthermore", "additionally",
|
| 100 |
-
"meanwhile", "alternatively", "otherwise",
|
| 101 |
-
"namely", "specifically", "for example", "such as",
|
| 102 |
-
"in fact", "indeed", "notably",
|
| 103 |
-
"in contrast", "on the other hand", "conversely",
|
| 104 |
-
"in conclusion", "to summarize", "finally"
|
| 105 |
-
]
|
| 106 |
-
|
| 107 |
device = (
|
| 108 |
"cuda"
|
| 109 |
if torch.cuda.is_available()
|
|
@@ -167,103 +156,36 @@ F5TTS_model_cfg = dict(
|
|
| 167 |
)
|
| 168 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
current_word_part = ""
|
| 185 |
-
word_batches = []
|
| 186 |
-
for word in words:
|
| 187 |
-
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
| 188 |
-
current_word_part += word + ' '
|
| 189 |
-
else:
|
| 190 |
-
if current_word_part:
|
| 191 |
-
# Try to find a suitable split word
|
| 192 |
-
for split_word in split_words:
|
| 193 |
-
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
| 194 |
-
if split_index != -1:
|
| 195 |
-
word_batches.append(current_word_part[:split_index].strip())
|
| 196 |
-
current_word_part = current_word_part[split_index:].strip() + ' '
|
| 197 |
-
break
|
| 198 |
-
else:
|
| 199 |
-
# If no suitable split word found, just append the current part
|
| 200 |
-
word_batches.append(current_word_part.strip())
|
| 201 |
-
current_word_part = ""
|
| 202 |
-
current_word_part += word + ' '
|
| 203 |
-
if current_word_part:
|
| 204 |
-
word_batches.append(current_word_part.strip())
|
| 205 |
-
return word_batches
|
| 206 |
|
| 207 |
for sentence in sentences:
|
| 208 |
-
if len(
|
| 209 |
-
|
| 210 |
else:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
if len(colon_parts) > 1:
|
| 221 |
-
for part in colon_parts:
|
| 222 |
-
if len(part.encode('utf-8')) <= max_chars:
|
| 223 |
-
batches.append(part)
|
| 224 |
-
else:
|
| 225 |
-
# If colon part is still too long, split by comma
|
| 226 |
-
comma_parts = re.split('[,,]', part)
|
| 227 |
-
if len(comma_parts) > 1:
|
| 228 |
-
current_comma_part = ""
|
| 229 |
-
for comma_part in comma_parts:
|
| 230 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
| 231 |
-
current_comma_part += comma_part + ','
|
| 232 |
-
else:
|
| 233 |
-
if current_comma_part:
|
| 234 |
-
batches.append(current_comma_part.rstrip(','))
|
| 235 |
-
current_comma_part = comma_part + ','
|
| 236 |
-
if current_comma_part:
|
| 237 |
-
batches.append(current_comma_part.rstrip(','))
|
| 238 |
-
else:
|
| 239 |
-
# If no comma, split by words
|
| 240 |
-
batches.extend(split_by_words(part))
|
| 241 |
-
else:
|
| 242 |
-
# If no colon, split by comma
|
| 243 |
-
comma_parts = re.split('[,,]', sentence)
|
| 244 |
-
if len(comma_parts) > 1:
|
| 245 |
-
current_comma_part = ""
|
| 246 |
-
for comma_part in comma_parts:
|
| 247 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
| 248 |
-
current_comma_part += comma_part + ','
|
| 249 |
-
else:
|
| 250 |
-
if current_comma_part:
|
| 251 |
-
batches.append(current_comma_part.rstrip(','))
|
| 252 |
-
current_comma_part = comma_part + ','
|
| 253 |
-
if current_comma_part:
|
| 254 |
-
batches.append(current_comma_part.rstrip(','))
|
| 255 |
-
else:
|
| 256 |
-
# If no comma, split by words
|
| 257 |
-
batches.extend(split_by_words(sentence))
|
| 258 |
-
else:
|
| 259 |
-
current_batch = sentence
|
| 260 |
-
|
| 261 |
-
if current_batch:
|
| 262 |
-
batches.append(current_batch)
|
| 263 |
-
|
| 264 |
-
return batches
|
| 265 |
|
| 266 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
| 267 |
if model == "F5-TTS":
|
| 268 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
| 269 |
elif model == "E2-TTS":
|
|
@@ -321,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
| 321 |
generated_waves.append(generated_wave)
|
| 322 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
| 323 |
|
| 324 |
-
# Combine all generated waves
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
with open(wave_path, "wb") as f:
|
| 328 |
sf.write(f.name, final_wave, target_sample_rate)
|
|
@@ -343,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
| 343 |
print(spectrogram_path)
|
| 344 |
|
| 345 |
|
| 346 |
-
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence,
|
| 347 |
-
if not custom_split_words.strip():
|
| 348 |
-
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
| 349 |
-
global SPLIT_WORDS
|
| 350 |
-
SPLIT_WORDS = custom_words
|
| 351 |
|
| 352 |
print(gen_text)
|
| 353 |
|
|
@@ -355,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
| 355 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 356 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
| 357 |
|
| 358 |
-
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=
|
| 359 |
non_silent_wave = AudioSegment.silent(duration=0)
|
| 360 |
for non_silent_seg in non_silent_segs:
|
| 361 |
non_silent_wave += non_silent_seg
|
|
@@ -387,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
| 387 |
else:
|
| 388 |
print("Using custom reference text...")
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
# Split the input text into batches
|
| 391 |
audio, sr = torchaudio.load(ref_audio)
|
| 392 |
-
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (
|
| 393 |
-
gen_text_batches =
|
| 394 |
print('ref_text', ref_text)
|
| 395 |
for i, gen_text in enumerate(gen_text_batches):
|
| 396 |
print(f'gen_text {i}', gen_text)
|
| 397 |
|
| 398 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 399 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
| 400 |
|
| 401 |
|
| 402 |
-
infer(ref_audio, ref_text, gen_text, model, remove_silence
|
|
|
|
| 93 |
spectrogram_path = Path(output_dir)/"out.png"
|
| 94 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
device = (
|
| 97 |
"cuda"
|
| 98 |
if torch.cuda.is_available()
|
|
|
|
| 156 |
)
|
| 157 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 158 |
|
| 159 |
+
|
| 160 |
+
def chunk_text(text, max_chars=135):
|
| 161 |
+
"""
|
| 162 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
| 163 |
+
Args:
|
| 164 |
+
text (str): The text to be split.
|
| 165 |
+
max_chars (int): The maximum number of characters per chunk.
|
| 166 |
+
Returns:
|
| 167 |
+
List[str]: A list of text chunks.
|
| 168 |
+
"""
|
| 169 |
+
chunks = []
|
| 170 |
+
current_chunk = ""
|
| 171 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
| 172 |
+
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
for sentence in sentences:
|
| 175 |
+
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
| 176 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
| 177 |
else:
|
| 178 |
+
if current_chunk:
|
| 179 |
+
chunks.append(current_chunk.strip())
|
| 180 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
| 181 |
+
|
| 182 |
+
if current_chunk:
|
| 183 |
+
chunks.append(current_chunk.strip())
|
| 184 |
+
|
| 185 |
+
return chunks
|
| 186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
| 189 |
if model == "F5-TTS":
|
| 190 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
| 191 |
elif model == "E2-TTS":
|
|
|
|
| 243 |
generated_waves.append(generated_wave)
|
| 244 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
| 245 |
|
| 246 |
+
# Combine all generated waves with cross-fading
|
| 247 |
+
if cross_fade_duration <= 0:
|
| 248 |
+
# Simply concatenate
|
| 249 |
+
final_wave = np.concatenate(generated_waves)
|
| 250 |
+
else:
|
| 251 |
+
final_wave = generated_waves[0]
|
| 252 |
+
for i in range(1, len(generated_waves)):
|
| 253 |
+
prev_wave = final_wave
|
| 254 |
+
next_wave = generated_waves[i]
|
| 255 |
+
|
| 256 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
| 257 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
| 258 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
| 259 |
+
|
| 260 |
+
if cross_fade_samples <= 0:
|
| 261 |
+
# No overlap possible, concatenate
|
| 262 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
# Overlapping parts
|
| 266 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
| 267 |
+
next_overlap = next_wave[:cross_fade_samples]
|
| 268 |
+
|
| 269 |
+
# Fade out and fade in
|
| 270 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
| 271 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
| 272 |
+
|
| 273 |
+
# Cross-faded overlap
|
| 274 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
| 275 |
+
|
| 276 |
+
# Combine
|
| 277 |
+
new_wave = np.concatenate([
|
| 278 |
+
prev_wave[:-cross_fade_samples],
|
| 279 |
+
cross_faded_overlap,
|
| 280 |
+
next_wave[cross_fade_samples:]
|
| 281 |
+
])
|
| 282 |
+
|
| 283 |
+
final_wave = new_wave
|
| 284 |
|
| 285 |
with open(wave_path, "wb") as f:
|
| 286 |
sf.write(f.name, final_wave, target_sample_rate)
|
|
|
|
| 301 |
print(spectrogram_path)
|
| 302 |
|
| 303 |
|
| 304 |
+
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
print(gen_text)
|
| 307 |
|
|
|
|
| 309 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 310 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
| 311 |
|
| 312 |
+
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
| 313 |
non_silent_wave = AudioSegment.silent(duration=0)
|
| 314 |
for non_silent_seg in non_silent_segs:
|
| 315 |
non_silent_wave += non_silent_seg
|
|
|
|
| 341 |
else:
|
| 342 |
print("Using custom reference text...")
|
| 343 |
|
| 344 |
+
# Add the functionality to ensure it ends with ". "
|
| 345 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
| 346 |
+
if ref_text.endswith("."):
|
| 347 |
+
ref_text += " "
|
| 348 |
+
else:
|
| 349 |
+
ref_text += ". "
|
| 350 |
+
|
| 351 |
# Split the input text into batches
|
| 352 |
audio, sr = torchaudio.load(ref_audio)
|
| 353 |
+
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
| 354 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
| 355 |
print('ref_text', ref_text)
|
| 356 |
for i, gen_text in enumerate(gen_text_batches):
|
| 357 |
print(f'gen_text {i}', gen_text)
|
| 358 |
|
| 359 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 360 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
| 361 |
|
| 362 |
|
| 363 |
+
infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
inference-cli.toml
CHANGED
|
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
|
|
| 6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
| 7 |
# File with text to generate. Ignores the text above.
|
| 8 |
gen_file = ""
|
| 9 |
-
remove_silence =
|
| 10 |
output_dir = "tests"
|
|
|
|
| 6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
| 7 |
# File with text to generate. Ignores the text above.
|
| 8 |
gen_file = ""
|
| 9 |
+
remove_silence = false
|
| 10 |
output_dir = "tests"
|