mduppes commited on
Commit
cca4f79
·
1 Parent(s): db54d0f

Add dataset selector for examples

Browse files
backend/app.py CHANGED
@@ -1,32 +1,27 @@
1
- from backend.config import (
2
- ABS_DATASET_DOMAIN,
3
- get_dataset_config,
4
- get_datasets,
5
- )
 
 
 
 
 
 
6
  from backend.descriptions import (
7
  DATASET_DESCRIPTIONS,
8
  DESCRIPTIONS,
9
  METRIC_DESCRIPTIONS,
10
  MODEL_DESCRIPTIONS,
11
  )
12
- from backend.examples import (
13
- get_examples_tab,
14
- )
15
- from flask import Flask, Response, send_from_directory, request
16
  from flask_cors import CORS
17
- import os
18
- import logging
19
- import pandas as pd
20
- import json
21
- from io import StringIO
22
  from tools import (
23
  get_leaderboard_filters,
24
  get_old_format_dataframe,
25
  ) # Import your function
26
- import typing as tp
27
- import requests
28
- from urllib.parse import unquote
29
- import mimetypes
30
 
31
 
32
  logger = logging.getLogger(__name__)
@@ -110,9 +105,17 @@ def example_files(type):
110
  """
111
  Serve example files from S3 or locally based on config
112
  """
 
 
113
 
114
- result = get_examples_tab(type)
115
- return Response(json.dumps(result), mimetype="application/json")
 
 
 
 
 
 
116
 
117
 
118
  @app.route("/descriptions")
 
1
+ import json
2
+ import logging
3
+ import mimetypes
4
+ import os
5
+ import typing as tp
6
+ from io import StringIO
7
+ from urllib.parse import unquote
8
+
9
+ import pandas as pd
10
+ import requests
11
+ from backend.config import ABS_DATASET_DOMAIN, get_dataset_config, get_datasets
12
  from backend.descriptions import (
13
  DATASET_DESCRIPTIONS,
14
  DESCRIPTIONS,
15
  METRIC_DESCRIPTIONS,
16
  MODEL_DESCRIPTIONS,
17
  )
18
+ from backend.examples import get_examples_tab
19
+ from flask import Flask, request, Response, send_from_directory
 
 
20
  from flask_cors import CORS
 
 
 
 
 
21
  from tools import (
22
  get_leaderboard_filters,
23
  get_old_format_dataframe,
24
  ) # Import your function
 
 
 
 
25
 
26
 
27
  logger = logging.getLogger(__name__)
 
105
  """
106
  Serve example files from S3 or locally based on config
107
  """
108
+ # Get dataset parameter from query string
109
+ dataset_name = request.args.get("dataset")
110
 
111
+ if not dataset_name:
112
+ return {"error": "Dataset parameter is required"}, 400
113
+
114
+ try:
115
+ result = get_examples_tab(type, dataset_name)
116
+ return Response(json.dumps(result), mimetype="application/json")
117
+ except ValueError as e:
118
+ return {"error": str(e)}, 400
119
 
120
 
121
  @app.route("/descriptions")
backend/config.py CHANGED
@@ -2,8 +2,6 @@
2
  # IMPORTANT: When running from docker more setup is required (e.g. on Huggingface)
3
  import os
4
  from collections import defaultdict
5
- from copy import deepcopy
6
- from typing import Any, Dict
7
 
8
  ABS_DATASET_DOMAIN = "https://dl.fbaipublicfiles.com"
9
 
@@ -127,7 +125,7 @@ MODALITY_CONFIG_CONSTANTS = {
127
  "H264rgb",
128
  "H265",
129
  ],
130
- }
131
  }
132
 
133
  DATASET_CONFIGS = {
@@ -139,30 +137,6 @@ DATASET_CONFIGS = {
139
  }
140
 
141
 
142
- EXAMPLE_CONFIGS = {
143
- "audio": {
144
- "dataset_name": "voxpopuli_1k",
145
- "path": ABS_DATASET_PATH,
146
- "db_key": "voxpopuli",
147
- },
148
- # "image": {
149
- # "dataset_name": "val2014_1k_v2",
150
- # "path": ABS_DATASET_PATH,
151
- # "db_key": "local_val2014",
152
- # },
153
- "image": {
154
- "dataset_name": "sa_1b_val_1k",
155
- "path": ABS_DATASET_PATH,
156
- "db_key": "local_valid",
157
- },
158
- "video": {
159
- "dataset_name": "sav_val_full_v2",
160
- "path": ABS_DATASET_PATH,
161
- "db_key": "sa-v_sav_val_videos",
162
- },
163
- }
164
-
165
-
166
  def get_user_dataset():
167
  datasets = defaultdict(list)
168
  user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", "./data")
@@ -170,7 +144,9 @@ def get_user_dataset():
170
  for user_data in os.listdir(user_data_dir):
171
  if not os.path.isdir(os.path.join(user_data_dir, user_data)):
172
  continue
173
- user_dtype = os.listdir(os.path.join(user_data_dir, user_data, "examples"))[0]
 
 
174
  datasets[user_dtype].append(user_data + "/" + user_dtype)
175
 
176
  return datasets
@@ -192,28 +168,59 @@ def get_datasets():
192
  return grouped
193
 
194
 
195
- def get_example_config(type):
196
- if type not in EXAMPLE_CONFIGS:
197
- raise ValueError(f"Unknown example type: {type}")
 
 
 
 
 
 
198
 
199
- examples_config: Dict[str, Any] = deepcopy(EXAMPLE_CONFIGS[type])
 
200
 
 
201
  user_datasets = get_user_dataset()
202
  user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", "./data")
203
- if len(user_datasets) > 0:
204
- assert user_data_dir, f"OMNISEAL_LEADERBOARD_DATA is reset during loading the examples for {type}. Please set it correctly"
205
- for dtype, user_names in user_datasets.items():
206
- if dtype == type:
207
- dataset_name = user_names[0].split("/")[0]
208
- path = user_data_dir + "/"
209
- examples_config = {
210
- "dataset_name": dataset_name,
211
- "path": path,
212
- "db_key": dataset_name,
213
- }
 
 
 
 
 
 
 
 
 
214
  return examples_config
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def get_dataset_config(dataset_name):
218
  if dataset_name in DATASET_CONFIGS:
219
  cfg = DATASET_CONFIGS[dataset_name]
 
2
  # IMPORTANT: When running from docker more setup is required (e.g. on Huggingface)
3
  import os
4
  from collections import defaultdict
 
 
5
 
6
  ABS_DATASET_DOMAIN = "https://dl.fbaipublicfiles.com"
7
 
 
125
  "H264rgb",
126
  "H265",
127
  ],
128
+ },
129
  }
130
 
131
  DATASET_CONFIGS = {
 
137
  }
138
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def get_user_dataset():
141
  datasets = defaultdict(list)
142
  user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", "./data")
 
144
  for user_data in os.listdir(user_data_dir):
145
  if not os.path.isdir(os.path.join(user_data_dir, user_data)):
146
  continue
147
+ user_dtype = os.listdir(os.path.join(user_data_dir, user_data, "examples"))[
148
+ 0
149
+ ]
150
  datasets[user_dtype].append(user_data + "/" + user_dtype)
151
 
152
  return datasets
 
168
  return grouped
169
 
170
 
171
+ def get_example_config(type, dataset_name):
172
+ """Get example configuration for a specific dataset."""
173
+ if not dataset_name:
174
+ raise ValueError(f"Dataset name is required")
175
+
176
+ # Check if it's a valid dataset for this type
177
+ all_datasets = get_datasets()
178
+ if dataset_name not in all_datasets.get(type, []):
179
+ raise ValueError(f"Unknown dataset {dataset_name} for type {type}")
180
 
181
+ # Extract the dataset name without the type suffix
182
+ dataset_base_name = dataset_name.split("/")[0]
183
 
184
+ # Check if it's a user dataset
185
  user_datasets = get_user_dataset()
186
  user_data_dir = os.getenv("OMNISEAL_LEADERBOARD_DATA", "./data")
