Akshatha Arodi commited on
Commit
db7977e
Β·
1 Parent(s): 78a0909
Files changed (1) hide show
  1. app.py +68 -55
app.py CHANGED
@@ -1,62 +1,35 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import onnxruntime as ort
4
  import torchvision.transforms as transforms
5
  import json
6
  import os
7
  import numpy as np
8
  import pandas as pd
9
  import random
10
- from huggingface_hub import snapshot_download, HfApi
11
- from transformers import CLIPTokenizer
 
 
12
 
13
  # --- Config ---
14
  HUB_REPO_ID = "CDL-AMLRT/OpenArenaLeaderboard"
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
  LOCAL_JSON = "leaderboard.json"
17
  HUB_JSON = "leaderboard.json"
18
- MODEL_PATH = "mobilenet_v2_fake_detector.onnx"
 
19
  CLIP_IMAGE_ENCODER_PATH = "clip_image_encoder.onnx"
20
  CLIP_TEXT_ENCODER_PATH = "clip_text_encoder.onnx"
21
  PROMPT_CSV_PATH = "generate2_1.csv"
22
  PROMPT_MATCH_THRESHOLD = 10 # percent
23
 
24
- # --- Download leaderboard + model checkpoint from HF Hub ---
25
  def load_assets():
26
- try:
27
- snapshot_download(
28
- repo_id=HUB_REPO_ID,
29
- local_dir=".",
30
- repo_type="dataset",
31
- token=HF_TOKEN,
32
- allow_patterns=[HUB_JSON, MODEL_PATH, CLIP_IMAGE_ENCODER_PATH, CLIP_TEXT_ENCODER_PATH, PROMPT_CSV_PATH]
33
- )
34
- except Exception as e:
35
- print(f"Failed to load assets from HF Hub: {e}")
36
 
37
  load_assets()
38
 
39
- # --- Load prompts from CSV ---
40
- def load_prompts():
41
- try:
42
- df = pd.read_csv(PROMPT_CSV_PATH)
43
- if "prompt" in df.columns:
44
- return df["prompt"].dropna().tolist()
45
- else:
46
- print("CSV missing 'prompt' column.")
47
- return []
48
- except Exception as e:
49
- print(f"Failed to load prompts: {e}")
50
- return []
51
-
52
- PROMPT_LIST = load_prompts()
53
-
54
- def load_initial_state():
55
- sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True)
56
- leaderboard_table = [[name, points] for name, points in sorted_scores]
57
- return gr.update(value=get_random_prompt()), leaderboard_table
58
-
59
-
60
  # --- Load leaderboard ---
61
  def load_leaderboard():
62
  try:
@@ -66,10 +39,8 @@ def load_leaderboard():
66
  print(f"Failed to read leaderboard: {e}")
67
  return {}
68
 
69
-
70
  leaderboard_scores = load_leaderboard()
71
 
72
- # --- Save and push to HF Hub ---
73
  def save_leaderboard():
74
  try:
75
  with open(HUB_JSON, "w", encoding="utf-8") as f:
@@ -91,10 +62,33 @@ def save_leaderboard():
91
  except Exception as e:
92
  print(f"Failed to save leaderboard to HF Hub: {e}")
93
 
94
- # --- Load ONNX models ---
95
- session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
96
- input_name = session.get_inputs()[0].name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
98
  clip_image_sess = ort.InferenceSession(CLIP_IMAGE_ENCODER_PATH, providers=["CPUExecutionProvider"])
99
  clip_text_sess = ort.InferenceSession(CLIP_TEXT_ENCODER_PATH, providers=["CPUExecutionProvider"])
100
  clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
@@ -134,19 +128,22 @@ def detect_with_model(image: Image.Image, prompt: str, username: str):
134
  prompt_score = compute_prompt_match(image, prompt)
135
  if prompt_score < PROMPT_MATCH_THRESHOLD:
136
  message = f"⚠️ Prompt match too low ({round(prompt_score, 2)}%). Please generate an image that better matches the prompt."
137
- return message, None, [], gr.update(visible=True), gr.update(visible=False), username
138
 
139
- image_tensor = transforms.Resize((224, 224))(image)
140
- image_tensor = transforms.ToTensor()(image_tensor).unsqueeze(0).numpy().astype(np.float32)
141
- outputs = session.run(None, {input_name: image_tensor})
142
- prob = round(1 / (1 + np.exp(-outputs[0][0][0])), 2)
143
- prediction = "Real" if prob > 0.5 else "Fake"
 
 
144
 
145
- score = 1 if prediction == "Real" else 0
146
- confidence = round(prob * 100, 2) if prediction == "Real" else round((1 - prob) * 100, 2)
147
 
