//! WebSocket client implementation use crate::websocket::{WebSocketConfig, WebSocketMessage, WebSocketError}; use futures_util::{SinkExt, StreamExt}; use std::pin::Pin; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use tokio::time::Duration; use tracing::{warn, debug}; use serde_json::Value; /// Message to send through the channel #[derive(Debug)] pub enum OutgoingMessage { /// Text message Text(String), /// Binary message Binary(Vec), /// Close connection Close, } /// Connection status #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum ConnectionStatus { /// Disconnected Disconnected, /// Connecting Connecting, /// Connected Connected, /// Reconnecting Reconnecting, /// Error Error, } impl std::fmt::Display for ConnectionStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ConnectionStatus::Disconnected => write!(f, "disconnected"), ConnectionStatus::Connecting => write!(f, "connecting"), ConnectionStatus::Connected => write!(f, "connected"), ConnectionStatus::Reconnecting => write!(f, "reconnecting"), ConnectionStatus::Error => write!(f, "error"), } } } /// Event callback types (wrapped in Arc for cloneability) pub type ConnectedCallback = Arc; pub type DisconnectedCallback = Arc; /// Callback for text messages - args: (msg_type, parsed_data, sender) /// sender can be used to send messages back synchronously #[derive(Clone)] pub struct MessageSender(std::sync::Arc>>>); impl MessageSender { /// Send a text message synchronously (from within the sync callback) pub fn send(&self, text: String) { if let Ok(guard) = self.0.try_lock() { if let Some(ref tx) = *guard { let _ = tx.try_send(OutgoingMessage::Text(text)); } } } } pub type MessageCallback = Arc; pub type BinaryCallback = Arc) + Send + Sync>; pub type ErrorCallback = Arc; pub type StatusCallback = Arc; pub type SentCallback = Arc; /// Callback triggered before reconnecting /// Arguments: (attempt_number, url_arc) - app can update the URL via url_arc /// Note: This is an async callback - return a boxed Future pub type ReconnectingCallback = Arc>) -> Pin + Send + Sync + 'static>> + Send + Sync>; /// Callback triggered after successful reconnection (after the first connection) /// Arguments: (url, send_fn) - send_fn can be called to send messages /// Note: This is an async callback - return a boxed Future pub type ReconnectedCallback = Arc>>>) -> Pin + Send + Sync + 'static>> + Send + Sync>; /// Callback triggered on first successful connection (before any reconnect) /// Arguments: (url, send_fn) - send_fn can be called to send messages /// Note: This is an async callback - return a boxed Future pub type FirstConnectCallback = Arc>>>) -> Pin + Send + Sync + 'static>> + Send + Sync>; /// WebSocket client with event-driven architecture pub struct WebSocketClient { config: WebSocketConfig, /// Dynamic URL (can be updated for reconnecting with new URL) url: Arc>, status: Arc>, sender: Arc>>>, message_queue: Arc>>, task_handle: Arc>>>, is_running: Arc>, // Event callbacks on_connected: Option, on_disconnected: Option, on_message: Option, on_binary: Option, on_error: Option, on_status_changed: Option, on_message_sent: Option, /// Callback triggered before reconnecting (attempt number passed as argument) on_reconnecting: Option, /// Callback triggered on first successful connection on_first_connect: Option, /// Callback triggered after successful reconnection (after the first connection) on_reconnected: Option, // Reconnection state reconnect_attempts: Arc>, reconnect_delay_ms: Arc>, } impl WebSocketClient { /// Create a new WebSocket client pub fn new(config: WebSocketConfig) -> Self { Self { config: config.clone(), url: Arc::new(Mutex::new(config.ws_url.clone())), status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)), sender: Arc::new(Mutex::new(None)), message_queue: Arc::new(Mutex::new(Vec::new())), task_handle: Arc::new(Mutex::new(None)), is_running: Arc::new(Mutex::new(false)), on_connected: None, on_disconnected: None, on_message: None, on_binary: None, on_error: None, on_status_changed: None, on_message_sent: None, on_reconnecting: None, on_first_connect: None, on_reconnected: None, reconnect_attempts: Arc::new(Mutex::new(0)), reconnect_delay_ms: Arc::new(Mutex::new(config.reconnect_delay_ms)), } } /// Create a simple client with just the URL pub fn simple(url: impl Into) -> Self { Self::new(WebSocketConfig::new(url)) } // ==================== Event Handlers ==================== /// Set callback for connected event pub fn on_connected(&mut self, callback: F) -> &mut Self where F: Fn(String) + Send + Sync + 'static, { self.on_connected = Some(Arc::new(callback)); self } /// Set callback for disconnected event pub fn on_disconnected(&mut self, callback: F) -> &mut Self where F: Fn() + Send + Sync + 'static, { self.on_disconnected = Some(Arc::new(callback)); self } /// Set callback for text message received /// The callback receives: (msg_type, parsed_data, sender) /// The sender can be used to send messages back synchronously pub fn on_message(&mut self, callback: F) -> &mut Self where F: Fn(String, Value, MessageSender) + Send + Sync + 'static, { self.on_message = Some(Arc::new(callback)); self } /// Set callback for binary message received pub fn on_binary(&mut self, callback: F) -> &mut Self where F: Fn(Vec) + Send + Sync + 'static, { self.on_binary = Some(Arc::new(callback)); self } /// Set callback for error pub fn on_error(&mut self, callback: F) -> &mut Self where F: Fn(String) + Send + Sync + 'static, { self.on_error = Some(Arc::new(callback)); self } /// Set callback for status changed pub fn on_status_changed(&mut self, callback: F) -> &mut Self where F: Fn(ConnectionStatus, ConnectionStatus) + Send + Sync + 'static, { self.on_status_changed = Some(Arc::new(callback)); self } /// Set callback for message sent pub fn on_message_sent(&mut self, callback: F) -> &mut Self where F: Fn(String, Value) + Send + Sync + 'static, { self.on_message_sent = Some(Arc::new(callback)); self } /// Set callback for reconnecting (called before each reconnect attempt) /// This is an async callback - the returned Future will be awaited /// The callback receives: (attempt_number, url_arc) /// Use url_arc to update the URL for the next connection attempt pub fn on_reconnecting(&mut self, callback: F) -> &mut Self where F: Fn(u32, Arc>) -> Pin + Send + Sync + 'static>> + Send + Sync + 'static, { self.on_reconnecting = Some(Arc::new(callback)); self } /// Set callback for first successful connection (before any reconnect) /// This is an async callback - the returned Future will be awaited pub fn on_first_connect(&mut self, callback: F) -> &mut Self where F: Fn(String, Arc>>>) -> Pin + Send + Sync + 'static>> + Send + Sync + 'static, { self.on_first_connect = Some(Arc::new(callback)); self } /// Set callback for reconnection (called after successful reconnection, not on first connection) /// This is an async callback - the returned Future will be awaited pub fn on_reconnected(&mut self, callback: F) -> &mut Self where F: Fn(String, Arc>>>) -> Pin + Send + Sync + 'static>> + Send + Sync + 'static, { self.on_reconnected = Some(Arc::new(callback)); self } /// Update the URL for future reconnection attempts /// Call this method when the server URL changes pub fn update_url(&mut self, new_url: String) { if self.config.debug_mode { debug!("Updating URL: {} -> {}", self.config.ws_url, new_url); } self.config.ws_url = new_url.clone(); *self.url.blocking_lock() = new_url; } /// Send a text message through the WebSocket (can be called from any context) /// Returns Ok(()) if the message was sent, Err(msg) if not connected pub async fn send_text(&self, text: String) -> Result<(), String> { let sender = self.sender.lock().await; if let Some(ref tx) = *sender { tx.send(OutgoingMessage::Text(text)).await.map_err(|e| e.to_string()) } else { Err("Not connected".to_string()) } } // ==================== Connection Management ==================== /// Connect to the WebSocket server (uses dynamic URL that can be updated) /// This method blocks until the connection ends completely (or is stopped via disconnect()) pub async fn connect(&mut self) { let url = self.url.lock().await.clone(); self.connect_with_url(&url).await; } /// Connect to a specific URL (overrides config URL) /// Spawns the WebSocket task and waits for it to complete pub async fn connect_with_url(&mut self, url: &str) { let _old_status = *self.status.lock().await; self.set_status(ConnectionStatus::Connecting).await; if self.config.debug_mode { debug!("Connecting to WebSocket server: {}", url); } // Create channel for outgoing messages let (tx, rx) = mpsc::channel::(100); *self.sender.lock().await = Some(tx); // Clone shared state for the task let url = url.to_string(); let config = self.config.clone(); let status = Arc::clone(&self.status); let queue = Arc::clone(&self.message_queue); let is_running = Arc::clone(&self.is_running); let reconnect_attempts = Arc::clone(&self.reconnect_attempts); let reconnect_delay_ms = Arc::clone(&self.reconnect_delay_ms); let client_url = Arc::clone(&self.url); let sender = Arc::clone(&self.sender); // Callbacks let on_connected = self.on_connected.clone(); let on_disconnected = self.on_disconnected.clone(); let on_message = self.on_message.clone(); let on_binary = self.on_binary.clone(); let on_error = self.on_error.clone(); let on_reconnecting = self.on_reconnecting.clone(); let on_first_connect = self.on_first_connect.clone(); let on_reconnected = self.on_reconnected.clone(); // Spawn the WebSocket task *self.is_running.lock().await = true; let task_handle = tokio::spawn(async move { Self::websocket_loop( url, config, rx, status, queue, is_running, reconnect_attempts, reconnect_delay_ms, client_url, on_connected, on_disconnected, on_message, on_binary, on_error, on_reconnecting, on_first_connect, on_reconnected, sender, ) .await; }); // Store handle and wait for task to complete *self.task_handle.lock().await = Some(task_handle); // Wait for the task to complete (blocks until websocket_loop ends) if let Some(handle) = self.task_handle.lock().await.take() { let _ = handle.await; } } /// Main WebSocket loop async fn websocket_loop( _url: String, // Initial URL (superseded by client_url for reconnecting) config: WebSocketConfig, mut receiver: mpsc::Receiver, status: Arc>, queue: Arc>>, is_running: Arc>, reconnect_attempts: Arc>, reconnect_delay_ms: Arc>, client_url: Arc>, on_connected: Option, on_disconnected: Option, on_message: Option, on_binary: Option, on_error: Option, on_reconnecting: Option, on_first_connect: Option, on_reconnected: Option, sender: Arc>>>, ) { loop { let should_run = *is_running.lock().await; if !should_run { break; } // Get current URL (may have been updated) let current_url = client_url.lock().await.clone(); // Check if this is the first connection (not a reconnect) let is_first = *reconnect_attempts.lock().await == 0; // Clone sender for first_connect callback let sender_clone = Arc::clone(&sender); // Create MessageSender for on_message callback (same underlying sender) let message_sender = MessageSender(Arc::clone(&sender)); match Self::connect_and_handle(¤t_url, &config, &mut receiver, &queue, &on_connected, &on_message, &on_binary, &on_error, is_first, &on_first_connect, &on_reconnected, sender_clone, message_sender).await { Ok(_) => { if config.debug_mode { debug!("WebSocket connection closed normally"); } // 即使正常关闭,也尝试重连(如果启用) if config.reconnect && *is_running.lock().await { let delay = *reconnect_delay_ms.lock().await; let attempts = *reconnect_attempts.lock().await + 1; *reconnect_attempts.lock().await = attempts; if config.debug_mode { debug!("Reconnecting in {}ms (attempt {})", delay, attempts); } // Trigger reconnecting callback if let Some(ref callback) = on_reconnecting { callback(attempts, client_url.clone()).await; } *status.lock().await = ConnectionStatus::Reconnecting; tokio::time::sleep(Duration::from_millis(delay)).await; // Exponential backoff let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms); *reconnect_delay_ms.lock().await = new_delay; continue; } else { break; } } Err(e) => { if config.debug_mode { warn!("WebSocket error: {:?}", e); } if let Some(ref callback) = on_error { callback(e.to_string()); } if config.reconnect && *is_running.lock().await { let delay = *reconnect_delay_ms.lock().await; let attempts = *reconnect_attempts.lock().await + 1; *reconnect_attempts.lock().await = attempts; if config.debug_mode { debug!("Reconnecting in {}ms (attempt {})", delay, attempts); } // Trigger reconnecting callback (allows app to update URL, reload config, etc.) if let Some(ref callback) = on_reconnecting { callback(attempts, client_url.clone()).await; } *status.lock().await = ConnectionStatus::Reconnecting; tokio::time::sleep(Duration::from_millis(delay)).await; // Exponential backoff let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms); *reconnect_delay_ms.lock().await = new_delay; continue; } else { break; } } } } *status.lock().await = ConnectionStatus::Disconnected; if let Some(callback) = on_disconnected { callback(); } } /// Connect and handle messages async fn connect_and_handle( url: &str, config: &WebSocketConfig, receiver: &mut mpsc::Receiver, _queue: &Arc>>, on_connected: &Option, on_message: &Option, on_binary: &Option, on_error: &Option, is_first: bool, on_first_connect: &Option, on_reconnected: &Option, sender: Arc>>>, message_sender: MessageSender, ) -> Result<(), WebSocketError> { use http::header::{HeaderName, HeaderValue}; use tokio_tungstenite::tungstenite::client::IntoClientRequest; // Build request with headers let mut request = url.into_client_request()?; // Add client type header if specified if let Some(ref client_type) = config.client_type { request.headers_mut().insert( HeaderName::from_static("clienttype"), HeaderValue::from_str(client_type) .map_err(|e| WebSocketError::HeaderError(e.to_string()))?, ); } // Add custom headers for (key, value) in &config.custom_headers { if let (Ok(name), Ok(val)) = ( HeaderName::try_from(key.as_str()), HeaderValue::from_str(value), ) { request.headers_mut().insert(name, val); } } // Connect with timeout let ws_stream = tokio::time::timeout( Duration::from_millis(config.connect_timeout_ms), tokio_tungstenite::connect_async(request), ) .await .map_err(|_| WebSocketError::ConnectionTimeout)? .map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?; let (mut write, mut read) = ws_stream.0.split(); if config.debug_mode { debug!("Connected to WebSocket server"); } // Call on_connected callback if let Some(ref callback) = on_connected { callback(url.to_string()); } // Call on_first_connect callback (only on first connection, not on reconnect) if is_first { if let Some(ref callback) = on_first_connect { callback(url.to_string(), sender).await; } } else { // Call on_reconnected callback (after successful reconnection) if let Some(ref callback) = on_reconnected { callback(url.to_string(), sender).await; } } // Start heartbeat task if enabled let heartbeat_handle = if config.heartbeat_interval_ms > 0 { let interval_duration = Duration::from_millis(config.heartbeat_interval_ms); Some(tokio::spawn(async move { let mut ticker = tokio::time::interval(interval_duration); loop { ticker.tick().await; // Heartbeat will be sent via the receiver } })) } else { None }; // Handle messages loop { tokio::select! { // Incoming message from server msg = read.next() => { match msg { Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => { if config.debug_mode { debug!("Received text: {}", text); } // Parse and extract type if let Ok(parsed) = serde_json::from_str::(&text) { let msg_type = parsed.get("Type") .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); if let Some(ref callback) = on_message { callback(msg_type, parsed, message_sender.clone()); } } else if let Some(ref callback) = on_message { callback("raw".to_string(), serde_json::json!(text), message_sender.clone()); } } Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => { if config.debug_mode { debug!("Received binary: {} bytes", data.len()); } if let Some(ref callback) = on_binary { callback(data); } } Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) => { if config.debug_mode { debug!("Server initiated close"); } break; } Some(Err(e)) => { if config.debug_mode { warn!("WebSocket read error: {}", e); } if let Some(ref callback) = on_error { callback(e.to_string()); } break; } None => { if config.debug_mode { debug!("WebSocket stream ended"); } break; } _ => {} } } // Outgoing message outgoing = receiver.recv() => { match outgoing { Some(OutgoingMessage::Text(text)) => { if config.debug_mode { debug!("Sending text: {}", text); } write.send(tokio_tungstenite::tungstenite::Message::Text(text)) .await .map_err(|e| WebSocketError::SendFailed(e.to_string()))?; } Some(OutgoingMessage::Binary(data)) => { if config.debug_mode { debug!("Sending binary: {} bytes", data.len()); } write.send(tokio_tungstenite::tungstenite::Message::Binary(data)) .await .map_err(|e| WebSocketError::SendFailed(e.to_string()))?; } Some(OutgoingMessage::Close) => { if config.debug_mode { debug!("Initiating close"); } write.close().await.ok(); break; } None => { break; } } } } } // Wait for heartbeat handle if let Some(handle) = heartbeat_handle { handle.abort(); } Ok(()) } /// Disconnect from the server pub async fn disconnect(&mut self) { *self.is_running.lock().await = false; if let Some(sender) = self.sender.lock().await.take() { let _ = sender.send(OutgoingMessage::Close).await; } if let Some(handle) = self.task_handle.lock().await.take() { handle.abort(); } self.set_status(ConnectionStatus::Disconnected).await; } // ==================== Message Sending ==================== /// Send a text message pub async fn send(&self, msg_type: &str, data: Value) -> Result<(), WebSocketError> { let message = WebSocketMessage::new(msg_type, data); let json = message.to_json_string()?; self.send_raw_text(json).await } /// Send raw text pub async fn send_raw_text(&self, text: String) -> Result<(), WebSocketError> { let status = *self.status.lock().await; if status == ConnectionStatus::Connected { if let Some(ref sender) = *self.sender.lock().await { sender.send(OutgoingMessage::Text(text)) .await .map_err(|_| WebSocketError::NotConnected)?; } } else { // Queue the message let mut queue = self.message_queue.lock().await; if queue.len() < self.config.max_queue_size { if self.config.debug_mode { debug!("Queueing message (queue size: {})", queue.len() + 1); } queue.push(OutgoingMessage::Text(text)); } else { if self.config.debug_mode { warn!("Queue full, message dropped"); } return Err(WebSocketError::QueueFull); } } Ok(()) } /// Send binary data pub async fn send_binary(&self, data: Vec) -> Result<(), WebSocketError> { let status = *self.status.lock().await; if status == ConnectionStatus::Connected { if let Some(ref sender) = *self.sender.lock().await { sender.send(OutgoingMessage::Binary(data)) .await .map_err(|_| WebSocketError::NotConnected)?; } } else { let mut queue = self.message_queue.lock().await; if queue.len() < self.config.max_queue_size { if self.config.debug_mode { debug!("Queueing binary message"); } queue.push(OutgoingMessage::Binary(data)); } else { return Err(WebSocketError::QueueFull); } } Ok(()) } // ==================== Status & Queue ==================== /// Get current connection status pub async fn get_status(&self) -> ConnectionStatus { *self.status.lock().await } /// Check if connected pub async fn is_connected(&self) -> bool { *self.status.lock().await == ConnectionStatus::Connected } /// Get queue size pub async fn get_queue_size(&self) -> usize { self.message_queue.lock().await.len() } /// Flush queued messages pub async fn flush_queue(&self) -> Result<(), WebSocketError> { let status = *self.status.lock().await; if status != ConnectionStatus::Connected { return Err(WebSocketError::NotConnected); } let mut queue = self.message_queue.lock().await; if let Some(ref sender) = *self.sender.lock().await { while let Some(msg) = queue.pop() { sender.send(msg) .await .map_err(|_| WebSocketError::NotConnected)?; } } Ok(()) } // ==================== Private Helpers ==================== #[allow(clippy::large_enum_variant)] async fn set_status(&mut self, new_status: ConnectionStatus) { let old_status = *self.status.lock().await; if old_status != new_status { *self.status.lock().await = new_status.clone(); // Reset reconnect state on successful connection if new_status == ConnectionStatus::Connected { *self.reconnect_attempts.lock().await = 0; *self.reconnect_delay_ms.lock().await = self.config.reconnect_delay_ms; } if let Some(ref callback) = self.on_status_changed { callback(old_status, new_status); } } } } impl Drop for WebSocketClient { fn drop(&mut self) { // 不在这里调用异步 disconnect(),因为: // 1. 无法在 Drop 中安全地使用 block_on (会导致 "Cannot start a runtime from within a runtime") // 2. 调用者应该显式调用 disconnect() 方法 // 3. Tokio 的 JoinHandle 会在 task 结束时自动清理资源 // 如果 client 未被显式 disconnect() 就被 drop,资源会通过 JoinHandle 的 abort 自然清理 } } #[cfg(test)] mod tests { use super::*; #[test] fn test_config_builder() { let config = WebSocketConfig::new("ws://localhost:8080") .with_client_type("Client") .with_debug(true) .with_reconnect(false); assert_eq!(config.ws_url, "ws://localhost:8080"); assert_eq!(config.client_type, Some("Client".to_string())); assert!(config.debug_mode); assert!(!config.reconnect); } #[tokio::test] async fn test_client_creation() { let config = WebSocketConfig::default(); let client = WebSocketClient::new(config); let status = client.get_status().await; assert_eq!(status, ConnectionStatus::Disconnected); } }