Spaces:
Sleeping
Sleeping
2508181758
Browse files- main.py +58 -69
- web/script.js +20 -5
main.py
CHANGED
@@ -4,16 +4,14 @@ from typing import List, Dict, Any
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
7 |
-
# from torchvision import datasets -> datasetsライブラリを使うため削除
|
8 |
from torchvision import transforms
|
9 |
-
from torch.utils.data import DataLoader
|
10 |
import numpy as np
|
11 |
import base64
|
12 |
from io import BytesIO
|
13 |
from PIL import Image
|
14 |
import random
|
15 |
-
import
|
16 |
-
from datasets import load_dataset # ★★★ Hugging Face datasetsライブラリをインポート
|
17 |
|
18 |
# FastAPIアプリケーションインスタンスを作成
|
19 |
app = FastAPI()
|
@@ -105,96 +103,84 @@ class PlayerModel(nn.Module):
|
|
105 |
x = layer(x)
|
106 |
return x
|
107 |
|
108 |
-
# --- グローバル変数とデータ準備 (
|
|
|
109 |
device = torch.device("cpu")
|
110 |
-
|
111 |
-
# 1. Hugging Face HubからMNISTデータセットをロード
|
112 |
mnist_dataset = load_dataset("mnist")
|
|
|
113 |
|
114 |
-
# 2. torchvisionのtransformを定義
|
115 |
-
transform = transforms.Compose([
|
116 |
-
transforms.ToTensor(),
|
117 |
-
# transforms.Normalize((0.1307,), (0.3081,)) # 必要に応じて正規化
|
118 |
-
])
|
119 |
-
|
120 |
-
# 3. データセットにtransformを適用する関数を定義
|
121 |
def apply_transforms(examples):
|
122 |
-
# PIL Imageのリストをテンソルのリストに変換
|
123 |
examples['image'] = [transform(image.convert("L")) for image in examples['image']]
|
124 |
return examples
|
125 |
|
126 |
-
# 4. データセットにtransformを適用
|
127 |
mnist_dataset.set_transform(apply_transforms)
|
128 |
-
|
129 |
-
# 5. DataLoaderを準備
|
130 |
train_subset = mnist_dataset['train'].select(range(1000))
|
131 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
132 |
|
133 |
-
# 6. テスト用の画像リストを作成 (DataLoaderと同じ (image, label) タプルの形式を維持)
|
134 |
test_images = []
|
135 |
-
# メモリ使用量を考慮し、テスト画像は1000個に絞る
|
136 |
test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
|
137 |
for item in test_subset_for_inference:
|
138 |
-
|
139 |
-
image_tensor = item['image'].unsqueeze(0) # バッチ次元 (1) を追加
|
140 |
label_tensor = torch.tensor(item['label'])
|
141 |
test_images.append((image_tensor, label_tensor))
|
142 |
|
143 |
-
|
144 |
-
trained_player_model = None
|
145 |
|
146 |
-
# --- バックエンドロジック ---
|
147 |
def get_enemy():
|
148 |
-
|
149 |
-
|
150 |
-
current_enemy = {"image": image, "label": label}
|
151 |
|
152 |
-
img_pil = transforms.ToPILImage()(
|
153 |
buffered = BytesIO()
|
154 |
img_pil.save(buffered, format="PNG")
|
155 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
156 |
|
157 |
-
return {
|
|
|
|
|
|
|
158 |
|
159 |
-
def
|
160 |
-
|
|
|
|
|
|
|
|
|
161 |
if not layer_configs:
|
162 |
-
return {"
|
163 |
-
|
164 |
try:
|
165 |
model = PlayerModel(layer_configs).to(device)
|
166 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
167 |
loss_fn = nn.CrossEntropyLoss()
|
168 |
|
169 |
model.train()
|
170 |
-
for epoch in range(3): # 3エポック学習
|
171 |
-
|
172 |
-
for batch_idx, batch in enumerate(train_loader):
|
173 |
data, target = batch['image'].to(device), batch['label'].to(device)
|
174 |
optimizer.zero_grad()
|
175 |
output = model(data)
|
176 |
loss = loss_fn(output, target)
|
177 |
loss.backward()
|
178 |
optimizer.step()
|
179 |
-
|
180 |
-
|
181 |
-
trained_player_model = model
|
182 |
-
return {"success": True, "message": "モデルの訓練が完了しました!"}
|
183 |
except Exception as e:
|
184 |
-
print(f"Error during training: {e}")
|
185 |
-
return {"
|
186 |
-
|
187 |
-
def run_inference():
|
188 |
-
global trained_player_model, current_enemy
|
189 |
-
if trained_player_model is None:
|
190 |
-
return {"error": "モデルが訓練されていません。"}
|
191 |
|
192 |
-
|
193 |
-
current_enemy = {"image": image, "label": label}
|
194 |
-
|
195 |
-
model = trained_player_model
|
196 |
model.eval()
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
intermediate_outputs = {}
|
199 |
hooks = []
|
200 |
def get_hook(name):
|
@@ -206,7 +192,6 @@ def run_inference():
|
|
206 |
hooks.append(layer.register_forward_hook(get_hook(name)))
|
207 |
|
208 |
with torch.no_grad():
|
209 |
-
image_tensor = current_enemy["image"].to(device)
|
210 |
output = model(image_tensor)
|
211 |
|
212 |
for h in hooks: h.remove()
|
@@ -224,34 +209,38 @@ def run_inference():
|
|
224 |
weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
|
225 |
weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
|
226 |
|
227 |
-
is_correct = (prediction ==
|
228 |
-
|
229 |
-
img_pil = transforms.ToPILImage()(image.squeeze(0))
|
230 |
-
buffered = BytesIO()
|
231 |
-
img_pil.save(buffered, format="PNG")
|
232 |
-
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
233 |
|
|
|
234 |
return {
|
235 |
-
"prediction": prediction,
|
|
|
|
|
236 |
"confidence": confidence,
|
237 |
-
"image_b64":
|
238 |
"architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
|
239 |
"outputs": intermediate_outputs,
|
240 |
"weights": weights
|
241 |
}
|
242 |
|
243 |
-
# --- FastAPI Endpoints
|
244 |
@app.get("/api/get_enemy")
|
245 |
async def get_enemy_endpoint():
|
246 |
return get_enemy()
|
247 |
|
248 |
-
@app.post("/api/train_player_model")
|
249 |
-
async def train_player_model_endpoint(layer_configs: List[Dict[str, Any]] = Body(...)):
|
250 |
-
return train_player_model(layer_configs)
|
251 |
-
|
252 |
@app.post("/api/run_inference")
|
253 |
-
async def run_inference_endpoint():
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
# --- 静的ファイルの配信
|
257 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
|
|
7 |
from torchvision import transforms
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
import numpy as np
|
10 |
import base64
|
11 |
from io import BytesIO
|
12 |
from PIL import Image
|
13 |
import random
|
14 |
+
from datasets import load_dataset
|
|
|
15 |
|
16 |
# FastAPIアプリケーションインスタンスを作成
|
17 |
app = FastAPI()
|
|
|
103 |
x = layer(x)
|
104 |
return x
|
105 |
|
106 |
+
# --- グローバル変数とデータ準備 (ステートレス対応) ---
|
107 |
+
# これらの変数はサーバー起動時に一度だけ初期化され、リクエスト間で変更されない定数として扱う
|
108 |
device = torch.device("cpu")
|
|
|
|
|
109 |
mnist_dataset = load_dataset("mnist")
|
110 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def apply_transforms(examples):
|
|
|
113 |
examples['image'] = [transform(image.convert("L")) for image in examples['image']]
|
114 |
return examples
|
115 |
|
|
|
116 |
mnist_dataset.set_transform(apply_transforms)
|
|
|
|
|
117 |
train_subset = mnist_dataset['train'].select(range(1000))
|
118 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
119 |
|
|
|
120 |
test_images = []
|
|
|
121 |
test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
|
122 |
for item in test_subset_for_inference:
|
123 |
+
image_tensor = item['image'].unsqueeze(0)
|
|
|
124 |
label_tensor = torch.tensor(item['label'])
|
125 |
test_images.append((image_tensor, label_tensor))
|
126 |
|
127 |
+
# --- バックエンドロジック (ステートレス関数) ---
|
|
|
128 |
|
|
|
129 |
def get_enemy():
|
130 |
+
"""新しい敵の画像(base64)と正解ラベルを返す。サーバー側では状態を保持しない。"""
|
131 |
+
image_tensor, label_tensor = random.choice(test_images)
|
|
|
132 |
|
133 |
+
img_pil = transforms.ToPILImage()(image_tensor.squeeze(0))
|
134 |
buffered = BytesIO()
|
135 |
img_pil.save(buffered, format="PNG")
|
136 |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
137 |
|
138 |
+
return {
|
139 |
+
"image_b64": "data:image/png;base64," + img_str,
|
140 |
+
"label": label_tensor.item()
|
141 |
+
}
|
142 |
|
143 |
+
def run_inference(layer_configs: list, enemy_image_b64: str, enemy_label: int):
|
144 |
+
"""
|
145 |
+
リクエストごとにモデルを構築・訓練し、与えられた敵データで推論を実行する。
|
146 |
+
サーバー側では状態を一切保持しない。
|
147 |
+
"""
|
148 |
+
# 1. モデルをその場で構築し、訓練する
|
149 |
if not layer_configs:
|
150 |
+
return {"error": "モデルが空です。"}
|
|
|
151 |
try:
|
152 |
model = PlayerModel(layer_configs).to(device)
|
153 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
154 |
loss_fn = nn.CrossEntropyLoss()
|
155 |
|
156 |
model.train()
|
157 |
+
for epoch in range(3): # 毎回3エポック学習
|
158 |
+
for batch in train_loader:
|
|
|
159 |
data, target = batch['image'].to(device), batch['label'].to(device)
|
160 |
optimizer.zero_grad()
|
161 |
output = model(data)
|
162 |
loss = loss_fn(output, target)
|
163 |
loss.backward()
|
164 |
optimizer.step()
|
165 |
+
print("On-the-fly training for inference finished.")
|
|
|
|
|
|
|
166 |
except Exception as e:
|
167 |
+
print(f"Error during on-the-fly training: {e}")
|
168 |
+
return {"error": f"推論中のモデル構築・訓練エラー: {e}"}
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
# 2. クライアントから送られてきた敵画像で推論する
|
|
|
|
|
|
|
171 |
model.eval()
|
172 |
|
173 |
+
# Base64文字列から画像テンソルにデコード
|
174 |
+
try:
|
175 |
+
header, encoded = enemy_image_b64.split(",", 1)
|
176 |
+
image_data = base64.b64decode(encoded)
|
177 |
+
image_pil = Image.open(BytesIO(image_data)).convert("L")
|
178 |
+
image_tensor = transforms.ToTensor()(image_pil).unsqueeze(0).to(device)
|
179 |
+
except Exception as e:
|
180 |
+
print(f"Error decoding enemy image: {e}")
|
181 |
+
return {"error": f"敵画像のデコードエラー: {e}"}
|
182 |
+
|
183 |
+
# 3. 推論と中間出力のキャプチャ
|
184 |
intermediate_outputs = {}
|
185 |
hooks = []
|
186 |
def get_hook(name):
|
|
|
192 |
hooks.append(layer.register_forward_hook(get_hook(name)))
|
193 |
|
194 |
with torch.no_grad():
|
|
|
195 |
output = model(image_tensor)
|
196 |
|
197 |
for h in hooks: h.remove()
|
|
|
209 |
weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
|
210 |
weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
|
211 |
|
212 |
+
is_correct = (prediction == enemy_label)
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
# 4. 結果をクライアントに返す
|
215 |
return {
|
216 |
+
"prediction": prediction,
|
217 |
+
"label": enemy_label,
|
218 |
+
"is_correct": is_correct,
|
219 |
"confidence": confidence,
|
220 |
+
"image_b64": enemy_image_b64, # 受け取った画像をそのまま返す
|
221 |
"architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
|
222 |
"outputs": intermediate_outputs,
|
223 |
"weights": weights
|
224 |
}
|
225 |
|
226 |
+
# --- FastAPI Endpoints ---
|
227 |
@app.get("/api/get_enemy")
|
228 |
async def get_enemy_endpoint():
|
229 |
return get_enemy()
|
230 |
|
|
|
|
|
|
|
|
|
231 |
@app.post("/api/run_inference")
|
232 |
+
async def run_inference_endpoint(payload: Dict[str, Any] = Body(...)):
|
233 |
+
"""
|
234 |
+
クライアントからモデル構成と敵データを受け取り、推論結果を返すエンドポイント。
|
235 |
+
"""
|
236 |
+
layer_configs = payload.get("layer_configs")
|
237 |
+
enemy_image_b64 = payload.get("enemy_image_b64")
|
238 |
+
enemy_label = payload.get("enemy_label")
|
239 |
+
|
240 |
+
if not all([layer_configs, enemy_image_b64, enemy_label is not None]):
|
241 |
+
return {"error": "リクエストのパラメータが不足しています。"}
|
242 |
+
|
243 |
+
return run_inference(layer_configs, enemy_image_b64, enemy_label)
|
244 |
|
245 |
+
# --- 静的ファイルの配信 ---
|
246 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
web/script.js
CHANGED
@@ -34,6 +34,7 @@ let isBattleInProgress = false; // ★★★ バトルループ中のフラグ
|
|
34 |
let draggedItem = null; // { type, layer, index }
|
35 |
let dragOverIndex = null; // 並び替え先のインデックス
|
36 |
let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
|
|
|
37 |
let ENEMY_MAX_HP = 100;
|
38 |
const PLAYER_MAX_HP = 100;
|
39 |
|
@@ -285,14 +286,19 @@ function updateHpBars() {
|
|
285 |
}
|
286 |
|
287 |
async function fetchNewEnemy() {
|
288 |
-
// EelからFetch APIに変更
|
289 |
const response = await fetch('/api/get_enemy');
|
290 |
-
const
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
|
293 |
-
enemyImage.src =
|
294 |
enemyImage.classList.remove('hidden');
|
295 |
-
await animateBattleLog('', true);
|
296 |
}
|
297 |
|
298 |
// --- D&D Functions ---
|
@@ -891,7 +897,15 @@ async function handleBattle() {
|
|
891 |
await animateBattleLog('新たな敵をスキャン... 推論実行...');
|
892 |
|
893 |
// EelからFetch APIに変更
|
894 |
-
const inferenceResponse = await fetch('/api/run_inference', {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
const result = await inferenceResponse.json();
|
896 |
|
897 |
if (result.error) {
|
@@ -921,6 +935,7 @@ async function handleBattle() {
|
|
921 |
|
922 |
if (enemyHP > 0 && playerHP > 0) {
|
923 |
await sleep(1500);
|
|
|
924 |
}
|
925 |
}
|
926 |
|
|
|
34 |
let draggedItem = null; // { type, layer, index }
|
35 |
let dragOverIndex = null; // 並び替え先のインデックス
|
36 |
let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
|
37 |
+
let currentEnemy = { image_b64: null, label: null }; // ★★★ クライアント側で敵の状態を保持
|
38 |
let ENEMY_MAX_HP = 100;
|
39 |
const PLAYER_MAX_HP = 100;
|
40 |
|
|
|
286 |
}
|
287 |
|
288 |
async function fetchNewEnemy() {
|
|
|
289 |
const response = await fetch('/api/get_enemy');
|
290 |
+
const enemyData = await response.json();
|
291 |
|
292 |
+
// ★★★ グローバル変数に保存
|
293 |
+
currentEnemy = {
|
294 |
+
image_b64: enemyData.image_b64,
|
295 |
+
label: enemyData.label
|
296 |
+
};
|
297 |
+
|
298 |
enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
|
299 |
+
enemyImage.src = currentEnemy.image_b64;
|
300 |
enemyImage.classList.remove('hidden');
|
301 |
+
await animateBattleLog('', true);
|
302 |
}
|
303 |
|
304 |
// --- D&D Functions ---
|
|
|
897 |
await animateBattleLog('新たな敵をスキャン... 推論実行...');
|
898 |
|
899 |
// EelからFetch APIに変更
|
900 |
+
const inferenceResponse = await fetch('/api/run_inference', {
|
901 |
+
method: 'POST',
|
902 |
+
headers: { 'Content-Type': 'application/json' },
|
903 |
+
body: JSON.stringify({
|
904 |
+
layer_configs: playerLayers,
|
905 |
+
enemy_image_b64: currentEnemy.image_b64,
|
906 |
+
enemy_label: currentEnemy.label,
|
907 |
+
}),
|
908 |
+
});
|
909 |
const result = await inferenceResponse.json();
|
910 |
|
911 |
if (result.error) {
|
|
|
935 |
|
936 |
if (enemyHP > 0 && playerHP > 0) {
|
937 |
await sleep(1500);
|
938 |
+
await fetchNewEnemy(); // 次の敵を準備
|
939 |
}
|
940 |
}
|
941 |
|