adnlp commited on
Commit
9f2785e
·
verified ·
1 Parent(s): 01e3bfb

Upload 26 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ logo.png filter=lfs diff=lfs merge=lfs -text
MultiModalTimer.py CHANGED
@@ -71,7 +71,7 @@ class MultiModalTimerModel(PreTrainedModel):
71
  text_model = AutoModelForCausalLM.from_pretrained(
72
  "Qwen/Qwen2-1.5B-Instruct",
73
  torch_dtype=torch.bfloat16,
74
- device_map="cpu",
75
  attn_implementation="sdpa"
76
  ).model
77
  state_dict = text_model.state_dict()
 
71
  text_model = AutoModelForCausalLM.from_pretrained(
72
  "Qwen/Qwen2-1.5B-Instruct",
73
  torch_dtype=torch.bfloat16,
74
+ device_map="auto",
75
  attn_implementation="sdpa"
76
  ).model
77
  state_dict = text_model.state_dict()
app.py CHANGED
@@ -3,57 +3,19 @@ from safetensors.torch import load_file
3
  import gradio as gr
4
  import torch
5
  import numpy as np
 
6
  import matplotlib.pyplot as plt
7
  import io
8
  from PIL import Image
 
9
 
10
  from transformers import CLIPImageProcessor
11
 
