Spaces:
Sleeping
Sleeping
2508181735
Browse files- .gitattributes +1 -2
- main.py +47 -31
- requirements.txt +2 -1
.gitattributes
CHANGED
@@ -32,5 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
data/* filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
main.py
CHANGED
@@ -4,7 +4,8 @@ from typing import List, Dict, Any
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
7 |
-
from torchvision import datasets
|
|
|
8 |
from torch.utils.data import DataLoader, Subset
|
9 |
import numpy as np
|
10 |
import base64
|
@@ -12,6 +13,7 @@ from io import BytesIO
|
|
12 |
from PIL import Image
|
13 |
import random
|
14 |
import json
|
|
|
15 |
|
16 |
# FastAPIアプリケーションインスタンスを作成
|
17 |
app = FastAPI()
|
@@ -22,7 +24,7 @@ class PlayerModel(nn.Module):
|
|
22 |
super(PlayerModel, self).__init__()
|
23 |
self.layers = nn.ModuleList()
|
24 |
self.architecture_info = []
|
25 |
-
self.hookable_layers = {}
|
26 |
|
27 |
in_channels = 1
|
28 |
feature_map_size = 28
|
@@ -30,10 +32,8 @@ class PlayerModel(nn.Module):
|
|
30 |
|
31 |
for i, config in enumerate(layer_configs):
|
32 |
layer_type = config['type']
|
33 |
-
# ユニークな名前を生成
|
34 |
name = f"{layer_type.lower()}_{len([info for info in self.architecture_info if info['type'] == layer_type])}"
|
35 |
|
36 |
-
# 畳み込み/プーリング層
|
37 |
if layer_type in ['Conv2d', 'MaxPool2d', 'AvgPool2d']:
|
38 |
is_flattened = False
|
39 |
if layer_type == 'Conv2d':
|
@@ -44,27 +44,23 @@ class PlayerModel(nn.Module):
|
|
44 |
self.hookable_layers[name] = layer
|
45 |
in_channels = out_channels
|
46 |
self.architecture_info.append({"type": "Conv2d", "name": name, "shape": [out_channels, feature_map_size, feature_map_size]})
|
47 |
-
else:
|
48 |
kernel_size = config['params']['kernel_size']
|
49 |
if layer_type == 'MaxPool2d':
|
50 |
layer = nn.MaxPool2d(kernel_size=kernel_size, stride=kernel_size)
|
51 |
-
else:
|
52 |
layer = nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size)
|
53 |
self.layers.append(layer)
|
54 |
self.hookable_layers[name] = layer
|
55 |
feature_map_size //= kernel_size
|
56 |
self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels, feature_map_size, feature_map_size]})
|
57 |
-
|
58 |
-
# 活性化/ドロップアウト層
|
59 |
elif layer_type in ['ReLU', 'Dropout']:
|
60 |
if layer_type == 'ReLU':
|
61 |
self.layers.append(nn.ReLU())
|
62 |
-
else:
|
63 |
p = config['params']['p']
|
64 |
self.layers.append(nn.Dropout(p=p))
|
65 |
self.architecture_info.append({"type": layer_type, "name": name})
|
66 |
-
|
67 |
-
# 構造変更層
|
68 |
elif layer_type == 'Flatten':
|
69 |
if not is_flattened:
|
70 |
layer = nn.Flatten()
|
@@ -74,8 +70,6 @@ class PlayerModel(nn.Module):
|
|
74 |
in_channels = flat_features
|
75 |
self.architecture_info.append({"type": "Flatten", "name": name, "shape": [flat_features]})
|
76 |
is_flattened = True
|
77 |
-
|
78 |
-
# 全結合/残差ブロック層
|
79 |
elif layer_type in ['Linear', 'ResidualBlock']:
|
80 |
if not is_flattened:
|
81 |
auto_flatten_name = f"auto_flatten_{i}"
|
@@ -84,28 +78,23 @@ class PlayerModel(nn.Module):
|
|
84 |
in_channels = flat_features
|
85 |
self.architecture_info.append({"type": "Flatten", "name": auto_flatten_name, "shape": [flat_features]})
|
86 |
is_flattened = True
|
87 |
-
|
88 |
if layer_type == 'Linear':
|
89 |
out_features = config['params']['out_features']
|
90 |
layer = nn.Linear(in_channels, out_features)
|
91 |
in_channels = out_features
|
92 |
-
else:
|
93 |
-
# ★★★ 残差ブロックは次元を維持する線形層として実装
|
94 |
features = in_channels
|
95 |
layer = nn.Linear(features, features)
|
96 |
-
|
97 |
self.layers.append(layer)
|
98 |
self.hookable_layers[name] = layer
|
99 |
self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels]})
|
100 |
|
101 |
-
# 最終出力層を強制的に追加
|
102 |
if not self.layers or not isinstance(self.layers[-1], nn.Linear) or self.layers[-1].out_features != 10:
|
103 |
if not is_flattened:
|
104 |
self.layers.append(nn.Flatten())
|
105 |
final_in_features = in_channels * feature_map_size * feature_map_size
|
106 |
else:
|
107 |
final_in_features = in_channels
|
108 |
-
|
109 |
output_layer = nn.Linear(final_in_features, 10)
|
110 |
self.layers.append(output_layer)
|
111 |
self.hookable_layers["linear_output"] = output_layer
|
@@ -116,18 +105,45 @@ class PlayerModel(nn.Module):
|
|
116 |
x = layer(x)
|
117 |
return x
|
118 |
|
119 |
-
# ---
|
120 |
device = torch.device("cpu")
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
current_enemy = None
|
128 |
trained_player_model = None
|
129 |
|
130 |
-
# --- バックエンドロジック
|
131 |
def get_enemy():
|
132 |
global current_enemy
|
133 |
image, label = random.choice(test_images)
|
@@ -152,8 +168,9 @@ def train_player_model(layer_configs: list):
|
|
152 |
|
153 |
model.train()
|
154 |
for epoch in range(3): # 3エポック学習
|
155 |
-
|
156 |
-
|
|
|
157 |
optimizer.zero_grad()
|
158 |
output = model(data)
|
159 |
loss = loss_fn(output, target)
|
@@ -223,7 +240,7 @@ def run_inference():
|
|
223 |
"weights": weights
|
224 |
}
|
225 |
|
226 |
-
# --- FastAPI Endpoints ---
|
227 |
@app.get("/api/get_enemy")
|
228 |
async def get_enemy_endpoint():
|
229 |
return get_enemy()
|
@@ -236,6 +253,5 @@ async def train_player_model_endpoint(layer_configs: List[Dict[str, Any]] = Body
|
|
236 |
async def run_inference_endpoint():
|
237 |
return run_inference()
|
238 |
|
239 |
-
# --- 静的ファイルの配信 ---
|
240 |
-
# フロントエンドのファイル (index.html, style.css, script.js) を配信
|
241 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
7 |
+
# from torchvision import datasets -> datasetsライブラリを使うため削除
|
8 |
+
from torchvision import transforms
|
9 |
from torch.utils.data import DataLoader, Subset
|
10 |
import numpy as np
|
11 |
import base64
|
|
|
13 |
from PIL import Image
|
14 |
import random
|
15 |
import json
|
16 |
+
from datasets import load_dataset # ★★★ Hugging Face datasetsライブラリをインポート
|
17 |
|
18 |
# FastAPIアプリケーションインスタンスを作成
|
19 |
app = FastAPI()
|
|
|
24 |
super(PlayerModel, self).__init__()
|
25 |
self.layers = nn.ModuleList()
|
26 |
self.architecture_info = []
|
27 |
+
self.hookable_layers = {}
|
28 |
|
29 |
in_channels = 1
|
30 |
feature_map_size = 28
|
|
|
32 |
|
33 |
for i, config in enumerate(layer_configs):
|
34 |
layer_type = config['type']
|
|
|
35 |
name = f"{layer_type.lower()}_{len([info for info in self.architecture_info if info['type'] == layer_type])}"
|
36 |
|
|
|
37 |
if layer_type in ['Conv2d', 'MaxPool2d', 'AvgPool2d']:
|
38 |
is_flattened = False
|
39 |
if layer_type == 'Conv2d':
|
|
|
44 |
self.hookable_layers[name] = layer
|
45 |
in_channels = out_channels
|
46 |
self.architecture_info.append({"type": "Conv2d", "name": name, "shape": [out_channels, feature_map_size, feature_map_size]})
|
47 |
+
else:
|
48 |
kernel_size = config['params']['kernel_size']
|
49 |
if layer_type == 'MaxPool2d':
|
50 |
layer = nn.MaxPool2d(kernel_size=kernel_size, stride=kernel_size)
|
51 |
+
else:
|
52 |
layer = nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size)
|
53 |
self.layers.append(layer)
|
54 |
self.hookable_layers[name] = layer
|
55 |
feature_map_size //= kernel_size
|
56 |
self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels, feature_map_size, feature_map_size]})
|
|
|
|
|
57 |
elif layer_type in ['ReLU', 'Dropout']:
|
58 |
if layer_type == 'ReLU':
|
59 |
self.layers.append(nn.ReLU())
|
60 |
+
else:
|
61 |
p = config['params']['p']
|
62 |
self.layers.append(nn.Dropout(p=p))
|
63 |
self.architecture_info.append({"type": layer_type, "name": name})
|
|
|
|
|
64 |
elif layer_type == 'Flatten':
|
65 |
if not is_flattened:
|
66 |
layer = nn.Flatten()
|
|
|
70 |
in_channels = flat_features
|
71 |
self.architecture_info.append({"type": "Flatten", "name": name, "shape": [flat_features]})
|
72 |
is_flattened = True
|
|
|
|
|
73 |
elif layer_type in ['Linear', 'ResidualBlock']:
|
74 |
if not is_flattened:
|
75 |
auto_flatten_name = f"auto_flatten_{i}"
|
|
|
78 |
in_channels = flat_features
|
79 |
self.architecture_info.append({"type": "Flatten", "name": auto_flatten_name, "shape": [flat_features]})
|
80 |
is_flattened = True
|
|
|
81 |
if layer_type == 'Linear':
|
82 |
out_features = config['params']['out_features']
|
83 |
layer = nn.Linear(in_channels, out_features)
|
84 |
in_channels = out_features
|
85 |
+
else:
|
|
|
86 |
features = in_channels
|
87 |
layer = nn.Linear(features, features)
|
|
|
88 |
self.layers.append(layer)
|
89 |
self.hookable_layers[name] = layer
|
90 |
self.architecture_info.append({"type": layer_type, "name": name, "shape": [in_channels]})
|
91 |
|
|
|
92 |
if not self.layers or not isinstance(self.layers[-1], nn.Linear) or self.layers[-1].out_features != 10:
|
93 |
if not is_flattened:
|
94 |
self.layers.append(nn.Flatten())
|
95 |
final_in_features = in_channels * feature_map_size * feature_map_size
|
96 |
else:
|
97 |
final_in_features = in_channels
|
|
|
98 |
output_layer = nn.Linear(final_in_features, 10)
|
99 |
self.layers.append(output_layer)
|
100 |
self.hookable_layers["linear_output"] = output_layer
|
|
|
105 |
x = layer(x)
|
106 |
return x
|
107 |
|
108 |
+
# --- グローバル変数とデータ準備 (★★★ ここを大幅に変更) ---
|
109 |
device = torch.device("cpu")
|
110 |
+
|
111 |
+
# 1. Hugging Face HubからMNISTデータセットをロード
|
112 |
+
mnist_dataset = load_dataset("mnist")
|
113 |
+
|
114 |
+
# 2. torchvisionのtransformを定義
|
115 |
+
transform = transforms.Compose([
|
116 |
+
transforms.ToTensor(),
|
117 |
+
# transforms.Normalize((0.1307,), (0.3081,)) # 必要に応じて正規化
|
118 |
+
])
|
119 |
+
|
120 |
+
# 3. データセットにtransformを適用する関数を定義
|
121 |
+
def apply_transforms(examples):
|
122 |
+
# PIL Imageのリストをテンソルのリストに変換
|
123 |
+
examples['image'] = [transform(image.convert("L")) for image in examples['image']]
|
124 |
+
return examples
|
125 |
+
|
126 |
+
# 4. データセットにtransformを適用
|
127 |
+
mnist_dataset.set_transform(apply_transforms)
|
128 |
+
|
129 |
+
# 5. DataLoaderを準備
|
130 |
+
train_subset = mnist_dataset['train'].select(range(1000))
|
131 |
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
|
132 |
+
|
133 |
+
# 6. テスト用の画像リストを作成 (DataLoaderと同じ (image, label) タプルの形式を維持)
|
134 |
+
test_images = []
|
135 |
+
# メモリ使用量を考慮し、テスト画像は1000個に絞る
|
136 |
+
test_subset_for_inference = mnist_dataset['test'].shuffle().select(range(1000))
|
137 |
+
for item in test_subset_for_inference:
|
138 |
+
# itemは {'image': tensor, 'label': int} という辞書
|
139 |
+
image_tensor = item['image'].unsqueeze(0) # バッチ次元 (1) を追加
|
140 |
+
label_tensor = torch.tensor(item['label'])
|
141 |
+
test_images.append((image_tensor, label_tensor))
|
142 |
+
|
143 |
current_enemy = None
|
144 |
trained_player_model = None
|
145 |
|
146 |
+
# --- バックエンドロジック ---
|
147 |
def get_enemy():
|
148 |
global current_enemy
|
149 |
image, label = random.choice(test_images)
|
|
|
168 |
|
169 |
model.train()
|
170 |
for epoch in range(3): # 3エポック学習
|
171 |
+
# ★★★ DataLoaderからのデータ受け取り方を変更
|
172 |
+
for batch_idx, batch in enumerate(train_loader):
|
173 |
+
data, target = batch['image'].to(device), batch['label'].to(device)
|
174 |
optimizer.zero_grad()
|
175 |
output = model(data)
|
176 |
loss = loss_fn(output, target)
|
|
|
240 |
"weights": weights
|
241 |
}
|
242 |
|
243 |
+
# --- FastAPI Endpoints (変更なし) ---
|
244 |
@app.get("/api/get_enemy")
|
245 |
async def get_enemy_endpoint():
|
246 |
return get_enemy()
|
|
|
253 |
async def run_inference_endpoint():
|
254 |
return run_inference()
|
255 |
|
256 |
+
# --- 静的ファイルの配信 (変更なし) ---
|
|
|
257 |
app.mount("/", StaticFiles(directory="web", html=True), name="static")
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ uvicorn[standard]
|
|
3 |
torch
|
4 |
torchvision
|
5 |
numpy
|
6 |
-
Pillow
|
|
|
|
3 |
torch
|
4 |
torchvision
|
5 |
numpy
|
6 |
+
Pillow
|
7 |
+
datasets
|