use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; use std::process::Command; use tokio::signal; use tokio::time; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; /// Updater 自身配置(AppData/Updater/config.json) /// 只负责 Updater 自己的行为参数,连接地址从公共 config.json 加载 #[derive(Debug, Serialize, Deserialize)] struct Config { /// 调试模式:true 时保留控制台窗口并输出日志 debug_mode: bool, } impl Default for Config { fn default() -> Self { Self { debug_mode: false } } } /// 获取 Updater 自身配置路径 AppData/Updater/config.json fn get_updater_config_path() -> PathBuf { let exe_path = std::env::current_exe().expect("Failed to get executable path"); let drive = exe_path .parent() .and_then(|p| p.as_os_str().to_str()) .and_then(|s| s.split('\\').next()) .unwrap_or("C:"); let appdata = PathBuf::from(format!("{}/AppData", drive)); let updater_dir = appdata.join("Updater"); let _ = fs::create_dir_all(&updater_dir); updater_dir.join("config.json") } /// 获取公共配置路径 AppData/config.json(与 BootLoader 同级) fn get_public_config_path() -> PathBuf { let exe_path = std::env::current_exe().expect("Failed to get executable path"); let drive = exe_path .parent() .and_then(|p| p.as_os_str().to_str()) .and_then(|s| s.split('\\').next()) .unwrap_or("C:"); PathBuf::from(format!("{}/AppData/config.json", drive)) } /// 加载 Updater 自身配置;若文件不存在则写入默认值 fn load_updater_config() -> Config { let config_path = get_updater_config_path(); if config_path.exists() { if let Ok(content) = fs::read_to_string(&config_path) { if let Ok(config) = serde_json::from_str::(&content) { return config; } } } // 文件不存在或解析失败 → 写入默认值 let default_config = Config::default(); if let Ok(content) = serde_json::to_string_pretty(&default_config) { let _ = fs::write(&config_path, content); } default_config } /// 从公共 config.json 读取 ServerUrl 字段 fn resolve_ws_url() -> String { let config_path = get_public_config_path(); if let Ok(content) = fs::read_to_string(&config_path) { if let Ok(json) = serde_json::from_str::(&content) { if let Some(url) = json.get("ServerUrl").and_then(|v| v.as_str()) { return url.to_string(); } } } // 读取失败 → 降级到默认值 "ws://127.0.0.1:8087/ws".to_string() } fn is_process_running(process_name: &str) -> bool { use std::process::id; let current_pid = id().to_string(); let output = Command::new("tasklist") .args(["/FI", &format!("IMAGENAME eq {}", process_name), "/FO", "CSV"]) .output() .expect("Failed to execute tasklist"); let output_str = String::from_utf8_lossy(&output.stdout); let lines: Vec<&str> = output_str.lines().collect(); let mut count = 0; for line in lines { if line.contains(&format!("\"{}\"", process_name)) && !line.contains(¤t_pid) { count += 1; } } count > 0 } async fn upgrade(server_url: &str, debug_mode: bool) { if debug_mode { println!("开始升级检查,连接服务端..."); } connect_to_websocket(server_url, debug_mode).await; } async fn connect_to_websocket(server_url: &str, debug_mode: bool) { use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::{connect_async, tungstenite::Message}; if debug_mode { println!("Connecting to WebSocket server: {}", server_url); } if let Ok(request) = server_url.into_client_request() { match connect_async(request).await { Ok((mut ws_stream, _)) => { if debug_mode { println!("Connected to WebSocket server"); } // 注册身份 let app_info = serde_json::json!({ "type": "Updater", "action": "register" }); if let Err(e) = ws_stream.send(Message::Text(app_info.to_string())).await { eprintln!("Failed to send register info: {:?}", e); } else { match ws_stream.next().await { Some(Ok(message)) => { if debug_mode { println!("Server response: {:?}", message); } } Some(Err(e)) => { eprintln!("WebSocket error: {:?}", e); } None => { eprintln!("WebSocket connection closed by server"); } } } } Err(e) => { eprintln!("Failed to connect to WebSocket server: {:?}", e); } } } else { eprintln!("Invalid WebSocket URL: {}", server_url); } } #[tokio::main] async fn main() { // 检查是否已有 Updater 进程在运行 if is_process_running("Updater.exe") { return; } // 加载 Updater 自身配置(debug_mode) let config = load_updater_config(); // 从公共 config.json 解析 WebSocket 连接地址 let server_url = resolve_ws_url(); // 非 debug 模式下释放控制台,后台静默运行 if !config.debug_mode { #[cfg(windows)] { use windows::Win32::System::Console; use windows::Win32::Foundation::HWND; unsafe { let console = Console::GetConsoleWindow(); if console != HWND::default() { let _ = Console::FreeConsole(); } } } } if config.debug_mode { println!("Updater started in debug mode"); println!("Server URL: {}", server_url); let mut interval = time::interval(time::Duration::from_secs(300)); loop { tokio::select! { _ = interval.tick() => { upgrade(&server_url, config.debug_mode).await; } _ = signal::ctrl_c() => { println!("Received Ctrl+C, exiting..."); break; } } } } else { let mut interval = time::interval(time::Duration::from_secs(300)); loop { interval.tick().await; upgrade(&server_url, config.debug_mode).await; } } }