Dionyssos commited on
Commit
0213c67
·
1 Parent(s): 61d3afa
Files changed (2) hide show
  1. app.py +558 -0
  2. audionar.py +623 -0
app.py CHANGED
@@ -16,6 +16,538 @@ import textwrap
16
  from tts import StyleTTS2
17
  import audresample
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  device = 0 if torch.cuda.is_available() else "cpu"
21
  duration = 2 # limit processing of audio
@@ -582,4 +1114,30 @@ with gr.Blocks(theme='huggingface', css=css_buttons) as demo:
582
  submit_btn.click(recognize, input, outputs)
583
 
584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  demo.launch(debug=True)
 
16
  from tts import StyleTTS2
17
  import audresample
18
 
19
+ # --
20
+ # -*- coding: utf-8 -*-
21
+
22
+ # https://huggingface.co/spaces/dpc/mmstts/tree/main
23
+ # https://huggingface.co/spaces/mms-meta/MMS/blob/main/tts.py
24
+
25
+ import json
26
+ import soundfile
27
+ import re
28
+ import unicodedata
29
+ import gradio as gr
30
+ import textwrap
31
+ import numpy as np
32
+ import torch
33
+ import nltk
34
+ from num2words import num2words
35
+ from num2word_greek.numbers2words import convert_numbers
36
+ from vits import VitsModel, VitsTokenizer
37
+
38
+ nltk.download('punkt', download_dir='./')
39
+ nltk.download('punkt_tab', download_dir='./')
40
+ nltk.data.path.append('.')
41
+
42
+ device = 'cpu'
43
+
44
+
45
+ def fix_vocals(text, lang='ron'):
46
+
47
+ # Longer phrases should come before shorter ones to prevent partial matches.
48
+
49
+ ron_replacements = {
50
+ 'ţ': 'ț',
51
+ 'ț': 'ts',
52
+ 'î': 'u',
53
+ 'â': 'a',
54
+ 'ş': 's',
55
+ 'w': 'oui',
56
+ 'k': 'c',
57
+ 'l': 'll',
58
+ # Math symbols
59
+ 'sqrt': ' rădăcina pătrată din ',
60
+ '^': ' la puterea ',
61
+ '+': ' plus ',
62
+ ' - ': ' minus ', # only replace if standalone so to not say minus if is a-b-c
63
+ '*': ' ori ', # times
64
+ '/': ' împărțit la ', # divided by
65
+ '=': ' egal cu ', # equals
66
+ 'pi': ' pi ',
67
+ '<': ' mai mic decât ',
68
+ '>': ' mai mare decât',
69
+ '%': ' la sută ', # percent (from previous)
70
+ '(': ' paranteză deschisă ',
71
+ ')': ' paranteză închisă ',
72
+ '[': ' paranteză pătrată deschisă ',
73
+ ']': ' paranteză pătrată închisă ',
74
+ '{': ' acoladă deschisă ',
75
+ '}': ' acoladă închisă ',
76
+ '≠': ' nu este egal cu ',
77
+ '≤': ' mai mic sau egal cu ',
78
+ '≥': ' mai mare sau egal cu ',
79
+ '≈': ' aproximativ ',
80
+ '∞': ' infinit ',
81
+ '€': ' euro ',
82
+ '$': ' dolar ',
83
+ '£': ' liră ',
84
+ '&': ' și ', # and
85
+ '@': ' la ', # at
86
+ '#': ' diez ', # hash
87
+ '∑': ' sumă ',
88
+ '∫': ' integrală ',
89
+ '√': ' rădăcina pătrată a ', # more generic square root
90
+ }
91
+
92
+ eng_replacements = {
93
+ 'wik': 'weaky',
94
+ 'sh': 'ss',
95
+ 'ch': 'ttss',
96
+ 'oo': 'oeo',
97
+ # Math symbols for English
98
+ 'sqrt': ' square root of ',
99
+ '^': ' to the power of ',
100
+ '+': ' plus ',
101
+ ' - ': ' minus ',
102
+ '*': ' times ',
103
+ ' / ': ' divided by ',
104
+ '=': ' equals ',
105
+ 'pi': ' pi ',
106
+ '<': ' less than ',
107
+ '>': ' greater than ',
108
+ # Additional common math symbols from previous list
109
+ '%': ' percent ',
110
+ '(': ' open parenthesis ',
111
+ ')': ' close parenthesis ',
112
+ '[': ' open bracket ',
113
+ ']': ' close bracket ',
114
+ '{': ' open curly brace ',
115
+ '}': ' close curly brace ',
116
+ '∑': ' sum ',
117
+ '∫': ' integral ',
118
+ '√': ' square root of ',
119
+ '≠': ' not equals ',
120
+ '≤': ' less than or equals ',
121
+ '≥': ' greater than or equals ',
122
+ '≈': ' approximately ',
123
+ '∞': ' infinity ',
124
+ '€': ' euro ',
125
+ '$': ' dollar ',
126
+ '£': ' pound ',
127
+ '&': ' and ',
128
+ '@': ' at ',
129
+ '#': ' hash ',
130
+ }
131
+
132
+ serbian_replacements = {
133
+ 'rn': 'rrn',
134
+ 'ć': 'č',
135
+ 'c': 'č',
136
+ 'đ': 'd',
137
+ 'j': 'i',
138
+ 'l': 'lll',
139
+ 'w': 'v',
140
+ # https://huggingface.co/facebook/mms-tts-rmc-script_latin
141
+ 'sqrt': 'kvadratni koren iz',
142
+ '^': ' na stepen ',
143
+ '+': ' plus ',
144
+ ' - ': ' minus ',
145
+ '*': ' puta ',
146
+ ' / ': ' podeljeno sa ',
147
+ '=': ' jednako ',
148
+ 'pi': ' pi ',
149
+ '<': ' manje od ',
150
+ '>': ' veće od ',
151
+ '%': ' procenat ',
152
+ '(': ' otvorena zagrada ',
153
+ ')': ' zatvorena zagrada ',
154
+ '[': ' otvorena uglasta zagrada ',
155
+ ']': ' zatvorena uglasta zagrada ',
156
+ '{': ' otvorena vitičasta zagrada ',
157
+ '}': ' zatvorena vitičasta zagrada ',
158
+ '∑': ' suma ',
159
+ '∫': ' integral ',
160
+ '√': ' kvadratni koren ',
161
+ '≠': ' nije jednako ',
162
+ '≤': ' manje ili jednako od ',
163
+ '≥': ' veće ili jednako od ',
164
+ '≈': ' približno ',
165
+ '∞': ' beskonačnost ',
166
+ '€': ' evro ',
167
+ '$': ' dolar ',
168
+ '£': ' funta ',
169
+ '&': ' i ',
170
+ '@': ' et ',
171
+ '#': ' taraba ',
172
+ # Others
173
+ # 'rn': 'rrn',
174
+ # 'ć': 'č',
175
+ # 'c': 'č',
176
+ # 'đ': 'd',
177
+ # 'l': 'le',
178
+ # 'ij': 'i',
179
+ # 'ji': 'i',
180
+ # 'j': 'i',
181
+ # 'služ': 'sloooozz', # 'službeno'
182
+ # 'suver': 'siuveeerra', # 'suverena'
183
+ # 'država': 'dirrezav', # 'država'
184
+ # 'iči': 'ici', # 'Graniči'
185
+ # 's ': 'se', # a s with space
186
+ # 'q': 'ku',
187
+ # 'w': 'aou',
188
+ # 'z': 's',
189
+ # "š": "s",
190
+ # 'th': 'ta',
191
+ # 'v': 'vv',
192
+ # "ć": "č",
193
+ # "đ": "ď",
194
+ # "lj": "ľ",
195
+ # "nj": "ň",
196
+ # "ž": "z",
197
+ # "c": "č"
198
+ }
199
+
200
+ deu_replacements = {
201
+ 'sch': 'sh',
202
+ 'ch': 'kh',
203
+ 'ie': 'ee',
204
+ 'ei': 'ai',
205
+ 'ä': 'ae',
206
+ 'ö': 'oe',
207
+ 'ü': 'ue',
208
+ 'ß': 'ss',
209
+ # Math symbols for German
210
+ 'sqrt': ' Quadratwurzel aus ',
211
+ '^': ' hoch ',
212
+ '+': ' plus ',
213
+ ' - ': ' minus ',
214
+ '*': ' mal ',
215
+ ' / ': ' geteilt durch ',
216
+ '=': ' gleich ',
217
+ 'pi': ' pi ',
218
+ '<': ' kleiner als ',
219
+ '>': ' größer als',
220
+ # Additional common math symbols from previous list
221
+ '%': ' prozent ',
222
+ '(': ' Klammer auf ',
223
+ ')': ' Klammer zu ',
224
+ '[': ' eckige Klammer auf ',
225
+ ']': ' eckige Klammer zu ',
226
+ '{': ' geschweifte Klammer auf ',
227
+ '}': ' geschweifte Klammer zu ',
228
+ '∑': ' Summe ',
229
+ '∫': ' Integral ',
230
+ '√': ' Quadratwurzel ',
231
+ '≠': ' ungleich ',
232
+ '≤': ' kleiner oder gleich ',
233
+ '≥': ' größer oder gleich ',
234
+ '≈': ' ungefähr ',
235
+ '∞': ' unendlich ',
236
+ '€': ' euro ',
237
+ '$': ' dollar ',
238
+ '£': ' pfund ',
239
+ '&': ' und ',
240
+ '@': ' at ', # 'Klammeraffe' is also common but 'at' is simpler
241
+ '#': ' raute ',
242
+ }
243
+
244
+ fra_replacements = {
245
+ # French specific phonetic replacements (add as needed)
246
+ # e.g., 'ç': 's', 'é': 'e', etc.
247
+ 'w': 'v',
248
+ # Math symbols for French
249
+ 'sqrt': ' racine carrée de ',
250
+ '^': ' à la puissance ',
251
+ '+': ' plus ',
252
+ ' - ': ' moins ', # tiré ;
253
+ '*': ' fois ',
254
+ ' / ': ' divisé par ',
255
+ '=': ' égale ',
256
+ 'pi': ' pi ',
257
+ '<': ' inférieur à ',
258
+ '>': ' supérieur à ',
259
+ # Add more common math symbols as needed for French
260
+ '%': ' pour cent ',
261
+ '(': ' parenthèse ouverte ',
262
+ ')': ' parenthèse fermée ',
263
+ '[': ' crochet ouvert ',
264
+ ']': ' crochet fermé ',
265
+ '{': ' accolade ouverte ',
266
+ '}': ' accolade fermée ',
267
+ '∑': ' somme ',
268
+ '∫': ' intégrale ',
269
+ '√': ' racine carrée ',
270
+ '≠': ' n\'égale pas ',
271
+ '≤': ' inférieur ou égal à ',
272
+ '≥': ' supérieur ou égal à ',
273
+ '≈': ' approximativement ',
274
+ '∞': ' infini ',
275
+ '€': ' euro ',
276
+ '$': ' dollar ',
277
+ '£': ' livre ',
278
+ '&': ' et ',
279
+ '@': ' arobase ',
280
+ '#': ' dièse ',
281
+ }
282
+
283
+ hun_replacements = {
284
+ # Hungarian specific phonetic replacements (add as needed)
285
+ # e.g., 'á': 'a', 'é': 'e', etc.
286
+ 'ch': 'ts',
287
+ 'cs': 'tz',
288
+ 'g': 'gk',
289
+ 'w': 'v',
290
+ 'z': 'zz',
291
+ # Math symbols for Hungarian
292
+ 'sqrt': ' négyzetgyök ',
293
+ '^': ' hatvány ',
294
+ '+': ' plusz ',
295
+ ' - ': ' mínusz ',
296
+ '*': ' szorozva ',
297
+ ' / ': ' osztva ',
298
+ '=': ' egyenlő ',
299
+ 'pi': ' pi ',
300
+ '<': ' kisebb mint ',
301
+ '>': ' nagyobb mint ',
302
+ # Add more common math symbols as needed for Hungarian
303
+ '%': ' százalék ',
304
+ '(': ' nyitó zárójel ',
305
+ ')': ' záró zárójel ',
306
+ '[': ' nyitó szögletes zárójel ',
307
+ ']': ' záró szögletes zárójel ',
308
+ '{': ' nyitó kapcsos zárójel ',
309
+ '}': ' záró kapcsos zárójel ',
310
+ '∑': ' szumma ',
311
+ '∫': ' integrál ',
312
+ '√': ' négyzetgyök ',
313
+ '≠': ' nem egyenlő ',
314
+ '≤': ' kisebb vagy egyenlő ',
315
+ '≥': ' nagyobb vagy egyenlő ',
316
+ '≈': ' körülbelül ',
317
+ '∞': ' végtelen ',
318
+ '€': ' euró ',
319
+ '$': ' dollár ',
320
+ '£': ' font ',
321
+ '&': ' és ',
322
+ '@': ' kukac ',
323
+ '#': ' kettőskereszt ',
324
+ }
325
+
326
+ grc_replacements = {
327
+ # Ancient Greek specific phonetic replacements (add as needed)
328
+ # These are more about transliterating Greek letters if they are in the input text.
329
+ # Math symbols for Ancient Greek (literal translations)
330
+ 'sqrt': ' τετραγωνικὴ ῥίζα ',
331
+ '^': ' εἰς τὴν δύναμιν ',
332
+ '+': ' σὺν ',
333
+ ' - ': ' χωρὶς ',
334
+ '*': ' πολλάκις ',
335
+ ' / ': ' διαιρέω ',
336
+ '=': ' ἴσον ',
337
+ 'pi': ' πῖ ',
338
+ '<': ' ἔλαττον ',
339
+ '>': ' μεῖζον ',
340
+ # Add more common math symbols as needed for Ancient Greek
341
+ '%': ' τοῖς ἑκατόν ', # tois hekaton - 'of the hundred'
342
+ '(': ' ἀνοικτὴ παρένθεσις ',
343
+ ')': ' κλειστὴ παρένθεσις ',
344
+ '[': ' ἀνοικτὴ ἀγκύλη ',
345
+ ']': ' κλειστὴ ἀγκύλη ',
346
+ '{': ' ἀνοικτὴ σγουρὴ ἀγκύλ�� ',
347
+ '}': ' κλειστὴ σγουρὴ ἀγκύλη ',
348
+ '∑': ' ἄθροισμα ',
349
+ '∫': ' ὁλοκλήρωμα ',
350
+ '√': ' τετραγωνικὴ ῥίζα ',
351
+ '≠': ' οὐκ ἴσον ',
352
+ '≤': ' ἔλαττον ἢ ἴσον ',
353
+ '≥': ' μεῖζον ἢ ἴσον ',
354
+ '≈': ' περίπου ',
355
+ '∞': ' ἄπειρον ',
356
+ '€': ' εὐρώ ',
357
+ '$': ' δολάριον ',
358
+ '£': ' λίρα ',
359
+ '&': ' καὶ ',
360
+ '@': ' ἀτ ', # at
361
+ '#': ' δίεση ', # hash
362
+ }
363
+
364
+
365
+ # Select the appropriate replacement dictionary based on the language
366
+ replacements_map = {
367
+ 'grc': grc_replacements,
368
+ 'ron': ron_replacements,
369
+ 'eng': eng_replacements,
370
+ 'deu': deu_replacements,
371
+ 'fra': fra_replacements,
372
+ 'hun': hun_replacements,
373
+ 'rmc-script_latin': serbian_replacements,
374
+ }
375
+
376
+ current_replacements = replacements_map.get(lang)
377
+ if current_replacements:
378
+ # Sort replacements by length of the key in descending order.
379
+ # This is crucial for correctly replacing multi-character strings (like 'sqrt', 'sch')
380
+ # before their shorter substrings ('s', 'ch', 'q', 'r', 't').
381
+ sorted_replacements = sorted(current_replacements.items(), key=lambda item: len(item[0]), reverse=True)
382
+ for old, new in sorted_replacements:
383
+ text = text.replace(old, new)
384
+ return text
385
+ else:
386
+ # If the language is not supported, return the original text
387
+ print(f"Warning: Language '{lang}' not supported for text replacement. Returning original text.")
388
+ return text
389
+
390
+
391
+ def _num2words(text='01234', lang=None):
392
+ if lang == 'grc':
393
+ return convert_numbers(text)
394
+ return num2words(text, lang=lang) # HAS TO BE kwarg lang=lang
395
+
396
+
397
+ def transliterate_number(number_string,
398
+ lang=None):
399
+ if lang == 'rmc-script_latin':
400
+ lang = 'sr'
401
+ exponential_pronoun = ' puta deset na stepen od '
402
+ comma = ' tačka '
403
+ elif lang == 'ron':
404
+ lang = 'ro'
405
+ exponential_pronoun = ' tízszer a erejéig '
406
+ comma = ' virgulă '
407
+ elif lang == 'hun':
408
+ lang = 'hu'
409
+ exponential_pronoun = ' tízszer a erejéig '
410
+ comma = ' virgula '
411
+ elif lang == 'deu':
412
+ exponential_pronoun = ' mal zehn hoch '
413
+ comma = ' komma '
414
+ elif lang == 'fra':
415
+ lang = 'fr'
416
+ exponential_pronoun = ' puissance '
417
+ comma = 'virgule'
418
+ elif lang == 'grc':
419
+ exponential_pronoun = ' εις την δυναμην του '
420
+ comma = 'κομμα'
421
+ else:
422
+ lang = lang[:2]
423
+ exponential_pronoun = ' times ten to the power of '
424
+ comma = ' point '
425
+
426
+ def replace_number(match):
427
+ prefix = match.group(1) or ""
428
+ number_part = match.group(2)
429
+ suffix = match.group(5) or ""
430
+
431
+ try:
432
+ if 'e' in number_part.lower():
433
+ base, exponent = number_part.lower().split('e')
434
+ words = _num2words(base, lang=lang) + exponential_pronoun + _num2words(exponent, lang=lang)
435
+ elif '.' in number_part:
436
+ integer_part, decimal_part = number_part.split('.')
437
+ words = _num2words(integer_part, lang=lang) + comma + " ".join(
438
+ [_num2words(digit, lang=lang) for digit in decimal_part])
439
+ else:
440
+ words = _num2words(number_part, lang=lang)
441
+ return prefix + words + suffix
442
+ except ValueError:
443
+ return match.group(0) # Return original if conversion fails
444
+
445
+ pattern = r'([^\d]*)(\d+(\.\d+)?([Ee][+-]?\d+)?)([^\d]*)'
446
+ return re.sub(pattern, replace_number, number_string)
447
+
448
+
449
+ language_names = ['Ancient greek',
450
+ 'English',
451
+ 'Deutsch',
452
+ 'French',
453
+ 'Hungarian',
454
+ 'Romanian',
455
+ 'Serbian (Approx.)']
456
+
457
+
458
+ def audionar_tts(text=None,
459
+ lang='romanian'):
460
+
461
+ # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
462
+
463
+ lang = lang.lower()
464
+
465
+ # https://huggingface.co/spaces/mms-meta/MMS
466
+
467
+ if 'hun' in lang:
468
+
469
+ lang_code = 'hun'
470
+
471
+ elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]):
472
+
473
+ # romani carpathian (has also Vlax) - cooler voice
474
+ lang_code = 'rmc-script_latin'
475
+
476
+ elif 'rom' in lang:
477
+
478
+ lang_code = 'ron'
479
+
480
+ elif 'ger' in lang or 'deu' in lang or 'allem' in lang:
481
+
482
+ lang_code = 'deu'
483
+
484
+ elif 'french' in lang:
485
+
486
+ lang_code = 'fra'
487
+
488
+ elif 'eng' in lang:
489
+
490
+ lang_code = 'eng'
491
+
492
+ elif 'ancient greek' in lang:
493
+
494
+ lang_code = 'grc'
495
+
496
+ else:
497
+
498
+ lang_code = lang.split()[0].strip() # latin & future option
499
+
500
+ # LATIN / GRC / CYRILLIC
501
+
502
+ text = only_greek_or_only_latin(text, lang=lang_code) # assure gr-chars if lang=='grc' / latin if lang!='grc'
503
+
504
+ # NUMERALS (^ in math expression found & substituted here before arriving to fix_vocals)
505
+
506
+ text = transliterate_number(text, lang=lang_code)
507
+
508
+ # PRONOUNC.
509
+
510
+ text = fix_vocals(text, lang=lang_code)
511
+
512
+ # VITS
513
+
514
+ global cached_lang_code, cached_net_g, cached_tokenizer
515
+
516
+ if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
517
+ cached_lang_code = lang_code
518
+ cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
519
+ cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
520
+
521
+ net_g = cached_net_g
522
+ tokenizer = cached_tokenizer
523
+
524
+ total_audio = []
525
+
526
+ if not isinstance(text, list):
527
+ text = textwrap.wrap(text, width=439)
528
+
529
+ for _t in text:
530
+ inputs = tokenizer(_t, return_tensors="pt")
531
+ with torch.no_grad():
532
+ x = net_g(input_ids=inputs.input_ids.to(device),
533
+ attention_mask=inputs.attention_mask.to(device),
534
+ lang_code=lang_code,
535
+ )[0, :]
536
+ total_audio.append(x)
537
+
538
+ print(f'\n\n_______________________________ {_t} {x.shape=}')
539
+
540
+ x = torch.cat(total_audio).cpu().numpy()
541
+
542
+ tmp_file = f'_speech.wav'
543
+
544
+ soundfile.write(tmp_file, x, 16000)
545
+
546
+ return tmp_file
547
+
548
+
549
+ # --
550
+
551
 
552
  device = 0 if torch.cuda.is_available() else "cpu"
553
  duration = 2 # limit processing of audio
 
1114
  submit_btn.click(recognize, input, outputs)
1115
 
1116
 
1117
+ with gr.Tab("audionar TTS"):
1118
+ with gr.Row():
1119
+ text_input = gr.Textbox(
1120
+ lines=4,
1121
+ value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.',
1122
+ label="Type text for TTS"
1123
+ )
1124
+ lang_dropdown = gr.Dropdown(
1125
+ choices=language_names,
1126
+ label="TTS language",
1127
+ value="Ancient greek",
1128
+ )
1129
+
1130
+ # Create a button to trigger the TTS function
1131
+ tts_button = gr.Button("Generate Audio")
1132
+
1133
+ # Create the output audio component
1134
+ audio_output = gr.Audio(label="Generated Audio")
1135
+
1136
+ # Link the button click event to the mms_tts function
1137
+ tts_button.click(
1138
+ fn=audionar_tts,
1139
+ inputs=[text_input, lang_dropdown],
1140
+ outputs=audio_output
1141
+ )
1142
+
1143
  demo.launch(debug=True)
audionar.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ import json
8
+ import os
9
+ import re
10
+ from transformers.tokenization_utils import PreTrainedTokenizer
11
+ import phonemizer
12
+ import torch.nn.functional as F
13
+
14
+
15
+
16
+ OSCILLATION = {
17
+ 'deu': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
18
+ 'rmc-script_latin': [2, 2, 1, 2, 2],
19
+ 'hun': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
20
+ 'fra': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
21
+ 'eng': [1, 2, 2, 1, 2, 2],
22
+ 'grc': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1],
23
+ 'ron': [1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2],
24
+ }
25
+
26
+
27
+ def has_non_roman_characters(input_string):
28
+ # Find any character outside the ASCII range
29
+ non_roman_pattern = re.compile(r"[^\x00-\x7F]")
30
+
31
+ # Search the input string for non-Roman characters
32
+ match = non_roman_pattern.search(input_string)
33
+ has_non_roman = match is not None
34
+ return has_non_roman
35
+
36
+
37
+ class VitsConfig(PretrainedConfig):
38
+
39
+ model_type = "vits"
40
+
41
+ def __init__(
42
+ self,
43
+ vocab_size=38,
44
+ hidden_size=192,
45
+ num_hidden_layers=6,
46
+ num_attention_heads=2,
47
+ use_bias=True,
48
+ ffn_dim=768,
49
+ ffn_kernel_size=3,
50
+ flow_size=192,
51
+ # hidden_act="relu",
52
+ upsample_initial_channel=512,
53
+ upsample_rates=[8, 8, 2, 2],
54
+ upsample_kernel_sizes=[16, 16, 4, 4],
55
+ resblock_kernel_sizes=[3, 7, 11],
56
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
57
+ prior_encoder_num_flows=4,
58
+ prior_encoder_num_wavenet_layers=4,
59
+ wavenet_kernel_size=5,
60
+ **kwargs,
61
+ ):
62
+ self.vocab_size = vocab_size
63
+ self.hidden_size = hidden_size
64
+ self.num_hidden_layers = num_hidden_layers
65
+ self.num_attention_heads = num_attention_heads
66
+ self.use_bias = use_bias
67
+ self.ffn_dim = ffn_dim
68
+ self.ffn_kernel_size = ffn_kernel_size
69
+ self.flow_size = flow_size
70
+ self.upsample_initial_channel = upsample_initial_channel
71
+ self.upsample_rates = upsample_rates
72
+ self.upsample_kernel_sizes = upsample_kernel_sizes
73
+ self.resblock_kernel_sizes = resblock_kernel_sizes
74
+ self.resblock_dilation_sizes = resblock_dilation_sizes
75
+ self.prior_encoder_num_flows = prior_encoder_num_flows
76
+ self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
77
+ self.wavenet_kernel_size = wavenet_kernel_size
78
+ super().__init__()
79
+
80
+
81
+ class VitsWaveNet(torch.nn.Module):
82
+ def __init__(self, config, num_layers):
83
+ super().__init__()
84
+ self.hidden_size = config.hidden_size
85
+ self.num_layers = num_layers
86
+ self.in_layers = torch.nn.ModuleList()
87
+ self.res_skip_layers = torch.nn.ModuleList()
88
+ # if hasattr(nn.utils.parametrizations, "weight_norm"):
89
+ # # raise ValueError
90
+ weight_norm = nn.utils.parametrizations.weight_norm
91
+ # else:
92
+ # raise ValueError
93
+ # # weight_norm = nn.utils.weight_norm
94
+ for i in range(num_layers):
95
+
96
+ in_layer = torch.nn.Conv1d(
97
+ in_channels=config.hidden_size,
98
+ out_channels=2 * config.hidden_size,
99
+ kernel_size=config.wavenet_kernel_size,
100
+ dilation=1,
101
+ padding=2,
102
+ )
103
+ in_layer = weight_norm(in_layer, name="weight")
104
+ self.in_layers.append(in_layer)
105
+
106
+ # last one is not necessary
107
+ if i < num_layers - 1:
108
+ res_skip_channels = 2 * config.hidden_size
109
+ else:
110
+ res_skip_channels = config.hidden_size
111
+ res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
112
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
113
+ self.res_skip_layers.append(res_skip_layer)
114
+
115
+ def forward(self,
116
+ inputs):
117
+ outputs = torch.zeros_like(inputs)
118
+ num_channels = torch.IntTensor([self.hidden_size])[0]
119
+ for i in range(self.num_layers):
120
+ in_act = self.in_layers[i](inputs)
121
+ # global_states = torch.zeros_like(hidden_states) # style ?
122
+ # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
123
+ # --
124
+ # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
125
+ # in_act = input_a # + input_b
126
+ t_act = torch.tanh(in_act[:, :num_channels, :])
127
+ s_act = torch.sigmoid(in_act[:, num_channels:, :])
128
+ acts = t_act * s_act
129
+ res_skip_acts = self.res_skip_layers[i](acts)
130
+ if i < self.num_layers - 1:
131
+ res_acts = res_skip_acts[:, : self.hidden_size, :]
132
+ inputs = inputs + res_acts
133
+ outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
134
+ else:
135
+ outputs = outputs + res_skip_acts
136
+ return outputs
137
+
138
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
139
+ class HifiGanResidualBlock(nn.Module):
140
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
141
+ super().__init__()
142
+ self.leaky_relu_slope = leaky_relu_slope
143
+
144
+ self.convs1 = nn.ModuleList(
145
+ [
146
+ nn.Conv1d(
147
+ channels,
148
+ channels,
149
+ kernel_size,
150
+ stride=1,
151
+ dilation=dilation[i],
152
+ padding=self.get_padding(kernel_size, dilation[i]),
153
+ )
154
+ for i in range(len(dilation))
155
+ ]
156
+ )
157
+ self.convs2 = nn.ModuleList(
158
+ [
159
+ nn.Conv1d(
160
+ channels,
161
+ channels,
162
+ kernel_size,
163
+ stride=1,
164
+ dilation=1,
165
+ padding=self.get_padding(kernel_size, 1),
166
+ )
167
+ for _ in range(len(dilation))
168
+ ]
169
+ )
170
+
171
+ def get_padding(self, kernel_size, dilation=1):
172
+ # 1, 3, 5, 15
173
+ return (kernel_size * dilation - dilation) // 2
174
+
175
+ def forward(self, hidden_states):
176
+ for conv1, conv2 in zip(self.convs1, self.convs2):
177
+ residual = hidden_states
178
+ hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope)
179
+ hidden_states = conv1(hidden_states)
180
+ hidden_states = nn.functional.leaky_relu(hidden_states, negative_slope=self.leaky_relu_slope)
181
+ hidden_states = conv2(hidden_states)
182
+ hidden_states = hidden_states + residual
183
+ return hidden_states
184
+
185
+
186
+ class VitsHifiGan(nn.Module):
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.config = config
190
+ self.num_kernels = len(config.resblock_kernel_sizes)
191
+ self.num_upsamples = len(config.upsample_rates)
192
+ self.conv_pre = nn.Conv1d(
193
+ config.flow_size,
194
+ config.upsample_initial_channel,
195
+ kernel_size=7,
196
+ stride=1,
197
+ padding=3,
198
+ )
199
+
200
+ self.upsampler = nn.ModuleList()
201
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
202
+ self.upsampler.append(
203
+ nn.ConvTranspose1d(
204
+ config.upsample_initial_channel // (2**i),
205
+ config.upsample_initial_channel // (2 ** (i + 1)),
206
+ kernel_size=kernel_size,
207
+ stride=upsample_rate,
208
+ padding=(kernel_size - upsample_rate) // 2,
209
+ )
210
+ )
211
+
212
+ self.resblocks = nn.ModuleList()
213
+ for i in range(len(self.upsampler)):
214
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
215
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
216
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation))
217
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
218
+
219
+ def forward(self,
220
+ spectrogram):
221
+ hidden_states = self.conv_pre(spectrogram)
222
+ for i in range(self.num_upsamples):
223
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=.1, inplace=True)
224
+ hidden_states = self.upsampler[i](hidden_states)
225
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
226
+ for j in range(1, self.num_kernels):
227
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
228
+ hidden_states = res_state / self.num_kernels
229
+ hidden_states = F.leaky_relu(hidden_states, negative_slope=.01, inplace=True)
230
+ hidden_states = self.conv_post(hidden_states)
231
+ waveform = torch.tanh(hidden_states)
232
+ return waveform
233
+
234
+
235
+ class VitsResidualCouplingLayer(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.half_channels = config.flow_size // 2
239
+ self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
240
+ self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
241
+ self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
242
+
243
+ def forward(self,
244
+ x,
245
+ reverse=False):
246
+ first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
247
+ hidden_states = self.conv_pre(first_half)
248
+ hidden_states = self.wavenet(hidden_states)
249
+ mean = self.conv_post(hidden_states)
250
+ second_half = (second_half - mean)
251
+ outputs = torch.cat([first_half, second_half], dim=1)
252
+ return outputs
253
+
254
+
255
+ class VitsResidualCouplingBlock(nn.Module):
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.flows = nn.ModuleList()
259
+ for _ in range(config.prior_encoder_num_flows):
260
+ self.flows.append(VitsResidualCouplingLayer(config))
261
+
262
+ def forward(self, x, reverse=False):
263
+ # x L [1, 192, 481]
264
+ for flow in reversed(self.flows):
265
+ x = torch.flip(x, [1]) # flipud CHANNELs
266
+ x = flow(x, reverse=True)
267
+ return x
268
+
269
+
270
+ class VitsAttention(nn.Module):
271
+ """has no positional info"""
272
+
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.embed_dim = config.hidden_size
276
+ self.num_heads = config.num_attention_heads
277
+
278
+
279
+
280
+ self.head_dim = self.embed_dim // self.num_heads
281
+ self.scaling = self.head_dim**-0.5
282
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
283
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
284
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
285
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
286
+
287
+ def _shape(self, tensor, seq_len, bsz):
288
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states,
293
+ layer_head_mask = None,
294
+ output_attentions = False,
295
+ ):
296
+
297
+
298
+ bsz, tgt_len, _ = hidden_states.size()
299
+
300
+ # Q
301
+
302
+ query_states = self.q_proj(hidden_states) * self.scaling
303
+
304
+ # K/V
305
+ hidden_states = hidden_states[:, :40, :] # drop time-frames from k/v [bs*2, time, 96=ch]
306
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
307
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
308
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
309
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
310
+ key_states = key_states.view(*proj_shape)
311
+ value_states = value_states.view(*proj_shape)
312
+
313
+
314
+
315
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
316
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
317
+ attn_output = torch.bmm(attn_weights,
318
+ value_states)
319
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
320
+ attn_output = attn_output.transpose(1, 2)
321
+
322
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
323
+ # partitioned aross GPUs when using tensor-parallelism.
324
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
325
+
326
+ attn_output = self.out_proj(attn_output)
327
+
328
+ return attn_output
329
+
330
+
331
+ class VitsFeedForward(nn.Module):
332
+ def __init__(self, config):
333
+ super().__init__()
334
+ self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size, padding=1)
335
+ self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size, padding=1)
336
+
337
+ def forward(self, hidden_states):
338
+ hidden_states = hidden_states.permute(0, 2, 1)
339
+ hidden_states = F.relu(self.conv_1(hidden_states)) # inplace changes sound ;
340
+ hidden_states = self.conv_2(hidden_states)
341
+ hidden_states = hidden_states.permute(0, 2, 1)
342
+ return hidden_states
343
+
344
+
345
+ class VitsEncoderLayer(nn.Module):
346
+ def __init__(self, config):
347
+ super().__init__()
348
+ self.attention = VitsAttention(config)
349
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5)
350
+ self.feed_forward = VitsFeedForward(config)
351
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-5)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states,
356
+ output_attentions = False,
357
+ ):
358
+ residual = hidden_states
359
+ hidden_states = self.attention(
360
+ hidden_states=hidden_states,
361
+ # attention_mask=attention_mask,
362
+ output_attentions=output_attentions,
363
+ )
364
+
365
+
366
+ hidden_states = self.layer_norm(residual + hidden_states)
367
+
368
+ residual = hidden_states
369
+ hidden_states = self.feed_forward(hidden_states)
370
+
371
+ hidden_states = self.final_layer_norm(residual + hidden_states)
372
+
373
+ outputs = (hidden_states,)
374
+
375
+ return outputs
376
+
377
+
378
+ class VitsEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
383
+
384
+ def forward(
385
+ self,
386
+ hidden_states):
387
+ for _layer in self.layers:
388
+ layer_outputs = _layer(hidden_states)
389
+ hidden_states = layer_outputs[0]
390
+ return hidden_states
391
+
392
+
393
+
394
+ class VitsTextEncoder(nn.Module):
395
+ """
396
+ Has VitsEncoder
397
+ """
398
+
399
+ def __init__(self, config):
400
+ super().__init__()
401
+ self.config = config
402
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
403
+ self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
404
+ self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
405
+
406
+ def forward(self,
407
+ input_ids
408
+ ):
409
+ hidden_states = self.embed_tokens(input_ids) * 4 #Actually4-or-4.856406460551018-@-845-len-ids-deu
410
+ stats = self.project(self.encoder(hidden_states=hidden_states).transpose(1, 2)).transpose(1, 2)
411
+ return stats[:, :, :self.config.flow_size] # prior_means
412
+
413
+
414
+ class VitsPreTrainedModel(PreTrainedModel):
415
+ config_class = VitsConfig
416
+ base_model_prefix = "vits"
417
+ main_input_name = "input_ids"
418
+ supports_gradient_checkpointing = True
419
+
420
+
421
+
422
+ class VitsModel(VitsPreTrainedModel):
423
+ def __init__(self, config):
424
+ super().__init__(config)
425
+ self.config = config
426
+ self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
427
+ self.flow = VitsResidualCouplingBlock(config)
428
+ self.decoder = VitsHifiGan(config)
429
+
430
+ def forward(
431
+ self,
432
+ input_ids = None,
433
+ attention_mask = None,
434
+ speaker_id = None,
435
+ output_attentions = None,
436
+ output_hidden_states = None,
437
+ return_dict = None,
438
+ labels = None,
439
+ speed = None,
440
+ lang_code = 'deu', # speed oscillation pattern per voice/lang
441
+ ):
442
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
443
+ if attention_mask is not None:
444
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
445
+ else:
446
+ raise ValueError
447
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
448
+ prior_means = self.text_encoder(input_ids=input_ids)
449
+
450
+ input_padding_mask = input_padding_mask.transpose(1, 2)
451
+
452
+
453
+ bs, in_len, _ = prior_means.shape
454
+ # VITS Duration Oscillation
455
+ pattern = OSCILLATION.get(lang_code, [1, 2, 1])
456
+
457
+ duration = torch.tensor(pattern,
458
+ device=prior_means.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language
459
+ duration[:, :, 0] = 4
460
+ duration[:, :, -1] = 3
461
+ # ATTN
462
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
463
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
464
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
465
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
466
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
467
+ batch_size, _, output_length, input_length = attn_mask.shape
468
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
469
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
470
+ valid_indices = indices.unsqueeze(0) < cum_duration
471
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
472
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
473
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
474
+ attn = attn[:, 0, :, :]
475
+
476
+
477
+ attn = attn + 1e-4 * torch.rand_like(attn)
478
+ attn /= attn.sum(2, keepdims=True)
479
+ #print(attn)
480
+ prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means
481
+
482
+ #prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed
483
+
484
+
485
+
486
+ # prior means have now been replicated x duration of each prior mean
487
+
488
+ latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
489
+ reverse=True)
490
+
491
+ waveform = self.decoder(latents) # [bs, 1, 16000]
492
+
493
+ return waveform[:, 0, :]
494
+
495
+
496
+ class VitsTokenizer(PreTrainedTokenizer):
497
+ vocab_files_names = {"vocab_file": "vocab.json"}
498
+ model_input_names = ["input_ids", "attention_mask"]
499
+
500
+ def __init__(
501
+ self,
502
+ vocab_file,
503
+ pad_token="<pad>",
504
+ unk_token="<unk>",
505
+ language=None,
506
+ add_blank=True,
507
+ normalize=True,
508
+ phonemize=True,
509
+ is_uroman=False,
510
+ **kwargs,
511
+ ):
512
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
513
+ self.encoder = json.load(vocab_handle)
514
+
515
+ self.decoder = {v: k for k, v in self.encoder.items()}
516
+ self.language = language
517
+ self.add_blank = add_blank
518
+ self.normalize = normalize
519
+ self.phonemize = phonemize
520
+
521
+ self.is_uroman = is_uroman
522
+
523
+ super().__init__(
524
+ pad_token=pad_token,
525
+ unk_token=unk_token,
526
+ language=language,
527
+ add_blank=add_blank,
528
+ normalize=normalize,
529
+ phonemize=phonemize,
530
+ is_uroman=is_uroman,
531
+ **kwargs,
532
+ )
533
+
534
+ @property
535
+ def vocab_size(self):
536
+ return len(self.encoder)
537
+
538
+ def get_vocab(self):
539
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
540
+ vocab.update(self.added_tokens_encoder)
541
+ return vocab
542
+
543
+ def normalize_text(self, input_string):
544
+ """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
545
+ all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
546
+ filtered_text = ""
547
+
548
+ i = 0
549
+ while i < len(input_string):
550
+ found_match = False
551
+ for word in all_vocabulary:
552
+ if input_string[i : i + len(word)] == word:
553
+ filtered_text += word
554
+ i += len(word)
555
+ found_match = True
556
+ break
557
+
558
+ if not found_match:
559
+ filtered_text += input_string[i].lower()
560
+ i += 1
561
+
562
+ return filtered_text
563
+
564
+ def _preprocess_char(self, text):
565
+ """Special treatment of characters in certain languages"""
566
+ if self.language == "ron":
567
+ text = text.replace("ț", "ţ")
568
+ return text
569
+
570
+ def prepare_for_tokenization(
571
+ self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):
572
+
573
+ normalize = normalize if normalize is not None else self.normalize
574
+
575
+ if normalize:
576
+ # normalise for casing
577
+ text = self.normalize_text(text)
578
+
579
+ filtered_text = self._preprocess_char(text)
580
+
581
+ if has_non_roman_characters(filtered_text) and self.is_uroman:
582
+ # 7 langs - For now replace all to romans in app.py
583
+ raise ValueError
584
+
585
+ if self.phonemize:
586
+ if not is_phonemizer_available():
587
+ raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
588
+
589
+ filtered_text = phonemizer.phonemize(
590
+ filtered_text,
591
+ language="en-us",
592
+ backend="espeak",
593
+ strip=True,
594
+ preserve_punctuation=True,
595
+ with_stress=True,
596
+ )
597
+ filtered_text = re.sub(r"\s+", " ", filtered_text)
598
+ elif normalize:
599
+ # strip any chars outside of the vocab (punctuation)
600
+ filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
601
+
602
+ return filtered_text, kwargs
603
+
604
+ def _tokenize(self, text):
605
+ """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
606
+ tokens = list(text)
607
+
608
+ if self.add_blank:
609
+ # sounds dyslexi if no space between letters
610
+ # sounds disconnected if >2 spaces between letters
611
+ interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd
612
+ interspersed[::2] = tokens
613
+ tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd)
614
+
615
+ return tokens
616
+
617
+ def _convert_token_to_id(self, token):
618
+ """Converts a token (str) in an id using the vocab."""
619
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
620
+
621
+ def _convert_id_to_token(self, index):
622
+ """Converts an index (integer) in a token (str) using the vocab."""
623
+ return self.decoder.get(index)