12
- inputs = {
13
- "NN5 Daily": [
14
- "0.3910, 1.7167, 1.1042, -0.6679, -0.0730, -0.5204, -0.2152, 0.6485, 2.8109, 1.4125, -0.2568, 0.7014, 1.2558, 2.9853, -0.9938, -1.0800, -0.2318, -0.9222, -0.6326, -0.1353, 1.3316, -1.1879, 0.3443, -1.3457, -0.6679, -0.6450, -0.7810, -0.9419, -0.1787, -1.2575, 0.7232, -0.0730, -0.5599, -0.8962, -0.7987, 0.1450, 1.1177, 0.4772, -0.9699, -0.7748, -0.7634, -0.6741, -0.1602, 1.2849, 2.1974, 0.2062, -0.3564, -0.5547, -0.5588, 0.1626, 1.1073, 0.4575, -0.7198, -0.5443, -0.6357, -0.8070",
15
- "0.5456, 1.4991, 1.1747, -0.6072, -0.1441, 0.3698, -0.3747, 0.3952, 1.3213, 2.2103, -0.2672, 1.3604, 1.3135, 2.1556, -1.9260, -1.8713, -0.5368, -1.2031, 0.4577, 0.0024, 1.6632, -1.3907, -0.0855, -0.0914, -1.1113, -0.3629, -1.3047, -0.7185, -0.3414, 0.7820, 0.4909, -0.6033, -0.2496, -0.7478, -0.3942, 0.2901, 0.3600, 0.7508, -0.7439, -0.6404, -0.7088, -0.6462, 0.5163, 1.2275, 0.4850, -2.0940, -1.6876, -0.3199, -0.1988, 0.4206, 1.4014, 1.3916, 0.0591, -0.5935, -0.1773, -0.5017",
16
- "0.6767, -0.7404, -0.0247, -0.5122, 0.1192, 0.8660, 3.3852, 1.7256, -0.4448, 0.4084, 1.4521, 1.7801, -1.4599, -1.1773, -0.2282, -1.0087, 0.0907, -0.0234, 1.0890, -1.3394, 0.5938, -0.1453, -1.1928, -0.4085, -0.7144, -0.1673, -0.0182, 0.8453, 0.3980, -1.0645, -0.4875, -0.6755, -0.3086, 0.3358, 1.1383, 0.2074, -1.0554, -0.4357, -0.4759, -0.3631, 0.1529, 1.1953, 0.1633, -0.9932, -0.3177, -0.6457, -0.6807, -0.2412, -1.5286, 2.7797, -0.5757, -0.0986, -1.3899, -0.2995, 0.5471, 1.2926"
17
- ],
18
- "Australian Electricity": [
19
- "0.1561, 0.0250, -0.1232, -0.3055, -0.3450, -0.6220, -0.9393, -1.1442, -1.3506, -1.5285, -1.6454, -1.7910, -1.8572, -1.8765, -1.8602, -1.7844, -1.5492, -1.3167, -0.8213, -0.3058, 0.1053, 0.4125, 0.6468, 0.7649, 0.8247, 0.8206, 0.8007, 0.8021, 0.8096, 0.8684, 0.8706, 0.8588, 0.8742, 0.9041, 0.8933, 0.8302, 0.8022, 0.8173, 0.8135, 0.7948, 0.6960, 0.6982, 0.9220, 0.9045, 0.7065, 0.6502, 0.6367, 0.4564",
20
- "-0.3264, -0.3155, -0.3910, -0.2909, -0.3695, -0.4594, -0.6087, -0.7602, -0.9589, -1.2557, -1.5857, -1.8475, -1.9940, -2.0640, -2.0380, -1.8492, -1.6252, -1.1220, -0.5189, -0.2555, 0.1371, 0.3406, 0.4399, 0.6321, 0.7476, 0.8676, 0.9294, 0.9488, 1.0222, 1.0061, 1.0013, 1.0430, 1.0417, 1.0566, 1.0340, 0.9836, 0.9892, 0.9969, 0.9595, 0.8409, 0.7677, 0.6153, 0.5256, 0.4210, 0.5014, 0.4849, 0.3080, -0.0058",
21
- "0.5091, 0.1714, -0.2558, -0.6192, -0.9358, -1.1645, -1.3079, -1.3875, -1.4657, -1.6180, -1.6260, -1.5982, -1.5640, -1.5354, -1.4960, -1.3395, -1.1070, -0.5613, 0.2179, 1.1485, 1.7363, 1.8791, 1.6987, 1.3585, 1.1350, 0.8669, 0.7281, 0.5676, 0.4495, 0.2027, 0.0901, -0.0161, -0.0750, -0.0981, -0.0711, 0.0944, 0.0466, 0.1898, 0.2478, 0.3899, 0.5053, 0.5426, 0.6315, 0.6308, 0.8361, 1.0080, 1.0234, 0.9366"
22
- ],
23
- "CIF 2016": [
24
- "-2.1365, -0.6420, -0.9545, -0.4169, 0.5554, -0.5115, 0.2702, -0.1587, 1.0649, 1.3542, 0.6126, 0.9629",
25
- "1.5888, 1.3587, 0.6136, 0.7122, 0.6355, -0.4164, -0.7670, -1.0409, -1.0300, 0.0219, -1.5559, -0.1205",
26
- "1.3580, 1.1319, 1.5850, 0.7912, 0.0083, -0.0066, -0.3669, -0.6251, -0.5215, -1.3642, -0.7737, -1.2164"
27
- ],
28
- "Tourism Monthly": [
29
- "-0.7495, -0.8636, -0.9378, -0.7584, -0.3905, -0.1575, 0.3522, 1.9419, 1.9991, 0.5223, -0.2807, -0.8089, -0.7009, -0.8471, -0.9227, -0.6670, -0.4053, -0.1508, 0.3640, 1.8908, 1.9565, 0.5652, -0.1896, -0.7617",
30
- "0.8332, 1.4813, 0.1061, -0.4110, -1.0407, -1.1521, -0.8156, -0.8842, -0.9908, 0.0014, 0.8833, 1.5613, 1.3095, 1.6260, 0.2534, -0.4883, -1.0412, -1.1210, -0.7034, -0.8304, -0.8239, -0.0673, 0.7531, 1.5613",
31
- "0.8778, -0.0008, -0.6814, -0.7851, -0.7828, -0.6682, -0.8473, -0.9634, -0.4615, 0.1523, 2.3755, 1.0437, 1.5980, 0.1482, -0.1063, -0.7360, -0.7700, -0.7560, -0.7145, -0.8552, -0.4242, 0.2896, 2.2307, 0.8371"
32
- ]
33
- }
34
 
35
- targets = {
36
- "NN5 Daily": [
37
- [-0.5433, 0.6589, 0.4668, -0.6959, -0.5474, -0.7685, -0.7125, -0.0273, 1.3170, 0.5883, -0.7675, -0.5163, -0.6035, -0.5028, 0.0505, 0.5530, 0.5810, -0.8049, -0.5370, -0.6149, -0.5609, 0.2166, 1.3202, 0.5852, -0.5921, -0.5038, -0.7301, -0.7644, 0.2011, 0.6942, 0.5052, -0.7644, -0.7644, -0.7364, -0.3896, 0.1917, 1.1084, 0.3827, -0.6679, -0.5568, -0.5910, -0.4052, 0.3599, 1.0617, 1.0461, -0.6326, -0.6990, -0.9056, -0.6159, 0.2696, 1.5299, 0.8032, -1.1879, -0.8153, -0.4872, -0.3491],
38
- [0.6140, 1.0575, 0.2213, -0.3903, -0.5388, -0.6677, -0.4196, 0.1841, 1.6300, 0.4147, -0.5505, -0.4469, -0.3903, -0.3786, 0.5886, 1.0321, 0.3053, -0.5368, -0.6443, -0.5700, -0.5290, 0.1587, 1.6437, 1.1435, -0.6404, -0.1988, -0.6423, -0.4489, 0.1880, 0.9872, 1.1044, -0.7654, -0.2320, -0.6931, -0.9725, 0.3014, 1.3447, 0.8895, -0.5446, -0.0445, -0.3336, -0.6482, 0.7097, 1.7238, 1.3037, -0.3551, -0.4919, -0.5544, -0.4196, 0.4753, 1.4795, 1.3760, -0.4430, -0.0347, 0.2330, -0.4079],
39
- [0.2800, -0.9634, -0.2114, -0.6185, -0.4473, 0.4408, 1.3276, 0.4408, -1.0165, -0.6470, -0.4103, -0.7404, -0.2490, 1.0709, 0.6690, -1.0528, -0.4655, -0.7365, -0.3981, 0.0907, 1.4754, 1.0359, -0.9400, -0.3449, -0.6418, 0.0155, 0.5717, 1.2355, 0.4589, -0.9776, -0.1867, -0.6587, -0.3358, 0.2644, 1.4145, 0.3189, -1.1708, -0.2412, -0.5848, -0.1077, 0.5951, 0.8777, 0.6599, -1.1565, -0.5122, -0.5666, -0.2931, 0.1983, 1.8410, 1.1370, -1.2330, -0.2840, 0.0090, 0.2411, 0.5912, 1.2822]
40
- ],
41
- "Australian Electricity": [
42
- [0.2418, 0.1000, -0.0648, -0.2259, -0.2967, -0.5842, -0.9359, -1.1101, -1.3467, -1.5249, -1.7000, -1.7966, -1.8497, -1.8812, -1.9009, -1.8194, -1.6176, -1.3769, -0.8809, -0.4156, 0.0521, 0.3505, 0.6003, 0.7191, 0.8154, 0.8006, 0.7769, 0.7888, 0.8156, 0.8311, 0.8516, 0.8480, 0.8405, 0.9085, 0.8990, 0.9031, 0.9135, 0.9160, 0.8557, 0.8282, 0.7075, 0.7170, 0.9111, 0.9671, 0.7844, 0.7513, 0.6663, 0.5040],
43
- [-0.2806, -0.2256, -0.3062, -0.2318, -0.3017, -0.4339, -0.5957, -0.7385, -0.9333, -1.2554, -1.5583, -1.8179, -1.9639, -2.0259, -2.0105, -1.8515, -1.5978, -1.1026, -0.5616, -0.2580, 0.0937, 0.3312, 0.3877, 0.5459, 0.6491, 0.7382, 0.7652, 0.7550, 0.8130, 0.7863, 0.7734, 0.7761, 0.7552, 0.7527, 0.7178, 0.6544, 0.6689, 0.6651, 0.6142, 0.4929, 0.3658, 0.1463, 0.0122, -0.0509, 0.0453, 0.0132, -0.1152, -0.3515],
44
- [0.6565, 0.3575, -0.0906, -0.4682, -0.7850, -1.0572, -1.2190, -1.3699, -1.4444, -1.4917, -1.4912, -1.4580, -1.4893, -1.4474, -1.3971, -1.2866, -1.0777, -0.6444, -0.0932, 0.5655, 1.0824, 1.1828, 1.1444, 1.1274, 0.9795, 0.7276, 0.5866, 0.4147, 0.1967, 0.0191, -0.1381, -0.2384, -0.3049, -0.3523, -0.3015, -0.1735, -0.1186, -0.0017, 0.1577, 0.2584, 0.3643, 0.5303, 0.6169, 0.6950, 0.8434, 0.9273, 0.8578, 0.7277]
45
- ],
46
- "CIF 2016": [
47
- [0.7780, 1.4980, 0.9115, 0.5034, 1.8105, 1.6583, 1.6255, 1.9613, 1.6311, 2.2943, 2.4116, 2.4785],
48
- [-1.1615, -0.6684, -1.1396, -0.8766, -0.1863, -0.5040, -1.4683, -1.6765, -1.7312, -0.9862, -0.7341, -0.5807],
49
- [-0.9360, -1.0646, -1.2660, -1.1586, -1.7037, -2.1642, -2.2156, -2.0456, -2.1657, -2.6387, -2.3959, -2.3520]
50
- ],
51
- "Tourism Monthly": [
52
- [-0.7218, -0.8162, -0.8599, -0.6781, -0.1500, 0.0297, 0.5061, 2.2237, 1.9718, 0.7502, -0.0718, -0.7254, -0.5999, -0.8048, -0.8552, -0.5962, -0.2393, -0.0513, 0.5497, 2.0177, 2.3518, 0.6854, -0.1190, -0.7002],
53
- [1.4626, 1.6710, 0.9104, -0.5530, -0.9392, -1.0148, -0.6596, -0.7959, -0.7170, 0.1262, 1.1316, 1.9928, 1.8925, 2.1458, 0.7225, -0.1372, -0.8087, -0.8994, -0.5187, -0.6408, -0.5757, 0.1131, 1.2596, 2.4437],
54
- [0.7737, 0.2534, -0.4698, -0.8390, -0.7692, -0.6648, -0.7311, -0.8054, -0.3891, 0.8623, 2.7314, 1.3378, 1.2805, 0.5682, -0.4299, -0.7511, -0.8982, -0.6067, -0.6572, -0.8337, -0.2564, 0.5460, 2.5010, 1.5309]
55
- ]
56
- }
57
 
