vrp-shanghai-transformer / inference.py
a-ragab-h-m's picture
Update inference.py
aa1a460 verified
import torch
import json
import os
import csv
from datetime import datetime
import pandas as pd
import numpy as np
from nets.model import Model
from Actor.actor import Actor
# --- إعداد المسارات ---
safe_data_dir = "/home/user/data"
orders_file = os.path.join(safe_data_dir, "orders.csv")
params_path = os.path.join(safe_data_dir, 'params_saved.json')
model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
# --- التحقق من الملفات ---
if not os.path.exists(params_path):
raise FileNotFoundError(f"Settings file not found at {params_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
if not os.path.exists(orders_file):
raise FileNotFoundError(f"orders.csv not found at {orders_file}")
# --- تحميل الإعدادات ---
with open(params_path, 'r') as f:
params = json.load(f)
device = params['device']
# --- تحميل النموذج بعد ضبط input_size = 4 ---
model = Model(
input_size = 4,
embedding_size=params["embedding_size"],
decoder_input_size=params["decoder_input_size"]
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# --- قراءة بيانات الطلبات ---
df = pd.read_csv(orders_file)
# --- استخراج الإحداثيات ---
pickup_lng = df['lng'].to_numpy()
pickup_lat = df['lat'].to_numpy()
delivery_lng = df['delivery_gps_lng'].to_numpy()
delivery_lat = df['delivery_gps_lat'].to_numpy()
coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
# --- تجهيز Batch كـ Tuple لتجنب unpacking error في Actor ---
graph_data = {"coords": coords_tensor}
fleet_data = {"dummy": torch.tensor([0])} # يمكن تعديل هذا لاحقاً عند الحاجة
batch = (graph_data, fleet_data)
# --- تهيئة الممثل والـ NN ---
actor = Actor(
model=model,
num_movers=params['num_movers'],
num_neighbors_encoder=params['num_neighbors_encoder'],
num_neighbors_action=params['num_neighbors_action'],
device=device,
normalize=False
)
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
nn_actor.nearest_neighbors()
# --- تنفيذ الاستدلال ---
with torch.no_grad():
actor.greedy_search()
actor_output = actor(batch)
total_time = actor_output['total_time'].item()
nn_output = nn_actor(batch)
nn_time = nn_output['total_time'].item()
improvement = (nn_time - total_time) / nn_time * 100
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# --- عرض ملخص الإدخالات ---
coords_preview = "\n".join([
f"Order {i}: P=({lng1:.4f},{lat1:.4f}) → D=({lng2:.4f},{lat2:.4f})"
for i, (lng1, lat1, lng2, lat2) in enumerate(coords[:5])
])
if coords.shape[0] > 5:
coords_preview += f"\n... (showing 5 of {coords.shape[0]} orders)"
input_summary = f"📌 Input Orders Preview:\n{coords_preview}"
# --- نتيجة مفصلة للطباعة ---
result_text = (
"\n===== INFERENCE RESULT =====\n"
f"Time: {timestamp}\n"
f"Actor Model Total Cost: {total_time:.4f} units\n"
f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
f"Improvement over NN : {improvement:.2f}%\n"
)
print(result_text)
# --- ملخص للواجهة ---
summary_text = (
f"🕒 Time: {timestamp}\n"
f"🚚 Actor Cost: {total_time:.4f} units\n"
f"📍 NN Cost: {nn_time:.4f} units\n"
f"📈 Improvement: {improvement:.2f}%\n\n"
f"{input_summary}"
)
print(f"\n🔍 Summary for UI:\n{summary_text}")
# --- حفظ النتائج CSV ---
write_header = not os.path.exists(csv_results_file)
with open(csv_results_file, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
if write_header:
writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
# --- حفظ نصي ---
with open(txt_results_file, 'a') as f:
f.write(result_text)
f.write("\n=============================\n")