Spaces:
Sleeping
Sleeping
Migrated away business logic and replaced with fastpi+celery API endpoint
Browse files- .env +0 -2
- Dockerfile +0 -10
- holosubs.py +0 -119
- main.py +19 -9
- requirements.txt +3 -16
- transcribe.py +0 -78
- youtubeaudio.py +0 -51
.env
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
peft_model_id ="teoha/openai-whisper-medium-LORA-ja"
|
2 |
-
install_location = "/tmp/elite_understanding"
|
|
|
|
|
|
Dockerfile
CHANGED
@@ -1,17 +1,7 @@
|
|
1 |
FROM pytorch/pytorch
|
2 |
|
3 |
WORKDIR /code
|
4 |
-
RUN mkdir /.cache
|
5 |
-
RUN chmod 1777 /.cache
|
6 |
COPY ./requirements.txt /code/requirements.txt
|
7 |
-
RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib/libcudart.so' >> ~/.bashrc
|
8 |
-
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
9 |
-
RUN /opt/conda/bin/pip install peft
|
10 |
-
RUN /opt/conda/bin/pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
11 |
-
# Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
|
12 |
-
RUN --mount=type=secret,id=HUGGINGFACE_TOKEN,mode=0444,required=true \
|
13 |
-
huggingface-cli login --token $(cat /run/secrets/HUGGINGFACE_TOKEN) && \
|
14 |
-
echo "HUGGINGFACE_TOKEN=$( cat /run/secrets/HUGGINGFACE_TOKEN )" >> .env
|
15 |
|
16 |
COPY . .
|
17 |
|
|
|
1 |
FROM pytorch/pytorch
|
2 |
|
3 |
WORKDIR /code
|
|
|
|
|
4 |
COPY ./requirements.txt /code/requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
COPY . .
|
7 |
|
holosubs.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
""""
|
2 |
-
Entry point and main execution block of the video transcription job
|
3 |
-
"""
|
4 |
-
import re
|
5 |
-
|
6 |
-
from dotenv import load_dotenv
|
7 |
-
from youtubeaudio import YoutubeAudio
|
8 |
-
from transcribe import Transcriber
|
9 |
-
import torchaudio
|
10 |
-
from pyannote.audio import Pipeline
|
11 |
-
from webvtt import WebVTT, Caption
|
12 |
-
import torch
|
13 |
-
import logging
|
14 |
-
from huggingface_hub._login import _login
|
15 |
-
import os
|
16 |
-
|
17 |
-
load_dotenv()
|
18 |
-
WHISPER_SAMPLE_RATE=16000
|
19 |
-
TIMESTAMP_PATTERN='[0-9]+:[0-9]+:[0-9]+\.[0-9]+'
|
20 |
-
MAX_CHUNK_DURATION=30000 # ms
|
21 |
-
|
22 |
-
format = "%(asctime)s: %(message)s"
|
23 |
-
logging.basicConfig(format=format, level=logging.DEBUG,
|
24 |
-
datefmt="%H:%M:%S")
|
25 |
-
_login(token=os.getenv('HUGGINGFACE_TOKEN'), add_to_git_credential=False)
|
26 |
-
|
27 |
-
def get_video_vtt(url) -> str:
|
28 |
-
# Download wav file
|
29 |
-
ytaudio=YoutubeAudio(url)
|
30 |
-
ytaudio.download_audio()
|
31 |
-
# Load audio
|
32 |
-
audio, sample_rate = torchaudio.load(ytaudio.filename)
|
33 |
-
audio_dict={"waveform": audio, "sample_rate": sample_rate}
|
34 |
-
# Diarization
|
35 |
-
pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1', use_auth_token=True)
|
36 |
-
dzs = pipeline(audio_dict)
|
37 |
-
groups = group_segments(str(dzs).splitlines())
|
38 |
-
# Preprocess audio segments for translation
|
39 |
-
audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=WHISPER_SAMPLE_RATE)
|
40 |
-
audio_segments, timestamps = get_segments(groups, audio)
|
41 |
-
# Decoding audio segments into subtitles
|
42 |
-
transcriber = Transcriber(task="translate")
|
43 |
-
captions = decode_segments(audio_segments, timestamps, transcriber)
|
44 |
-
vtt = create_vtt(captions)
|
45 |
-
ytaudio.clean()
|
46 |
-
return vtt.content
|
47 |
-
|
48 |
-
def decode_segments(audio_segments, timestamps, transcriber):
|
49 |
-
captions = []
|
50 |
-
for i, segment in enumerate(audio_segments):
|
51 |
-
result = transcriber.decode(segment)
|
52 |
-
captions.append(Caption(timestamps[i][0], timestamps[i][1], result))
|
53 |
-
logging.info(f"Chunk output no.{i+1}: {result}")
|
54 |
-
return captions
|
55 |
-
|
56 |
-
def millisec(timeStr):
|
57 |
-
spl = timeStr.split(":")
|
58 |
-
s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2]) )* 1000)
|
59 |
-
return s
|
60 |
-
|
61 |
-
def group_segments(dzs):
|
62 |
-
groups = []
|
63 |
-
g = []
|
64 |
-
lastend = 0
|
65 |
-
|
66 |
-
for d in dzs:
|
67 |
-
if g and (g[0].split()[-1] != d.split()[-1]): #same speaker
|
68 |
-
groups.append(g)
|
69 |
-
g = []
|
70 |
-
|
71 |
-
g.append(d)
|
72 |
-
|
73 |
-
end = re.findall('[0-9]+:[0-9]+:[0-9]+\.[0-9]+', string=d)[1]
|
74 |
-
end = millisec(end)
|
75 |
-
if (lastend > end): #segment engulfed by a previous segment
|
76 |
-
groups.append(g)
|
77 |
-
g = []
|
78 |
-
else:
|
79 |
-
lastend = end
|
80 |
-
if g:
|
81 |
-
groups.append(g)
|
82 |
-
logging.debug(groups)
|
83 |
-
return groups
|
84 |
-
|
85 |
-
def create_vtt(captions):
|
86 |
-
vtt = WebVTT()
|
87 |
-
for caption in captions:
|
88 |
-
vtt.captions.append(caption)
|
89 |
-
return vtt
|
90 |
-
# vtt.save(path)
|
91 |
-
|
92 |
-
def get_segments(groups, audio):
|
93 |
-
monoaudio=torch.mean(input=audio,dim=0).numpy()
|
94 |
-
audio_segments = []
|
95 |
-
timestamps = []
|
96 |
-
for g in groups:
|
97 |
-
cur_start_time, cur_end_time = re.findall(TIMESTAMP_PATTERN, string=g[0])
|
98 |
-
cur_start_millisec = millisec(cur_start_time) #- spacermilli
|
99 |
-
cur_end_millisec = millisec(cur_end_time) #- spacermilli
|
100 |
-
for window in g[1:]:
|
101 |
-
start_time, end_time = re.findall(TIMESTAMP_PATTERN, string=window)
|
102 |
-
start_millisec = millisec(start_time) #- spacermilli
|
103 |
-
end_millisec = millisec(end_time) #- spacermilli
|
104 |
-
# Check if new window exceeds chunk size
|
105 |
-
seg_duration_with_window=end_millisec-cur_start_millisec
|
106 |
-
if seg_duration_with_window>MAX_CHUNK_DURATION: # Segment with window exceeds max chunk duration
|
107 |
-
start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
|
108 |
-
audio_segments.append(monoaudio[start_frame:end_frame])
|
109 |
-
timestamps.append((cur_start_time, cur_end_time))
|
110 |
-
cur_start_time, cur_end_time = start_time, end_time
|
111 |
-
cur_start_millisec, cur_end_millisec = start_millisec, end_millisec
|
112 |
-
else:
|
113 |
-
cur_end_time=end_time
|
114 |
-
cur_end_millisec=end_millisec
|
115 |
-
# Final update
|
116 |
-
start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
|
117 |
-
audio_segments.append(monoaudio[start_frame:end_frame])
|
118 |
-
timestamps.append((cur_start_time, cur_end_time))
|
119 |
-
return audio_segments, timestamps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -1,14 +1,24 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
-
from
|
3 |
-
from pydantic import BaseModel
|
4 |
|
5 |
-
|
6 |
-
url: str
|
7 |
|
|
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
@
|
12 |
-
def
|
13 |
-
|
14 |
-
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
+
from celery import Celery
|
|
|
3 |
|
4 |
+
app = FastAPI()
|
|
|
5 |
|
6 |
+
BROKER_URL = 'redis://139.59.127.180:6379/0'
|
7 |
+
BACKEND_URL = 'redis://139.59.127.180:6379/0'
|
8 |
|
9 |
+
celery = Celery(
|
10 |
+
__name__,
|
11 |
+
broker=BROKER_URL,
|
12 |
+
backend=BACKEND_URL
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@app.get("/")
|
17 |
+
async def root():
|
18 |
+
return {"message": "Hello World"}
|
19 |
|
20 |
+
@celery.task
|
21 |
+
def divide(x, y):
|
22 |
+
import time
|
23 |
+
time.sleep(5)
|
24 |
+
return x / y
|
requirements.txt
CHANGED
@@ -1,16 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
uvicorn[standard]==0.17.*
|
5 |
-
numpy==1.24.4
|
6 |
-
pyannote.audio==1.1.2
|
7 |
-
pyannote.core==5.0.0
|
8 |
-
pyannote.database==5.0.1
|
9 |
-
pyannote.metrics==3.2.1
|
10 |
-
pyannote.pipeline==1.5.2
|
11 |
-
python-dotenv==1.0.0
|
12 |
-
torch==2.0.1
|
13 |
-
torchaudio==2.0.2
|
14 |
-
transformers==4.31.0
|
15 |
-
webvtt_py==0.4.6
|
16 |
-
yt_dlp==2023.7.6
|
|
|
1 |
+
celery==5.1.2
|
2 |
+
fastapi==0.103.2
|
3 |
+
pydantic==1.10.12
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Represents a model that transcribes and translates audio.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import logging
|
6 |
-
import os
|
7 |
-
from typing import Union
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
from dotenv import load_dotenv
|
12 |
-
from peft import PeftConfig, PeftModel
|
13 |
-
from transformers import (AutomaticSpeechRecognitionPipeline,
|
14 |
-
WhisperForConditionalGeneration, WhisperProcessor,
|
15 |
-
WhisperTokenizer)
|
16 |
-
|
17 |
-
load_dotenv()
|
18 |
-
format = "%(asctime)s: %(message)s"
|
19 |
-
logging.basicConfig(format=format, level=logging.DEBUG,
|
20 |
-
datefmt="%H:%M:%S")
|
21 |
-
|
22 |
-
class Transcriber:
|
23 |
-
def __init__(self, model_id="teoha/openai-whisper-medium-LORA-ja", language="Japanese", task="translate"):
|
24 |
-
self.language=language
|
25 |
-
self.task=task
|
26 |
-
peft_model_id = model_id if model_id else os.getenv('peft_model_id')
|
27 |
-
# TODO: Fix Download and install model locally
|
28 |
-
# self.install_model(peft_model_id)
|
29 |
-
self.initialize_pipe(peft_model_id) #initialize pipe
|
30 |
-
|
31 |
-
def install_model(self, peft_model_id:str) -> None:
|
32 |
-
save_location = os.path.join(os.getenv('install_location'), peft_model_id)
|
33 |
-
offload_location = os.path.join(os.getenv('install_location'), "offload")
|
34 |
-
#Save Model
|
35 |
-
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
36 |
-
model = WhisperForConditionalGeneration.from_pretrained(
|
37 |
-
peft_config.base_model_name_or_path,
|
38 |
-
load_in_8bit=False, device_map="auto"
|
39 |
-
)
|
40 |
-
model = PeftModel.from_pretrained(model, peft_model_id, offload_folder="offload_location")
|
41 |
-
model.save_pretrained(save_location)
|
42 |
-
|
43 |
-
#Save tokenizer/processor
|
44 |
-
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
45 |
-
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
46 |
-
tokenizer.save_pretrained(save_location)
|
47 |
-
processor.save_pretrained(save_location)
|
48 |
-
logging.info("Installation Completed successfully")
|
49 |
-
|
50 |
-
def initialize_pipe(self, peft_model_id: str) -> None:
|
51 |
-
offload_location = os.path.join(os.getenv('install_location'), "offload")
|
52 |
-
# Initalize model configs
|
53 |
-
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
54 |
-
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto")
|
55 |
-
model = PeftModel.from_pretrained(model, peft_model_id, offload_folder=offload_location)
|
56 |
-
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
57 |
-
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
|
58 |
-
feature_extractor = processor.feature_extractor
|
59 |
-
# Initialize class variables
|
60 |
-
self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
|
61 |
-
self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
62 |
-
logging.info("Pipe successfully initialized")
|
63 |
-
|
64 |
-
|
65 |
-
def decode(self, audio: Union[np.ndarray, bytes, str]) -> str:
|
66 |
-
'''
|
67 |
-
Transcribes a sequence of floats representing an audio snippet.
|
68 |
-
Args:
|
69 |
-
inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
|
70 |
-
The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
|
71 |
-
:obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
|
72 |
-
the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
|
73 |
-
using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
|
74 |
-
supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
|
75 |
-
'''
|
76 |
-
with torch.cuda.amp.autocast():
|
77 |
-
text = self.pipe(audio, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids})["text"]
|
78 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
youtubeaudio.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Represents a Youtube video
|
3 |
-
"""
|
4 |
-
|
5 |
-
from dotenv import load_dotenv
|
6 |
-
import logging
|
7 |
-
from yt_dlp import YoutubeDL
|
8 |
-
import os
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
load_dotenv()
|
12 |
-
format = "%(asctime)s: %(message)s"
|
13 |
-
logging.basicConfig(format=format, level=logging.DEBUG,
|
14 |
-
datefmt="%H:%M:%S")
|
15 |
-
|
16 |
-
class YoutubeAudio:
|
17 |
-
def __init__(self, url, dir="/tmp/holosubs/audio"):
|
18 |
-
self.url=url
|
19 |
-
self.dir=dir
|
20 |
-
|
21 |
-
def download_audio(self):
|
22 |
-
ydl_opts = {
|
23 |
-
'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
|
24 |
-
'logger': logging,
|
25 |
-
'progress_hooks': [self.progress_hook],
|
26 |
-
'format': 'm4a/bestaudio/best',
|
27 |
-
'postprocessors': [{ # Extract audio using ffmpeg
|
28 |
-
'key': 'FFmpegExtractAudio',
|
29 |
-
'preferredcodec': 'wav',
|
30 |
-
}]
|
31 |
-
}
|
32 |
-
with YoutubeDL(ydl_opts) as ydl:
|
33 |
-
error_code = ydl.download([self.url])
|
34 |
-
|
35 |
-
def clean(self):
|
36 |
-
if not self.filename:
|
37 |
-
logging.error("Audio not downloaded")
|
38 |
-
return
|
39 |
-
location=os.path.join(self.dir, self.filename)
|
40 |
-
if os.path.exists(self.filename):
|
41 |
-
os.remove(self.filename)
|
42 |
-
logging.info(f"File {self.filename} successfully removed")
|
43 |
-
self.filename=None
|
44 |
-
else:
|
45 |
-
print(f"File {self.filename} does not exist")
|
46 |
-
|
47 |
-
def progress_hook(self, d):
|
48 |
-
if d['status'] == 'finished':
|
49 |
-
self.filename=os.path.join(self.dir, Path(d.get('info_dict').get('_filename')).stem + ".wav")
|
50 |
-
print(f'Done downloading {self.filename}, now post-processing ...')
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|