Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline | |
import numpy as np | |
import io | |
import tempfile | |
import os | |
from datetime import datetime | |
import spaces | |
import threading | |
import queue | |
import time | |
from huggingface_hub import login | |
# Get HuggingFace token from environment (Spaces secret) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
# Global variables to store loaded models to avoid reloading | |
loaded_whisper_model = None | |
loaded_whisper_processor = None | |
loaded_whisper_model_name = None | |
loaded_llm_model = None | |
loaded_llm_tokenizer = None | |
loaded_llm_model_name = None | |
def chunk_audio(audio, chunk_length_seconds=30, overlap_seconds=5, sample_rate=16000): | |
"""Split audio into overlapping chunks for better transcription""" | |
chunk_length_samples = chunk_length_seconds * sample_rate | |
overlap_samples = overlap_seconds * sample_rate | |
chunks = [] | |
start = 0 | |
while start < len(audio): | |
end = min(start + chunk_length_samples, len(audio)) | |
chunk = audio[start:end] | |
chunks.append((chunk, start / sample_rate)) # Include timestamp | |
if end >= len(audio): | |
break | |
start += chunk_length_samples - overlap_samples | |
return chunks | |
def transcribe_audio_streaming(audio_file, model_name, progress=gr.Progress()): | |
"""Transcribe audio using Whisper model with streaming output""" | |
global loaded_whisper_model, loaded_whisper_processor, loaded_whisper_model_name | |
try: | |
if audio_file is None: | |
yield "No audio file provided" | |
return | |
progress(0.05, desc="Processing audio input...") | |
# Load and process audio | |
if isinstance(audio_file, str): | |
waveform, sample_rate = torchaudio.load(audio_file) | |
else: | |
sample_rate, audio_data = audio_file | |
if isinstance(audio_data, np.ndarray): | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=1) # Convert to mono | |
if audio_data.dtype == np.int16: | |
audio_data = audio_data.astype(np.float32) / 32768.0 | |
elif audio_data.dtype == np.int32: | |
audio_data = audio_data.astype(np.float32) / 2147483648.0 | |
waveform = torch.from_numpy(audio_data).unsqueeze(0) | |
else: | |
waveform = torch.tensor(audio_data).unsqueeze(0) | |
# Resample if necessary | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
# Ensure mono audio | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
audio_length = waveform.shape[1] / 16000 # Length in seconds | |
progress(0.1, desc=f"Audio loaded: {audio_length:.1f} seconds") | |
# Load model if not already loaded or if model changed | |
if (loaded_whisper_model is None or | |
loaded_whisper_processor is None or | |
loaded_whisper_model_name != model_name): | |
progress(0.15, desc="Loading Whisper model...") | |
loaded_whisper_processor = WhisperProcessor.from_pretrained(model_name) | |
loaded_whisper_model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
loaded_whisper_model_name = model_name | |
# Move to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
loaded_whisper_model = loaded_whisper_model.to(device) | |
progress(0.25, desc=f"Model loaded on {device}") | |
device = next(loaded_whisper_model.parameters()).device | |
# Split audio into chunks for long audio files | |
audio_numpy = waveform.squeeze().numpy() | |
chunks = chunk_audio(audio_numpy, chunk_length_seconds=30, overlap_seconds=2) | |
progress(0.3, desc=f"Processing {len(chunks)} audio chunks...") | |
full_transcription = "" | |
overlap_words = [] | |
for i, (chunk, timestamp) in enumerate(chunks): | |
chunk_progress = 0.3 + (0.6 * i / len(chunks)) | |
progress(chunk_progress, desc=f"Transcribing chunk {i+1}/{len(chunks)} (at {timestamp:.1f}s)") | |
# Process chunk | |
input_features = loaded_whisper_processor( | |
chunk, | |
sampling_rate=16000, | |
return_tensors="pt" | |
).input_features | |
input_features = input_features.to(device) | |
# Create prompt tokens for medical context | |
prompt = """ Endocrinology clinic visit transcript. | |
Previous discussion mentioned levothyroxine dose adjustment for Hashimotoβs thyroiditis and an Hemoglobin A1C review. | |
Terms we frequently use: levothyroxine, Synthroid, Cytomel, methimazole, propylthiouracil, PTU semaglutide, Ozempic, Mounjaro, insulin glargine, Lantus, Tresiba, Toujeo, Metformin, Glipizide, Januvia, Jardiance, Atorvastatin, Rosuvastatin, Pravastatin. | |
Common labs and hormones: TSH, free T4, free T3, HbA1C, cortisol, ACTH, parathyroid hormone, LDL, microalbumin. | |
Typical procedures: thyroid ultrasound, fine-needle aspiration biopsy, FNA, medtronics pump, Tandem pump, OmniPod pump, radioactive iodine ablation, continuous glucose monitoring, CGM, DEXA scan.""" | |
prompt_ids = loaded_whisper_processor.get_prompt_ids(prompt, return_tensors="pt").to(device) | |
# Generate transcription for chunk | |
with torch.no_grad(): | |
predicted_ids = loaded_whisper_model.generate( | |
input_features, | |
max_length=448, | |
num_beams=1, # Faster generation | |
prompt_ids=prompt_ids, | |
do_sample=False | |
) | |
# Decode chunk transcription | |
chunk_transcription = loaded_whisper_processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True | |
)[0].strip() | |
# Handle overlap removal for better continuity | |
if i > 0 and overlap_words: | |
chunk_words = chunk_transcription.split() | |
# Remove potential overlap by checking first few words | |
for j in range(min(10, len(chunk_words))): | |
if chunk_words[j:j+3] == overlap_words[-3:]: | |
chunk_transcription = " ".join(chunk_words[j+3:]) | |
break | |
# Store last few words for next chunk overlap detection | |
chunk_words = chunk_transcription.split() | |
overlap_words = chunk_words[-5:] if len(chunk_words) >= 5 else chunk_words | |
# Add to full transcription | |
if chunk_transcription: | |
if full_transcription and not full_transcription.endswith(" "): | |
full_transcription += " " | |
full_transcription += chunk_transcription | |
# Yield intermediate result for streaming | |
yield full_transcription | |
time.sleep(0.1) # Small delay for visual streaming effect | |
progress(0.95, desc="Finalizing transcription...") | |
# Final cleanup | |
full_transcription = full_transcription.strip() | |
if not full_transcription: | |
full_transcription = "No speech detected in the audio." | |
progress(1.0, desc="Transcription complete!") | |
yield full_transcription | |
except Exception as e: | |
yield f"Error during transcription: {str(e)}" | |
def generate_medical_note(transcription, model_name, note_type, progress=gr.Progress()): | |
"""Generate medical note using selected LLM with proper chat template formatting""" | |
global loaded_llm_model, loaded_llm_tokenizer, loaded_llm_model_name | |
try: | |
if not transcription or transcription.strip() == "": | |
return "No transcription provided" | |
progress(0.1, desc="Loading medical note model...") | |
# Load model if not already loaded or if model changed | |
if (loaded_llm_model is None or | |
loaded_llm_tokenizer is None or | |
loaded_llm_model_name != model_name): | |
# Load tokenizer and model with token if needed | |
model_kwargs = { | |
"trust_remote_code": True, | |
"torch_dtype": torch.float16, | |
"device_map": "auto" | |
} | |
tokenizer_kwargs = { | |
"trust_remote_code": True | |
} | |
# Add token if available for private models | |
if HF_TOKEN: | |
model_kwargs["token"] = HF_TOKEN # Updated parameter name | |
tokenizer_kwargs["token"] = HF_TOKEN | |
loaded_llm_tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) | |
loaded_llm_model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | |
loaded_llm_model_name = model_name | |
# Add padding token if not present | |
if loaded_llm_tokenizer.pad_token is None: | |
loaded_llm_tokenizer.pad_token = loaded_llm_tokenizer.eos_token | |
progress(0.3, desc="Preparing prompt...") | |
# Define system prompts based on note type (matching your local code) | |
if note_type == "SOAP note": | |
system_prompt = """You are an expert medical professor assisting in the creation of medically accurate SOAP summaries. | |
Please ensure the response follows the structured format: S:, O:, A:, P: without using markdown or special formatting. | |
Create a Medical SOAP note summary from the dialogue, following these guidelines: | |
S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. | |
Rely on the patient's statements as the primary source and ensure standardized terminology. | |
O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. | |
Include normal ranges where relevant. | |
A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook. | |
P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges. | |
Considerations: Compile the report based solely on the transcript provided. Use concise medical jargon and abbreviations for effective doctor communication. | |
Please format the summary in a clean, simple list format without using markdown or bullet points. Use 'S:', 'O:', 'A:', 'P:' directly followed by the text. Avoid any styling or special characters.""" | |
prompt_prefix = '''Convert the following medical transcript to a SOAP note. Transcript: \n''' | |
else: | |
# Structured note system prompt (from your local code) | |
system_prompt = """Your an expert medical transcriptionist. You convert medical transcript to a structured medical note with these sections in this order: | |
1. Presenting Illness | |
(Bullet point statements of the main problem) | |
2. History of Presenting Illness | |
(Chronological narrative: symptom onset, progression, modifiers, associated factors) | |
3. Past Medical History | |
(List chronic illnesses and past medical diagnoses mentioned in the transcript. Do not include surgeries) | |
4. Surgical History | |
(List prior surgeries with year if known mentioned in the transcript) | |
5. Family History | |
(Relevant family history mentioned in the transcript) | |
6. Social History | |
(Occupation, tobacco/alcohol/drug use, exercise, living situation if mentioned in the transcript) | |
7. Allergy History | |
(Drug/food/environmental allergies + reactions - if mentioned in the transcript) | |
8. Medication History | |
(List medications the patient is already taking. Do not place any medication the patient is currently not taking.) | |
9. Dietary History | |
("Not applicable" if unrelated, otherwise summarize diet pattern) | |
10. Review of Systems | |
(Head-to-toe -ordered bullets; note positives and pertinent negatives- mentioned in the transcript) | |
11. Physical Exam Findings | |
Vital Signs (BP, HR, RR, Temp, SpOβ, HT, WT, BMI) - if mentioned in the transcript | |
(Structured by system: General, HEENT, CV, Resp, Abd, Neuro, MSK, Skin, Psych) - if mentioned in the transcript | |
12. Labs and Imaging | |
(labs, imaging results) | |
13. Assessment and Plan | |
(List each diagnoses and treatment plan. No other information needed in this section.Do not generate new diagnoses)""" | |
prompt_prefix = '''Convert the following medical transcript to a structured medical note. Transcript: \n''' | |
# Create the user prompt | |
user_prompt = prompt_prefix + transcription | |
# Format using chat template (matching your local approach) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
progress(0.4, desc="Applying chat template...") | |
# Apply chat template | |
input_ids = loaded_llm_tokenizer.apply_chat_template( | |
messages, | |
return_tensors="pt", | |
add_generation_prompt=True, # Adds the assistant turn prompt | |
truncation=True, | |
max_length=4096 # Adjust based on model's context window | |
).to(loaded_llm_model.device) | |
progress(0.5, desc="Generating medical note...") | |
# Generate with parameters matching your local setup | |
with torch.no_grad(): | |
outputs = loaded_llm_model.generate( | |
input_ids, | |
max_new_tokens=4096, # Matching your local setup | |
do_sample=True, # Matching your local setup | |
temperature=0.01, # Matching your local setup (very low temperature) | |
top_p=0.95, # Matching your local setup | |
num_return_sequences=1, | |
eos_token_id=loaded_llm_tokenizer.eos_token_id, | |
pad_token_id=loaded_llm_tokenizer.pad_token_id, | |
use_cache=True # Matching your local setup | |
) | |
progress(0.9, desc="Processing response...") | |
# Decode the generated text (matching your local decoding logic) | |
decoded_output_full = loaded_llm_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
# Extract generated text using the same logic as your local code | |
assistant_prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>" | |
assistant_prompt_end_index = decoded_output_full.rfind(assistant_prompt_end_marker) | |
if assistant_prompt_end_index != -1: | |
# Find the newline character immediately following the assistant prompt marker | |
start_of_generation = decoded_output_full.find('\n', assistant_prompt_end_index + len(assistant_prompt_end_marker)) | |
if start_of_generation != -1: | |
# Extract text after the newline, skipping special tokens | |
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() | |
else: | |
# Fallback if structure is unexpected | |
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() | |
else: | |
# Fallback if assistant prompt marker is not found | |
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() | |
# Clean up the response | |
if not medical_note: | |
medical_note = "Unable to generate medical note. Please check the transcription and try again." | |
progress(1.0, desc="Complete!") | |
return medical_note | |
except Exception as e: | |
return f"Error generating medical note: {str(e)}" | |
# Global variable to store original audio for download | |
original_audio_data = None | |
def store_original_audio(audio_data): | |
"""Store the original audio data for download""" | |
global original_audio_data | |
original_audio_data = audio_data | |
return audio_data | |
def save_original_audio(): | |
"""Save the stored original audio without any processing""" | |
global original_audio_data | |
if original_audio_data is None: | |
return None | |
try: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
sample_rate, audio_array = original_audio_data | |
# Create temporary file with timestamp | |
temp_dir = tempfile.gettempdir() | |
filename = f"medical_recording_{timestamp}.wav" | |
filepath = os.path.join(temp_dir, filename) | |
# Method 1: Try using scipy.io.wavfile (preserves original format best) | |
try: | |
import scipy.io.wavfile as wavfile | |
wavfile.write(filepath, sample_rate, audio_array) | |
return filepath | |
except: | |
pass | |
# Method 2: Fallback to torchaudio with minimal processing | |
try: | |
if isinstance(audio_array, np.ndarray): | |
audio_tensor = torch.from_numpy(audio_array.copy()) # Copy to avoid modifications | |
else: | |
audio_tensor = torch.tensor(audio_array) | |
# Handle tensor dimensions | |
if len(audio_tensor.shape) == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) # Add channel dimension | |
elif len(audio_tensor.shape) == 2: | |
if audio_tensor.shape[0] > audio_tensor.shape[1]: | |
audio_tensor = audio_tensor.T # Transpose if needed | |
# Save with original data type preservation | |
torchaudio.save( | |
filepath, | |
audio_tensor, | |
sample_rate, | |
encoding="PCM_S", | |
bits_per_sample=16 | |
) | |
return filepath | |
except Exception as e: | |
print(f"Torchaudio save failed: {e}") | |
# Method 3: Last resort - use soundfile | |
try: | |
import soundfile as sf | |
sf.write(filepath, audio_array, sample_rate) | |
return filepath | |
except Exception as e: | |
print(f"Soundfile save failed: {e}") | |
return None | |
except Exception as e: | |
print(f"Error saving original audio: {e}") | |
return None | |
# Available models | |
WHISPER_MODELS = [ | |
"openai/whisper-tiny", | |
"openai/whisper-base", | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-large-v2", | |
"openai/whisper-large-v3" | |
] | |
MEDICAL_LLM_MODELS = [ | |
"OnDeviceMedNotes/OnDevice_med_note_v0_5", # Updated to match your local model | |
"OnDeviceMedNotes/Struct_Med_Note_v01", | |
"OnDeviceMedNotes/JT_latest_model", | |
"OnDeviceMedNotes/Medical_Summary_Notes", | |
"Johnyquest7/combined_hpi", | |
"meta-llama/Llama-3.2-1B-Instruct", | |
] | |
# Create Gradio interface | |
with gr.Blocks(title="Medical Transcription & Note Generation", theme=gr.themes.Soft()) as demo: | |
gr.HTML("<h1 style='text-align: center; color: #2563eb;'>π₯ Medical Transcription & Note Generation</h1>") | |
gr.HTML("<p style='text-align: center;'>Only for research purpose, not for clinical use.</p>") | |
# Add warning about HF token | |
#if not HF_TOKEN: | |
# gr.HTML("<div style='background-color: #fef3cd; border: 1px solid #ffeaa7; padding: 10px; margin: 10px 0; border-radius: 5px;'><strong>β οΈ Warning:</strong> HF_TOKEN not found. Some models may not be accessible. Please add your Hugging Face token as a Space secret.</div>") | |
#else: | |
# gr.HTML("<div style='background-color: #d4edda; border: 1px solid #c3e6cb; padding: 10px; margin: 10px 0; border-radius: 5px;'><strong>β Status:</strong> Hugging Face token loaded successfully.</div>") | |
with gr.Row(): | |
# Left Column - Audio Recording and Transcription | |
with gr.Column(scale=1): | |
gr.HTML("<h2>ποΈ Audio Recording & Transcription</h2>") | |
# Audio input | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="numpy", | |
label="Record or Upload Audio" | |
) | |
# Whisper model selection | |
whisper_model = gr.Dropdown( | |
choices=WHISPER_MODELS, | |
value="openai/whisper-tiny", | |
label="Select Whisper Model" | |
) | |
# Transcription controls | |
with gr.Row(): | |
transcribe_btn = gr.Button("π€ Transcribe Audio", variant="primary") | |
download_audio_btn = gr.Button("πΎ Download Audio") | |
# Download file output | |
audio_download = gr.File(label="Download Audio File", visible=False) | |
# Right Column - Medical Note Generation | |
with gr.Column(scale=1): | |
gr.HTML("<h2>π Medical Note Generation</h2>") | |
# LLM model selection | |
llm_model = gr.Dropdown( | |
choices=MEDICAL_LLM_MODELS, | |
value="OnDeviceMedNotes/OnDevice_med_note_v0_5", # Updated default | |
label="Select Medical LLM Model" | |
) | |
# Note type selection | |
note_type = gr.Radio( | |
choices=["SOAP note", "Full structured medical note"], | |
value="SOAP note", | |
label="Note Type" | |
) | |
# Generate button | |
generate_btn = gr.Button("π Generate Medical Note", variant="primary") | |
# Output Section | |
gr.HTML("<h2>π Results</h2>") | |
with gr.Row(): | |
# Transcription output | |
with gr.Column(scale=1): | |
gr.HTML("<h3>Transcription</h3>") | |
transcription_output = gr.Textbox( | |
label="Transcribed Text (Editable)", | |
lines=15, | |
max_lines=20, | |
interactive=True, | |
placeholder="Transcribed text will appear here as it's processed..." | |
) | |
copy_transcription_btn = gr.Button("π Copy Transcription") | |
# Medical note output | |
with gr.Column(scale=1): | |
gr.HTML("<h3>Medical Note</h3>") | |
medical_note_output = gr.Textbox( | |
label="Generated Medical Note", | |
lines=15, | |
max_lines=20, | |
interactive=False, | |
placeholder="Generated medical note will appear here..." | |
) | |
copy_note_btn = gr.Button("π Copy Medical Note") | |
# Event handlers | |
transcribe_btn.click( | |
fn=transcribe_audio_streaming, | |
inputs=[audio_input, whisper_model], | |
outputs=transcription_output, | |
show_progress=True | |
) | |
generate_btn.click( | |
fn=generate_medical_note, | |
inputs=[transcription_output, llm_model, note_type], | |
outputs=medical_note_output, | |
show_progress=True | |
) | |
# Download audio functionality | |
def prepare_audio_download(audio_data): | |
if audio_data is None: | |
return None, gr.update(visible=False) | |
# Store the original audio first | |
store_original_audio(audio_data) | |
# Save the original audio without processing | |
file_path = save_original_audio() | |
if file_path: | |
return file_path, gr.update(visible=True) | |
return None, gr.update(visible=False) | |
# Audio input change handler to store original audio | |
audio_input.change( | |
fn=store_original_audio, | |
inputs=audio_input, | |
outputs=None | |
) | |
download_audio_btn.click( | |
fn=prepare_audio_download, | |
inputs=audio_input, | |
outputs=[audio_download, audio_download] | |
) | |
# Copy functionality (JavaScript) | |
copy_transcription_btn.click( | |
fn=None, | |
inputs=transcription_output, | |
outputs=None, | |
js="(text) => {navigator.clipboard.writeText(text); return text;}" | |
) | |
copy_note_btn.click( | |
fn=None, | |
inputs=medical_note_output, | |
outputs=None, | |
js="(text) => {navigator.clipboard.writeText(text); return text;}" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |