Spaces:
Sleeping
Sleeping
Update prediction.py
Browse files- prediction.py +48 -12
prediction.py
CHANGED
@@ -13,6 +13,7 @@ from config import STATION_NAMES
|
|
13 |
from supabase_utils import (
|
14 |
get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client
|
15 |
)
|
|
|
16 |
|
17 |
def get_common_args(station_id):
|
18 |
return [
|
@@ -23,19 +24,30 @@ def get_common_args(station_id):
|
|
23 |
]
|
24 |
|
25 |
def validate_csv_file(file_path, required_rows=144):
|
26 |
-
"""CSV 파일 유효성 검사"""
|
27 |
try:
|
28 |
df = pd.read_csv(file_path)
|
29 |
-
required_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp', 'residual']
|
30 |
-
missing_columns = [col for col in required_columns if col not in df.columns]
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if len(df) < required_rows:
|
36 |
return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행"
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
except Exception as e:
|
40 |
return False, f"파일 읽기 오류: {str(e)}"
|
41 |
|
@@ -231,12 +243,27 @@ def single_prediction(station_id, input_csv_file):
|
|
231 |
if input_csv_file is None:
|
232 |
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
|
233 |
|
|
|
234 |
is_valid, message = validate_csv_file(input_csv_file.name)
|
235 |
if not is_valid:
|
236 |
raise gr.Error(f"파일 오류: {message}")
|
237 |
|
238 |
station_name = STATION_NAMES.get(station_id, station_id)
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
common_args = get_common_args(station_id)
|
241 |
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
|
242 |
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
|
@@ -247,10 +274,11 @@ def single_prediction(station_id, input_csv_file):
|
|
247 |
if not os.path.exists(scaler_path):
|
248 |
raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}")
|
249 |
|
|
|
250 |
command = ["python", "inference.py",
|
251 |
"--checkpoint_path", checkpoint_path,
|
252 |
"--scaler_path", scaler_path,
|
253 |
-
"--predict_input_file",
|
254 |
|
255 |
gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
|
256 |
|
@@ -261,12 +289,13 @@ def single_prediction(station_id, input_csv_file):
|
|
261 |
if os.path.exists(prediction_file):
|
262 |
residual_predictions = np.load(prediction_file)
|
263 |
|
264 |
-
|
265 |
-
input_df
|
266 |
last_time = input_df['date'].iloc[-1]
|
267 |
|
268 |
prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
|
269 |
-
|
|
|
270 |
|
271 |
has_harmonic = any(h != 0 for h in prediction_results['harmonic'])
|
272 |
|
@@ -291,7 +320,14 @@ def single_prediction(station_id, input_csv_file):
|
|
291 |
else:
|
292 |
save_message = "\n⚠️ Supabase 저장 실패"
|
293 |
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
else:
|
296 |
return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
|
297 |
except Exception as e:
|
|
|
13 |
from supabase_utils import (
|
14 |
get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client
|
15 |
)
|
16 |
+
from preprocessing import preprocess_uploaded_file
|
17 |
|
18 |
def get_common_args(station_id):
|
19 |
return [
|
|
|
24 |
]
|
25 |
|
26 |
def validate_csv_file(file_path, required_rows=144):
|
27 |
+
"""CSV 파일 유효성 검사 - tide_level 또는 residual 지원"""
|
28 |
try:
|
29 |
df = pd.read_csv(file_path)
|
|
|
|
|
30 |
|
31 |
+
# 기본 필수 컬럼 (tide_level 또는 residual 중 하나는 있어야 함)
|
32 |
+
base_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp']
|
33 |
+
missing_base = [col for col in base_columns if col not in df.columns]
|
34 |
+
|
35 |
+
if missing_base:
|
36 |
+
return False, f"필수 컬럼이 누락되었습니다: {missing_base}"
|
37 |
+
|
38 |
+
# tide_level 또는 residual 중 하나는 있어야 함
|
39 |
+
has_tide_level = 'tide_level' in df.columns
|
40 |
+
has_residual = 'residual' in df.columns
|
41 |
+
|
42 |
+
if not has_tide_level and not has_residual:
|
43 |
+
return False, "tide_level 또는 residual 컬럼이 필요합니다."
|
44 |
|
45 |
if len(df) < required_rows:
|
46 |
return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행"
|
47 |
+
|
48 |
+
data_type = "tide_level" if has_tide_level else "residual"
|
49 |
+
return True, f"파일이 유효합니다. (데이터 형태: {data_type})"
|
50 |
+
|
51 |
except Exception as e:
|
52 |
return False, f"파일 읽기 오류: {str(e)}"
|
53 |
|
|
|
243 |
if input_csv_file is None:
|
244 |
raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
|
245 |
|
246 |
+
# 1. 초기 파일 검증
|
247 |
is_valid, message = validate_csv_file(input_csv_file.name)
|
248 |
if not is_valid:
|
249 |
raise gr.Error(f"파일 오류: {message}")
|
250 |
|
251 |
station_name = STATION_NAMES.get(station_id, station_id)
|
252 |
|
253 |
+
# 2. 전처리 수행 (tide_level → residual 변환 포함)
|
254 |
+
gr.Info(f"📊 {station_name}({station_id}) 데이터 전처리 중...")
|
255 |
+
processed_data, preprocess_result = preprocess_uploaded_file(input_csv_file.name, station_id)
|
256 |
+
|
257 |
+
if processed_data is None:
|
258 |
+
raise gr.Error(f"전처리 실패: {preprocess_result}")
|
259 |
+
|
260 |
+
# 전처리 결과가 문자열(에러)인지 딕셔너리(성공)인지 확인
|
261 |
+
if isinstance(preprocess_result, str):
|
262 |
+
raise gr.Error(f"전처리 오류: {preprocess_result}")
|
263 |
+
|
264 |
+
# 전처리된 파일 경로 사용
|
265 |
+
processed_file_path = preprocess_result['output_file']
|
266 |
+
|
267 |
common_args = get_common_args(station_id)
|
268 |
setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
|
269 |
checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
|
|
|
274 |
if not os.path.exists(scaler_path):
|
275 |
raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}")
|
276 |
|
277 |
+
# 전처리된 파일을 inference에 전달
|
278 |
command = ["python", "inference.py",
|
279 |
"--checkpoint_path", checkpoint_path,
|
280 |
"--scaler_path", scaler_path,
|
281 |
+
"--predict_input_file", processed_file_path] + common_args
|
282 |
|
283 |
gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
|
284 |
|
|
|
289 |
if os.path.exists(prediction_file):
|
290 |
residual_predictions = np.load(prediction_file)
|
291 |
|
292 |
+
# 전처리된 데이터 사용
|
293 |
+
input_df = processed_data
|
294 |
last_time = input_df['date'].iloc[-1]
|
295 |
|
296 |
prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
|
297 |
+
# 플롯은 전처리된 데이터 파일을 사용
|
298 |
+
plot = create_enhanced_prediction_plot(prediction_results, type('obj', (object,), {'name': processed_file_path}), station_name)
|
299 |
|
300 |
has_harmonic = any(h != 0 for h in prediction_results['harmonic'])
|
301 |
|
|
|
320 |
else:
|
321 |
save_message = "\n⚠️ Supabase 저장 실패"
|
322 |
|
323 |
+
# 전처리 정보 추가
|
324 |
+
preprocess_info = f"""📊 전처리 결과:
|
325 |
+
- 원본 데이터: {preprocess_result['original_rows']}행
|
326 |
+
- 처리 데이터: {preprocess_result['processed_rows']}행
|
327 |
+
- Residual 평균: {preprocess_result['residual_mean']:.2f}cm
|
328 |
+
- Residual 표준편차: {preprocess_result['residual_std']:.2f}cm"""
|
329 |
+
|
330 |
+
return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{preprocess_info}\n\n{output}"
|
331 |
else:
|
332 |
return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
|
333 |
except Exception as e:
|