148
- message = f"πŸ” Prediction: {prediction} ({round(confidence, 2)}% confidence)\n🧐 Prompt match: {prompt_score}%"
149
 
 
150
  if prediction == "Real":
151
  leaderboard_scores[username] = leaderboard_scores.get(username, 0) + score
152
  message += "\nπŸŽ‰ Nice! You fooled the AI. +1 point!"
@@ -167,17 +164,22 @@ def detect_with_model(image: Image.Image, prompt: str, username: str):
167
  username
168
  )
169
 
170
- # --- UI Layout ---
171
  def get_random_prompt():
172
  return random.choice(PROMPT_LIST) if PROMPT_LIST else "A synthetic scene with dramatic lighting"
173
 
 
 
 
 
 
 
174
  with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo:
175
  gr.Markdown("## 🌝 OpenFake Arena")
176
  gr.Markdown("Welcome to the OpenFake Arena!\n\n**Your mission:** Generate a synthetic image for the prompt, upload it, and try to fool the AI detector into thinking it’s real.\n\n**Rules:**\n- Only synthetic images allowed!\n- No cheating with real photos.\n\nMake it wild. Make it weird. Most of all β€” make it fun.")
177
 
178
  with gr.Group(visible=True) as input_section:
179
  username_input = gr.Textbox(label="Your Name", placeholder="Enter your name", interactive=True)
180
- model_input = gr.Textbox(label="Model Used", placeholder="Name of tge model used to generate the image", interactive=True)
181
 
182
  with gr.Row():
183
  prompt_input = gr.Textbox(
@@ -225,7 +227,16 @@ with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo:
225
  )
226
 
227
  try_again_btn.click(
228
- fn=lambda name: ("", None, [], gr.update(visible=True), gr.update(visible=False), name, gr.update(value=get_random_prompt())),
 
 
 
 
 
 
 
 
 
229
  inputs=[username_input],
230
  outputs=[
231
  prediction_output,
@@ -234,14 +245,16 @@ with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo:
234
  input_section,
235
  try_again_btn,
236
  username_input,
237
- prompt_input
 
238
  ]
239
  )
240
 
 
241
  demo.load(
242
  fn=load_initial_state,
243
  outputs=[prompt_input, leaderboard]
244
- )
245
 
246
  if __name__ == "__main__":
247
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
+ import torch
4
  import torchvision.transforms as transforms
5
  import json
6
  import os
7
  import numpy as np
8
  import pandas as pd
9
  import random
10
+ import onnxruntime as ort
11
+ from huggingface_hub import HfApi
12
+ from transformers import CLIPTokenizer, AutoImageProcessor, AutoModelForImageClassification
13
+ from safetensors.torch import load_file as safe_load
14
 
15
  # --- Config ---
16
  HUB_REPO_ID = "CDL-AMLRT/OpenArenaLeaderboard"
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
  LOCAL_JSON = "leaderboard.json"
19
  HUB_JSON = "leaderboard.json"
20
+ MODEL_PATH = "model.safetensors" # βœ… updated filename
21
+ MODEL_BACKBONE = "microsoft/swinv2-small-patch4-window16-256"
22
  CLIP_IMAGE_ENCODER_PATH = "clip_image_encoder.onnx"
23
  CLIP_TEXT_ENCODER_PATH = "clip_text_encoder.onnx"
24
  PROMPT_CSV_PATH = "generate2_1.csv"
25
  PROMPT_MATCH_THRESHOLD = 10 # percent
26
 
27
+ # --- No-op for HF Space ---
28
  def load_assets():
29
+ print("Skipping snapshot_download. Assuming files exist via Git LFS in HF Space.")
 
 
 
 
 
 
 
 
 
30
 
31
  load_assets()
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # --- Load leaderboard ---
34
  def load_leaderboard():
35
  try:
 
39
  print(f"Failed to read leaderboard: {e}")
40
  return {}
41
 
 
42
  leaderboard_scores = load_leaderboard()
43
 
 
44
  def save_leaderboard():
45
  try:
46
  with open(HUB_JSON, "w", encoding="utf-8") as f:
 
62
  except Exception as e:
63
  print(f"Failed to save leaderboard to HF Hub: {e}")
64
 
65
+ # --- Load prompts from CSV ---
66
+ def load_prompts():
67
+ try:
68
+ df = pd.read_csv(PROMPT_CSV_PATH)
69
+ if "prompt" in df.columns:
70
+ return df["prompt"].dropna().tolist()
71
+ else:
72
+ print("CSV missing 'prompt' column.")
73
+ return []
74
+ except Exception as e:
75
+ print(f"Failed to load prompts: {e}")
76
+ return []
77
+
78
+ PROMPT_LIST = load_prompts()
79
+
80
+ # --- Load model + processor ---
81
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ processor = AutoImageProcessor.from_pretrained(MODEL_BACKBONE)
84
+ model = AutoModelForImageClassification.from_pretrained(MODEL_BACKBONE)
85
+ model.classifier = torch.nn.Linear(model.config.hidden_size, 2)
86
 
