horiyouta commited on
Commit
d8c161f
·
1 Parent(s): 9905360

2508181758

Browse files
Files changed (2) hide show
  1. main.py +58 -69
  2. web/script.js +20 -5
main.py CHANGED
@@ -4,16 +4,14 @@ from typing import List, Dict, Any
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
7
- # from torchvision import datasets -> datasetsライブラリを使うため削除
8
  from torchvision import transforms
9
- from torch.utils.data import DataLoader, Subset
10
  import numpy as np
11
  import base64
12
  from io import BytesIO
13
  from PIL import Image
14
  import random
15
- import json
16
- from datasets import load_dataset # ★★★ Hugging Face datasetsライブラリをインポート
17
 
18
  # FastAPIアプリケーションインスタンスを作成
19
  app = FastAPI()
@@ -105,96 +103,84 @@ class PlayerModel(nn.Module):
105
  x = layer(x)
106
  return x
107
 
108
- # --- グローバル変数とデータ準備 (★★★ ここを大幅に変更) ---
 
109
  device = torch.device("cpu")
110
-
111
- # 1. Hugging Face HubからMNISTデータセットをロード
112
  mnist_dataset = load_dataset("mnist")
 
113
 
114
- # 2. torchvisionのtransformを定義
115
- transform = transforms.Compose([
116
- transforms.ToTensor(),
117
- # transforms.Normalize((0.1307,), (0.3081,)) # 必要に応じて正規化
118
- ])
119
-
120
- # 3. データセットにtransformを適用する関数を定義
121
  def apply_transforms(examples):
122
- # PIL Imageのリストをテンソルのリストに変換
123
  examples['image'] = [transform(image.convert("L")) for image in examples['image']]
124
  return examples
125
 
126
- # 4. データセットにtransformを適用
127
  mnist_dataset.set_transform(apply_transforms)
128
-
129
- # 5. DataLoaderを準備
130
  train_subset = mnist_dataset['train'].select(range(1000))
131
  train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
132
 
133
- # 6. テスト用の画像リストを作成 (DataLoaderと同じ (image, label) タプルの形式を維持)
134
  test_images = []
135
- # メモリ使用量を考慮し、テスト画像は1000個に絞る
136
  test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
137
  for item in test_subset_for_inference:
138
- # itemは {'image': tensor, 'label': int} という辞書
139
- image_tensor = item['image'].unsqueeze(0) # バッチ次元 (1) を追加
140
  label_tensor = torch.tensor(item['label'])
141
  test_images.append((image_tensor, label_tensor))
142
 
143
- current_enemy = None
144
- trained_player_model = None
145
 
146
- # --- バックエンドロジック ---
147
  def get_enemy():
148
- global current_enemy
149
- image, label = random.choice(test_images)
150
- current_enemy = {"image": image, "label": label}
151
 
152
- img_pil = transforms.ToPILImage()(image.squeeze(0))
153
  buffered = BytesIO()
154
  img_pil.save(buffered, format="PNG")
155
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
156
 
157
- return {"image_b64": "data:image/png;base64," + img_str}
 
 
 
158
 
159
- def train_player_model(layer_configs: list):
160
- global trained_player_model
 
 
 
 
161
  if not layer_configs:
162
- return {"success": False, "message": "モデルが空です。"}
163
-
164
  try:
165
  model = PlayerModel(layer_configs).to(device)
166
  optimizer = optim.Adam(model.parameters(), lr=0.001)
167
  loss_fn = nn.CrossEntropyLoss()
168
 
169
  model.train()
170
- for epoch in range(3): # 3エポック学習
171
- # ★★★ DataLoaderからのデータ受け取り方を変更
172
- for batch_idx, batch in enumerate(train_loader):
173
  data, target = batch['image'].to(device), batch['label'].to(device)
174
  optimizer.zero_grad()
175
  output = model(data)
176
  loss = loss_fn(output, target)
177
  loss.backward()
178
  optimizer.step()
179
- print(f"Epoch {epoch+1} training finished.")
180
-
181
- trained_player_model = model
182
- return {"success": True, "message": "モデルの訓練が完了しました!"}
183
  except Exception as e:
184
- print(f"Error during training: {e}")
185
- return {"success": False, "message": f"訓練中にエラーが発生しました: {e}"}
186
-
187
- def run_inference():
188
- global trained_player_model, current_enemy
189
- if trained_player_model is None:
190
- return {"error": "モデルが訓練されていません。"}
191
 
192
- image, label = random.choice(test_images)
193
- current_enemy = {"image": image, "label": label}
194
-
195
- model = trained_player_model
196
  model.eval()
197
 
 
 
 
 
 
 
 
 
 
 
 
198
  intermediate_outputs = {}
199
  hooks = []
200
  def get_hook(name):
@@ -206,7 +192,6 @@ def run_inference():
206
  hooks.append(layer.register_forward_hook(get_hook(name)))
207
 
208
  with torch.no_grad():
209
- image_tensor = current_enemy["image"].to(device)
210
  output = model(image_tensor)
211
 
212
  for h in hooks: h.remove()
@@ -224,34 +209,38 @@ def run_inference():
224
  weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
225
  weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
226
 
227
- is_correct = (prediction == current_enemy["label"].item())
228
-
229
- img_pil = transforms.ToPILImage()(image.squeeze(0))
230
- buffered = BytesIO()
231
- img_pil.save(buffered, format="PNG")
232
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
233
 
 
234
  return {
235
- "prediction": prediction, "label": current_enemy["label"].item(), "is_correct": is_correct,
 
 
236
  "confidence": confidence,
237
- "image_b64": "data:image/png;base64," + img_str,
238
  "architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
239
  "outputs": intermediate_outputs,
240
  "weights": weights
241
  }
242
 
243
- # --- FastAPI Endpoints (変更なし) ---
244
  @app.get("/api/get_enemy")
245
  async def get_enemy_endpoint():
246
  return get_enemy()
247
 
248
- @app.post("/api/train_player_model")
249
- async def train_player_model_endpoint(layer_configs: List[Dict[str, Any]] = Body(...)):
250
- return train_player_model(layer_configs)
251
-
252
  @app.post("/api/run_inference")
253
- async def run_inference_endpoint():
254
- return run_inference()
 
 
 
 
 
 
 
 
 
 
255
 
256
- # --- 静的ファイルの配信 (変更なし) ---
257
  app.mount("/", StaticFiles(directory="web", html=True), name="static")
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
 
7
  from torchvision import transforms
8
+ from torch.utils.data import DataLoader
9
  import numpy as np
10
  import base64
11
  from io import BytesIO
12
  from PIL import Image
13
  import random
14
+ from datasets import load_dataset
 
15
 
16
  # FastAPIアプリケーションインスタンスを作成
17
  app = FastAPI()
 
103
  x = layer(x)
104
  return x
105
 
106
+ # --- グローバル変数とデータ準備 (ステートレス対応) ---
107
+ # これらの変数はサーバー起動時に一度だけ初期化され、リクエスト間で変更されない定数として扱う
108
  device = torch.device("cpu")
 
 
109
  mnist_dataset = load_dataset("mnist")
110
+ transform = transforms.Compose([transforms.ToTensor()])
111
 
 
 
 
 
 
 
 
112
  def apply_transforms(examples):
 
113
  examples['image'] = [transform(image.convert("L")) for image in examples['image']]
114
  return examples
115
 
 
116
  mnist_dataset.set_transform(apply_transforms)
 
 
117
  train_subset = mnist_dataset['train'].select(range(1000))
118
  train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
119
 
 
120
  test_images = []
 
121
  test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
122
  for item in test_subset_for_inference:
123
+ image_tensor = item['image'].unsqueeze(0)
 
124
  label_tensor = torch.tensor(item['label'])
125
  test_images.append((image_tensor, label_tensor))
126
 
127
+ # --- バックエンドロジック (ステートレス関数) ---
 
128
 
 
129
  def get_enemy():
130
+ """新しい敵の画像(base64)と正解ラベルを返す。サーバー側では状態を保持しない。"""
131
+ image_tensor, label_tensor = random.choice(test_images)
 
132
 
133
+ img_pil = transforms.ToPILImage()(image_tensor.squeeze(0))
134
  buffered = BytesIO()
135
  img_pil.save(buffered, format="PNG")
136
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
137
 
138
+ return {
139
+ "image_b64": "data:image/png;base64," + img_str,
140
+ "label": label_tensor.item()
141
+ }
142
 
143
+ def run_inference(layer_configs: list, enemy_image_b64: str, enemy_label: int):
144
+ """
145
+ リクエストごとにモデルを構築・訓練し、与えられた敵データで推論を実行する。
146
+ サーバー側では状態を一切保持しない。
147
+ """
148
+ # 1. モデルをその場で構築し、訓練する
149
  if not layer_configs:
150
+ return {"error": "モデルが空です。"}
 
151
  try:
152
  model = PlayerModel(layer_configs).to(device)
153
  optimizer = optim.Adam(model.parameters(), lr=0.001)
154
  loss_fn = nn.CrossEntropyLoss()
155
 
156
  model.train()
157
+ for epoch in range(3): # 毎回3エポック学習
158
+ for batch in train_loader:
 
159
  data, target = batch['image'].to(device), batch['label'].to(device)
160
  optimizer.zero_grad()
161
  output = model(data)
162
  loss = loss_fn(output, target)
163
  loss.backward()
164
  optimizer.step()
165
+ print("On-the-fly training for inference finished.")
 
 
 
166
  except Exception as e:
167
+ print(f"Error during on-the-fly training: {e}")
168
+ return {"error": f"推論中のモデル構築・訓練エラー: {e}"}
 
 
 
 
 
169
 
170
+ # 2. クライアントから送られてきた敵画像で推論する
 
 
 
171
  model.eval()
172
 
173
+ # Base64文字列から画像テンソルにデコード
174
+ try:
175
+ header, encoded = enemy_image_b64.split(",", 1)
176
+ image_data = base64.b64decode(encoded)
177
+ image_pil = Image.open(BytesIO(image_data)).convert("L")
178
+ image_tensor = transforms.ToTensor()(image_pil).unsqueeze(0).to(device)
179
+ except Exception as e:
180
+ print(f"Error decoding enemy image: {e}")
181
+ return {"error": f"敵画像のデコードエラー: {e}"}
182
+
183
+ # 3. 推論と中間出力のキャプチャ
184
  intermediate_outputs = {}
185
  hooks = []
186
  def get_hook(name):
 
192
  hooks.append(layer.register_forward_hook(get_hook(name)))
193
 
194
  with torch.no_grad():
 
195
  output = model(image_tensor)
196
 
197
  for h in hooks: h.remove()
 
209
  weights[name + '_w'] = layer.weight.cpu().detach().numpy().tolist()
210
  weights[name + '_b'] = layer.bias.cpu().detach().numpy().tolist()
211
 
212
+ is_correct = (prediction == enemy_label)
 
 
 
 
 
213
 
214
+ # 4. 結果をクライアントに返す
215
  return {
216
+ "prediction": prediction,
217
+ "label": enemy_label,
218
+ "is_correct": is_correct,
219
  "confidence": confidence,
220
+ "image_b64": enemy_image_b64, # 受け取った画像をそのまま返す
221
  "architecture": [{"type": "Input", "name": "input", "shape": [1, 28, 28]}] + model.architecture_info,
222
  "outputs": intermediate_outputs,
223
  "weights": weights
224
  }
225
 
226
+ # --- FastAPI Endpoints ---
227
  @app.get("/api/get_enemy")
228
  async def get_enemy_endpoint():
229
  return get_enemy()
230
 
 
 
 
 
231
  @app.post("/api/run_inference")
232
+ async def run_inference_endpoint(payload: Dict[str, Any] = Body(...)):
233
+ """
234
+ クライアントからモデル構成と敵データを受け取り、推論結果を返すエンドポイント。
235
+ """
236
+ layer_configs = payload.get("layer_configs")
237
+ enemy_image_b64 = payload.get("enemy_image_b64")
238
+ enemy_label = payload.get("enemy_label")
239
+
240
+ if not all([layer_configs, enemy_image_b64, enemy_label is not None]):
241
+ return {"error": "リクエストのパラメータが不足しています。"}
242
+
243
+ return run_inference(layer_configs, enemy_image_b64, enemy_label)
244
 
245
+ # --- 静的ファイルの配信 ---
246
  app.mount("/", StaticFiles(directory="web", html=True), name="static")
web/script.js CHANGED
@@ -34,6 +34,7 @@ let isBattleInProgress = false; // ★★★ バトルループ中のフラグ
34
  let draggedItem = null; // { type, layer, index }
35
  let dragOverIndex = null; // 並び替え先のインデックス
36
  let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
 
37
  let ENEMY_MAX_HP = 100;
38
  const PLAYER_MAX_HP = 100;
39
 
@@ -285,14 +286,19 @@ function updateHpBars() {
285
  }
286
 
287
  async function fetchNewEnemy() {
288
- // EelからFetch APIに変更
289
  const response = await fetch('/api/get_enemy');
290
- const enemy = await response.json();
291
 
 
 
 
 
 
 
292
  enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
293
- enemyImage.src = enemy.image_b64;
294
  enemyImage.classList.remove('hidden');
295
- await animateBattleLog('', true); // Clear battle log
296
  }
297
 
298
  // --- D&D Functions ---
@@ -891,7 +897,15 @@ async function handleBattle() {
891
  await animateBattleLog('新たな敵をスキャン... 推論実行...');
892
 
893
  // EelからFetch APIに変更
894
- const inferenceResponse = await fetch('/api/run_inference', { method: 'POST' });
 
 
 
 
 
 
 
 
895
  const result = await inferenceResponse.json();
896
 
897
  if (result.error) {
@@ -921,6 +935,7 @@ async function handleBattle() {
921
 
922
  if (enemyHP > 0 && playerHP > 0) {
923
  await sleep(1500);
 
924
  }
925
  }
926
 
 
34
  let draggedItem = null; // { type, layer, index }
35
  let dragOverIndex = null; // 並び替え先のインデックス
36
  let wasDroppedSuccessfully = false; // ★★★ このフラグを追加
37
+ let currentEnemy = { image_b64: null, label: null }; // ★★★ クライアント側で敵の状態を保持
38
  let ENEMY_MAX_HP = 100;
39
  const PLAYER_MAX_HP = 100;
40
 
 
286
  }
287
 
288
  async function fetchNewEnemy() {
 
289
  const response = await fetch('/api/get_enemy');
290
+ const enemyData = await response.json();
291
 
292
+ // ★★★ グローバル変数に保存
293
+ currentEnemy = {
294
+ image_b64: enemyData.image_b64,
295
+ label: enemyData.label
296
+ };
297
+
298
  enemyMessage.textContent = '野生のMNISTモンスターが現れた!';
299
+ enemyImage.src = currentEnemy.image_b64;
300
  enemyImage.classList.remove('hidden');
301
+ await animateBattleLog('', true);
302
  }
303
 
304
  // --- D&D Functions ---
 
897
  await animateBattleLog('新たな敵をスキャン... 推論実行...');
898
 
899
  // EelからFetch APIに変更
900
+ const inferenceResponse = await fetch('/api/run_inference', {
901
+ method: 'POST',
902
+ headers: { 'Content-Type': 'application/json' },
903
+ body: JSON.stringify({
904
+ layer_configs: playerLayers,
905
+ enemy_image_b64: currentEnemy.image_b64,
906
+ enemy_label: currentEnemy.label,
907
+ }),
908
+ });
909
  const result = await inferenceResponse.json();
910
 
911
  if (result.error) {
 
935
 
936
  if (enemyHP > 0 && playerHP > 0) {
937
  await sleep(1500);
938
+ await fetchNewEnemy(); // 次の敵を準備
939
  }
940
  }
941