|
|
|
""" |
|
Camie-Tagger-V2 Application |
|
A Streamlit web app for tagging images using an AI model. |
|
""" |
|
|
|
import streamlit as st |
|
import os |
|
import sys |
|
import traceback |
|
import tempfile |
|
import time |
|
import platform |
|
import subprocess |
|
import webbrowser |
|
import glob |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import io |
|
import base64 |
|
import json |
|
from matplotlib.colors import LinearSegmentedColormap |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from utils.image_processing import process_image, batch_process_images |
|
from utils.file_utils import save_tags_to_file, get_default_save_locations |
|
from utils.ui_components import display_progress_bar, show_example_images, display_batch_results |
|
from utils.onnx_processing import batch_process_images_onnx |
|
|
|
|
|
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
print(f"Using model directory: {MODEL_DIR}") |
|
|
|
|
|
threshold_profile_descriptions = { |
|
"Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.", |
|
"Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.", |
|
"Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.", |
|
"Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.", |
|
"Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results." |
|
} |
|
|
|
threshold_profile_explanations = { |
|
"Micro Optimized": """ |
|
### Micro Optimized Profile |
|
|
|
**Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions. |
|
|
|
**When to use**: When you want the best overall accuracy, especially for common tags and dominant categories. |
|
|
|
**Effects**: |
|
- Optimizes performance for the most frequent tags |
|
- Gives more weight to categories with many examples (like 'character' and 'general') |
|
- Provides higher precision in most common use cases |
|
|
|
**Performance from validation**: |
|
- Micro F1: ~67.3% |
|
- Macro F1: ~46.3% |
|
- Threshold: ~0.614 |
|
""", |
|
|
|
"Macro Optimized": """ |
|
### Macro Optimized Profile |
|
|
|
**Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size. |
|
|
|
**When to use**: When balanced performance across all categories is important, including rare tags. |
|
|
|
**Effects**: |
|
- More balanced performance across all tag categories |
|
- Better at detecting rare or unusual tags |
|
- Generally has lower thresholds than micro-optimized |
|
|
|
**Performance from validation**: |
|
- Micro F1: ~60.9% |
|
- Macro F1: ~50.6% |
|
- Threshold: ~0.492 |
|
""", |
|
|
|
"Balanced": """ |
|
### Balanced Profile |
|
|
|
**Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment. |
|
|
|
**When to use**: For general-purpose tagging when you don't have specific recall or precision requirements. |
|
|
|
**Effects**: |
|
- Good middle ground between precision and recall |
|
- Works well for most common use cases |
|
- Default choice for most users |
|
|
|
**Performance from validation**: |
|
- Micro F1: ~67.3% |
|
- Macro F1: ~46.3% |
|
- Threshold: ~0.614 |
|
""", |
|
|
|
"Overall": """ |
|
### Overall Profile |
|
|
|
**Technical definition**: Uses a single threshold value across all categories. |
|
|
|
**When to use**: When you want consistent behavior across all categories and a simple approach. |
|
|
|
**Effects**: |
|
- Consistent tagging threshold for all categories |
|
- Simpler to understand than category-specific thresholds |
|
- User-adjustable with a single slider |
|
|
|
**Default threshold value**: 0.5 (user-adjustable) |
|
|
|
**Note**: The threshold value is user-adjustable with the slider below. |
|
""", |
|
|
|
"Category-specific": """ |
|
### Category-specific Profile |
|
|
|
**Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning. |
|
|
|
**When to use**: When you want to customize tagging sensitivity for different categories. |
|
|
|
**Effects**: |
|
- Each category has its own independent threshold |
|
- Full control over category sensitivity |
|
- Best for fine-tuning results when some categories need different treatment |
|
|
|
**Default threshold values**: Starts with balanced thresholds for each category |
|
|
|
**Note**: Use the category sliders below to adjust thresholds for individual categories. |
|
""" |
|
} |
|
|
|
def load_validation_results(results_path): |
|
"""Load validation results from JSON file""" |
|
try: |
|
with open(results_path, 'r') as f: |
|
data = json.load(f) |
|
return data |
|
except Exception as e: |
|
print(f"Error loading validation results: {e}") |
|
return None |
|
|
|
def extract_thresholds_from_results(validation_data): |
|
"""Extract threshold information from validation results""" |
|
if not validation_data or 'results' not in validation_data: |
|
return {} |
|
|
|
thresholds = { |
|
'overall': {}, |
|
'categories': {} |
|
} |
|
|
|
|
|
for result in validation_data['results']: |
|
category = result['CATEGORY'].lower() |
|
profile = result['PROFILE'].lower().replace(' ', '_') |
|
threshold = result['THRESHOLD'] |
|
micro_f1 = result['MICRO-F1'] |
|
macro_f1 = result['MACRO-F1'] |
|
|
|
|
|
if profile == 'micro_opt': |
|
profile = 'micro_optimized' |
|
elif profile == 'macro_opt': |
|
profile = 'macro_optimized' |
|
|
|
threshold_info = { |
|
'threshold': threshold, |
|
'micro_f1': micro_f1, |
|
'macro_f1': macro_f1 |
|
} |
|
|
|
if category == 'overall': |
|
thresholds['overall'][profile] = threshold_info |
|
else: |
|
if category not in thresholds['categories']: |
|
thresholds['categories'][category] = {} |
|
thresholds['categories'][category][profile] = threshold_info |
|
|
|
return thresholds |
|
|
|
def load_model_and_metadata(): |
|
"""Load model and metadata from available files""" |
|
|
|
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors") |
|
safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json") |
|
|
|
|
|
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx") |
|
|
|
|
|
validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json") |
|
|
|
model_info = { |
|
'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path), |
|
'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path), |
|
'validation_results_available': os.path.exists(validation_results_path) |
|
} |
|
|
|
|
|
metadata = None |
|
if os.path.exists(safetensors_metadata_path): |
|
try: |
|
with open(safetensors_metadata_path, 'r') as f: |
|
metadata = json.load(f) |
|
except Exception as e: |
|
print(f"Error loading metadata: {e}") |
|
|
|
|
|
thresholds = {} |
|
if model_info['validation_results_available']: |
|
validation_data = load_validation_results(validation_results_path) |
|
if validation_data: |
|
thresholds = extract_thresholds_from_results(validation_data) |
|
|
|
|
|
if not thresholds: |
|
thresholds = { |
|
'overall': { |
|
'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0}, |
|
'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0}, |
|
'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0} |
|
}, |
|
'categories': {} |
|
} |
|
|
|
return model_info, metadata, thresholds |
|
|
|
def load_safetensors_model(safetensors_path, metadata_path): |
|
"""Load SafeTensors model""" |
|
try: |
|
from safetensors.torch import load_file |
|
import torch |
|
|
|
|
|
with open(metadata_path, 'r') as f: |
|
metadata = json.load(f) |
|
|
|
|
|
|
|
from utils.model_loader import ImageTagger |
|
|
|
model_info = metadata['model_info'] |
|
dataset_info = metadata['dataset_info'] |
|
|
|
|
|
model = ImageTagger( |
|
total_tags=dataset_info['total_tags'], |
|
dataset=None, |
|
model_name=model_info['backbone'], |
|
num_heads=model_info['num_attention_heads'], |
|
dropout=0.0, |
|
pretrained=False, |
|
tag_context_size=model_info['tag_context_size'], |
|
use_gradient_checkpointing=False, |
|
img_size=model_info['img_size'] |
|
) |
|
|
|
|
|
state_dict = load_file(safetensors_path) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
return model, metadata |
|
except Exception as e: |
|
raise Exception(f"Failed to load SafeTensors model: {e}") |
|
|
|
def get_profile_metrics(thresholds, profile_name): |
|
"""Extract metrics for the given profile from the thresholds dictionary""" |
|
profile_key = None |
|
|
|
|
|
if profile_name == "Micro Optimized": |
|
profile_key = "micro_optimized" |
|
elif profile_name == "Macro Optimized": |
|
profile_key = "macro_optimized" |
|
elif profile_name == "Balanced": |
|
profile_key = "balanced" |
|
elif profile_name in ["Overall", "Category-specific"]: |
|
profile_key = "macro_optimized" |
|
|
|
if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']: |
|
return thresholds['overall'][profile_key] |
|
|
|
return None |
|
|
|
def on_threshold_profile_change(): |
|
"""Handle threshold profile changes""" |
|
new_profile = st.session_state.threshold_profile |
|
|
|
if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'): |
|
|
|
if st.session_state.settings['active_category_thresholds'] is None: |
|
st.session_state.settings['active_category_thresholds'] = {} |
|
|
|
current_thresholds = st.session_state.settings['active_category_thresholds'] |
|
|
|
|
|
profile_key = None |
|
if new_profile == "Micro Optimized": |
|
profile_key = "micro_optimized" |
|
elif new_profile == "Macro Optimized": |
|
profile_key = "macro_optimized" |
|
elif new_profile == "Balanced": |
|
profile_key = "balanced" |
|
|
|
|
|
if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold'] |
|
|
|
|
|
for category in st.session_state.categories: |
|
if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]: |
|
current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
|
|
elif new_profile == "Overall": |
|
|
|
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] |
|
else: |
|
st.session_state.settings['active_threshold'] = 0.5 |
|
|
|
|
|
st.session_state.settings['active_category_thresholds'] = {} |
|
|
|
elif new_profile == "Category-specific": |
|
|
|
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: |
|
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] |
|
else: |
|
st.session_state.settings['active_threshold'] = 0.5 |
|
|
|
|
|
for category in st.session_state.categories: |
|
if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]: |
|
current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold'] |
|
else: |
|
current_thresholds[category] = st.session_state.settings['active_threshold'] |
|
|
|
def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories): |
|
"""Apply thresholds to raw probabilities and return filtered tags""" |
|
tags = {} |
|
all_tags = [] |
|
|
|
|
|
active_category_thresholds = active_category_thresholds or {} |
|
|
|
for category, cat_probs in all_probs.items(): |
|
|
|
threshold = active_category_thresholds.get(category, active_threshold) |
|
|
|
|
|
tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold] |
|
|
|
|
|
if selected_categories.get(category, True): |
|
for tag, prob in tags[category]: |
|
all_tags.append(tag) |
|
|
|
return tags, all_tags |
|
|
|
def image_tagger_app(): |
|
"""Main Streamlit application for image tagging.""" |
|
st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️") |
|
|
|
st.title("Camie-Tagger-v2 Interface") |
|
st.markdown("---") |
|
|
|
|
|
if 'settings' not in st.session_state: |
|
st.session_state.settings = { |
|
'show_all_tags': False, |
|
'compact_view': True, |
|
'min_confidence': 0.01, |
|
'threshold_profile': "Macro", |
|
'active_threshold': 0.5, |
|
'active_category_thresholds': {}, |
|
'selected_categories': {}, |
|
'replace_underscores': False |
|
} |
|
st.session_state.show_profile_help = False |
|
|
|
|
|
if 'model_loaded' not in st.session_state: |
|
st.session_state.model_loaded = False |
|
st.session_state.model = None |
|
st.session_state.thresholds = None |
|
st.session_state.metadata = None |
|
st.session_state.model_type = "onnx" |
|
|
|
|
|
with st.sidebar: |
|
|
|
st.subheader("💡 Notes") |
|
|
|
st.markdown(""" |
|
This tagger was trained on a subset of the available data due to hardware limitations. |
|
|
|
A more comprehensive model trained on the full 3+ million image dataset would provide: |
|
- More recent characters and tags. |
|
- Improved accuracy. |
|
|
|
If you find this tool useful and would like to support future development: |
|
""") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
@keyframes coffee-button-glow { |
|
0% { box-shadow: 0 0 5px #FFD700; } |
|
50% { box-shadow: 0 0 15px #FFD700; } |
|
100% { box-shadow: 0 0 5px #FFD700; } |
|
} |
|
|
|
.coffee-button { |
|
display: inline-block; |
|
animation: coffee-button-glow 2s infinite; |
|
border-radius: 5px; |
|
transition: transform 0.3s ease; |
|
} |
|
|
|
.coffee-button:hover { |
|
transform: scale(1.05); |
|
} |
|
</style> |
|
|
|
<a href="https://ko-fi.com/camais" target="_blank" class="coffee-button"> |
|
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" |
|
alt="Buy Me A Coffee" |
|
style="height: 45px; width: 162px; border-radius: 5px;" /> |
|
</a> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
Your support helps with: |
|
- GPU costs for training |
|
- Storage for larger datasets |
|
- Development of new features |
|
- Future projects |
|
|
|
Thank you! 🙏 |
|
|
|
Full Details: https://huggingface.co/Camais03/camie-tagger-v2 |
|
""") |
|
|
|
st.header("Model Selection") |
|
|
|
|
|
model_info, metadata, thresholds = load_model_and_metadata() |
|
|
|
|
|
model_options = [] |
|
if model_info['onnx_available']: |
|
model_options.append("ONNX (Recommended)") |
|
if model_info['safetensors_available']: |
|
model_options.append("SafeTensors (PyTorch)") |
|
|
|
if not model_options: |
|
st.error("No model files found!") |
|
st.info(f"Looking for models in: {MODEL_DIR}") |
|
st.info("Expected files:") |
|
st.info("- camie-tagger-v2.onnx") |
|
st.info("- camie-tagger-v2.safetensors") |
|
st.info("- camie-tagger-v2-metadata.json") |
|
st.stop() |
|
|
|
|
|
default_index = 0 if model_info['onnx_available'] else 0 |
|
model_type = st.radio( |
|
"Select Model Type:", |
|
model_options, |
|
index=default_index, |
|
help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format" |
|
) |
|
|
|
|
|
if model_type == "ONNX (Recommended)": |
|
selected_model_type = "onnx" |
|
else: |
|
selected_model_type = "safetensors" |
|
|
|
|
|
if selected_model_type != st.session_state.model_type: |
|
st.session_state.model_loaded = False |
|
st.session_state.model_type = selected_model_type |
|
|
|
|
|
if st.button("Reload Model") and st.session_state.model_loaded: |
|
st.session_state.model_loaded = False |
|
st.info("Reloading model...") |
|
|
|
|
|
if not st.session_state.model_loaded: |
|
try: |
|
with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."): |
|
if st.session_state.model_type == "onnx": |
|
|
|
import onnxruntime as ort |
|
|
|
onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx") |
|
|
|
|
|
providers = ort.get_available_providers() |
|
gpu_available = any('CUDA' in provider for provider in providers) |
|
|
|
|
|
session = ort.InferenceSession(onnx_path, providers=providers) |
|
|
|
st.session_state.model = session |
|
st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})" |
|
st.session_state.param_dtype = "float32" |
|
|
|
else: |
|
|
|
safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors") |
|
metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json") |
|
|
|
model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path) |
|
|
|
st.session_state.model = model |
|
device = next(model.parameters()).device |
|
param_dtype = next(model.parameters()).dtype |
|
st.session_state.device = device |
|
st.session_state.param_dtype = param_dtype |
|
metadata = loaded_metadata |
|
|
|
|
|
st.session_state.thresholds = thresholds |
|
st.session_state.metadata = metadata |
|
st.session_state.model_loaded = True |
|
|
|
|
|
if metadata and 'dataset_info' in metadata: |
|
tag_mapping = metadata['dataset_info']['tag_mapping'] |
|
categories = list(set(tag_mapping['tag_to_category'].values())) |
|
st.session_state.categories = categories |
|
|
|
|
|
if not st.session_state.settings['selected_categories']: |
|
st.session_state.settings['selected_categories'] = {cat: True for cat in categories} |
|
|
|
|
|
if 'overall' in thresholds and 'balanced' in thresholds['overall']: |
|
st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold'] |
|
|
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.code(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Model Information") |
|
if st.session_state.model_loaded: |
|
if st.session_state.model_type == "onnx": |
|
st.success("Using ONNX Model") |
|
else: |
|
st.success("Using SafeTensors Model") |
|
|
|
st.write(f"Device: {st.session_state.device}") |
|
st.write(f"Precision: {st.session_state.param_dtype}") |
|
|
|
if st.session_state.metadata: |
|
if 'dataset_info' in st.session_state.metadata: |
|
total_tags = st.session_state.metadata['dataset_info']['total_tags'] |
|
st.write(f"Total tags: {total_tags}") |
|
elif 'total_tags' in st.session_state.metadata: |
|
st.write(f"Total tags: {st.session_state.metadata['total_tags']}") |
|
|
|
|
|
with st.expander("Available Categories"): |
|
for category in sorted(st.session_state.categories): |
|
st.write(f"- {category.capitalize()}") |
|
|
|
|
|
with st.expander("About this app"): |
|
st.write(""" |
|
This app uses a trained image tagging model to analyze and tag images. |
|
|
|
**Model Options**: |
|
- **ONNX (Recommended)**: Optimized for inference speed with broad compatibility |
|
- **SafeTensors**: Native PyTorch format for advanced users |
|
|
|
**Features**: |
|
- Upload or process images in batches |
|
- Multiple threshold profiles based on validation results |
|
- Category-specific threshold adjustment |
|
- Export tags in various formats |
|
- Fast inference with GPU acceleration (when available) |
|
|
|
**Threshold Profiles**: |
|
- **Micro Optimized**: Best overall F1 score (67.3% micro F1) |
|
- **Macro Optimized**: Balanced across categories (50.6% macro F1) |
|
- **Balanced**: Good general-purpose setting |
|
- **Overall**: Single adjustable threshold |
|
- **Category-specific**: Fine-tune each category individually |
|
""") |
|
|
|
|
|
col1, col2 = st.columns([1, 1.5]) |
|
|
|
with col1: |
|
st.header("Image") |
|
|
|
upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"]) |
|
|
|
image_path = None |
|
|
|
with upload_tab: |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: |
|
tmp_file.write(uploaded_file.getvalue()) |
|
image_path = tmp_file.name |
|
|
|
st.session_state.original_filename = uploaded_file.name |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, use_container_width=True) |
|
|
|
with batch_tab: |
|
st.subheader("Batch Process Images") |
|
|
|
|
|
batch_folder = st.text_input("Enter folder path containing images:", "") |
|
|
|
|
|
save_options = st.radio( |
|
"Where to save tag files:", |
|
["Same folder as images", "Custom location", "Default save folder"], |
|
index=0 |
|
) |
|
|
|
|
|
st.subheader("Performance Options") |
|
batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4, |
|
help="Higher values may improve speed but use more memory") |
|
|
|
|
|
enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False) |
|
|
|
if enable_category_limits and hasattr(st.session_state, 'categories'): |
|
if 'category_limits' not in st.session_state: |
|
st.session_state.category_limits = {} |
|
|
|
st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags") |
|
|
|
limit_cols = st.columns(2) |
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
col_idx = i % 2 |
|
with limit_cols[col_idx]: |
|
current_limit = st.session_state.category_limits.get(category, -1) |
|
new_limit = st.number_input( |
|
f"{category.capitalize()}:", |
|
value=current_limit, |
|
min_value=-1, |
|
step=1, |
|
key=f"limit_{category}" |
|
) |
|
st.session_state.category_limits[category] = new_limit |
|
|
|
|
|
if batch_folder and os.path.isdir(batch_folder): |
|
image_files = [] |
|
for ext in ['*.jpg', '*.jpeg', '*.png']: |
|
image_files.extend(glob.glob(os.path.join(batch_folder, ext))) |
|
image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper()))) |
|
|
|
if image_files: |
|
st.write(f"Found {len(image_files)} images") |
|
|
|
if st.button("🔄 Process All Images", type="primary"): |
|
if not st.session_state.model_loaded: |
|
st.error("Model not loaded") |
|
else: |
|
with st.spinner("Processing images..."): |
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
def update_progress(current, total, image_path): |
|
progress = current / total if total > 0 else 0 |
|
progress_bar.progress(progress) |
|
status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}") |
|
|
|
|
|
if save_options == "Same folder as images": |
|
save_dir = batch_folder |
|
elif save_options == "Custom location": |
|
save_dir = st.text_input("Custom save directory:", batch_folder) |
|
else: |
|
save_dir = os.path.join(os.path.dirname(__file__), "saved_tags") |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
category_limits = st.session_state.category_limits if enable_category_limits else None |
|
|
|
|
|
if st.session_state.model_type == "onnx": |
|
batch_results = batch_process_images_onnx( |
|
folder_path=batch_folder, |
|
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"), |
|
metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"), |
|
threshold_profile=st.session_state.settings['threshold_profile'], |
|
active_threshold=st.session_state.settings['active_threshold'], |
|
active_category_thresholds=st.session_state.settings['active_category_thresholds'], |
|
save_dir=save_dir, |
|
progress_callback=update_progress, |
|
min_confidence=st.session_state.settings['min_confidence'], |
|
batch_size=batch_size, |
|
category_limits=category_limits |
|
) |
|
else: |
|
|
|
st.error("SafeTensors batch processing not implemented yet") |
|
batch_results = None |
|
|
|
if batch_results: |
|
display_batch_results(batch_results) |
|
|
|
|
|
with col2: |
|
st.header("Tagging Controls") |
|
|
|
|
|
all_profiles = [ |
|
"Micro Optimized", |
|
"Macro Optimized", |
|
"Balanced", |
|
"Overall", |
|
"Category-specific" |
|
] |
|
|
|
profile_col1, profile_col2 = st.columns([3, 1]) |
|
|
|
with profile_col1: |
|
threshold_profile = st.selectbox( |
|
"Select threshold profile", |
|
options=all_profiles, |
|
index=1, |
|
key="threshold_profile", |
|
on_change=on_threshold_profile_change |
|
) |
|
|
|
with profile_col2: |
|
if st.button("ℹ️ Help", key="profile_help"): |
|
st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False) |
|
|
|
|
|
if st.session_state.get('show_profile_help', False): |
|
st.markdown(threshold_profile_explanations[threshold_profile]) |
|
else: |
|
st.info(threshold_profile_descriptions[threshold_profile]) |
|
|
|
|
|
if st.session_state.model_loaded: |
|
metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile) |
|
|
|
if metrics: |
|
metrics_cols = st.columns(3) |
|
|
|
with metrics_cols[0]: |
|
st.metric("Threshold", f"{metrics['threshold']:.3f}") |
|
|
|
with metrics_cols[1]: |
|
st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%") |
|
|
|
with metrics_cols[2]: |
|
st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%") |
|
|
|
|
|
if st.session_state.model_loaded: |
|
active_threshold = st.session_state.settings.get('active_threshold', 0.5) |
|
active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {}) |
|
|
|
if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]: |
|
|
|
st.slider( |
|
"Threshold (from validation)", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01, |
|
disabled=True, |
|
help="This threshold is optimized from validation results" |
|
) |
|
|
|
elif threshold_profile == "Overall": |
|
|
|
active_threshold = st.slider( |
|
"Overall threshold", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01 |
|
) |
|
st.session_state.settings['active_threshold'] = active_threshold |
|
|
|
elif threshold_profile == "Category-specific": |
|
|
|
st.slider( |
|
"Overall threshold (reference)", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(active_threshold), |
|
step=0.01, |
|
disabled=True |
|
) |
|
|
|
st.write("Adjust thresholds for individual categories:") |
|
|
|
|
|
slider_cols = st.columns(2) |
|
|
|
if not active_category_thresholds: |
|
active_category_thresholds = {} |
|
|
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
col_idx = i % 2 |
|
with slider_cols[col_idx]: |
|
default_val = active_category_thresholds.get(category, active_threshold) |
|
new_threshold = st.slider( |
|
f"{category.capitalize()}", |
|
min_value=0.01, |
|
max_value=1.0, |
|
value=float(default_val), |
|
step=0.01, |
|
key=f"slider_{category}" |
|
) |
|
active_category_thresholds[category] = new_threshold |
|
|
|
st.session_state.settings['active_category_thresholds'] = active_category_thresholds |
|
|
|
|
|
with st.expander("Display Options", expanded=False): |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
show_all_tags = st.checkbox("Show all tags (including below threshold)", |
|
value=st.session_state.settings['show_all_tags']) |
|
compact_view = st.checkbox("Compact view (hide progress bars)", |
|
value=st.session_state.settings['compact_view']) |
|
replace_underscores = st.checkbox("Replace underscores with spaces", |
|
value=st.session_state.settings.get('replace_underscores', False)) |
|
|
|
with col2: |
|
min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5, |
|
st.session_state.settings['min_confidence'], 0.01) |
|
|
|
|
|
st.session_state.settings.update({ |
|
'show_all_tags': show_all_tags, |
|
'compact_view': compact_view, |
|
'min_confidence': min_confidence, |
|
'replace_underscores': replace_underscores |
|
}) |
|
|
|
|
|
st.write("Categories to include in 'All Tags' section:") |
|
|
|
category_cols = st.columns(3) |
|
selected_categories = {} |
|
|
|
if hasattr(st.session_state, 'categories'): |
|
for i, category in enumerate(sorted(st.session_state.categories)): |
|
col_idx = i % 3 |
|
with category_cols[col_idx]: |
|
default_val = st.session_state.settings['selected_categories'].get(category, True) |
|
selected_categories[category] = st.checkbox( |
|
f"{category.capitalize()}", |
|
value=default_val, |
|
key=f"cat_select_{category}" |
|
) |
|
|
|
st.session_state.settings['selected_categories'] = selected_categories |
|
|
|
|
|
if image_path and st.button("Run Tagging"): |
|
if not st.session_state.model_loaded: |
|
st.error("Model not loaded") |
|
else: |
|
with st.spinner("Analyzing image..."): |
|
try: |
|
|
|
if st.session_state.model_type == "onnx": |
|
from utils.onnx_processing import process_single_image_onnx |
|
|
|
result = process_single_image_onnx( |
|
image_path=image_path, |
|
model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"), |
|
metadata=st.session_state.metadata, |
|
threshold_profile=threshold_profile, |
|
active_threshold=st.session_state.settings['active_threshold'], |
|
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), |
|
min_confidence=st.session_state.settings['min_confidence'] |
|
) |
|
else: |
|
|
|
result = process_image( |
|
image_path=image_path, |
|
model=st.session_state.model, |
|
thresholds=st.session_state.thresholds, |
|
metadata=st.session_state.metadata, |
|
threshold_profile=threshold_profile, |
|
active_threshold=st.session_state.settings['active_threshold'], |
|
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), |
|
min_confidence=st.session_state.settings['min_confidence'] |
|
) |
|
|
|
if result['success']: |
|
st.session_state.all_probs = result['all_probs'] |
|
st.session_state.tags = result['tags'] |
|
st.session_state.all_tags = result['all_tags'] |
|
st.success("Analysis completed!") |
|
else: |
|
st.error(f"Analysis failed: {result.get('error', 'Unknown error')}") |
|
|
|
except Exception as e: |
|
st.error(f"Error during analysis: {str(e)}") |
|
st.code(traceback.format_exc()) |
|
|
|
|
|
if image_path and hasattr(st.session_state, 'all_probs'): |
|
st.header("Predictions") |
|
|
|
|
|
filtered_tags, current_all_tags = apply_thresholds( |
|
st.session_state.all_probs, |
|
threshold_profile, |
|
st.session_state.settings['active_threshold'], |
|
st.session_state.settings.get('active_category_thresholds', {}), |
|
st.session_state.settings['min_confidence'], |
|
st.session_state.settings['selected_categories'] |
|
) |
|
|
|
all_tags = [] |
|
|
|
|
|
for category in sorted(st.session_state.all_probs.keys()): |
|
all_tags_in_category = st.session_state.all_probs.get(category, []) |
|
filtered_tags_in_category = filtered_tags.get(category, []) |
|
|
|
if all_tags_in_category: |
|
expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)" |
|
|
|
with st.expander(expander_label, expanded=True): |
|
|
|
active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {} |
|
threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold']) |
|
|
|
|
|
if st.session_state.settings['show_all_tags']: |
|
tags_to_display = all_tags_in_category |
|
else: |
|
tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold] |
|
|
|
if not tags_to_display: |
|
st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence") |
|
continue |
|
|
|
|
|
if st.session_state.settings['compact_view']: |
|
|
|
tag_list = [] |
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
|
|
for tag, prob in tags_to_display: |
|
percentage = int(prob * 100) |
|
display_tag = tag.replace('_', ' ') if replace_underscores else tag |
|
tag_list.append(f"{display_tag} ({percentage}%)") |
|
|
|
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): |
|
all_tags.append(tag) |
|
|
|
st.markdown(", ".join(tag_list)) |
|
else: |
|
|
|
for tag, prob in tags_to_display: |
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
display_tag = tag.replace('_', ' ') if replace_underscores else tag |
|
|
|
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): |
|
all_tags.append(tag) |
|
tag_display = f"**{display_tag}**" |
|
else: |
|
tag_display = display_tag |
|
|
|
st.write(tag_display) |
|
st.markdown(display_progress_bar(prob), unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("---") |
|
st.subheader(f"All Tags ({len(all_tags)} total)") |
|
if all_tags: |
|
replace_underscores = st.session_state.settings.get('replace_underscores', False) |
|
if replace_underscores: |
|
display_tags = [tag.replace('_', ' ') for tag in all_tags] |
|
st.write(", ".join(display_tags)) |
|
else: |
|
st.write(", ".join(all_tags)) |
|
else: |
|
st.info("No tags detected above the threshold.") |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("Save Tags") |
|
|
|
if 'custom_folders' not in st.session_state: |
|
st.session_state.custom_folders = get_default_save_locations() |
|
|
|
selected_folder = st.selectbox( |
|
"Select save location:", |
|
options=st.session_state.custom_folders, |
|
format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x |
|
) |
|
|
|
if st.button("💾 Save to Selected Location"): |
|
try: |
|
original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None |
|
|
|
saved_path = save_tags_to_file( |
|
image_path=image_path, |
|
all_tags=all_tags, |
|
original_filename=original_filename, |
|
custom_dir=selected_folder, |
|
overwrite=True |
|
) |
|
|
|
st.success(f"Tags saved to: {os.path.basename(saved_path)}") |
|
st.info(f"Full path: {saved_path}") |
|
|
|
|
|
with st.expander("File Contents", expanded=True): |
|
with open(saved_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
st.code(content, language='text') |
|
|
|
except Exception as e: |
|
st.error(f"Error saving tags: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
image_tagger_app() |