58
  descriptions = {
59
  "NN5 Daily": "Daily cash withdrawal volumes from automated teller machines (ATMs) in the United Kingdom, originally used in the NN5 forecasting competition.",
@@ -63,7 +25,6 @@ descriptions = {
63
  }
64
 
65
  models = {}
66
- # for dataset in ["NN5_Daily", "Australian_Electricity", "CIF_2016", "Tourism_Monthly"]:
67
  for dataset in ["NN5_Daily", "Australian_Electricity"]:
68
  config = MultiModalTimerConfig.from_pretrained(f"ckpt/CLIPQwenTimer/{dataset}/config.json")
69
  model = MultiModalTimerModel(config)
@@ -83,16 +44,16 @@ context_length = {
83
  "Tourism Monthly": 24
84
  }
85
 
86
- def predict(dataset, example, inputs, text):
87
- inputs = np.array([float(x.strip()) for x in inputs.split(',')])
88
- mean = np.mean(inputs)
89
- std = np.std(inputs)
90
- inputs = (inputs-mean)/std
91
- input_ids = torch.tensor(inputs).to(torch.float32).to(device)
92
  input_ids = input_ids.unsqueeze(0)
93
 
94
  plt.figure(figsize=(384/100, 384/100), dpi=100)
95
- plt.plot(inputs, color="black", linestyle="-", linewidth=1, marker="*", markersize=1)
96
  plt.xticks([])
97
  plt.yticks([])
98
  plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
@@ -114,79 +75,88 @@ def predict(dataset, example, inputs, text):
114
 
115
  cl = context_length[dataset]
116
  out = out[0, :cl]
 
 
 
 
 
 
 
117
 
118
  plt.style.use("seaborn-v0_8")
119
  fig, ax = plt.subplots()
120
- ax.plot(range(cl), inputs, color="black", alpha=0.7, linewidth=3, label='Input')
121
- if example == "Custom":
 
122
  pass
123
  else:
124
- ax.plot(range(cl, 2*cl), targets[dataset][int(example)-1], color='C0', alpha=0.7, linewidth=3, label='True')
125
- ax.plot(range(cl, 2*cl), out.detach().cpu().numpy(), color='C2', alpha=0.7, linewidth=3, label='Forecast')
 
126
  ax.legend()
127
 
128
- buf = io.BytesIO()
129
- fig.savefig(buf, format='png')
130
- buf.seek(0)
131
- forecast_img = Image.open(buf).convert('RGB')
132
 
133
- # return plot_img, out, forecast_img
134
- return forecast_img
 
 
135
 
136
- def make_input_example_dropdown(example, done):
137
- if done:
138
- return example, True
139
- else:
140
- return gr.Dropdown(["1", "2", "3", "Custom"], label="Input Examples", value=None, interactive=True), True
141
-
142
- def update_options(dataset, example):
143
- if example == "1":
144
- time_series = inputs[dataset][0]
145
- desc = descriptions[dataset]
146
- placeholder = None
147
- interactive = False
148
- elif example == "2":
149
- time_series = inputs[dataset][1]
150
- desc = descriptions[dataset]
151
- placeholder = None
152
- interactive = False
153
- elif example == "3":
154
- time_series = inputs[dataset][2]
155
- desc = descriptions[dataset]
156
- placeholder = None
157
- interactive = False
158
- elif example == "Custom":
159
- time_series = ""
160
- desc = ""
161
- placeholder = f"Please Enter {context_length[dataset]} Time Steps Long Time Series Input."
162
- interactive = True
163
- else:
164
- time_series = ""
165
- desc = ""
166
- placeholder = None
167
- interactive = False
168
 
169
- return gr.Textbox(value=time_series, label="Time Series Input", placeholder=placeholder, interactive=interactive), gr.Textbox(value=desc, label="Dataset Description", interactive=interactive)
 
 
 
 
 
 
 
170
 
171
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Row():
173
  with gr.Column():
174
- # dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity", "CIF 2016", "Tourism Monthly"], value=None, label="Datasets", interactive=True)
175
  dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
176
- input_example_dropdown = gr.Dropdown([], label="Input Examples", value=None, interactive=False)
177
- done = gr.State(False)
178
 
179
- time_series_textbox = gr.Textbox(label="Time Series Input")
180
- dataset_description_textbox = gr.Textbox(label="Dataset Description")
 
 
 
 
 
 
 
 
181
 
182
- dataset_dropdown.change(make_input_example_dropdown, inputs=[input_example_dropdown, done], outputs=[input_example_dropdown, done])
183
- dataset_dropdown.change(update_options, inputs=[dataset_dropdown, input_example_dropdown], outputs=[time_series_textbox, dataset_description_textbox])
184
- input_example_dropdown.change(update_options, inputs=[dataset_dropdown, input_example_dropdown], outputs=[time_series_textbox, dataset_description_textbox])
185
 
186
  btn = gr.Button("Run")
187
  with gr.Column():
188
- forecast_image = gr.Image(label="Forecast")
189
 
190
- btn.click(predict, inputs=[dataset_dropdown, input_example_dropdown, time_series_textbox, dataset_description_textbox], outputs=forecast_image)
191
 
192
- demo.launch(ssr_mode=False)
 
 
3
  import gradio as gr
4
  import torch
5
  import numpy as np
6
+ import pandas as pd
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
+ import pickle
11
 
12
  from transformers import CLIPImageProcessor
13
 
14
+ with open('example/inputs.pkl', 'rb') as f:
15
+ inputs = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ with open('example/targets.pkl', 'rb') as f:
18
+ targets = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  descriptions = {
21
  "NN5 Daily": "Daily cash withdrawal volumes from automated teller machines (ATMs) in the United Kingdom, originally used in the NN5 forecasting competition.",
 
25
  }
26
 
27
  models = {}
 
28
  for dataset in ["NN5_Daily", "Australian_Electricity"]:
29
  config = MultiModalTimerConfig.from_pretrained(f"ckpt/CLIPQwenTimer/{dataset}/config.json")
30
  model = MultiModalTimerModel(config)
 
44
  "Tourism Monthly": 24
45
  }