187
+
188
+ if dataset_name in user_datasets.get(type, []):
189
+ # It's a user dataset
190
+ examples_config = {
191
+ "dataset_name": dataset_base_name,
192
+ "path": user_data_dir + "/",
193
+ "db_key": dataset_base_name,
194
+ }
195
+ else:
196
+ # It's a predefined dataset from DATASET_CONFIGS
197
+ if dataset_name in DATASET_CONFIGS:
198
+ dataset_config = DATASET_CONFIGS[dataset_name]
199
+ examples_config = {
200
+ "dataset_name": dataset_base_name,
201
+ "path": dataset_config["path"],
202
+ "db_key": _get_db_key_for_dataset(dataset_base_name, type),
203
+ }
204
+ else:
205
+ raise ValueError(f"Dataset {dataset_name} not found in configurations")
206
+
207
  return examples_config
208
 
209
 
210
+ def _get_db_key_for_dataset(dataset_base_name, type):
211
+ """Helper function to determine the database key for a dataset"""
212
+ # Map of dataset names to their db keys
213
+ db_key_mapping = {
214
+ "voxpopuli_1k": "voxpopuli",
215
+ "val2014_1k_v2": "local_val2014",
216
+ "sa_1b_val_1k": "local_valid",
217
+ "sav_val_full_v2": "sa-v_sav_val_videos",
218
+ "ravdess_1k": "ravdess", # Add mapping for ravdess dataset
219
+ }
220
+
221
+ return db_key_mapping.get(dataset_base_name, dataset_base_name)
222
+
223
+
224
  def get_dataset_config(dataset_name):
225
  if dataset_name in DATASET_CONFIGS:
226
  cfg = DATASET_CONFIGS[dataset_name]
backend/examples.py CHANGED
@@ -92,9 +92,9 @@ def build_description(
92
  }
93
 
94
 
95
- def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
96
 
97
- def generate_file_patterns(prefixes, extensions, indices):
98
  return [
99
  f"{prefix}_{index:05d}.{ext}"
100
  for prefix in prefixes
@@ -102,6 +102,11 @@ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
102
  for ext in extensions
103
  ]
104
 
 
 
 
 
 
105
  if datatype == "audio":
106
  quality_metrics = ["snr", "sisnr", "stoi", "pesq"]
107
  extensions = ["wav"]
@@ -118,7 +123,7 @@ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
118
  datatype_abbr = "video"
119
  # indices = [0, 1, 3, 4, 5]
120
 
121
- eval_results_path = abs_path + f"{dataset_name}/examples_eval_results.json"
122
 
123
  # Determine if eval_results_path is a URL or local file
