teoha commited on
Commit
e953bef
·
1 Parent(s): 71cd364

Migrated away business logic and replaced with fastpi+celery API endpoint

Browse files
Files changed (7) hide show
  1. .env +0 -2
  2. Dockerfile +0 -10
  3. holosubs.py +0 -119
  4. main.py +19 -9
  5. requirements.txt +3 -16
  6. transcribe.py +0 -78
  7. youtubeaudio.py +0 -51
.env DELETED
@@ -1,2 +0,0 @@
1
- peft_model_id ="teoha/openai-whisper-medium-LORA-ja"
2
- install_location = "/tmp/elite_understanding"
 
 
 
Dockerfile CHANGED
@@ -1,17 +1,7 @@
1
  FROM pytorch/pytorch
2
 
3
  WORKDIR /code
4
- RUN mkdir /.cache
5
- RUN chmod 1777 /.cache
6
  COPY ./requirements.txt /code/requirements.txt
7
- RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib/libcudart.so' >> ~/.bashrc
8
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
- RUN /opt/conda/bin/pip install peft
10
- RUN /opt/conda/bin/pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
11
- # Expose the secret SECRET_EXAMPLE at buildtime and use its value as git remote URL
12
- RUN --mount=type=secret,id=HUGGINGFACE_TOKEN,mode=0444,required=true \
13
- huggingface-cli login --token $(cat /run/secrets/HUGGINGFACE_TOKEN) && \
14
- echo "HUGGINGFACE_TOKEN=$( cat /run/secrets/HUGGINGFACE_TOKEN )" >> .env
15
 
16
  COPY . .
17
 
 
1
  FROM pytorch/pytorch
2
 
3
  WORKDIR /code
 
 
4
  COPY ./requirements.txt /code/requirements.txt
 
 
 
 
 
 
 
 
5
 
6
  COPY . .
7
 
holosubs.py DELETED
@@ -1,119 +0,0 @@
1
- """"
2
- Entry point and main execution block of the video transcription job
3
- """
4
- import re
5
-
6
- from dotenv import load_dotenv
7
- from youtubeaudio import YoutubeAudio
8
- from transcribe import Transcriber
9
- import torchaudio
10
- from pyannote.audio import Pipeline
11
- from webvtt import WebVTT, Caption
12
- import torch
13
- import logging
14
- from huggingface_hub._login import _login
15
- import os
16
-
17
- load_dotenv()
18
- WHISPER_SAMPLE_RATE=16000
19
- TIMESTAMP_PATTERN='[0-9]+:[0-9]+:[0-9]+\.[0-9]+'
20
- MAX_CHUNK_DURATION=30000 # ms
21
-
22
- format = "%(asctime)s: %(message)s"
23
- logging.basicConfig(format=format, level=logging.DEBUG,
24
- datefmt="%H:%M:%S")
25
- _login(token=os.getenv('HUGGINGFACE_TOKEN'), add_to_git_credential=False)
26
-
27
- def get_video_vtt(url) -> str:
28
- # Download wav file
29
- ytaudio=YoutubeAudio(url)
30
- ytaudio.download_audio()
31
- # Load audio
32
- audio, sample_rate = torchaudio.load(ytaudio.filename)
33
- audio_dict={"waveform": audio, "sample_rate": sample_rate}
34
- # Diarization
35
- pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization@2.1', use_auth_token=True)
36
- dzs = pipeline(audio_dict)
37
- groups = group_segments(str(dzs).splitlines())
38
- # Preprocess audio segments for translation
39
- audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=WHISPER_SAMPLE_RATE)
40
- audio_segments, timestamps = get_segments(groups, audio)
41
- # Decoding audio segments into subtitles
42
- transcriber = Transcriber(task="translate")
43
- captions = decode_segments(audio_segments, timestamps, transcriber)
44
- vtt = create_vtt(captions)
45
- ytaudio.clean()
46
- return vtt.content
47
-
48
- def decode_segments(audio_segments, timestamps, transcriber):
49
- captions = []
50
- for i, segment in enumerate(audio_segments):
51
- result = transcriber.decode(segment)
52
- captions.append(Caption(timestamps[i][0], timestamps[i][1], result))
53
- logging.info(f"Chunk output no.{i+1}: {result}")
54
- return captions
55
-
56
- def millisec(timeStr):
57
- spl = timeStr.split(":")
58
- s = (int)((int(spl[0]) * 60 * 60 + int(spl[1]) * 60 + float(spl[2]) )* 1000)
59
- return s
60
-
61
- def group_segments(dzs):
62
- groups = []
63
- g = []
64
- lastend = 0
65
-
66
- for d in dzs:
67
- if g and (g[0].split()[-1] != d.split()[-1]): #same speaker
68
- groups.append(g)
69
- g = []
70
-
71
- g.append(d)
72
-
73
- end = re.findall('[0-9]+:[0-9]+:[0-9]+\.[0-9]+', string=d)[1]
74
- end = millisec(end)
75
- if (lastend > end): #segment engulfed by a previous segment
76
- groups.append(g)
77
- g = []
78
- else:
79
- lastend = end
80
- if g:
81
- groups.append(g)
82
- logging.debug(groups)
83
- return groups
84
-
85
- def create_vtt(captions):
86
- vtt = WebVTT()
87
- for caption in captions:
88
- vtt.captions.append(caption)
89
- return vtt
90
- # vtt.save(path)
91
-
92
- def get_segments(groups, audio):
93
- monoaudio=torch.mean(input=audio,dim=0).numpy()
94
- audio_segments = []
95
- timestamps = []
96
- for g in groups:
97
- cur_start_time, cur_end_time = re.findall(TIMESTAMP_PATTERN, string=g[0])
98
- cur_start_millisec = millisec(cur_start_time) #- spacermilli
99
- cur_end_millisec = millisec(cur_end_time) #- spacermilli
100
- for window in g[1:]:
101
- start_time, end_time = re.findall(TIMESTAMP_PATTERN, string=window)
102
- start_millisec = millisec(start_time) #- spacermilli
103
- end_millisec = millisec(end_time) #- spacermilli
104
- # Check if new window exceeds chunk size
105
- seg_duration_with_window=end_millisec-cur_start_millisec
106
- if seg_duration_with_window>MAX_CHUNK_DURATION: # Segment with window exceeds max chunk duration
107
- start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
108
- audio_segments.append(monoaudio[start_frame:end_frame])
109
- timestamps.append((cur_start_time, cur_end_time))
110
- cur_start_time, cur_end_time = start_time, end_time
111
- cur_start_millisec, cur_end_millisec = start_millisec, end_millisec
112
- else:
113
- cur_end_time=end_time
114
- cur_end_millisec=end_millisec
115
- # Final update
116
- start_frame, end_frame = cur_start_millisec*WHISPER_SAMPLE_RATE//1000, cur_end_millisec*WHISPER_SAMPLE_RATE//1000
117
- audio_segments.append(monoaudio[start_frame:end_frame])
118
- timestamps.append((cur_start_time, cur_end_time))
119
- return audio_segments, timestamps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,14 +1,24 @@
1
  from fastapi import FastAPI