46
 
47
+ def predict(dataset, text, df, example_index):
48
+ time_series = np.array(df.iloc[:, -1])
49
+ mean = np.mean(time_series)
50
+ std = np.std(time_series)
51
+ time_series_normalized = (time_series-mean)/std
52
+ input_ids = torch.tensor(time_series_normalized).to(torch.float32).to(device)
53
  input_ids = input_ids.unsqueeze(0)
54
 
55
  plt.figure(figsize=(384/100, 384/100), dpi=100)
56
+ plt.plot(time_series_normalized, color="black", linestyle="-", linewidth=1, marker="*", markersize=1)
57
  plt.xticks([])
58
  plt.yticks([])
59
  plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
 
75
 
76
  cl = context_length[dataset]
77
  out = out[0, :cl]
78
+ out = out.detach().cpu().numpy()
79
+ out = out*std+mean
80
+
81
+ input_dates_series = pd.to_datetime(df["timestamp"])
82
+ time_diff = input_dates_series.diff().mode()[0]
83
+ start_time = input_dates_series.iloc[-1] + time_diff
84
+ forecast_dates_series = pd.date_range(start=start_time, periods=len(input_dates_series), freq=time_diff)
85
 
86
  plt.style.use("seaborn-v0_8")
87
  fig, ax = plt.subplots()
