Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import onnxruntime as ort | |
import torchvision.transforms as transforms | |
import json | |
import os | |
import numpy as np | |
import pandas as pd | |
import random | |
from huggingface_hub import snapshot_download, HfApi | |
from transformers import CLIPTokenizer | |
# --- Config --- | |
HUB_REPO_ID = "CDL-AMLRT/OpenArenaLeaderboard" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
LOCAL_JSON = "leaderboard.json" | |
HUB_JSON = "leaderboard.json" | |
MODEL_PATH = "mobilenet_v2_fake_detector.onnx" | |
CLIP_IMAGE_ENCODER_PATH = "clip_image_encoder.onnx" | |
CLIP_TEXT_ENCODER_PATH = "clip_text_encoder.onnx" | |
PROMPT_CSV_PATH = "generate2_1.csv" | |
PROMPT_MATCH_THRESHOLD = 10 # percent | |
# --- Download leaderboard + model checkpoint from HF Hub --- | |
def load_assets(): | |
try: | |
snapshot_download( | |
repo_id=HUB_REPO_ID, | |
local_dir=".", | |
repo_type="dataset", | |
token=HF_TOKEN, | |
allow_patterns=[HUB_JSON, MODEL_PATH, CLIP_IMAGE_ENCODER_PATH, CLIP_TEXT_ENCODER_PATH, PROMPT_CSV_PATH] | |
) | |
except Exception as e: | |
print(f"Failed to load assets from HF Hub: {e}") | |
load_assets() | |
# --- Load prompts from CSV --- | |
def load_prompts(): | |
try: | |
df = pd.read_csv(PROMPT_CSV_PATH) | |
if "prompt" in df.columns: | |
return df["prompt"].dropna().tolist() | |
else: | |
print("CSV missing 'prompt' column.") | |
return [] | |
except Exception as e: | |
print(f"Failed to load prompts: {e}") | |
return [] | |
PROMPT_LIST = load_prompts() | |
def load_initial_state(): | |
sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True) | |
leaderboard_table = [[name, points] for name, points in sorted_scores] | |
return gr.update(value=get_random_prompt()), leaderboard_table | |
# --- Load leaderboard --- | |
def load_leaderboard(): | |
try: | |
with open(HUB_JSON, "r") as f: | |
return json.load(f) | |
except Exception as e: | |
print(f"Failed to read leaderboard: {e}") | |
return {} | |
leaderboard_scores = load_leaderboard() | |
# --- Save and push to HF Hub --- | |
def save_leaderboard(): | |
try: | |
with open(HUB_JSON, "w", encoding="utf-8") as f: | |
json.dump(leaderboard_scores, f, ensure_ascii=False) | |
if HF_TOKEN is None: | |
print("HF_TOKEN not set. Skipping push to hub.") | |
return | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=HUB_JSON, | |
path_in_repo=HUB_JSON, | |
repo_id=HUB_REPO_ID, | |
repo_type="dataset", | |
token=HF_TOKEN, | |
commit_message="Update leaderboard" | |
) | |
except Exception as e: | |
print(f"Failed to save leaderboard to HF Hub: {e}") | |
# --- Load ONNX models --- | |
session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"]) | |
input_name = session.get_inputs()[0].name | |
clip_image_sess = ort.InferenceSession(CLIP_IMAGE_ENCODER_PATH, providers=["CPUExecutionProvider"]) | |
clip_text_sess = ort.InferenceSession(CLIP_TEXT_ENCODER_PATH, providers=["CPUExecutionProvider"]) | |
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
]) | |
def compute_prompt_match(image: Image.Image, prompt: str) -> float: | |
try: | |
img_tensor = transform(image).unsqueeze(0).numpy().astype(np.float32) | |
image_features = clip_image_sess.run(None, {clip_image_sess.get_inputs()[0].name: img_tensor})[0][0] | |
image_features /= np.linalg.norm(image_features) | |
inputs = clip_tokenizer(prompt, return_tensors="np", padding="max_length", truncation=True, max_length=77) | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
text_features = clip_text_sess.run(None, { | |
clip_text_sess.get_inputs()[0].name: input_ids, | |
clip_text_sess.get_inputs()[1].name: attention_mask | |
})[0][0] | |
text_features /= np.linalg.norm(text_features) | |
sim = np.dot(image_features, text_features) | |
return round(sim * 100, 2) | |
except Exception as e: | |
print(f"CLIP ONNX match failed: {e}") | |
return 0.0 | |
# --- Main prediction logic --- | |
def detect_with_model(image: Image.Image, prompt: str, username: str): | |
if not username.strip(): | |
return "Please enter your name.", None, [], gr.update(visible=True), gr.update(visible=False), username | |
prompt_score = compute_prompt_match(image, prompt) | |
if prompt_score < PROMPT_MATCH_THRESHOLD: | |
message = f"β οΈ Prompt match too low ({round(prompt_score, 2)}%). Please generate an image that better matches the prompt." | |
return message, None, [], gr.update(visible=True), gr.update(visible=False), username | |
image_tensor = transforms.Resize((224, 224))(image) | |
image_tensor = transforms.ToTensor()(image_tensor).unsqueeze(0).numpy().astype(np.float32) | |
outputs = session.run(None, {input_name: image_tensor}) | |
prob = round(1 / (1 + np.exp(-outputs[0][0][0])), 2) | |
prediction = "Real" if prob > 0.5 else "Fake" | |
score = 1 if prediction == "Real" else 0 | |
confidence = round(prob * 100, 2) if prediction == "Real" else round((1 - prob) * 100, 2) | |
message = f"π Prediction: {prediction} ({round(confidence, 2)}% confidence)\nπ§ Prompt match: {prompt_score}%" | |
if prediction == "Real": | |
leaderboard_scores[username] = leaderboard_scores.get(username, 0) + score | |
message += "\nπ Nice! You fooled the AI. +1 point!" | |
else: | |
message += "\nπ The AI caught you this time. Try again!" | |
save_leaderboard() | |
sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True) | |
leaderboard_table = [[name, points] for name, points in sorted_scores] | |
return ( | |
message, | |
image, | |
leaderboard_table, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
username | |
) | |
# --- UI Layout --- | |
def get_random_prompt(): | |
return random.choice(PROMPT_LIST) if PROMPT_LIST else "A synthetic scene with dramatic lighting" | |
with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo: | |
gr.Markdown("## π OpenFake Arena") | |
gr.Markdown("Welcome to the OpenFake Arena!\n\n**Your mission:** Generate a synthetic image for the prompt, upload it, and try to fool the AI detector into thinking itβs real.\n\n**Rules:**\n- Only synthetic images allowed!\n- No cheating with real photos.\n\nMake it wild. Make it weird. Most of all β make it fun.") | |
with gr.Group(visible=True) as input_section: | |
username_input = gr.Textbox(label="Your Name", placeholder="Enter your name", interactive=True) | |
model_input = gr.Textbox(label="Model Used", placeholder="Name of tge model used to generate the image", interactive=True) | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
label="Prompt to use", | |
placeholder="e.g., ...", | |
value="", | |
lines=2 | |
) | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload Synthetic Image") | |
with gr.Row(): | |
submit_btn = gr.Button("Upload") | |
try_again_btn = gr.Button("Try Again", visible=False) | |
with gr.Group(): | |
gr.Markdown("### π― Result") | |
with gr.Row(): | |
prediction_output = gr.Textbox(label="Prediction", interactive=False) | |
image_output = gr.Image(label="Submitted Image", show_label=False) | |
with gr.Group(): | |
gr.Markdown("### π Leaderboard") | |
leaderboard = gr.Dataframe( | |
headers=["Username", "Score"], | |
datatype=["str", "number"], | |
interactive=False, | |
row_count=5, | |
visible=True | |
) | |
submit_btn.click( | |
fn=detect_with_model, | |
inputs=[image_input, prompt_input, username_input], | |
outputs=[ | |
prediction_output, | |
image_output, | |
leaderboard, | |
input_section, | |
try_again_btn, | |
username_input | |
] | |
) | |
try_again_btn.click( | |
fn=lambda name: ("", None, [], gr.update(visible=True), gr.update(visible=False), name, gr.update(value=get_random_prompt())), | |
inputs=[username_input], | |
outputs=[ | |
prediction_output, | |
image_output, | |
leaderboard, | |
input_section, | |
try_again_btn, | |
username_input, | |
prompt_input | |
] | |
) | |
demo.load( | |
fn=load_initial_state, | |
outputs=[prompt_input, leaderboard] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |