Upload 26 files
Browse files- .gitattributes +1 -0
- MultiModalTimer.py +1 -1
- app.py +76 -106
- example/img/Australian_Electricity/0.png +0 -0
- example/img/Australian_Electricity/1.png +0 -0
- example/img/Australian_Electricity/2.png +0 -0
- example/img/NN5_Daily/0.png +0 -0
- example/img/NN5_Daily/1.png +0 -0
- example/img/NN5_Daily/2.png +0 -0
- example/inputs.pkl +3 -0
- example/targets.pkl +3 -0
- logo.png +3 -0
- requirements.txt +0 -1
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
logo.png filter=lfs diff=lfs merge=lfs -text
|
MultiModalTimer.py
CHANGED
@@ -71,7 +71,7 @@ class MultiModalTimerModel(PreTrainedModel):
|
|
71 |
text_model = AutoModelForCausalLM.from_pretrained(
|
72 |
"Qwen/Qwen2-1.5B-Instruct",
|
73 |
torch_dtype=torch.bfloat16,
|
74 |
-
device_map="
|
75 |
attn_implementation="sdpa"
|
76 |
).model
|
77 |
state_dict = text_model.state_dict()
|
|
|
71 |
text_model = AutoModelForCausalLM.from_pretrained(
|
72 |
"Qwen/Qwen2-1.5B-Instruct",
|
73 |
torch_dtype=torch.bfloat16,
|
74 |
+
device_map="auto",
|
75 |
attn_implementation="sdpa"
|
76 |
).model
|
77 |
state_dict = text_model.state_dict()
|
app.py
CHANGED
@@ -3,57 +3,19 @@ from safetensors.torch import load_file
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
8 |
from PIL import Image
|
|
|
9 |
|
10 |
from transformers import CLIPImageProcessor
|
11 |
|
12 |
-
inputs
|
13 |
-
|
14 |
-
"0.3910, 1.7167, 1.1042, -0.6679, -0.0730, -0.5204, -0.2152, 0.6485, 2.8109, 1.4125, -0.2568, 0.7014, 1.2558, 2.9853, -0.9938, -1.0800, -0.2318, -0.9222, -0.6326, -0.1353, 1.3316, -1.1879, 0.3443, -1.3457, -0.6679, -0.6450, -0.7810, -0.9419, -0.1787, -1.2575, 0.7232, -0.0730, -0.5599, -0.8962, -0.7987, 0.1450, 1.1177, 0.4772, -0.9699, -0.7748, -0.7634, -0.6741, -0.1602, 1.2849, 2.1974, 0.2062, -0.3564, -0.5547, -0.5588, 0.1626, 1.1073, 0.4575, -0.7198, -0.5443, -0.6357, -0.8070",
|
15 |
-
"0.5456, 1.4991, 1.1747, -0.6072, -0.1441, 0.3698, -0.3747, 0.3952, 1.3213, 2.2103, -0.2672, 1.3604, 1.3135, 2.1556, -1.9260, -1.8713, -0.5368, -1.2031, 0.4577, 0.0024, 1.6632, -1.3907, -0.0855, -0.0914, -1.1113, -0.3629, -1.3047, -0.7185, -0.3414, 0.7820, 0.4909, -0.6033, -0.2496, -0.7478, -0.3942, 0.2901, 0.3600, 0.7508, -0.7439, -0.6404, -0.7088, -0.6462, 0.5163, 1.2275, 0.4850, -2.0940, -1.6876, -0.3199, -0.1988, 0.4206, 1.4014, 1.3916, 0.0591, -0.5935, -0.1773, -0.5017",
|
16 |
-
"0.6767, -0.7404, -0.0247, -0.5122, 0.1192, 0.8660, 3.3852, 1.7256, -0.4448, 0.4084, 1.4521, 1.7801, -1.4599, -1.1773, -0.2282, -1.0087, 0.0907, -0.0234, 1.0890, -1.3394, 0.5938, -0.1453, -1.1928, -0.4085, -0.7144, -0.1673, -0.0182, 0.8453, 0.3980, -1.0645, -0.4875, -0.6755, -0.3086, 0.3358, 1.1383, 0.2074, -1.0554, -0.4357, -0.4759, -0.3631, 0.1529, 1.1953, 0.1633, -0.9932, -0.3177, -0.6457, -0.6807, -0.2412, -1.5286, 2.7797, -0.5757, -0.0986, -1.3899, -0.2995, 0.5471, 1.2926"
|
17 |
-
],
|
18 |
-
"Australian Electricity": [
|
19 |
-
"0.1561, 0.0250, -0.1232, -0.3055, -0.3450, -0.6220, -0.9393, -1.1442, -1.3506, -1.5285, -1.6454, -1.7910, -1.8572, -1.8765, -1.8602, -1.7844, -1.5492, -1.3167, -0.8213, -0.3058, 0.1053, 0.4125, 0.6468, 0.7649, 0.8247, 0.8206, 0.8007, 0.8021, 0.8096, 0.8684, 0.8706, 0.8588, 0.8742, 0.9041, 0.8933, 0.8302, 0.8022, 0.8173, 0.8135, 0.7948, 0.6960, 0.6982, 0.9220, 0.9045, 0.7065, 0.6502, 0.6367, 0.4564",
|
20 |
-
"-0.3264, -0.3155, -0.3910, -0.2909, -0.3695, -0.4594, -0.6087, -0.7602, -0.9589, -1.2557, -1.5857, -1.8475, -1.9940, -2.0640, -2.0380, -1.8492, -1.6252, -1.1220, -0.5189, -0.2555, 0.1371, 0.3406, 0.4399, 0.6321, 0.7476, 0.8676, 0.9294, 0.9488, 1.0222, 1.0061, 1.0013, 1.0430, 1.0417, 1.0566, 1.0340, 0.9836, 0.9892, 0.9969, 0.9595, 0.8409, 0.7677, 0.6153, 0.5256, 0.4210, 0.5014, 0.4849, 0.3080, -0.0058",
|
21 |
-
"0.5091, 0.1714, -0.2558, -0.6192, -0.9358, -1.1645, -1.3079, -1.3875, -1.4657, -1.6180, -1.6260, -1.5982, -1.5640, -1.5354, -1.4960, -1.3395, -1.1070, -0.5613, 0.2179, 1.1485, 1.7363, 1.8791, 1.6987, 1.3585, 1.1350, 0.8669, 0.7281, 0.5676, 0.4495, 0.2027, 0.0901, -0.0161, -0.0750, -0.0981, -0.0711, 0.0944, 0.0466, 0.1898, 0.2478, 0.3899, 0.5053, 0.5426, 0.6315, 0.6308, 0.8361, 1.0080, 1.0234, 0.9366"
|
22 |
-
],
|
23 |
-
"CIF 2016": [
|
24 |
-
"-2.1365, -0.6420, -0.9545, -0.4169, 0.5554, -0.5115, 0.2702, -0.1587, 1.0649, 1.3542, 0.6126, 0.9629",
|
25 |
-
"1.5888, 1.3587, 0.6136, 0.7122, 0.6355, -0.4164, -0.7670, -1.0409, -1.0300, 0.0219, -1.5559, -0.1205",
|
26 |
-
"1.3580, 1.1319, 1.5850, 0.7912, 0.0083, -0.0066, -0.3669, -0.6251, -0.5215, -1.3642, -0.7737, -1.2164"
|
27 |
-
],
|
28 |
-
"Tourism Monthly": [
|
29 |
-
"-0.7495, -0.8636, -0.9378, -0.7584, -0.3905, -0.1575, 0.3522, 1.9419, 1.9991, 0.5223, -0.2807, -0.8089, -0.7009, -0.8471, -0.9227, -0.6670, -0.4053, -0.1508, 0.3640, 1.8908, 1.9565, 0.5652, -0.1896, -0.7617",
|
30 |
-
"0.8332, 1.4813, 0.1061, -0.4110, -1.0407, -1.1521, -0.8156, -0.8842, -0.9908, 0.0014, 0.8833, 1.5613, 1.3095, 1.6260, 0.2534, -0.4883, -1.0412, -1.1210, -0.7034, -0.8304, -0.8239, -0.0673, 0.7531, 1.5613",
|
31 |
-
"0.8778, -0.0008, -0.6814, -0.7851, -0.7828, -0.6682, -0.8473, -0.9634, -0.4615, 0.1523, 2.3755, 1.0437, 1.5980, 0.1482, -0.1063, -0.7360, -0.7700, -0.7560, -0.7145, -0.8552, -0.4242, 0.2896, 2.2307, 0.8371"
|
32 |
-
]
|
33 |
-
}
|
34 |
|
35 |
-
targets
|
36 |
-
|
37 |
-
[-0.5433, 0.6589, 0.4668, -0.6959, -0.5474, -0.7685, -0.7125, -0.0273, 1.3170, 0.5883, -0.7675, -0.5163, -0.6035, -0.5028, 0.0505, 0.5530, 0.5810, -0.8049, -0.5370, -0.6149, -0.5609, 0.2166, 1.3202, 0.5852, -0.5921, -0.5038, -0.7301, -0.7644, 0.2011, 0.6942, 0.5052, -0.7644, -0.7644, -0.7364, -0.3896, 0.1917, 1.1084, 0.3827, -0.6679, -0.5568, -0.5910, -0.4052, 0.3599, 1.0617, 1.0461, -0.6326, -0.6990, -0.9056, -0.6159, 0.2696, 1.5299, 0.8032, -1.1879, -0.8153, -0.4872, -0.3491],
|
38 |
-
[0.6140, 1.0575, 0.2213, -0.3903, -0.5388, -0.6677, -0.4196, 0.1841, 1.6300, 0.4147, -0.5505, -0.4469, -0.3903, -0.3786, 0.5886, 1.0321, 0.3053, -0.5368, -0.6443, -0.5700, -0.5290, 0.1587, 1.6437, 1.1435, -0.6404, -0.1988, -0.6423, -0.4489, 0.1880, 0.9872, 1.1044, -0.7654, -0.2320, -0.6931, -0.9725, 0.3014, 1.3447, 0.8895, -0.5446, -0.0445, -0.3336, -0.6482, 0.7097, 1.7238, 1.3037, -0.3551, -0.4919, -0.5544, -0.4196, 0.4753, 1.4795, 1.3760, -0.4430, -0.0347, 0.2330, -0.4079],
|
39 |
-
[0.2800, -0.9634, -0.2114, -0.6185, -0.4473, 0.4408, 1.3276, 0.4408, -1.0165, -0.6470, -0.4103, -0.7404, -0.2490, 1.0709, 0.6690, -1.0528, -0.4655, -0.7365, -0.3981, 0.0907, 1.4754, 1.0359, -0.9400, -0.3449, -0.6418, 0.0155, 0.5717, 1.2355, 0.4589, -0.9776, -0.1867, -0.6587, -0.3358, 0.2644, 1.4145, 0.3189, -1.1708, -0.2412, -0.5848, -0.1077, 0.5951, 0.8777, 0.6599, -1.1565, -0.5122, -0.5666, -0.2931, 0.1983, 1.8410, 1.1370, -1.2330, -0.2840, 0.0090, 0.2411, 0.5912, 1.2822]
|
40 |
-
],
|
41 |
-
"Australian Electricity": [
|
42 |
-
[0.2418, 0.1000, -0.0648, -0.2259, -0.2967, -0.5842, -0.9359, -1.1101, -1.3467, -1.5249, -1.7000, -1.7966, -1.8497, -1.8812, -1.9009, -1.8194, -1.6176, -1.3769, -0.8809, -0.4156, 0.0521, 0.3505, 0.6003, 0.7191, 0.8154, 0.8006, 0.7769, 0.7888, 0.8156, 0.8311, 0.8516, 0.8480, 0.8405, 0.9085, 0.8990, 0.9031, 0.9135, 0.9160, 0.8557, 0.8282, 0.7075, 0.7170, 0.9111, 0.9671, 0.7844, 0.7513, 0.6663, 0.5040],
|
43 |
-
[-0.2806, -0.2256, -0.3062, -0.2318, -0.3017, -0.4339, -0.5957, -0.7385, -0.9333, -1.2554, -1.5583, -1.8179, -1.9639, -2.0259, -2.0105, -1.8515, -1.5978, -1.1026, -0.5616, -0.2580, 0.0937, 0.3312, 0.3877, 0.5459, 0.6491, 0.7382, 0.7652, 0.7550, 0.8130, 0.7863, 0.7734, 0.7761, 0.7552, 0.7527, 0.7178, 0.6544, 0.6689, 0.6651, 0.6142, 0.4929, 0.3658, 0.1463, 0.0122, -0.0509, 0.0453, 0.0132, -0.1152, -0.3515],
|
44 |
-
[0.6565, 0.3575, -0.0906, -0.4682, -0.7850, -1.0572, -1.2190, -1.3699, -1.4444, -1.4917, -1.4912, -1.4580, -1.4893, -1.4474, -1.3971, -1.2866, -1.0777, -0.6444, -0.0932, 0.5655, 1.0824, 1.1828, 1.1444, 1.1274, 0.9795, 0.7276, 0.5866, 0.4147, 0.1967, 0.0191, -0.1381, -0.2384, -0.3049, -0.3523, -0.3015, -0.1735, -0.1186, -0.0017, 0.1577, 0.2584, 0.3643, 0.5303, 0.6169, 0.6950, 0.8434, 0.9273, 0.8578, 0.7277]
|
45 |
-
],
|
46 |
-
"CIF 2016": [
|
47 |
-
[0.7780, 1.4980, 0.9115, 0.5034, 1.8105, 1.6583, 1.6255, 1.9613, 1.6311, 2.2943, 2.4116, 2.4785],
|
48 |
-
[-1.1615, -0.6684, -1.1396, -0.8766, -0.1863, -0.5040, -1.4683, -1.6765, -1.7312, -0.9862, -0.7341, -0.5807],
|
49 |
-
[-0.9360, -1.0646, -1.2660, -1.1586, -1.7037, -2.1642, -2.2156, -2.0456, -2.1657, -2.6387, -2.3959, -2.3520]
|
50 |
-
],
|
51 |
-
"Tourism Monthly": [
|
52 |
-
[-0.7218, -0.8162, -0.8599, -0.6781, -0.1500, 0.0297, 0.5061, 2.2237, 1.9718, 0.7502, -0.0718, -0.7254, -0.5999, -0.8048, -0.8552, -0.5962, -0.2393, -0.0513, 0.5497, 2.0177, 2.3518, 0.6854, -0.1190, -0.7002],
|
53 |
-
[1.4626, 1.6710, 0.9104, -0.5530, -0.9392, -1.0148, -0.6596, -0.7959, -0.7170, 0.1262, 1.1316, 1.9928, 1.8925, 2.1458, 0.7225, -0.1372, -0.8087, -0.8994, -0.5187, -0.6408, -0.5757, 0.1131, 1.2596, 2.4437],
|
54 |
-
[0.7737, 0.2534, -0.4698, -0.8390, -0.7692, -0.6648, -0.7311, -0.8054, -0.3891, 0.8623, 2.7314, 1.3378, 1.2805, 0.5682, -0.4299, -0.7511, -0.8982, -0.6067, -0.6572, -0.8337, -0.2564, 0.5460, 2.5010, 1.5309]
|
55 |
-
]
|
56 |
-
}
|
57 |
|
58 |
descriptions = {
|
59 |
"NN5 Daily": "Daily cash withdrawal volumes from automated teller machines (ATMs) in the United Kingdom, originally used in the NN5 forecasting competition.",
|
@@ -63,7 +25,6 @@ descriptions = {
|
|
63 |
}
|
64 |
|
65 |
models = {}
|
66 |
-
# for dataset in ["NN5_Daily", "Australian_Electricity", "CIF_2016", "Tourism_Monthly"]:
|
67 |
for dataset in ["NN5_Daily", "Australian_Electricity"]:
|
68 |
config = MultiModalTimerConfig.from_pretrained(f"ckpt/CLIPQwenTimer/{dataset}/config.json")
|
69 |
model = MultiModalTimerModel(config)
|
@@ -83,16 +44,16 @@ context_length = {
|
|
83 |
"Tourism Monthly": 24
|
84 |
}
|
85 |
|
86 |
-
def predict(dataset,
|
87 |
-
|
88 |
-
mean = np.mean(
|
89 |
-
std = np.std(
|
90 |
-
|
91 |
-
input_ids = torch.tensor(
|
92 |
input_ids = input_ids.unsqueeze(0)
|
93 |
|
94 |
plt.figure(figsize=(384/100, 384/100), dpi=100)
|
95 |
-
plt.plot(
|
96 |
plt.xticks([])
|
97 |
plt.yticks([])
|
98 |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
@@ -114,79 +75,88 @@ def predict(dataset, example, inputs, text):
|
|
114 |
|
115 |
cl = context_length[dataset]
|
116 |
out = out[0, :cl]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
plt.style.use("seaborn-v0_8")
|
119 |
fig, ax = plt.subplots()
|
120 |
-
ax.plot(
|
121 |
-
|
|
|
122 |
pass
|
123 |
else:
|
124 |
-
|
125 |
-
|
|
|
126 |
ax.legend()
|
127 |
|
128 |
-
|
129 |
-
fig.savefig(buf, format='png')
|
130 |
-
buf.seek(0)
|
131 |
-
forecast_img = Image.open(buf).convert('RGB')
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
135 |
|
136 |
-
def
|
137 |
-
|
138 |
-
return example, True
|
139 |
-
else:
|
140 |
-
return gr.Dropdown(["1", "2", "3", "Custom"], label="Input Examples", value=None, interactive=True), True
|
141 |
-
|
142 |
-
def update_options(dataset, example):
|
143 |
-
if example == "1":
|
144 |
-
time_series = inputs[dataset][0]
|
145 |
-
desc = descriptions[dataset]
|
146 |
-
placeholder = None
|
147 |
-
interactive = False
|
148 |
-
elif example == "2":
|
149 |
-
time_series = inputs[dataset][1]
|
150 |
-
desc = descriptions[dataset]
|
151 |
-
placeholder = None
|
152 |
-
interactive = False
|
153 |
-
elif example == "3":
|
154 |
-
time_series = inputs[dataset][2]
|
155 |
-
desc = descriptions[dataset]
|
156 |
-
placeholder = None
|
157 |
-
interactive = False
|
158 |
-
elif example == "Custom":
|
159 |
-
time_series = ""
|
160 |
-
desc = ""
|
161 |
-
placeholder = f"Please Enter {context_length[dataset]} Time Steps Long Time Series Input."
|
162 |
-
interactive = True
|
163 |
-
else:
|
164 |
-
time_series = ""
|
165 |
-
desc = ""
|
166 |
-
placeholder = None
|
167 |
-
interactive = False
|
168 |
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
with gr.Row():
|
173 |
with gr.Column():
|
174 |
-
# dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity", "CIF 2016", "Tourism Monthly"], value=None, label="Datasets", interactive=True)
|
175 |
dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
|
176 |
-
input_example_dropdown = gr.Dropdown([], label="Input Examples", value=None, interactive=False)
|
177 |
-
done = gr.State(False)
|
178 |
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
dataset_dropdown.change(
|
183 |
-
dataset_dropdown.change(
|
184 |
-
|
185 |
|
186 |
btn = gr.Button("Run")
|
187 |
with gr.Column():
|
188 |
-
|
189 |
|
190 |
-
btn.click(predict, inputs=[dataset_dropdown,
|
191 |
|
192 |
-
|
|
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
+
import pickle
|
11 |
|
12 |
from transformers import CLIPImageProcessor
|
13 |
|
14 |
+
with open('example/inputs.pkl', 'rb') as f:
|
15 |
+
inputs = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
with open('example/targets.pkl', 'rb') as f:
|
18 |
+
targets = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
descriptions = {
|
21 |
"NN5 Daily": "Daily cash withdrawal volumes from automated teller machines (ATMs) in the United Kingdom, originally used in the NN5 forecasting competition.",
|
|
|
25 |
}
|
26 |
|
27 |
models = {}
|
|
|
28 |
for dataset in ["NN5_Daily", "Australian_Electricity"]:
|
29 |
config = MultiModalTimerConfig.from_pretrained(f"ckpt/CLIPQwenTimer/{dataset}/config.json")
|
30 |
model = MultiModalTimerModel(config)
|
|
|
44 |
"Tourism Monthly": 24
|
45 |
}
|
46 |
|
47 |
+
def predict(dataset, text, df, example_index):
|
48 |
+
time_series = np.array(df.iloc[:, -1])
|
49 |
+
mean = np.mean(time_series)
|
50 |
+
std = np.std(time_series)
|
51 |
+
time_series_normalized = (time_series-mean)/std
|
52 |
+
input_ids = torch.tensor(time_series_normalized).to(torch.float32).to(device)
|
53 |
input_ids = input_ids.unsqueeze(0)
|
54 |
|
55 |
plt.figure(figsize=(384/100, 384/100), dpi=100)
|
56 |
+
plt.plot(time_series_normalized, color="black", linestyle="-", linewidth=1, marker="*", markersize=1)
|
57 |
plt.xticks([])
|
58 |
plt.yticks([])
|
59 |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
75 |
|
76 |
cl = context_length[dataset]
|
77 |
out = out[0, :cl]
|
78 |
+
out = out.detach().cpu().numpy()
|
79 |
+
out = out*std+mean
|
80 |
+
|
81 |
+
input_dates_series = pd.to_datetime(df["timestamp"])
|
82 |
+
time_diff = input_dates_series.diff().mode()[0]
|
83 |
+
start_time = input_dates_series.iloc[-1] + time_diff
|
84 |
+
forecast_dates_series = pd.date_range(start=start_time, periods=len(input_dates_series), freq=time_diff)
|
85 |
|
86 |
plt.style.use("seaborn-v0_8")
|
87 |
fig, ax = plt.subplots()
|
88 |
+
ax.plot(input_dates_series, time_series, color="black", alpha=0.7, linewidth=3, label='Input')
|
89 |
+
ax.plot(forecast_dates_series, out, color='C2', alpha=0.7, linewidth=3, label='Forecast')
|
90 |
+
if example_index == 3: # Custom Input
|
91 |
pass
|
92 |
else:
|
93 |
+
true = targets[dataset][example_index].iloc[:, -1]
|
94 |
+
ax.plot(forecast_dates_series, true, color='C0', alpha=0.7, linewidth=3, label='True')
|
95 |
+
pass
|
96 |
ax.legend()
|
97 |
|
98 |
+
return fig
|
|
|
|
|
|
|
99 |
|
100 |
+
def selected_dataset(dataset):
|
101 |
+
gallery_items = [(Image.open(f'example/img/{dataset.replace(" ", "_")}/{i}.png').convert('RGB'), str(i+1)) for i in range(3)]
|
102 |
+
gallery_items.append((np.ones((64, 64)), 'Custom Input'))
|
103 |
+
return gr.Gallery(gallery_items, interactive=True, columns=2, height="350px", object_fit="contain"), gr.Textbox(value=descriptions[dataset], label="Dataset Description", interactive=False)
|
104 |
|
105 |
+
def selected_example(evt: gr.SelectData):
|
106 |
+
return evt.index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
def update_time_series_dataframe(dataset, example_index):
|
109 |
+
if example_index is None:
|
110 |
+
pass
|
111 |
+
elif example_index == 3: # Custom Input
|
112 |
+
return gr.Dataframe(value=None, datatype="str", label="Time Series Input", interactive=True)
|
113 |
+
else:
|
114 |
+
df = inputs[dataset][example_index]
|
115 |
+
return gr.Dataframe(value=df, label="Time Series Input", interactive=False)
|
116 |
|
117 |
with gr.Blocks() as demo:
|
118 |
+
gr.Image(
|
119 |
+
value="logo.png",
|
120 |
+
show_label=False,
|
121 |
+
show_download_button=False,
|
122 |
+
show_fullscreen_button=False,
|
123 |
+
interactive=False,
|
124 |
+
height=128,
|
125 |
+
container=False,
|
126 |
+
elem_id="logo"
|
127 |
+
)
|
128 |
+
gr.HTML("""
|
129 |
+
<style>
|
130 |
+
#logo {
|
131 |
+
display: flex;
|
132 |
+
justify-content: flex-start;
|
133 |
+
}
|
134 |
+
</style>
|
135 |
+
""")
|
136 |
with gr.Row():
|
137 |
with gr.Column():
|
|
|
138 |
dataset_dropdown = gr.Dropdown(["NN5 Daily", "Australian Electricity"], value=None, label="Datasets", interactive=True)
|
|
|
|
|
139 |
|
140 |
+
dataset_description_textbox = gr.Textbox(label="Dataset Description", interactive=False)
|
141 |
+
|
142 |
+
example_gallery = gr.Gallery(
|
143 |
+
None,
|
144 |
+
interactive=False
|
145 |
+
)
|
146 |
+
example_index = gr.State(value=None)
|
147 |
+
example_gallery.select(selected_example, None, example_index)
|
148 |
+
|
149 |
+
time_series_dataframe = gr.Dataframe(value=None, headers=["Timestamp", "Value"], label="Time Series Input", interactive=False)
|
150 |
|
151 |
+
dataset_dropdown.change(selected_dataset, inputs=dataset_dropdown, outputs=[example_gallery, dataset_description_textbox])
|
152 |
+
dataset_dropdown.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=time_series_dataframe)
|
153 |
+
example_index.change(update_time_series_dataframe, inputs=[dataset_dropdown, example_index], outputs=time_series_dataframe)
|
154 |
|
155 |
btn = gr.Button("Run")
|
156 |
with gr.Column():
|
157 |
+
forecast_plot = gr.Plot(label="Forecast", format="png")
|
158 |
|
159 |
+
btn.click(predict, inputs=[dataset_dropdown, dataset_description_textbox, time_series_dataframe, example_index], outputs=forecast_plot)
|
160 |
|
161 |
+
if __name__ == "__main__":
|
162 |
+
demo.launch(ssr_mode=False)
|
example/img/Australian_Electricity/0.png
ADDED
![]() |
example/img/Australian_Electricity/1.png
ADDED
![]() |
example/img/Australian_Electricity/2.png
ADDED
![]() |
example/img/NN5_Daily/0.png
ADDED
![]() |
example/img/NN5_Daily/1.png
ADDED
![]() |
example/img/NN5_Daily/2.png
ADDED
![]() |
example/inputs.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2c7ada397bda4d8be66a84d12a145283bda4021a0beed57071b53dac67cc907
|
3 |
+
size 8919
|
example/targets.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2e0f82b9c0e9b1e922043a9cc4588ac4bb0ae298b2accc8fbb1c145b6566a18
|
3 |
+
size 8919
|
logo.png
ADDED
![]() |
Git LFS Details
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
gradio==5.42.0
|
2 |
-
accelerate==1.6.0
|
3 |
torch==2.6.0
|
4 |
numpy==2.2.4
|
5 |
transformers==4.40.1
|
|
|
1 |
gradio==5.42.0
|
|
|
2 |
torch==2.6.0
|
3 |
numpy==2.2.4
|
4 |
transformers==4.40.1
|