88
+ ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
89
+ ax.plot(forecast_dates_series, out, color='C2', alpha=0.7, linewidth=3, label='Forecast')
90
+ if example_index == 3: # Custom Input
91
  pass
92
  else:
93
+ true = targets[dataset][example_index].iloc[:, -1]
94
+ ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='True')
95
+ pass
96
  ax.legend()
97
 
98
+ return fig
 
 
 
99
 
100
+ def selected_dataset(dataset):
101
+ gallery_items = [(Image.open(f'example/img/{dataset.replace(" ", "_")}/{i}.png').convert('RGB'), str(i+1)) for i in range(3)]
102
+ gallery_items.append((np.ones((64, 64)), 'Custom Input'))
103
+ return gr.Gallery(gallery_items, interactive=True, columns=2, height="350px", object_fit="contain"), gr.Textbox(value=descriptions[dataset], label="Dataset Description", interactive=False)
104
 
105
+ def selected_example(evt: gr.SelectData):
106
+ return evt.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ def update_time_series_dataframe(dataset, example_index):
109
+ if example_index is None:
110
+ pass
111
+ elif example_index == 3: # Custom Input
112
+ return gr.Dataframe(value=None, datatype="str", label="Time Series Input", interactive=True)
113
+ else:
114
+ df = inputs[dataset][example_index]
115
+ return gr.Dataframe(value=df, label="Time Series Input", interactive=False)
116
 
117
  with gr.Blocks() as demo:
118
+ gr.Image(
119
+ value="logo.png",
120
+ show_label=False,
121
+ show_download_button=False,
122
+ show_fullscreen_button=False,
123
+ interactive=False,
124
+ height=128,
125
+ container=False,
126
+ elem_id="logo"
127
+ )
128
+ gr.HTML("""
129
+ <style>
130
+ #logo {
131
+ display: flex;
132
+ justify-content: flex-start;
133
+ }
134
+ </style>
135
+ """)
136
  with gr.Row():
137
  with gr.Column():
 
138
  dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
 
 
139
 
140
+ dataset_description_textbox = gr.Textbox(label="Dataset Description", interactive=False)
141
+
142
+ example_gallery = gr.Gallery(
143
+ None,
144
+ interactive=False
145
+ )
146
+ example_index = gr.State(value=None)
147
+ example_gallery.select(selected_example, None, example_index)
148
+
149
+ time_series_dataframe = gr.Dataframe(value=None, headers=["Timestamp", "Value"], label="Time Series Input", interactive=False)
150
 
151
+ dataset_dropdown.change(selected_dataset, inputs=dataset_dropdown, outputs=[example_gallery, dataset_description_textbox])
152
+ dataset_dropdown.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=time_series_dataframe)
153
+ example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=time_series_dataframe)
154
 
155
  btn = gr.Button("Run")
156
  with gr.Column():
157
+ forecast_plot = gr.Plot(label="Forecast", format="png")
158
 
159
+ btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, time_series_dataframe, example_index], outputs=forecast_plot)
160
 
161
+ if __name__ == "__main__":
162
+ demo.launch(ssr_mode=False)
example/img/Australian_Electricity/0.png ADDED
example/img/Australian_Electricity/1.png ADDED
example/img/Australian_Electricity/2.png ADDED
example/img/NN5_Daily/0.png ADDED
example/img/NN5_Daily/1.png ADDED
example/img/NN5_Daily/2.png ADDED
example/inputs.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c7ada397bda4d8be66a84d12a145283bda4021a0beed57071b53dac67cc907
3
+ size 8919
example/targets.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2e0f82b9c0e9b1e922043a9cc4588ac4bb0ae298b2accc8fbb1c145b6566a18
3
+ size 8919
logo.png ADDED

Git LFS Details

  • SHA256: 3e3832413e35a5e972cabea9a4b481dc16b668f1d177a4304c1072e3baf1d379
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  gradio==5.42.0
2
- accelerate==1.6.0
3
  torch==2.6.0
4
  numpy==2.2.4
5
  transformers==4.40.1
 
1
  gradio==5.42.0
 
2
  torch==2.6.0
3
  numpy==2.2.4
4
  transformers==4.40.1