Johnyquest7's picture
Update app.py
fa38c17 verified
import gradio as gr
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
import numpy as np
import io
import tempfile
import os
from datetime import datetime
import spaces
import threading
import queue
import time
from huggingface_hub import login
# Get HuggingFace token from environment (Spaces secret)
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# Global variables to store loaded models to avoid reloading
loaded_whisper_model = None
loaded_whisper_processor = None
loaded_whisper_model_name = None
loaded_llm_model = None
loaded_llm_tokenizer = None
loaded_llm_model_name = None
def chunk_audio(audio, chunk_length_seconds=30, overlap_seconds=5, sample_rate=16000):
"""Split audio into overlapping chunks for better transcription"""
chunk_length_samples = chunk_length_seconds * sample_rate
overlap_samples = overlap_seconds * sample_rate
chunks = []
start = 0
while start < len(audio):
end = min(start + chunk_length_samples, len(audio))
chunk = audio[start:end]
chunks.append((chunk, start / sample_rate)) # Include timestamp
if end >= len(audio):
break
start += chunk_length_samples - overlap_samples
return chunks
@spaces.GPU
def transcribe_audio_streaming(audio_file, model_name, progress=gr.Progress()):
"""Transcribe audio using Whisper model with streaming output"""
global loaded_whisper_model, loaded_whisper_processor, loaded_whisper_model_name
try:
if audio_file is None:
yield "No audio file provided"
return
progress(0.05, desc="Processing audio input...")
# Load and process audio
if isinstance(audio_file, str):
waveform, sample_rate = torchaudio.load(audio_file)
else:
sample_rate, audio_data = audio_file
if isinstance(audio_data, np.ndarray):
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1) # Convert to mono
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
waveform = torch.from_numpy(audio_data).unsqueeze(0)
else:
waveform = torch.tensor(audio_data).unsqueeze(0)
# Resample if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Ensure mono audio
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
audio_length = waveform.shape[1] / 16000 # Length in seconds
progress(0.1, desc=f"Audio loaded: {audio_length:.1f} seconds")
# Load model if not already loaded or if model changed
if (loaded_whisper_model is None or
loaded_whisper_processor is None or
loaded_whisper_model_name != model_name):
progress(0.15, desc="Loading Whisper model...")
loaded_whisper_processor = WhisperProcessor.from_pretrained(model_name)
loaded_whisper_model = WhisperForConditionalGeneration.from_pretrained(model_name)
loaded_whisper_model_name = model_name
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_whisper_model = loaded_whisper_model.to(device)
progress(0.25, desc=f"Model loaded on {device}")
device = next(loaded_whisper_model.parameters()).device
# Split audio into chunks for long audio files
audio_numpy = waveform.squeeze().numpy()
chunks = chunk_audio(audio_numpy, chunk_length_seconds=30, overlap_seconds=2)
progress(0.3, desc=f"Processing {len(chunks)} audio chunks...")
full_transcription = ""
overlap_words = []
for i, (chunk, timestamp) in enumerate(chunks):
chunk_progress = 0.3 + (0.6 * i / len(chunks))
progress(chunk_progress, desc=f"Transcribing chunk {i+1}/{len(chunks)} (at {timestamp:.1f}s)")
# Process chunk
input_features = loaded_whisper_processor(
chunk,
sampling_rate=16000,
return_tensors="pt"
).input_features
input_features = input_features.to(device)
# Create prompt tokens for medical context
prompt = """ Endocrinology clinic visit transcript.
Previous discussion mentioned levothyroxine dose adjustment for Hashimoto’s thyroiditis and an Hemoglobin A1C review.
Terms we frequently use: levothyroxine, Synthroid, Cytomel, methimazole, propylthiouracil, PTU semaglutide, Ozempic, Mounjaro, insulin glargine, Lantus, Tresiba, Toujeo, Metformin, Glipizide, Januvia, Jardiance, Atorvastatin, Rosuvastatin, Pravastatin.
Common labs and hormones: TSH, free T4, free T3, HbA1C, cortisol, ACTH, parathyroid hormone, LDL, microalbumin.
Typical procedures: thyroid ultrasound, fine-needle aspiration biopsy, FNA, medtronics pump, Tandem pump, OmniPod pump, radioactive iodine ablation, continuous glucose monitoring, CGM, DEXA scan."""
prompt_ids = loaded_whisper_processor.get_prompt_ids(prompt, return_tensors="pt").to(device)
# Generate transcription for chunk
with torch.no_grad():
predicted_ids = loaded_whisper_model.generate(
input_features,
max_length=448,
num_beams=1, # Faster generation
prompt_ids=prompt_ids,
do_sample=False
)
# Decode chunk transcription
chunk_transcription = loaded_whisper_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0].strip()
# Handle overlap removal for better continuity
if i > 0 and overlap_words:
chunk_words = chunk_transcription.split()
# Remove potential overlap by checking first few words
for j in range(min(10, len(chunk_words))):
if chunk_words[j:j+3] == overlap_words[-3:]:
chunk_transcription = " ".join(chunk_words[j+3:])
break
# Store last few words for next chunk overlap detection
chunk_words = chunk_transcription.split()
overlap_words = chunk_words[-5:] if len(chunk_words) >= 5 else chunk_words
# Add to full transcription
if chunk_transcription:
if full_transcription and not full_transcription.endswith(" "):
full_transcription += " "
full_transcription += chunk_transcription
# Yield intermediate result for streaming
yield full_transcription
time.sleep(0.1) # Small delay for visual streaming effect
progress(0.95, desc="Finalizing transcription...")
# Final cleanup
full_transcription = full_transcription.strip()
if not full_transcription:
full_transcription = "No speech detected in the audio."
progress(1.0, desc="Transcription complete!")
yield full_transcription
except Exception as e:
yield f"Error during transcription: {str(e)}"
@spaces.GPU
def generate_medical_note(transcription, model_name, note_type, progress=gr.Progress()):
"""Generate medical note using selected LLM with proper chat template formatting"""
global loaded_llm_model, loaded_llm_tokenizer, loaded_llm_model_name
try:
if not transcription or transcription.strip() == "":
return "No transcription provided"
progress(0.1, desc="Loading medical note model...")
# Load model if not already loaded or if model changed
if (loaded_llm_model is None or
loaded_llm_tokenizer is None or
loaded_llm_model_name != model_name):
# Load tokenizer and model with token if needed
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.float16,
"device_map": "auto"
}
tokenizer_kwargs = {
"trust_remote_code": True
}
# Add token if available for private models
if HF_TOKEN:
model_kwargs["token"] = HF_TOKEN # Updated parameter name
tokenizer_kwargs["token"] = HF_TOKEN
loaded_llm_tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
loaded_llm_model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
loaded_llm_model_name = model_name
# Add padding token if not present
if loaded_llm_tokenizer.pad_token is None:
loaded_llm_tokenizer.pad_token = loaded_llm_tokenizer.eos_token
progress(0.3, desc="Preparing prompt...")
# Define system prompts based on note type (matching your local code)
if note_type == "SOAP note":
system_prompt = """You are an expert medical professor assisting in the creation of medically accurate SOAP summaries.
Please ensure the response follows the structured format: S:, O:, A:, P: without using markdown or special formatting.
Create a Medical SOAP note summary from the dialogue, following these guidelines:
S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history.
Rely on the patient's statements as the primary source and ensure standardized terminology.
O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages.
Include normal ranges where relevant.
A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.
P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.
Considerations: Compile the report based solely on the transcript provided. Use concise medical jargon and abbreviations for effective doctor communication.
Please format the summary in a clean, simple list format without using markdown or bullet points. Use 'S:', 'O:', 'A:', 'P:' directly followed by the text. Avoid any styling or special characters."""
prompt_prefix = '''Convert the following medical transcript to a SOAP note. Transcript: \n'''
else:
# Structured note system prompt (from your local code)
system_prompt = """Your an expert medical transcriptionist. You convert medical transcript to a structured medical note with these sections in this order:
1. Presenting Illness
(Bullet point statements of the main problem)
2. History of Presenting Illness
(Chronological narrative: symptom onset, progression, modifiers, associated factors)
3. Past Medical History
(List chronic illnesses and past medical diagnoses mentioned in the transcript. Do not include surgeries)
4. Surgical History
(List prior surgeries with year if known mentioned in the transcript)
5. Family History
(Relevant family history mentioned in the transcript)
6. Social History
(Occupation, tobacco/alcohol/drug use, exercise, living situation if mentioned in the transcript)
7. Allergy History
(Drug/food/environmental allergies + reactions - if mentioned in the transcript)
8. Medication History
(List medications the patient is already taking. Do not place any medication the patient is currently not taking.)
9. Dietary History
("Not applicable" if unrelated, otherwise summarize diet pattern)
10. Review of Systems
(Head-to-toe -ordered bullets; note positives and pertinent negatives- mentioned in the transcript)
11. Physical Exam Findings
Vital Signs (BP, HR, RR, Temp, SpOβ‚‚, HT, WT, BMI) - if mentioned in the transcript
(Structured by system: General, HEENT, CV, Resp, Abd, Neuro, MSK, Skin, Psych) - if mentioned in the transcript
12. Labs and Imaging
(labs, imaging results)
13. Assessment and Plan
(List each diagnoses and treatment plan. No other information needed in this section.Do not generate new diagnoses)"""
prompt_prefix = '''Convert the following medical transcript to a structured medical note. Transcript: \n'''
# Create the user prompt
user_prompt = prompt_prefix + transcription
# Format using chat template (matching your local approach)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
progress(0.4, desc="Applying chat template...")
# Apply chat template
input_ids = loaded_llm_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True, # Adds the assistant turn prompt
truncation=True,
max_length=4096 # Adjust based on model's context window
).to(loaded_llm_model.device)
progress(0.5, desc="Generating medical note...")
# Generate with parameters matching your local setup
with torch.no_grad():
outputs = loaded_llm_model.generate(
input_ids,
max_new_tokens=4096, # Matching your local setup
do_sample=True, # Matching your local setup
temperature=0.01, # Matching your local setup (very low temperature)
top_p=0.95, # Matching your local setup
num_return_sequences=1,
eos_token_id=loaded_llm_tokenizer.eos_token_id,
pad_token_id=loaded_llm_tokenizer.pad_token_id,
use_cache=True # Matching your local setup
)
progress(0.9, desc="Processing response...")
# Decode the generated text (matching your local decoding logic)
decoded_output_full = loaded_llm_tokenizer.decode(outputs[0], skip_special_tokens=False)
# Extract generated text using the same logic as your local code
assistant_prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>"
assistant_prompt_end_index = decoded_output_full.rfind(assistant_prompt_end_marker)
if assistant_prompt_end_index != -1:
# Find the newline character immediately following the assistant prompt marker
start_of_generation = decoded_output_full.find('\n', assistant_prompt_end_index + len(assistant_prompt_end_marker))
if start_of_generation != -1:
# Extract text after the newline, skipping special tokens
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
else:
# Fallback if structure is unexpected
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
else:
# Fallback if assistant prompt marker is not found
medical_note = loaded_llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
# Clean up the response
if not medical_note:
medical_note = "Unable to generate medical note. Please check the transcription and try again."
progress(1.0, desc="Complete!")
return medical_note
except Exception as e:
return f"Error generating medical note: {str(e)}"
# Global variable to store original audio for download
original_audio_data = None
def store_original_audio(audio_data):
"""Store the original audio data for download"""
global original_audio_data
original_audio_data = audio_data
return audio_data
def save_original_audio():
"""Save the stored original audio without any processing"""
global original_audio_data
if original_audio_data is None:
return None
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
sample_rate, audio_array = original_audio_data
# Create temporary file with timestamp
temp_dir = tempfile.gettempdir()
filename = f"medical_recording_{timestamp}.wav"
filepath = os.path.join(temp_dir, filename)
# Method 1: Try using scipy.io.wavfile (preserves original format best)
try:
import scipy.io.wavfile as wavfile
wavfile.write(filepath, sample_rate, audio_array)
return filepath
except:
pass
# Method 2: Fallback to torchaudio with minimal processing
try:
if isinstance(audio_array, np.ndarray):
audio_tensor = torch.from_numpy(audio_array.copy()) # Copy to avoid modifications
else:
audio_tensor = torch.tensor(audio_array)
# Handle tensor dimensions
if len(audio_tensor.shape) == 1:
audio_tensor = audio_tensor.unsqueeze(0) # Add channel dimension
elif len(audio_tensor.shape) == 2:
if audio_tensor.shape[0] > audio_tensor.shape[1]:
audio_tensor = audio_tensor.T # Transpose if needed
# Save with original data type preservation
torchaudio.save(
filepath,
audio_tensor,
sample_rate,
encoding="PCM_S",
bits_per_sample=16
)
return filepath
except Exception as e:
print(f"Torchaudio save failed: {e}")
# Method 3: Last resort - use soundfile
try:
import soundfile as sf
sf.write(filepath, audio_array, sample_rate)
return filepath
except Exception as e:
print(f"Soundfile save failed: {e}")
return None
except Exception as e:
print(f"Error saving original audio: {e}")
return None
# Available models
WHISPER_MODELS = [
"openai/whisper-tiny",
"openai/whisper-base",
"openai/whisper-small",
"openai/whisper-medium",
"openai/whisper-large-v2",
"openai/whisper-large-v3"
]
MEDICAL_LLM_MODELS = [
"OnDeviceMedNotes/OnDevice_med_note_v0_5", # Updated to match your local model
"OnDeviceMedNotes/Struct_Med_Note_v01",
"OnDeviceMedNotes/JT_latest_model",
"OnDeviceMedNotes/Medical_Summary_Notes",
"Johnyquest7/combined_hpi",
"meta-llama/Llama-3.2-1B-Instruct",
]
# Create Gradio interface
with gr.Blocks(title="Medical Transcription & Note Generation", theme=gr.themes.Soft()) as demo:
gr.HTML("<h1 style='text-align: center; color: #2563eb;'>πŸ₯ Medical Transcription & Note Generation</h1>")
gr.HTML("<p style='text-align: center;'>Only for research purpose, not for clinical use.</p>")
# Add warning about HF token
#if not HF_TOKEN:
# gr.HTML("<div style='background-color: #fef3cd; border: 1px solid #ffeaa7; padding: 10px; margin: 10px 0; border-radius: 5px;'><strong>⚠️ Warning:</strong> HF_TOKEN not found. Some models may not be accessible. Please add your Hugging Face token as a Space secret.</div>")
#else:
# gr.HTML("<div style='background-color: #d4edda; border: 1px solid #c3e6cb; padding: 10px; margin: 10px 0; border-radius: 5px;'><strong>βœ… Status:</strong> Hugging Face token loaded successfully.</div>")
with gr.Row():
# Left Column - Audio Recording and Transcription
with gr.Column(scale=1):
gr.HTML("<h2>πŸŽ™οΈ Audio Recording & Transcription</h2>")
# Audio input
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Record or Upload Audio"
)
# Whisper model selection
whisper_model = gr.Dropdown(
choices=WHISPER_MODELS,
value="openai/whisper-tiny",
label="Select Whisper Model"
)
# Transcription controls
with gr.Row():
transcribe_btn = gr.Button("πŸ”€ Transcribe Audio", variant="primary")
download_audio_btn = gr.Button("πŸ’Ύ Download Audio")
# Download file output
audio_download = gr.File(label="Download Audio File", visible=False)
# Right Column - Medical Note Generation
with gr.Column(scale=1):
gr.HTML("<h2>πŸ“‹ Medical Note Generation</h2>")
# LLM model selection
llm_model = gr.Dropdown(
choices=MEDICAL_LLM_MODELS,
value="OnDeviceMedNotes/OnDevice_med_note_v0_5", # Updated default
label="Select Medical LLM Model"
)
# Note type selection
note_type = gr.Radio(
choices=["SOAP note", "Full structured medical note"],
value="SOAP note",
label="Note Type"
)
# Generate button
generate_btn = gr.Button("πŸ“ Generate Medical Note", variant="primary")
# Output Section
gr.HTML("<h2>πŸ“„ Results</h2>")
with gr.Row():
# Transcription output
with gr.Column(scale=1):
gr.HTML("<h3>Transcription</h3>")
transcription_output = gr.Textbox(
label="Transcribed Text (Editable)",
lines=15,
max_lines=20,
interactive=True,
placeholder="Transcribed text will appear here as it's processed..."
)
copy_transcription_btn = gr.Button("πŸ“‹ Copy Transcription")
# Medical note output
with gr.Column(scale=1):
gr.HTML("<h3>Medical Note</h3>")
medical_note_output = gr.Textbox(
label="Generated Medical Note",
lines=15,
max_lines=20,
interactive=False,
placeholder="Generated medical note will appear here..."
)
copy_note_btn = gr.Button("πŸ“‹ Copy Medical Note")
# Event handlers
transcribe_btn.click(
fn=transcribe_audio_streaming,
inputs=[audio_input, whisper_model],
outputs=transcription_output,
show_progress=True
)
generate_btn.click(
fn=generate_medical_note,
inputs=[transcription_output, llm_model, note_type],
outputs=medical_note_output,
show_progress=True
)
# Download audio functionality
def prepare_audio_download(audio_data):
if audio_data is None:
return None, gr.update(visible=False)
# Store the original audio first
store_original_audio(audio_data)
# Save the original audio without processing
file_path = save_original_audio()
if file_path:
return file_path, gr.update(visible=True)
return None, gr.update(visible=False)
# Audio input change handler to store original audio
audio_input.change(
fn=store_original_audio,
inputs=audio_input,
outputs=None
)
download_audio_btn.click(
fn=prepare_audio_download,
inputs=audio_input,
outputs=[audio_download, audio_download]
)
# Copy functionality (JavaScript)
copy_transcription_btn.click(
fn=None,
inputs=transcription_output,
outputs=None,
js="(text) => {navigator.clipboard.writeText(text); return text;}"
)
copy_note_btn.click(
fn=None,
inputs=medical_note_output,
outputs=None,
js="(text) => {navigator.clipboard.writeText(text); return text;}"
)
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)