2
- from holosubs import get_video_vtt
3
- from pydantic import BaseModel
4
 
5
- class Url(BaseModel):
6
- url: str
7
 
 
 
8
 
9
- app = FastAPI()
 
 
 
 
 
 
 
 
 
10
 
11
- @app.post("/captions/")
12
- def read_root(url: Url):
13
- vtt_captions = get_video_vtt(url.url)
14
- return {"captions": vtt_captions}
 
 
1
  from fastapi import FastAPI
2
+ from celery import Celery
 
3
 
4
+ app = FastAPI()
 
5
 
6
+ BROKER_URL = 'redis://139.59.127.180:6379/0'
7
+ BACKEND_URL = 'redis://139.59.127.180:6379/0'
8
 
9
+ celery = Celery(
10
+ __name__,
11
+ broker=BROKER_URL,
12
+ backend=BACKEND_URL
13
+ )
14
+
15
+
16
+ @app.get("/")
17
+ async def root():
18
+ return {"message": "Hello World"}
19
 
20
+ @celery.task
21
+ def divide(x, y):
22
+ import time
23
+ time.sleep(5)
24
+ return x / y
requirements.txt CHANGED
@@ -1,16 +1,3 @@
1
- fastapi==0.74.*
2
- requests==2.27.*
3
- sentencepiece==0.1.*
4
- uvicorn[standard]==0.17.*
5
- numpy==1.24.4
6
- pyannote.audio==1.1.2
7
- pyannote.core==5.0.0
8
- pyannote.database==5.0.1
9
- pyannote.metrics==3.2.1
10
- pyannote.pipeline==1.5.2
11
- python-dotenv==1.0.0
12
- torch==2.0.1
13
- torchaudio==2.0.2
14
- transformers==4.31.0
15
- webvtt_py==0.4.6
16
- yt_dlp==2023.7.6
 
1
+ celery==5.1.2
2
+ fastapi==0.103.2
3
+ pydantic==1.10.12
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- Represents a model that transcribes and translates audio.
3
- """
4
-
5
- import logging
6
- import os
7
- from typing import Union
8
-
9
- import numpy as np
10
- import torch
11
- from dotenv import load_dotenv
12
- from peft import PeftConfig, PeftModel
13
- from transformers import (AutomaticSpeechRecognitionPipeline,
14
- WhisperForConditionalGeneration, WhisperProcessor,
15
- WhisperTokenizer)
16
-
17
- load_dotenv()
18
- format = "%(asctime)s: %(message)s"
19
- logging.basicConfig(format=format, level=logging.DEBUG,
20
- datefmt="%H:%M:%S")
21
-
22
- class Transcriber:
23
- def __init__(self, model_id="teoha/openai-whisper-medium-LORA-ja", language="Japanese", task="translate"):
24
- self.language=language
25
- self.task=task
26
- peft_model_id = model_id if model_id else os.getenv('peft_model_id')
27
- # TODO: Fix Download and install model locally
28
- # self.install_model(peft_model_id)
29
- self.initialize_pipe(peft_model_id) #initialize pipe
30
-
31
- def install_model(self, peft_model_id:str) -> None:
32
- save_location = os.path.join(os.getenv('install_location'), peft_model_id)
33
- offload_location = os.path.join(os.getenv('install_location'), "offload")
34
- #Save Model
35
- peft_config = PeftConfig.from_pretrained(peft_model_id)
36
- model = WhisperForConditionalGeneration.from_pretrained(
37
- peft_config.base_model_name_or_path,
38
- load_in_8bit=False, device_map="auto"
39
- )
40
- model = PeftModel.from_pretrained(model, peft_model_id, offload_folder="offload_location")
41
- model.save_pretrained(save_location)
42
-
43
- #Save tokenizer/processor
44
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
45
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
46
- tokenizer.save_pretrained(save_location)
47
- processor.save_pretrained(save_location)
48
- logging.info("Installation Completed successfully")
49
-
50
- def initialize_pipe(self, peft_model_id: str) -> None:
51
- offload_location = os.path.join(os.getenv('install_location'), "offload")
52
- # Initalize model configs
53
- peft_config = PeftConfig.from_pretrained(peft_model_id)
54
- model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=False, device_map="auto")
55
- model = PeftModel.from_pretrained(model, peft_model_id, offload_folder=offload_location)
56
- tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
57
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=self.language, task=self.task)
58
- feature_extractor = processor.feature_extractor
59
- # Initialize class variables
60
- self.forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
61
- self.pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
62
- logging.info("Pipe successfully initialized")
63
-
64
-
65
- def decode(self, audio: Union[np.ndarray, bytes, str]) -> str:
66
- '''
67
- Transcribes a sequence of floats representing an audio snippet.
68
- Args:
69
- inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
70
- The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
71
- :obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
72
- the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
73
- using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
74
- supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
75
- '''
76
- with torch.cuda.amp.autocast():
77
- text = self.pipe(audio, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids})["text"]
78
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
youtubeaudio.py DELETED
@@ -1,51 +0,0 @@
1
- """
2
- Represents a Youtube video
3
- """
4
-
5
- from dotenv import load_dotenv
6
- import logging
7
- from yt_dlp import YoutubeDL
8
- import os
9
- from pathlib import Path
10
-
11
- load_dotenv()
12
- format = "%(asctime)s: %(message)s"
13
- logging.basicConfig(format=format, level=logging.DEBUG,
14
- datefmt="%H:%M:%S")
15
-
16
- class YoutubeAudio:
17
- def __init__(self, url, dir="/tmp/holosubs/audio"):
18
- self.url=url
19
- self.dir=dir
20
-
21
- def download_audio(self):
22
- ydl_opts = {
23
- 'outtmpl': os.path.join(self.dir, "%(id)s_%(epoch)s.%(ext)s"),
24
- 'logger': logging,
25
- 'progress_hooks': [self.progress_hook],
26
- 'format': 'm4a/bestaudio/best',
27
- 'postprocessors': [{ # Extract audio using ffmpeg
28
- 'key': 'FFmpegExtractAudio',
29
- 'preferredcodec': 'wav',
30
- }]
31
- }
32
- with YoutubeDL(ydl_opts) as ydl:
33
- error_code = ydl.download([self.url])
34
-
35
- def clean(self):
36
- if not self.filename:
37
- logging.error("Audio not downloaded")
38
- return
39
- location=os.path.join(self.dir, self.filename)
40
- if os.path.exists(self.filename):
41
- os.remove(self.filename)
42
- logging.info(f"File {self.filename} successfully removed")
43
- self.filename=None
44
- else:
45
- print(f"File {self.filename} does not exist")
46
-
47
- def progress_hook(self, d):
48
- if d['status'] == 'finished':
49
- self.filename=os.path.join(self.dir, Path(d.get('info_dict').get('_filename')).stem + ".wav")
50
- print(f'Done downloading {self.filename}, now post-processing ...')
51
-