alwaysgood commited on
Commit
7d0339b
·
verified ·
1 Parent(s): 12db48a

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +48 -12
prediction.py CHANGED
@@ -13,6 +13,7 @@ from config import STATION_NAMES
13
  from supabase_utils import (
14
  get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client
15
  )
 
16
 
17
  def get_common_args(station_id):
18
  return [
@@ -23,19 +24,30 @@ def get_common_args(station_id):
23
  ]
24
 
25
  def validate_csv_file(file_path, required_rows=144):
26
- """CSV 파일 유효성 검사"""
27
  try:
28
  df = pd.read_csv(file_path)
29
- required_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp', 'residual']
30
- missing_columns = [col for col in required_columns if col not in df.columns]
31
 
32
- if missing_columns:
33
- return False, f"필수 컬럼이 누락되었습니다: {missing_columns}"
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if len(df) < required_rows:
36
  return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행"
37
-
38
- return True, "파일이 유효합니다."
 
 
39
  except Exception as e:
40
  return False, f"파일 읽기 오류: {str(e)}"
41
 
@@ -231,12 +243,27 @@ def single_prediction(station_id, input_csv_file):
231
  if input_csv_file is None:
232
  raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
233
 
 
234
  is_valid, message = validate_csv_file(input_csv_file.name)
235
  if not is_valid:
236
  raise gr.Error(f"파일 오류: {message}")
237
 
238
  station_name = STATION_NAMES.get(station_id, station_id)
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  common_args = get_common_args(station_id)
241
  setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
242
  checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
@@ -247,10 +274,11 @@ def single_prediction(station_id, input_csv_file):
247
  if not os.path.exists(scaler_path):
248
  raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}")
249
 
 
250
  command = ["python", "inference.py",
251
  "--checkpoint_path", checkpoint_path,
252
  "--scaler_path", scaler_path,
253
- "--predict_input_file", input_csv_file.name] + common_args
254
 
255
  gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
256
 
@@ -261,12 +289,13 @@ def single_prediction(station_id, input_csv_file):
261
  if os.path.exists(prediction_file):
262
  residual_predictions = np.load(prediction_file)
263
 
264
- input_df = pd.read_csv(input_csv_file.name)
265
- input_df['date'] = pd.to_datetime(input_df['date'])
266
  last_time = input_df['date'].iloc[-1]
267
 
268
  prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
269
- plot = create_enhanced_prediction_plot(prediction_results, input_csv_file, station_name)
 
270
 
271
  has_harmonic = any(h != 0 for h in prediction_results['harmonic'])
272
 
@@ -291,7 +320,14 @@ def single_prediction(station_id, input_csv_file):
291
  else:
292
  save_message = "\n⚠️ Supabase 저장 실패"
293
 
294
- return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{output}"
 
 
 
 
 
 
 
295
  else:
296
  return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
297
  except Exception as e:
 
13
  from supabase_utils import (
14
  get_harmonic_predictions, save_predictions_to_supabase, get_supabase_client
15
  )
16
+ from preprocessing import preprocess_uploaded_file
17
 
18
  def get_common_args(station_id):
19
  return [
 
24
  ]
25
 
26
  def validate_csv_file(file_path, required_rows=144):
27
+ """CSV 파일 유효성 검사 - tide_level 또는 residual 지원"""
28
  try:
29
  df = pd.read_csv(file_path)
 
 
30
 
31
+ # 기본 필수 컬럼 (tide_level 또는 residual 중 하나는 있어야 함)
32
+ base_columns = ['date', 'air_pres', 'wind_dir', 'wind_speed', 'air_temp']
33
+ missing_base = [col for col in base_columns if col not in df.columns]
34
+
35
+ if missing_base:
36
+ return False, f"필수 컬럼이 누락되었습니다: {missing_base}"
37
+
38
+ # tide_level 또는 residual 중 하나는 있어야 함
39
+ has_tide_level = 'tide_level' in df.columns
40
+ has_residual = 'residual' in df.columns
41
+
42
+ if not has_tide_level and not has_residual:
43
+ return False, "tide_level 또는 residual 컬럼이 필요합니다."
44
 
45
  if len(df) < required_rows:
46
  return False, f"데이터가 부족합니다. 최소 {required_rows}행 필요, 현재 {len(df)}행"
47
+
48
+ data_type = "tide_level" if has_tide_level else "residual"
49
+ return True, f"파일이 유효합니다. (데이터 형태: {data_type})"
50
+
51
  except Exception as e:
52
  return False, f"파일 읽기 오류: {str(e)}"
53
 
 
243
  if input_csv_file is None:
244
  raise gr.Error("예측을 위한 입력 파일을 업로드해주세요.")
245
 
246
+ # 1. 초기 파일 검증
247
  is_valid, message = validate_csv_file(input_csv_file.name)
248
  if not is_valid:
249
  raise gr.Error(f"파일 오류: {message}")
250
 
251
  station_name = STATION_NAMES.get(station_id, station_id)
252
 
253
+ # 2. 전처리 수행 (tide_level → residual 변환 포함)
254
+ gr.Info(f"📊 {station_name}({station_id}) 데이터 전처리 중...")
255
+ processed_data, preprocess_result = preprocess_uploaded_file(input_csv_file.name, station_id)
256
+
257
+ if processed_data is None:
258
+ raise gr.Error(f"전처리 실패: {preprocess_result}")
259
+
260
+ # 전처리 결과가 문자열(에러)인지 딕셔너리(성공)인지 확인
261
+ if isinstance(preprocess_result, str):
262
+ raise gr.Error(f"전처리 오류: {preprocess_result}")
263
+
264
+ # 전처리된 파일 경로 사용
265
+ processed_file_path = preprocess_result['output_file']
266
+
267
  common_args = get_common_args(station_id)
268
  setting_name = f"long_term_forecast_{station_id}_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0"
269
  checkpoint_path = f"./checkpoints/{setting_name}/checkpoint.pth"
 
274
  if not os.path.exists(scaler_path):
275
  raise gr.Error(f"스케일러 파일을 찾을 수 없습니다: {scaler_path}")
276
 
277
+ # 전처리된 파일을 inference에 전달
278
  command = ["python", "inference.py",
279
  "--checkpoint_path", checkpoint_path,
280
  "--scaler_path", scaler_path,
281
+ "--predict_input_file", processed_file_path] + common_args
282
 
283
  gr.Info(f"{station_name}({station_id}) 통합 조위 예측을 실행중입니다...")
284
 
 
289
  if os.path.exists(prediction_file):
290
  residual_predictions = np.load(prediction_file)
291
 
292
+ # 전처리된 데이터 사용
293
+ input_df = processed_data
294
  last_time = input_df['date'].iloc[-1]
295
 
296
  prediction_results = calculate_final_tide(residual_predictions, station_id, last_time)
297
+ # 플롯은 전처리된 데이터 파일을 사용
298
+ plot = create_enhanced_prediction_plot(prediction_results, type('obj', (object,), {'name': processed_file_path}), station_name)
299
 
300
  has_harmonic = any(h != 0 for h in prediction_results['harmonic'])
301
 
 
320
  else:
321
  save_message = "\n⚠️ Supabase 저장 실패"
322
 
323
+ # 전처리 정보 추가
324
+ preprocess_info = f"""📊 전처리 결과:
325
+ - 원본 데이터: {preprocess_result['original_rows']}행
326
+ - 처리 데이터: {preprocess_result['processed_rows']}행
327
+ - Residual 평균: {preprocess_result['residual_mean']:.2f}cm
328
+ - Residual 표준편차: {preprocess_result['residual_std']:.2f}cm"""
329
+
330
+ return plot, result_df, f"✅ 예측 완료!{save_message}\n\n{preprocess_info}\n\n{output}"
331
  else:
332
  return None, None, f"❌ 결과 파일을 찾을 수 없습니다.\n\n{output}"
333
  except Exception as e: