@@ -10,7 +10,7 @@ use serde_json::Value;
/// Message to send through the channel
#[ derive(Debug) ]
enum OutgoingMessage {
pub enum OutgoingMessage {
/// Text message
Text ( String ) ,
/// Binary message
@@ -55,10 +55,18 @@ pub type BinaryCallback = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
pub type ErrorCallback = Arc < dyn Fn ( String ) + Send + Sync > ;
pub type StatusCallback = Arc < dyn Fn ( ConnectionStatus , ConnectionStatus ) + Send + Sync > ;
pub type SentCallback = Arc < dyn Fn ( String , Value ) + Send + Sync > ;
/// Callback triggered before reconnecting
/// Arguments: (attempt_number, url_arc) - app can update the URL via url_arc
pub type ReconnectingCallback = Arc < dyn Fn ( u32 , Arc < tokio ::sync ::Mutex < String > > ) + Send + Sync > ;
/// Callback triggered on first successful connection (before any reconnect)
/// Arguments: (url, send_fn) - send_fn can be called to send messages
pub type FirstConnectCallback = Arc < dyn Fn ( String , Arc < Mutex < Option < mpsc ::Sender < OutgoingMessage > > > > ) + Send + Sync > ;
/// WebSocket client with event-driven architecture
pub struct WebSocketClient {
config : WebSocketConfig ,
/// Dynamic URL (can be updated for reconnecting with new URL)
url : Arc < Mutex < String > > ,
status : Arc < Mutex < ConnectionStatus > > ,
sender : Arc < Mutex < Option < mpsc ::Sender < OutgoingMessage > > > > ,
message_queue : Arc < Mutex < Vec < OutgoingMessage > > > ,
@@ -73,6 +81,10 @@ pub struct WebSocketClient {
on_error : Option < ErrorCallback > ,
on_status_changed : Option < StatusCallback > ,
on_message_sent : Option < SentCallback > ,
/// Callback triggered before reconnecting (attempt number passed as argument)
on_reconnecting : Option < ReconnectingCallback > ,
/// Callback triggered on first successful connection
on_first_connect : Option < FirstConnectCallback > ,
// Reconnection state
reconnect_attempts : Arc < Mutex < u32 > > ,
@@ -83,7 +95,8 @@ impl WebSocketClient {
/// Create a new WebSocket client
pub fn new ( config : WebSocketConfig ) -> Self {
Self {
config ,
config : config . clone ( ) ,
url : Arc ::new ( Mutex ::new ( config . ws_url . clone ( ) ) ) ,
status : Arc ::new ( Mutex ::new ( ConnectionStatus ::Disconnected ) ) ,
sender : Arc ::new ( Mutex ::new ( None ) ) ,
message_queue : Arc ::new ( Mutex ::new ( Vec ::new ( ) ) ) ,
@@ -96,8 +109,10 @@ impl WebSocketClient {
on_error : None ,
on_status_changed : None ,
on_message_sent : None ,
on_reconnecting : None ,
on_first_connect : None ,
reconnect_attempts : Arc ::new ( Mutex ::new ( 0 ) ) ,
reconnect_delay_ms : Arc ::new ( Mutex ::new ( 1000 ) ) ,
reconnect_delay_ms : Arc ::new ( Mutex ::new ( config . reconnect_delay_ms ) ) ,
}
}
@@ -171,15 +186,48 @@ impl WebSocketClient {
self
}
/// Set callback for reconnecting (called before each reconnect attempt)
/// The callback receives: (attempt_number, url_arc)
/// Use url_arc to update the URL for the next connection attempt
pub fn on_reconnecting < F > ( & mut self , callback : F ) -> & mut Self
where
F : Fn ( u32 , Arc < tokio ::sync ::Mutex < String > > ) + Send + Sync + 'static ,
{
self . on_reconnecting = Some ( Arc ::new ( callback ) ) ;
self
}
/// Set callback for first successful connection (before any reconnect)
/// This is called only on the first connection, allowing the app to send initial messages
pub fn on_first_connect < F > ( & mut self , callback : F ) -> & mut Self
where
F : Fn ( String , Arc < Mutex < Option < mpsc ::Sender < OutgoingMessage > > > > ) + Send + Sync + 'static ,
{
self . on_first_connect = Some ( Arc ::new ( callback ) ) ;
self
}
/// Update the URL for future reconnection attempts
/// Call this method when the server URL changes
pub fn update_url ( & mut self , new_url : String ) {
if self . config . debug_mode {
debug! ( " Updating URL: {} -> {} " , self . config . ws_url , new_url ) ;
}
self . config . ws_url = new_url . clone ( ) ;
* self . url . blocking_lock ( ) = new_url ;
}
// ==================== Connection Management ====================
/// Connect to the WebSocket server
/// Connect to the WebSocket server (uses dynamic URL that can be updated)
/// This method blocks until the connection ends completely (or is stopped via disconnect())
pub async fn connect ( & mut self ) {
let url = self . config . ws_url . clone ( ) ;
let url = self . url . lock ( ) . await . clone ( ) ;
self . connect_with_url ( & url ) . await ;
}
/// Connect to a specific URL (overrides config URL)
/// Spawns the WebSocket task and waits for it to complete
pub async fn connect_with_url ( & mut self , url : & str ) {
let _old_status = * self . status . lock ( ) . await ;
self . set_status ( ConnectionStatus ::Connecting ) . await ;
@@ -200,6 +248,8 @@ impl WebSocketClient {
let is_running = Arc ::clone ( & self . is_running ) ;
let reconnect_attempts = Arc ::clone ( & self . reconnect_attempts ) ;
let reconnect_delay_ms = Arc ::clone ( & self . reconnect_delay_ms ) ;
let client_url = Arc ::clone ( & self . url ) ;
let sender = Arc ::clone ( & self . sender ) ;
// Callbacks
let on_connected = self . on_connected . clone ( ) ;
@@ -207,6 +257,8 @@ impl WebSocketClient {
let on_message = self . on_message . clone ( ) ;
let on_binary = self . on_binary . clone ( ) ;
let on_error = self . on_error . clone ( ) ;
let on_reconnecting = self . on_reconnecting . clone ( ) ;
let on_first_connect = self . on_first_connect . clone ( ) ;
// Spawn the WebSocket task
* self . is_running . lock ( ) . await = true ;
@@ -221,21 +273,31 @@ impl WebSocketClient {
is_running ,
reconnect_attempts ,
reconnect_delay_ms ,
client_url ,
on_connected ,
on_disconnected ,
on_message ,
on_binary ,
on_error ,
on_reconnecting ,
on_first_connect ,
sender ,
)
. await ;
} ) ;
// Store handle and wait for task to complete
* self . task_handle . lock ( ) . await = Some ( task_handle ) ;
// Wait for the task to complete (blocks until websocket_loop ends)
if let Some ( handle ) = self . task_handle . lock ( ) . await . take ( ) {
let _ = handle . await ;
}
}
/// Main WebSocket loop
async fn websocket_loop (
url : String ,
_ url : String , // Initial URL (superseded by client_url for reconnecting)
config : WebSocketConfig ,
mut receiver : mpsc ::Receiver < OutgoingMessage > ,
status : Arc < Mutex < ConnectionStatus > > ,
@@ -243,11 +305,15 @@ impl WebSocketClient {
is_running : Arc < Mutex < bool > > ,
reconnect_attempts : Arc < Mutex < u32 > > ,
reconnect_delay_ms : Arc < Mutex < u64 > > ,
client_url : Arc < Mutex < String > > ,
on_connected : Option < ConnectedCallback > ,
on_disconnected : Option < DisconnectedCallback > ,
on_message : Option < MessageCallback > ,
on_binary : Option < BinaryCallback > ,
on_error : Option < ErrorCallback > ,
on_reconnecting : Option < ReconnectingCallback > ,
on_first_connect : Option < FirstConnectCallback > ,
sender : Arc < Mutex < Option < mpsc ::Sender < OutgoingMessage > > > > ,
) {
loop {
let should_run = * is_running . lock ( ) . await ;
@@ -255,13 +321,47 @@ impl WebSocketClient {
break ;
}
match Self ::connect_and_handle ( & url , & config , & mut receiver , & queue , & on_connected , & on_message , & on_binary , & on_error ) . await {
// Get current URL (may have been updated)
let current_url = client_url . lock ( ) . await . clone ( ) ;
// Check if this is the first connection (not a reconnect)
let is_first = * reconnect_attempts . lock ( ) . await = = 0 ;
// Clone sender for first_connect callback
let sender_clone = Arc ::clone ( & sender ) ;
match Self ::connect_and_handle ( & current_url , & config , & mut receiver , & queue , & on_connected , & on_message , & on_binary , & on_error , is_first , & on_first_connect , sender_clone ) . await {
Ok ( _ ) = > {
if config . debug_mode {
debug! ( " WebSocket connection closed normally " ) ;
}
// 即使正常关闭,也尝试重连(如果启用)
if config . reconnect & & * is_running . lock ( ) . await {
let delay = * reconnect_delay_ms . lock ( ) . await ;
let attempts = * reconnect_attempts . lock ( ) . await + 1 ;
* reconnect_attempts . lock ( ) . await = attempts ;
if config . debug_mode {
debug! ( " Reconnecting in {}ms (attempt {}) " , delay , attempts ) ;
}
// Trigger reconnecting callback
if let Some ( ref callback ) = on_reconnecting {
callback ( attempts , client_url . clone ( ) ) ;
}
* status . lock ( ) . await = ConnectionStatus ::Reconnecting ;
tokio ::time ::sleep ( Duration ::from_millis ( delay ) ) . await ;
// Exponential backoff
let new_delay = std ::cmp ::min ( delay * 2 , config . max_reconnect_delay_ms ) ;
* reconnect_delay_ms . lock ( ) . await = new_delay ;
continue ;
} else {
break ;
}
}
Err ( e ) = > {
if config . debug_mode {
warn! ( " WebSocket error: {:?} " , e ) ;
@@ -280,6 +380,11 @@ impl WebSocketClient {
debug! ( " Reconnecting in {}ms (attempt {}) " , delay , attempts ) ;
}
// Trigger reconnecting callback (allows app to update URL, reload config, etc.)
if let Some ( ref callback ) = on_reconnecting {
callback ( attempts , client_url . clone ( ) ) ;
}
* status . lock ( ) . await = ConnectionStatus ::Reconnecting ;
tokio ::time ::sleep ( Duration ::from_millis ( delay ) ) . await ;
@@ -311,6 +416,9 @@ impl WebSocketClient {
on_message : & Option < MessageCallback > ,
on_binary : & Option < BinaryCallback > ,
on_error : & Option < ErrorCallback > ,
is_first : bool ,
on_first_connect : & Option < FirstConnectCallback > ,
sender : Arc < Mutex < Option < mpsc ::Sender < OutgoingMessage > > > > ,
) -> Result < ( ) , WebSocketError > {
use http ::header ::{ HeaderName , HeaderValue } ;
use tokio_tungstenite ::tungstenite ::client ::IntoClientRequest ;
@@ -357,6 +465,13 @@ impl WebSocketClient {
callback ( url . to_string ( ) ) ;
}
// Call on_first_connect callback (only on first connection, not on reconnect)
if is_first {
if let Some ( ref callback ) = on_first_connect {
callback ( url . to_string ( ) , sender ) ;
}
}
// Start heartbeat task if enabled
let heartbeat_handle = if config . heartbeat_interval_ms > 0 {
let interval_duration = Duration ::from_millis ( config . heartbeat_interval_ms ) ;
@@ -605,8 +720,11 @@ impl WebSocketClient {
impl Drop for WebSocketClient {
fn drop ( & mut self ) {
// Attempt to disconnect synchronously
let _ = tokio ::runtime ::Handle ::current ( ) . block_on ( self . disconnect ( ) ) ;
// 不在这里调用异步 disconnect(),因为:
// 1. 无法在 Drop 中安全地使用 block_on (会导致 "Cannot start a runtime from within a runtime")
// 2. 调用者应该显式调用 disconnect() 方法
// 3. Tokio 的 JoinHandle 会在 task 结束时自动清理资源
// 如果 client 未被显式 disconnect() 就被 drop, 资源会通过 JoinHandle 的 abort 自然清理
}
}