Files
JoyD/Windows/Rust/CubeLib/src/websocket/client.rs

818 lines
31 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! WebSocket client implementation
use crate::websocket::{WebSocketConfig, WebSocketMessage, WebSocketError};
use futures_util::{SinkExt, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::time::Duration;
use tracing::{warn, debug};
use serde_json::Value;
/// Message to send through the channel
#[derive(Debug)]
pub enum OutgoingMessage {
/// Text message
Text(String),
/// Binary message
Binary(Vec<u8>),
/// Close connection
Close,
}
/// Connection status
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(clippy::upper_case_acronyms)]
pub enum ConnectionStatus {
/// Disconnected
Disconnected,
/// Connecting
Connecting,
/// Connected
Connected,
/// Reconnecting
Reconnecting,
/// Error
Error,
}
impl std::fmt::Display for ConnectionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionStatus::Disconnected => write!(f, "disconnected"),
ConnectionStatus::Connecting => write!(f, "connecting"),
ConnectionStatus::Connected => write!(f, "connected"),
ConnectionStatus::Reconnecting => write!(f, "reconnecting"),
ConnectionStatus::Error => write!(f, "error"),
}
}
}
/// Event callback types (wrapped in Arc for cloneability)
pub type ConnectedCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type DisconnectedCallback = Arc<dyn Fn() + Send + Sync>;
/// Callback for text messages - args: (msg_type, parsed_data, sender)
/// sender can be used to send messages back synchronously
#[derive(Clone)]
pub struct MessageSender(std::sync::Arc<tokio::sync::Mutex<Option<mpsc::Sender<OutgoingMessage>>>>);
impl MessageSender {
/// Send a text message synchronously (from within the sync callback)
pub fn send(&self, text: String) {
if let Ok(guard) = self.0.try_lock() {
if let Some(ref tx) = *guard {
let _ = tx.try_send(OutgoingMessage::Text(text));
}
}
}
}
pub type MessageCallback = Arc<dyn Fn(String, Value, MessageSender) + Send + Sync>;
pub type BinaryCallback = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
pub type ErrorCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type StatusCallback = Arc<dyn Fn(ConnectionStatus, ConnectionStatus) + Send + Sync>;
pub type SentCallback = Arc<dyn Fn(String, Value) + Send + Sync>;
/// Callback triggered before reconnecting
/// Arguments: (attempt_number, url_arc) - app can update the URL via url_arc
/// Note: This is an async callback - return a boxed Future
pub type ReconnectingCallback = Arc<dyn Fn(u32, Arc<tokio::sync::Mutex<String>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
/// Callback triggered after successful reconnection (after the first connection)
/// Arguments: (url, send_fn) - send_fn can be called to send messages
/// Note: This is an async callback - return a boxed Future
pub type ReconnectedCallback = Arc<dyn Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
/// Callback triggered on first successful connection (before any reconnect)
/// Arguments: (url, send_fn) - send_fn can be called to send messages
/// Note: This is an async callback - return a boxed Future
pub type FirstConnectCallback = Arc<dyn Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
/// WebSocket client with event-driven architecture
pub struct WebSocketClient {
config: WebSocketConfig,
/// Dynamic URL (can be updated for reconnecting with new URL)
url: Arc<Mutex<String>>,
status: Arc<Mutex<ConnectionStatus>>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
message_queue: Arc<Mutex<Vec<OutgoingMessage>>>,
task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
is_running: Arc<Mutex<bool>>,
// Event callbacks
on_connected: Option<ConnectedCallback>,
on_disconnected: Option<DisconnectedCallback>,
on_message: Option<MessageCallback>,
on_binary: Option<BinaryCallback>,
on_error: Option<ErrorCallback>,
on_status_changed: Option<StatusCallback>,
on_message_sent: Option<SentCallback>,
/// Callback triggered before reconnecting (attempt number passed as argument)
on_reconnecting: Option<ReconnectingCallback>,
/// Callback triggered on first successful connection
on_first_connect: Option<FirstConnectCallback>,
/// Callback triggered after successful reconnection (after the first connection)
on_reconnected: Option<ReconnectedCallback>,
// Reconnection state
reconnect_attempts: Arc<Mutex<u32>>,
reconnect_delay_ms: Arc<Mutex<u64>>,
}
impl WebSocketClient {
/// Create a new WebSocket client
pub fn new(config: WebSocketConfig) -> Self {
Self {
config: config.clone(),
url: Arc::new(Mutex::new(config.ws_url.clone())),
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
sender: Arc::new(Mutex::new(None)),
message_queue: Arc::new(Mutex::new(Vec::new())),
task_handle: Arc::new(Mutex::new(None)),
is_running: Arc::new(Mutex::new(false)),
on_connected: None,
on_disconnected: None,
on_message: None,
on_binary: None,
on_error: None,
on_status_changed: None,
on_message_sent: None,
on_reconnecting: None,
on_first_connect: None,
on_reconnected: None,
reconnect_attempts: Arc::new(Mutex::new(0)),
reconnect_delay_ms: Arc::new(Mutex::new(config.reconnect_delay_ms)),
}
}
/// Create a simple client with just the URL
pub fn simple(url: impl Into<String>) -> Self {
Self::new(WebSocketConfig::new(url))
}
// ==================== Event Handlers ====================
/// Set callback for connected event
pub fn on_connected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String) + Send + Sync + 'static,
{
self.on_connected = Some(Arc::new(callback));
self
}
/// Set callback for disconnected event
pub fn on_disconnected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.on_disconnected = Some(Arc::new(callback));
self
}
/// Set callback for text message received
/// The callback receives: (msg_type, parsed_data, sender)
/// The sender can be used to send messages back synchronously
pub fn on_message<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Value, MessageSender) + Send + Sync + 'static,
{
self.on_message = Some(Arc::new(callback));
self
}
/// Set callback for binary message received
pub fn on_binary<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
self.on_binary = Some(Arc::new(callback));
self
}
/// Set callback for error
pub fn on_error<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String) + Send + Sync + 'static,
{
self.on_error = Some(Arc::new(callback));
self
}
/// Set callback for status changed
pub fn on_status_changed<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(ConnectionStatus, ConnectionStatus) + Send + Sync + 'static,
{
self.on_status_changed = Some(Arc::new(callback));
self
}
/// Set callback for message sent
pub fn on_message_sent<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Value) + Send + Sync + 'static,
{
self.on_message_sent = Some(Arc::new(callback));
self
}
/// Set callback for reconnecting (called before each reconnect attempt)
/// This is an async callback - the returned Future will be awaited
/// The callback receives: (attempt_number, url_arc)
/// Use url_arc to update the URL for the next connection attempt
pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(u32, Arc<tokio::sync::Mutex<String>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_reconnecting = Some(Arc::new(callback));
self
}
/// Set callback for first successful connection (before any reconnect)
/// This is an async callback - the returned Future will be awaited
pub fn on_first_connect<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_first_connect = Some(Arc::new(callback));
self
}
/// Set callback for reconnection (called after successful reconnection, not on first connection)
/// This is an async callback - the returned Future will be awaited
pub fn on_reconnected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_reconnected = Some(Arc::new(callback));
self
}
/// Update the URL for future reconnection attempts
/// Call this method when the server URL changes
pub fn update_url(&mut self, new_url: String) {
if self.config.debug_mode {
debug!("Updating URL: {} -> {}", self.config.ws_url, new_url);
}
self.config.ws_url = new_url.clone();
*self.url.blocking_lock() = new_url;
}
/// Send a text message through the WebSocket (can be called from any context)
/// Returns Ok(()) if the message was sent, Err(msg) if not connected
pub async fn send_text(&self, text: String) -> Result<(), String> {
let sender = self.sender.lock().await;
if let Some(ref tx) = *sender {
tx.send(OutgoingMessage::Text(text)).await.map_err(|e| e.to_string())
} else {
Err("Not connected".to_string())
}
}
// ==================== Connection Management ====================
/// Connect to the WebSocket server (uses dynamic URL that can be updated)
/// This method blocks until the connection ends completely (or is stopped via disconnect())
pub async fn connect(&mut self) {
let url = self.url.lock().await.clone();
self.connect_with_url(&url).await;
}
/// Connect to a specific URL (overrides config URL)
/// Spawns the WebSocket task and waits for it to complete
pub async fn connect_with_url(&mut self, url: &str) {
let _old_status = *self.status.lock().await;
self.set_status(ConnectionStatus::Connecting).await;
if self.config.debug_mode {
debug!("Connecting to WebSocket server: {}", url);
}
// Create channel for outgoing messages
let (tx, rx) = mpsc::channel::<OutgoingMessage>(100);
*self.sender.lock().await = Some(tx);
// Clone shared state for the task
let url = url.to_string();
let config = self.config.clone();
let status = Arc::clone(&self.status);
let queue = Arc::clone(&self.message_queue);
let is_running = Arc::clone(&self.is_running);
let reconnect_attempts = Arc::clone(&self.reconnect_attempts);
let reconnect_delay_ms = Arc::clone(&self.reconnect_delay_ms);
let client_url = Arc::clone(&self.url);
let sender = Arc::clone(&self.sender);
// Callbacks
let on_connected = self.on_connected.clone();
let on_disconnected = self.on_disconnected.clone();
let on_message = self.on_message.clone();
let on_binary = self.on_binary.clone();
let on_error = self.on_error.clone();
let on_reconnecting = self.on_reconnecting.clone();
let on_first_connect = self.on_first_connect.clone();
let on_reconnected = self.on_reconnected.clone();
// Spawn the WebSocket task
*self.is_running.lock().await = true;
let task_handle = tokio::spawn(async move {
Self::websocket_loop(
url,
config,
rx,
status,
queue,
is_running,
reconnect_attempts,
reconnect_delay_ms,
client_url,
on_connected,
on_disconnected,
on_message,
on_binary,
on_error,
on_reconnecting,
on_first_connect,
on_reconnected,
sender,
)
.await;
});
// Store handle and wait for task to complete
*self.task_handle.lock().await = Some(task_handle);
// Wait for the task to complete (blocks until websocket_loop ends)
if let Some(handle) = self.task_handle.lock().await.take() {
let _ = handle.await;
}
}
/// Main WebSocket loop
async fn websocket_loop(
_url: String, // Initial URL (superseded by client_url for reconnecting)
config: WebSocketConfig,
mut receiver: mpsc::Receiver<OutgoingMessage>,
status: Arc<Mutex<ConnectionStatus>>,
queue: Arc<Mutex<Vec<OutgoingMessage>>>,
is_running: Arc<Mutex<bool>>,
reconnect_attempts: Arc<Mutex<u32>>,
reconnect_delay_ms: Arc<Mutex<u64>>,
client_url: Arc<Mutex<String>>,
on_connected: Option<ConnectedCallback>,
on_disconnected: Option<DisconnectedCallback>,
on_message: Option<MessageCallback>,
on_binary: Option<BinaryCallback>,
on_error: Option<ErrorCallback>,
on_reconnecting: Option<ReconnectingCallback>,
on_first_connect: Option<FirstConnectCallback>,
on_reconnected: Option<ReconnectedCallback>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
) {
loop {
let should_run = *is_running.lock().await;
if !should_run {
break;
}
// Get current URL (may have been updated)
let current_url = client_url.lock().await.clone();
// Check if this is the first connection (not a reconnect)
let is_first = *reconnect_attempts.lock().await == 0;
// Clone sender for first_connect callback
let sender_clone = Arc::clone(&sender);
// Create MessageSender for on_message callback (same underlying sender)
let message_sender = MessageSender(Arc::clone(&sender));
match Self::connect_and_handle(&current_url, &config, &mut receiver, &queue, &on_connected, &on_message, &on_binary, &on_error, is_first, &on_first_connect, &on_reconnected, sender_clone, message_sender).await {
Ok(_) => {
if config.debug_mode {
debug!("WebSocket connection closed normally");
}
// 即使正常关闭,也尝试重连(如果启用)
if config.reconnect && *is_running.lock().await {
let delay = *reconnect_delay_ms.lock().await;
let attempts = *reconnect_attempts.lock().await + 1;
*reconnect_attempts.lock().await = attempts;
if config.debug_mode {
debug!("Reconnecting in {}ms (attempt {})", delay, attempts);
}
// Trigger reconnecting callback
if let Some(ref callback) = on_reconnecting {
callback(attempts, client_url.clone()).await;
}
*status.lock().await = ConnectionStatus::Reconnecting;
tokio::time::sleep(Duration::from_millis(delay)).await;
// Exponential backoff
let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms);
*reconnect_delay_ms.lock().await = new_delay;
continue;
} else {
break;
}
}
Err(e) => {
if config.debug_mode {
warn!("WebSocket error: {:?}", e);
}
if let Some(ref callback) = on_error {
callback(e.to_string());
}
if config.reconnect && *is_running.lock().await {
let delay = *reconnect_delay_ms.lock().await;
let attempts = *reconnect_attempts.lock().await + 1;
*reconnect_attempts.lock().await = attempts;
if config.debug_mode {
debug!("Reconnecting in {}ms (attempt {})", delay, attempts);
}
// Trigger reconnecting callback (allows app to update URL, reload config, etc.)
if let Some(ref callback) = on_reconnecting {
callback(attempts, client_url.clone()).await;
}
*status.lock().await = ConnectionStatus::Reconnecting;
tokio::time::sleep(Duration::from_millis(delay)).await;
// Exponential backoff
let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms);
*reconnect_delay_ms.lock().await = new_delay;
continue;
} else {
break;
}
}
}
}
*status.lock().await = ConnectionStatus::Disconnected;
if let Some(callback) = on_disconnected {
callback();
}
}
/// Connect and handle messages
async fn connect_and_handle(
url: &str,
config: &WebSocketConfig,
receiver: &mut mpsc::Receiver<OutgoingMessage>,
_queue: &Arc<Mutex<Vec<OutgoingMessage>>>,
on_connected: &Option<ConnectedCallback>,
on_message: &Option<MessageCallback>,
on_binary: &Option<BinaryCallback>,
on_error: &Option<ErrorCallback>,
is_first: bool,
on_first_connect: &Option<FirstConnectCallback>,
on_reconnected: &Option<ReconnectedCallback>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
message_sender: MessageSender,
) -> Result<(), WebSocketError> {
use http::header::{HeaderName, HeaderValue};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
// Build request with headers
let mut request = url.into_client_request()?;
// Add client type header if specified
if let Some(ref client_type) = config.client_type {
request.headers_mut().insert(
HeaderName::from_static("clienttype"),
HeaderValue::from_str(client_type)
.map_err(|e| WebSocketError::HeaderError(e.to_string()))?,
);
}
// Add custom headers
for (key, value) in &config.custom_headers {
if let (Ok(name), Ok(val)) = (
HeaderName::try_from(key.as_str()),
HeaderValue::from_str(value),
) {
request.headers_mut().insert(name, val);
}
}
// Connect with timeout
let ws_stream = tokio::time::timeout(
Duration::from_millis(config.connect_timeout_ms),
tokio_tungstenite::connect_async(request),
)
.await
.map_err(|_| WebSocketError::ConnectionTimeout)?
.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
let (mut write, mut read) = ws_stream.0.split();
if config.debug_mode {
debug!("Connected to WebSocket server");
}
// Call on_connected callback
if let Some(ref callback) = on_connected {
callback(url.to_string());
}
// Call on_first_connect callback (only on first connection, not on reconnect)
if is_first {
if let Some(ref callback) = on_first_connect {
callback(url.to_string(), sender).await;
}
} else {
// Call on_reconnected callback (after successful reconnection)
if let Some(ref callback) = on_reconnected {
callback(url.to_string(), sender).await;
}
}
// Start heartbeat task if enabled
let heartbeat_handle = if config.heartbeat_interval_ms > 0 {
let interval_duration = Duration::from_millis(config.heartbeat_interval_ms);
Some(tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval_duration);
loop {
ticker.tick().await;
// Heartbeat will be sent via the receiver
}
}))
} else {
None
};
// Handle messages
loop {
tokio::select! {
// Incoming message from server
msg = read.next() => {
match msg {
Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => {
if config.debug_mode {
debug!("Received text: {}", text);
}
// Parse and extract type
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
let msg_type = parsed.get("Type")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
if let Some(ref callback) = on_message {
callback(msg_type, parsed, message_sender.clone());
}
} else if let Some(ref callback) = on_message {
callback("raw".to_string(), serde_json::json!(text), message_sender.clone());
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
if config.debug_mode {
debug!("Received binary: {} bytes", data.len());
}
if let Some(ref callback) = on_binary {
callback(data);
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) => {
if config.debug_mode {
debug!("Server initiated close");
}
break;
}
Some(Err(e)) => {
if config.debug_mode {
warn!("WebSocket read error: {}", e);
}
if let Some(ref callback) = on_error {
callback(e.to_string());
}
break;
}
None => {
if config.debug_mode {
debug!("WebSocket stream ended");
}
break;
}
_ => {}
}
}
// Outgoing message
outgoing = receiver.recv() => {
match outgoing {
Some(OutgoingMessage::Text(text)) => {
if config.debug_mode {
debug!("Sending text: {}", text);
}
write.send(tokio_tungstenite::tungstenite::Message::Text(text))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
}
Some(OutgoingMessage::Binary(data)) => {
if config.debug_mode {
debug!("Sending binary: {} bytes", data.len());
}
write.send(tokio_tungstenite::tungstenite::Message::Binary(data))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
}
Some(OutgoingMessage::Close) => {
if config.debug_mode {
debug!("Initiating close");
}
write.close().await.ok();
break;
}
None => {
break;
}
}
}
}
}
// Wait for heartbeat handle
if let Some(handle) = heartbeat_handle {
handle.abort();
}
Ok(())
}
/// Disconnect from the server
pub async fn disconnect(&mut self) {
*self.is_running.lock().await = false;
if let Some(sender) = self.sender.lock().await.take() {
let _ = sender.send(OutgoingMessage::Close).await;
}
if let Some(handle) = self.task_handle.lock().await.take() {
handle.abort();
}
self.set_status(ConnectionStatus::Disconnected).await;
}
// ==================== Message Sending ====================
/// Send a text message
pub async fn send(&self, msg_type: &str, data: Value) -> Result<(), WebSocketError> {
let message = WebSocketMessage::new(msg_type, data);
let json = message.to_json_string()?;
self.send_raw_text(json).await
}
/// Send raw text
pub async fn send_raw_text(&self, text: String) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status == ConnectionStatus::Connected {
if let Some(ref sender) = *self.sender.lock().await {
sender.send(OutgoingMessage::Text(text))
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
} else {
// Queue the message
let mut queue = self.message_queue.lock().await;
if queue.len() < self.config.max_queue_size {
if self.config.debug_mode {
debug!("Queueing message (queue size: {})", queue.len() + 1);
}
queue.push(OutgoingMessage::Text(text));
} else {
if self.config.debug_mode {
warn!("Queue full, message dropped");
}
return Err(WebSocketError::QueueFull);
}
}
Ok(())
}
/// Send binary data
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status == ConnectionStatus::Connected {
if let Some(ref sender) = *self.sender.lock().await {
sender.send(OutgoingMessage::Binary(data))
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
} else {
let mut queue = self.message_queue.lock().await;
if queue.len() < self.config.max_queue_size {
if self.config.debug_mode {
debug!("Queueing binary message");
}
queue.push(OutgoingMessage::Binary(data));
} else {
return Err(WebSocketError::QueueFull);
}
}
Ok(())
}
// ==================== Status & Queue ====================
/// Get current connection status
pub async fn get_status(&self) -> ConnectionStatus {
*self.status.lock().await
}
/// Check if connected
pub async fn is_connected(&self) -> bool {
*self.status.lock().await == ConnectionStatus::Connected
}
/// Get queue size
pub async fn get_queue_size(&self) -> usize {
self.message_queue.lock().await.len()
}
/// Flush queued messages
pub async fn flush_queue(&self) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status != ConnectionStatus::Connected {
return Err(WebSocketError::NotConnected);
}
let mut queue = self.message_queue.lock().await;
if let Some(ref sender) = *self.sender.lock().await {
while let Some(msg) = queue.pop() {
sender.send(msg)
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
}
Ok(())
}
// ==================== Private Helpers ====================
#[allow(clippy::large_enum_variant)]
async fn set_status(&mut self, new_status: ConnectionStatus) {
let old_status = *self.status.lock().await;
if old_status != new_status {
*self.status.lock().await = new_status.clone();
// Reset reconnect state on successful connection
if new_status == ConnectionStatus::Connected {
*self.reconnect_attempts.lock().await = 0;
*self.reconnect_delay_ms.lock().await = self.config.reconnect_delay_ms;
}
if let Some(ref callback) = self.on_status_changed {
callback(old_status, new_status);
}
}
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
// 不在这里调用异步 disconnect(),因为:
// 1. 无法在 Drop 中安全地使用 block_on (会导致 "Cannot start a runtime from within a runtime")
// 2. 调用者应该显式调用 disconnect() 方法
// 3. Tokio 的 JoinHandle 会在 task 结束时自动清理资源
// 如果 client 未被显式 disconnect() 就被 drop资源会通过 JoinHandle 的 abort 自然清理
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = WebSocketConfig::new("ws://localhost:8080")
.with_client_type("Client")
.with_debug(true)
.with_reconnect(false);
assert_eq!(config.ws_url, "ws://localhost:8080");
assert_eq!(config.client_type, Some("Client".to_string()));
assert!(config.debug_mode);
assert!(!config.reconnect);
}
#[tokio::test]
async fn test_client_creation() {
let config = WebSocketConfig::default();
let client = WebSocketClient::new(config);
let status = client.get_status().await;
assert_eq!(status, ConnectionStatus::Disconnected);
}
}