124
  if eval_results_path.startswith("http://") or eval_results_path.startswith(
@@ -141,7 +146,9 @@ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
141
  first_model = next(iter(dataset.keys()))
142
  first_attack = next(iter(dataset[first_model].keys()))
143
  first_attack_variant = next(iter(dataset[first_model][first_attack].keys()))
144
- indices = [item["idx"] for item in dataset[first_model][first_attack][first_attack_variant]]
 
 
145
 
146
  prefixes = [
147
  f"attacked_{datatype_abbr}",
@@ -168,11 +175,15 @@ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
168
  attack = attack_name
169
  else:
170
  # TODO: Update data on S3 with new Omniseal Bench V2 eval script
171
- if str(abs_path).startswith("http") or str(abs_path).startswith("https") or str(abs_path).startswith("s3://"):
 
 
 
 
172
  attack = f"{attack_name}_{attack_variant}"
173
  else:
174
  attack = f"{attack_name}__{attack_variant}"
175
-
176
  if len(attack_rows) == 0:
177
  model_infos[attack] = []
178
  continue
@@ -227,8 +238,8 @@ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
227
  return infos
228
 
229
 
230
- def get_examples_tab(datatype: str):
231
- config = get_example_config(datatype)
232
  infos = build_infos(
233
  config["path"],
234
  datatype=datatype,
 
92
  }
93
 
94
 
95
+ def build_infos(abs_path, datatype: str, dataset_name: str, db_key: str):
96
 
97
+ def generate_file_patterns(prefixes, extensions, indices):
98
  return [
99
  f"{prefix}_{index:05d}.{ext}"
100
  for prefix in prefixes
 
102
  for ext in extensions
103
  ]
104
 
105
+ # Initialize defaults
106
+ quality_metrics = []
107
+ extensions = []
108
+ datatype_abbr = ""
109
+
110
  if datatype == "audio":
111
  quality_metrics = ["snr", "sisnr", "stoi", "pesq"]
112
  extensions = ["wav"]
 
123
  datatype_abbr = "video"
124
  # indices = [0, 1, 3, 4, 5]
125
 
126
+ eval_results_path = str(abs_path) + f"{dataset_name}/examples_eval_results.json"
127
 
128
  # Determine if eval_results_path is a URL or local file
129
  if eval_results_path.startswith("http://") or eval_results_path.startswith(
 
146
  first_model = next(iter(dataset.keys()))
147
  first_attack = next(iter(dataset[first_model].keys()))
148
  first_attack_variant = next(iter(dataset[first_model][first_attack].keys()))
149
+ indices = [
150
+ item["idx"] for item in dataset[first_model][first_attack][first_attack_variant]
151
+ ]
152
 
153
  prefixes = [
154
  f"attacked_{datatype_abbr}",
 
175
  attack = attack_name
176
  else:
177
  # TODO: Update data on S3 with new Omniseal Bench V2 eval script
178
+ if (
179
+ str(abs_path).startswith("http")
180
+ or str(abs_path).startswith("https")
181
+ or str(abs_path).startswith("s3://")
182
+ ):
183
  attack = f"{attack_name}_{attack_variant}"
184
  else:
185
  attack = f"{attack_name}__{attack_variant}"
186
+
187
  if len(attack_rows) == 0:
188
  model_infos[attack] = []
189
  continue
 
238
  return infos
239
 
240
 
241
+ def get_examples_tab(datatype: str, dataset_name: str):
242
+ config = get_example_config(datatype, dataset_name)
243
  infos = build_infos(
244
  config["path"],
245
  datatype=datatype,
frontend/src/API.ts CHANGED
@@ -17,8 +17,11 @@ class API {
17
  }
18
 
19
  // Rename the method to fetchExamplesByType
20
- static fetchExamplesByType(type: 'image' | 'audio' | 'video'): Promise<any> {
21
- return fetch(`${VITE_API_SERVER_URL}/examples/${type}`).then((response) => {
 
 
 
22
  if (!response.ok) {
23
  throw new Error(`Failed to fetch examples of type ${type}`)
24
  }
@@ -52,6 +55,21 @@ class API {
52
  if (!response.ok) throw new Error('Failed to fetch descriptions')
53
  return response.json()
54
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
 
57
  export default API
 
 
17
  }
18
 
19
  // Rename the method to fetchExamplesByType
20
+ static fetchExamplesByType(type: 'image' | 'audio' | 'video', dataset?: string): Promise<any> {
21
+ const url = dataset
22
+ ? `${VITE_API_SERVER_URL}/examples/${type}?dataset=${encodeURIComponent(dataset)}`
23
+ : `${VITE_API_SERVER_URL}/examples/${type}`
24
+ return fetch(url).then((response) => {
25
  if (!response.ok) {
26
  throw new Error(`Failed to fetch examples of type ${type}`)
27
  }
 
55
  if (!response.ok) throw new Error('Failed to fetch descriptions')
56
  return response.json()
57
  }
58
+
59
+ // Fetch leaderboard data from the backend
60
+ static async fetchLeaderboard(datasetName: string): Promise<any> {
61
+ const response = await fetch(`${VITE_API_SERVER_URL}/data/${datasetName}?dataset_type=benchmark`)
62
+ if (!response.ok) throw new Error(`Failed to fetch leaderboard for ${datasetName}`)
63
+ return response.json()
64
+ }
65
+
66
+ // Fetch leaderboard data from the backend
67
+ static async fetchChart(datasetName: string): Promise<any> {
68
+ const response = await fetch(`${VITE_API_SERVER_URL}/data/${datasetName}?dataset_type=attacks_variations`)
69
+ if (!response.ok) throw new Error(`Failed to fetch chart data for ${datasetName}`)
70
+ return response.json()
71
+ }
72
  }
73
 
74
  export default API
75
+ export { VITE_API_SERVER_URL as API_BASE_URL }
frontend/src/components/Examples.tsx CHANGED
@@ -7,6 +7,7 @@ import AudioGallery from './AudioGallery'
7
  import VideoGallery from './VideoGallery'
8
  import ModelInfoIcon from './ModelInfoIcon'
9
  import Descriptions from '../Descriptions'
 
10
 
11
  interface ExamplesProps {
12
  fileType: 'image' | 'audio' | 'video'
@@ -32,11 +33,24 @@ const Examples = ({ fileType }: ExamplesProps) => {
32
  const [selectedModel, setSelectedModel] = useState<string | null>(null)
33
  const [selectedAttack, setSelectedAttack] = useState<string | null>(null)
34
  const [descriptionsLoaded, setDescriptionsLoaded] = useState(false)
 
 
35
  const descriptions = useRef(Descriptions.getInstance())
36
 
37
  useEffect(() => {
38
  descriptions.current.load().then(() => setDescriptionsLoaded(true))
39
- }, [])
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  const location = useLocation()
42
  // Parse query params for model and attack
@@ -72,9 +86,11 @@ const Examples = ({ fileType }: ExamplesProps) => {
72
  }, [location.search, selectedModel, examples])
73
 
74
  useEffect(() => {
 
 
75
  setLoading(true)
76
  setError(null)
77
- API.fetchExamplesByType(fileType)
78
  .then((data) => {
79
  setExamples(data)
80
  const models = Object.keys(data)
@@ -97,7 +113,7 @@ const Examples = ({ fileType }: ExamplesProps) => {
97
  setError(err.message)
98
  setLoading(false)
99
  })
100
- }, [fileType])
101
 
102
  if (loading) {
103
  return <LoadingSpinner />
@@ -106,6 +122,11 @@ const Examples = ({ fileType }: ExamplesProps) => {
106
  return (
107
  <div className="examples-container">
108
  <div className="selectors-container flex flex-col gap-4">
 
 
 
 
 
109
  <fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200">
110
  <legend className="fieldset-legend font-semibold">Model</legend>
111
  <div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-1 max-h-48 overflow-y-auto pr-2">
 
7
  import VideoGallery from './VideoGallery'
8
  import ModelInfoIcon from './ModelInfoIcon'
9
  import Descriptions from '../Descriptions'
10
+ import DatasetSelector from './DatasetSelector'
11
 
12
  interface ExamplesProps {
13
  fileType: 'image' | 'audio' | 'video'
 
33
  const [selectedModel, setSelectedModel] = useState<string | null>(null)
34
  const [selectedAttack, setSelectedAttack] = useState<string | null>(null)
35
  const [descriptionsLoaded, setDescriptionsLoaded] = useState(false)
36
+ const [datasets, setDatasets] = useState<any>({})
37
+ const [selectedDataset, setSelectedDataset] = useState<string>('')
38
  const descriptions = useRef(Descriptions.getInstance())
39
 
40
  useEffect(() => {
41
  descriptions.current.load().then(() => setDescriptionsLoaded(true))
42
+ // Fetch datasets when component loads
43
+ API.fetchDatasets().then((datasetsData) => {
44
+ setDatasets(datasetsData)
45
+ // Set default selected dataset based on fileType
46
+ const datasetsForType = datasetsData[fileType] || []
47
+ if (datasetsForType.length > 0) {
48
+ setSelectedDataset(datasetsForType[0])
49
+ }
50
+ }).catch((err) => {
51
+ console.error('Failed to fetch datasets:', err)
52
+ })
53
+ }, [fileType])
54
 
55
  const location = useLocation()
56
  // Parse query params for model and attack
 
86
  }, [location.search, selectedModel, examples])
87
 
88
  useEffect(() => {
89
+ if (!selectedDataset) return
90
+
91
  setLoading(true)
92
  setError(null)
93
+ API.fetchExamplesByType(fileType, selectedDataset)
94
  .then((data) => {
95
  setExamples(data)
96
  const models = Object.keys(data)
 
113
  setError(err.message)
114
  setLoading(false)
115
  })
116
+ }, [fileType, selectedDataset])
117
 
118
  if (loading) {
119
  return <LoadingSpinner />
 
122
  return (
123
  <div className="examples-container">
124
  <div className="selectors-container flex flex-col gap-4">
125
+ <DatasetSelector
126
+ datasetNames={datasets[fileType] || []}
127
+ selectedDatasetName={selectedDataset}
128
+ onDatasetNameChange={setSelectedDataset}
129
+ />
130
  <fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200">
131
  <legend className="fieldset-legend font-semibold">Model</legend>
132
  <div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-1 max-h-48 overflow-y-auto pr-2">