87
+ model.load_state_dict(safe_load(MODEL_PATH, device="cpu"), strict=False)
88
+ model.to(device)
89
+ model.eval()
90
+
91
+ # --- CLIP prompt matching ---
92
  clip_image_sess = ort.InferenceSession(CLIP_IMAGE_ENCODER_PATH, providers=["CPUExecutionProvider"])
93
  clip_text_sess = ort.InferenceSession(CLIP_TEXT_ENCODER_PATH, providers=["CPUExecutionProvider"])
94
  clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
 
128
  prompt_score = compute_prompt_match(image, prompt)
129
  if prompt_score < PROMPT_MATCH_THRESHOLD:
130
  message = f"⚠️ Prompt match too low ({round(prompt_score, 2)}%). Please generate an image that better matches the prompt."
131
+ return message, None, leaderboard, gr.update(visible=True), gr.update(visible=False), username
132
 
133
+ # Run model inference
134
+ inputs = processor(image, return_tensors="pt").to(device)
135
+ with torch.no_grad():
136
+ outputs = model(**inputs)
137
+ logits = outputs.logits
138
+ pred_class = torch.argmax(logits, dim=-1).item()
139
+ prediction = "Real" if pred_class == 0 else "Fake"
140
 
141
+ probs = torch.softmax(logits, dim=-1)[0]
142
+ confidence = round(probs[pred_class].item() * 100, 2)
143
 
144
+ score = 1 if prediction == "Real" else 0
145
 
146
+ message = f"πŸ” Prediction: {prediction} ({confidence}% confidence)\n🧐 Prompt match: {round(prompt_score, 2)}%"
147
  if prediction == "Real":
148
  leaderboard_scores[username] = leaderboard_scores.get(username, 0) + score
149
  message += "\nπŸŽ‰ Nice! You fooled the AI. +1 point!"
 
164
  username
165
  )
166
 
 
167
  def get_random_prompt():
168
  return random.choice(PROMPT_LIST) if PROMPT_LIST else "A synthetic scene with dramatic lighting"
169
 
170
+ def load_initial_state():
171
+ sorted_scores = sorted(leaderboard_scores.items(), key=lambda x: x[1], reverse=True)
172
+ leaderboard_table = [[name, points] for name, points in sorted_scores]
173
+ return gr.update(value=get_random_prompt()), leaderboard_table
174
+
175
+ # --- Gradio UI ---
176
  with gr.Blocks(css=".gr-button {font-size: 16px !important}") as demo:
177
  gr.Markdown("## 🌝 OpenFake Arena")
178
  gr.Markdown("Welcome to the OpenFake Arena!\n\n**Your mission:** Generate a synthetic image for the prompt, upload it, and try to fool the AI detector into thinking it’s real.\n\n**Rules:**\n- Only synthetic images allowed!\n- No cheating with real photos.\n\nMake it wild. Make it weird. Most of all β€” make it fun.")
179
 
180
  with gr.Group(visible=True) as input_section:
181
  username_input = gr.Textbox(label="Your Name", placeholder="Enter your name", interactive=True)
182
+ model_input = gr.Textbox(label="Model Used", placeholder="Name of the model used to generate the image", interactive=True)
183
 
184
  with gr.Row():
185
  prompt_input = gr.Textbox(
 
227
  )
228
 
229
  try_again_btn.click(
230
+ fn=lambda name: (
231
+ "", # Clear prediction text
232
+ None, # Clear uploaded image
233
+ leaderboard, # Clear leaderboard (temporarily, gets reloaded on next submit)
234
+ gr.update(visible=True), # Show input section
235
+ gr.update(visible=False), # Hide "Try Again" button
236
+ name, # Keep username
237
+ gr.update(value=get_random_prompt()), # Load new prompt
238
+ None # Clear image input
239
+ ),
240
  inputs=[username_input],
241
  outputs=[
242
  prediction_output,
 
245
  input_section,
246
  try_again_btn,
247
  username_input,
248
+ prompt_input,
249
+ image_input # ← added output to clear image
250
  ]
251
  )
252
 
253
+
254
  demo.load(
255
  fn=load_initial_state,
256
  outputs=[prompt_input, leaderboard]
257
+ )
258
 
259
  if __name__ == "__main__":
260
  demo.launch()