Files
JoyD/Windows/Rust/CubeLib/src/websocket/client.rs

818 lines
31 KiB
Rust
Raw Normal View History

2026-04-07 13:55:40 +08:00
//! WebSocket client implementation
use crate::websocket::{WebSocketConfig, WebSocketMessage, WebSocketError};
use futures_util::{SinkExt, StreamExt};
2026-04-07 15:44:30 +08:00
use std::pin::Pin;
2026-04-07 13:55:40 +08:00
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::time::Duration;
use tracing::{warn, debug};
use serde_json::Value;
/// Message to send through the channel
#[derive(Debug)]
pub enum OutgoingMessage {
2026-04-07 13:55:40 +08:00
/// Text message
Text(String),
/// Binary message
Binary(Vec<u8>),
/// Close connection
Close,
}
/// Connection status
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(clippy::upper_case_acronyms)]
pub enum ConnectionStatus {
/// Disconnected
Disconnected,
/// Connecting
Connecting,
/// Connected
Connected,
/// Reconnecting
Reconnecting,
/// Error
Error,
}
impl std::fmt::Display for ConnectionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionStatus::Disconnected => write!(f, "disconnected"),
ConnectionStatus::Connecting => write!(f, "connecting"),
ConnectionStatus::Connected => write!(f, "connected"),
ConnectionStatus::Reconnecting => write!(f, "reconnecting"),
ConnectionStatus::Error => write!(f, "error"),
}
}
}
/// Event callback types (wrapped in Arc for cloneability)
pub type ConnectedCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type DisconnectedCallback = Arc<dyn Fn() + Send + Sync>;
/// Callback for text messages - args: (msg_type, parsed_data, sender)
/// sender can be used to send messages back synchronously
#[derive(Clone)]
pub struct MessageSender(std::sync::Arc<tokio::sync::Mutex<Option<mpsc::Sender<OutgoingMessage>>>>);
impl MessageSender {
/// Send a text message synchronously (from within the sync callback)
pub fn send(&self, text: String) {
if let Ok(guard) = self.0.try_lock() {
if let Some(ref tx) = *guard {
let _ = tx.try_send(OutgoingMessage::Text(text));
}
}
}
}
pub type MessageCallback = Arc<dyn Fn(String, Value, MessageSender) + Send + Sync>;
2026-04-07 13:55:40 +08:00
pub type BinaryCallback = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
pub type ErrorCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type StatusCallback = Arc<dyn Fn(ConnectionStatus, ConnectionStatus) + Send + Sync>;
pub type SentCallback = Arc<dyn Fn(String, Value) + Send + Sync>;
/// Callback triggered before reconnecting
/// Arguments: (attempt_number, url_arc) - app can update the URL via url_arc
2026-04-07 16:09:34 +08:00
/// Note: This is an async callback - return a boxed Future
pub type ReconnectingCallback = Arc<dyn Fn(u32, Arc<tokio::sync::Mutex<String>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
/// Callback triggered after successful reconnection (after the first connection)
/// Arguments: (url, send_fn) - send_fn can be called to send messages
/// Note: This is an async callback - return a boxed Future
pub type ReconnectedCallback = Arc<dyn Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
/// Callback triggered on first successful connection (before any reconnect)
/// Arguments: (url, send_fn) - send_fn can be called to send messages
2026-04-07 15:44:30 +08:00
/// Note: This is an async callback - return a boxed Future
pub type FirstConnectCallback = Arc<dyn Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync>;
2026-04-07 13:55:40 +08:00
/// WebSocket client with event-driven architecture
pub struct WebSocketClient {
config: WebSocketConfig,
/// Dynamic URL (can be updated for reconnecting with new URL)
url: Arc<Mutex<String>>,
2026-04-07 13:55:40 +08:00
status: Arc<Mutex<ConnectionStatus>>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
message_queue: Arc<Mutex<Vec<OutgoingMessage>>>,
task_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
is_running: Arc<Mutex<bool>>,
// Event callbacks
on_connected: Option<ConnectedCallback>,
on_disconnected: Option<DisconnectedCallback>,
on_message: Option<MessageCallback>,
on_binary: Option<BinaryCallback>,
on_error: Option<ErrorCallback>,
on_status_changed: Option<StatusCallback>,
on_message_sent: Option<SentCallback>,
/// Callback triggered before reconnecting (attempt number passed as argument)
on_reconnecting: Option<ReconnectingCallback>,
/// Callback triggered on first successful connection
on_first_connect: Option<FirstConnectCallback>,
2026-04-07 16:09:34 +08:00
/// Callback triggered after successful reconnection (after the first connection)
on_reconnected: Option<ReconnectedCallback>,
2026-04-07 13:55:40 +08:00
// Reconnection state
reconnect_attempts: Arc<Mutex<u32>>,
reconnect_delay_ms: Arc<Mutex<u64>>,
}
impl WebSocketClient {
/// Create a new WebSocket client
pub fn new(config: WebSocketConfig) -> Self {
Self {
config: config.clone(),
url: Arc::new(Mutex::new(config.ws_url.clone())),
2026-04-07 13:55:40 +08:00
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
sender: Arc::new(Mutex::new(None)),
message_queue: Arc::new(Mutex::new(Vec::new())),
task_handle: Arc::new(Mutex::new(None)),
is_running: Arc::new(Mutex::new(false)),
on_connected: None,
on_disconnected: None,
on_message: None,
on_binary: None,
on_error: None,
on_status_changed: None,
on_message_sent: None,
on_reconnecting: None,
on_first_connect: None,
2026-04-07 16:09:34 +08:00
on_reconnected: None,
2026-04-07 13:55:40 +08:00
reconnect_attempts: Arc::new(Mutex::new(0)),
reconnect_delay_ms: Arc::new(Mutex::new(config.reconnect_delay_ms)),
2026-04-07 13:55:40 +08:00
}
}
/// Create a simple client with just the URL
pub fn simple(url: impl Into<String>) -> Self {
Self::new(WebSocketConfig::new(url))
}
// ==================== Event Handlers ====================
/// Set callback for connected event
pub fn on_connected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String) + Send + Sync + 'static,
{
self.on_connected = Some(Arc::new(callback));
self
}
/// Set callback for disconnected event
pub fn on_disconnected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn() + Send + Sync + 'static,
{
self.on_disconnected = Some(Arc::new(callback));
self
}
/// Set callback for text message received
/// The callback receives: (msg_type, parsed_data, sender)
/// The sender can be used to send messages back synchronously
2026-04-07 13:55:40 +08:00
pub fn on_message<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Value, MessageSender) + Send + Sync + 'static,
2026-04-07 13:55:40 +08:00
{
self.on_message = Some(Arc::new(callback));
self
}
/// Set callback for binary message received
pub fn on_binary<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
self.on_binary = Some(Arc::new(callback));
self
}
/// Set callback for error
pub fn on_error<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String) + Send + Sync + 'static,
{
self.on_error = Some(Arc::new(callback));
self
}
/// Set callback for status changed
pub fn on_status_changed<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(ConnectionStatus, ConnectionStatus) + Send + Sync + 'static,
{
self.on_status_changed = Some(Arc::new(callback));
self
}
/// Set callback for message sent
pub fn on_message_sent<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Value) + Send + Sync + 'static,
{
self.on_message_sent = Some(Arc::new(callback));
self
}
/// Set callback for reconnecting (called before each reconnect attempt)
2026-04-07 16:09:34 +08:00
/// This is an async callback - the returned Future will be awaited
/// The callback receives: (attempt_number, url_arc)
/// Use url_arc to update the URL for the next connection attempt
pub fn on_reconnecting<F>(&mut self, callback: F) -> &mut Self
where
2026-04-07 16:09:34 +08:00
F: Fn(u32, Arc<tokio::sync::Mutex<String>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_reconnecting = Some(Arc::new(callback));
self
}
/// Set callback for first successful connection (before any reconnect)
2026-04-07 15:44:30 +08:00
/// This is an async callback - the returned Future will be awaited
pub fn on_first_connect<F>(&mut self, callback: F) -> &mut Self
where
2026-04-07 15:44:30 +08:00
F: Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_first_connect = Some(Arc::new(callback));
self
}
2026-04-07 16:09:34 +08:00
/// Set callback for reconnection (called after successful reconnection, not on first connection)
/// This is an async callback - the returned Future will be awaited
pub fn on_reconnected<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>) -> Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>> + Send + Sync + 'static,
{
self.on_reconnected = Some(Arc::new(callback));
self
}
/// Update the URL for future reconnection attempts
/// Call this method when the server URL changes
pub fn update_url(&mut self, new_url: String) {
if self.config.debug_mode {
debug!("Updating URL: {} -> {}", self.config.ws_url, new_url);
}
self.config.ws_url = new_url.clone();
*self.url.blocking_lock() = new_url;
}
/// Send a text message through the WebSocket (can be called from any context)
/// Returns Ok(()) if the message was sent, Err(msg) if not connected
pub async fn send_text(&self, text: String) -> Result<(), String> {
let sender = self.sender.lock().await;
if let Some(ref tx) = *sender {
tx.send(OutgoingMessage::Text(text)).await.map_err(|e| e.to_string())
} else {
Err("Not connected".to_string())
}
}
2026-04-07 13:55:40 +08:00
// ==================== Connection Management ====================
/// Connect to the WebSocket server (uses dynamic URL that can be updated)
/// This method blocks until the connection ends completely (or is stopped via disconnect())
2026-04-07 13:55:40 +08:00
pub async fn connect(&mut self) {
let url = self.url.lock().await.clone();
2026-04-07 13:55:40 +08:00
self.connect_with_url(&url).await;
}
/// Connect to a specific URL (overrides config URL)
/// Spawns the WebSocket task and waits for it to complete
2026-04-07 13:55:40 +08:00
pub async fn connect_with_url(&mut self, url: &str) {
let _old_status = *self.status.lock().await;
self.set_status(ConnectionStatus::Connecting).await;
if self.config.debug_mode {
debug!("Connecting to WebSocket server: {}", url);
}
// Create channel for outgoing messages
let (tx, rx) = mpsc::channel::<OutgoingMessage>(100);
*self.sender.lock().await = Some(tx);
// Clone shared state for the task
let url = url.to_string();
let config = self.config.clone();
let status = Arc::clone(&self.status);
let queue = Arc::clone(&self.message_queue);
let is_running = Arc::clone(&self.is_running);
let reconnect_attempts = Arc::clone(&self.reconnect_attempts);
let reconnect_delay_ms = Arc::clone(&self.reconnect_delay_ms);
let client_url = Arc::clone(&self.url);
let sender = Arc::clone(&self.sender);
2026-04-07 13:55:40 +08:00
// Callbacks
let on_connected = self.on_connected.clone();
let on_disconnected = self.on_disconnected.clone();
let on_message = self.on_message.clone();
let on_binary = self.on_binary.clone();
let on_error = self.on_error.clone();
let on_reconnecting = self.on_reconnecting.clone();
let on_first_connect = self.on_first_connect.clone();
2026-04-07 16:09:34 +08:00
let on_reconnected = self.on_reconnected.clone();
2026-04-07 13:55:40 +08:00
// Spawn the WebSocket task
*self.is_running.lock().await = true;
let task_handle = tokio::spawn(async move {
Self::websocket_loop(
url,
config,
rx,
status,
queue,
is_running,
reconnect_attempts,
reconnect_delay_ms,
client_url,
2026-04-07 13:55:40 +08:00
on_connected,
on_disconnected,
on_message,
on_binary,
on_error,
on_reconnecting,
on_first_connect,
2026-04-07 16:09:34 +08:00
on_reconnected,
sender,
2026-04-07 13:55:40 +08:00
)
.await;
});
// Store handle and wait for task to complete
2026-04-07 13:55:40 +08:00
*self.task_handle.lock().await = Some(task_handle);
// Wait for the task to complete (blocks until websocket_loop ends)
if let Some(handle) = self.task_handle.lock().await.take() {
let _ = handle.await;
}
2026-04-07 13:55:40 +08:00
}
/// Main WebSocket loop
async fn websocket_loop(
_url: String, // Initial URL (superseded by client_url for reconnecting)
2026-04-07 13:55:40 +08:00
config: WebSocketConfig,
mut receiver: mpsc::Receiver<OutgoingMessage>,
status: Arc<Mutex<ConnectionStatus>>,
queue: Arc<Mutex<Vec<OutgoingMessage>>>,
is_running: Arc<Mutex<bool>>,
reconnect_attempts: Arc<Mutex<u32>>,
reconnect_delay_ms: Arc<Mutex<u64>>,
client_url: Arc<Mutex<String>>,
2026-04-07 13:55:40 +08:00
on_connected: Option<ConnectedCallback>,
on_disconnected: Option<DisconnectedCallback>,
on_message: Option<MessageCallback>,
on_binary: Option<BinaryCallback>,
on_error: Option<ErrorCallback>,
on_reconnecting: Option<ReconnectingCallback>,
on_first_connect: Option<FirstConnectCallback>,
2026-04-07 16:09:34 +08:00
on_reconnected: Option<ReconnectedCallback>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
2026-04-07 13:55:40 +08:00
) {
loop {
let should_run = *is_running.lock().await;
if !should_run {
break;
}
// Get current URL (may have been updated)
let current_url = client_url.lock().await.clone();
// Check if this is the first connection (not a reconnect)
let is_first = *reconnect_attempts.lock().await == 0;
// Clone sender for first_connect callback
let sender_clone = Arc::clone(&sender);
// Create MessageSender for on_message callback (same underlying sender)
let message_sender = MessageSender(Arc::clone(&sender));
match Self::connect_and_handle(&current_url, &config, &mut receiver, &queue, &on_connected, &on_message, &on_binary, &on_error, is_first, &on_first_connect, &on_reconnected, sender_clone, message_sender).await {
2026-04-07 13:55:40 +08:00
Ok(_) => {
if config.debug_mode {
debug!("WebSocket connection closed normally");
}
// 即使正常关闭,也尝试重连(如果启用)
if config.reconnect && *is_running.lock().await {
let delay = *reconnect_delay_ms.lock().await;
let attempts = *reconnect_attempts.lock().await + 1;
*reconnect_attempts.lock().await = attempts;
if config.debug_mode {
debug!("Reconnecting in {}ms (attempt {})", delay, attempts);
}
// Trigger reconnecting callback
if let Some(ref callback) = on_reconnecting {
2026-04-07 16:09:34 +08:00
callback(attempts, client_url.clone()).await;
}
*status.lock().await = ConnectionStatus::Reconnecting;
tokio::time::sleep(Duration::from_millis(delay)).await;
// Exponential backoff
let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms);
*reconnect_delay_ms.lock().await = new_delay;
continue;
} else {
break;
}
2026-04-07 13:55:40 +08:00
}
Err(e) => {
if config.debug_mode {
warn!("WebSocket error: {:?}", e);
}
if let Some(ref callback) = on_error {
callback(e.to_string());
}
if config.reconnect && *is_running.lock().await {
let delay = *reconnect_delay_ms.lock().await;
let attempts = *reconnect_attempts.lock().await + 1;
*reconnect_attempts.lock().await = attempts;
if config.debug_mode {
debug!("Reconnecting in {}ms (attempt {})", delay, attempts);
}
// Trigger reconnecting callback (allows app to update URL, reload config, etc.)
if let Some(ref callback) = on_reconnecting {
2026-04-07 16:09:34 +08:00
callback(attempts, client_url.clone()).await;
}
2026-04-07 13:55:40 +08:00
*status.lock().await = ConnectionStatus::Reconnecting;
tokio::time::sleep(Duration::from_millis(delay)).await;
// Exponential backoff
let new_delay = std::cmp::min(delay * 2, config.max_reconnect_delay_ms);
*reconnect_delay_ms.lock().await = new_delay;
continue;
} else {
break;
}
}
}
}
*status.lock().await = ConnectionStatus::Disconnected;
if let Some(callback) = on_disconnected {
callback();
}
}
/// Connect and handle messages
async fn connect_and_handle(
url: &str,
config: &WebSocketConfig,
receiver: &mut mpsc::Receiver<OutgoingMessage>,
_queue: &Arc<Mutex<Vec<OutgoingMessage>>>,
on_connected: &Option<ConnectedCallback>,
on_message: &Option<MessageCallback>,
on_binary: &Option<BinaryCallback>,
on_error: &Option<ErrorCallback>,
is_first: bool,
on_first_connect: &Option<FirstConnectCallback>,
2026-04-07 16:09:34 +08:00
on_reconnected: &Option<ReconnectedCallback>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
message_sender: MessageSender,
2026-04-07 13:55:40 +08:00
) -> Result<(), WebSocketError> {
use http::header::{HeaderName, HeaderValue};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
// Build request with headers
let mut request = url.into_client_request()?;
// Add client type header if specified
if let Some(ref client_type) = config.client_type {
request.headers_mut().insert(
HeaderName::from_static("clienttype"),
HeaderValue::from_str(client_type)
.map_err(|e| WebSocketError::HeaderError(e.to_string()))?,
);
}
// Add custom headers
for (key, value) in &config.custom_headers {
if let (Ok(name), Ok(val)) = (
HeaderName::try_from(key.as_str()),
HeaderValue::from_str(value),
) {
request.headers_mut().insert(name, val);
}
}
// Connect with timeout
let ws_stream = tokio::time::timeout(
Duration::from_millis(config.connect_timeout_ms),
tokio_tungstenite::connect_async(request),
)
.await
.map_err(|_| WebSocketError::ConnectionTimeout)?
.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
let (mut write, mut read) = ws_stream.0.split();
if config.debug_mode {
debug!("Connected to WebSocket server");
}
// Call on_connected callback
if let Some(ref callback) = on_connected {
callback(url.to_string());
}
// Call on_first_connect callback (only on first connection, not on reconnect)
if is_first {
if let Some(ref callback) = on_first_connect {
2026-04-07 15:44:30 +08:00
callback(url.to_string(), sender).await;
}
2026-04-07 16:09:34 +08:00
} else {
// Call on_reconnected callback (after successful reconnection)
if let Some(ref callback) = on_reconnected {
callback(url.to_string(), sender).await;
}
}
2026-04-07 13:55:40 +08:00
// Start heartbeat task if enabled
let heartbeat_handle = if config.heartbeat_interval_ms > 0 {
let interval_duration = Duration::from_millis(config.heartbeat_interval_ms);
Some(tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval_duration);
loop {
ticker.tick().await;
// Heartbeat will be sent via the receiver
}
}))
} else {
None
};
// Handle messages
loop {
tokio::select! {
// Incoming message from server
msg = read.next() => {
match msg {
Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => {
if config.debug_mode {
debug!("Received text: {}", text);
}
// Parse and extract type
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
let msg_type = parsed.get("Type")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
if let Some(ref callback) = on_message {
callback(msg_type, parsed, message_sender.clone());
2026-04-07 13:55:40 +08:00
}
} else if let Some(ref callback) = on_message {
callback("raw".to_string(), serde_json::json!(text), message_sender.clone());
2026-04-07 13:55:40 +08:00
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {
if config.debug_mode {
debug!("Received binary: {} bytes", data.len());
}
if let Some(ref callback) = on_binary {
callback(data);
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) => {
if config.debug_mode {
debug!("Server initiated close");
}
break;
}
Some(Err(e)) => {
if config.debug_mode {
warn!("WebSocket read error: {}", e);
}
if let Some(ref callback) = on_error {
callback(e.to_string());
}
break;
}
None => {
if config.debug_mode {
debug!("WebSocket stream ended");
}
break;
}
_ => {}
}
}
// Outgoing message
outgoing = receiver.recv() => {
match outgoing {
Some(OutgoingMessage::Text(text)) => {
if config.debug_mode {
debug!("Sending text: {}", text);
}
write.send(tokio_tungstenite::tungstenite::Message::Text(text))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
}
Some(OutgoingMessage::Binary(data)) => {
if config.debug_mode {
debug!("Sending binary: {} bytes", data.len());
}
write.send(tokio_tungstenite::tungstenite::Message::Binary(data))
.await
.map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
}
Some(OutgoingMessage::Close) => {
if config.debug_mode {
debug!("Initiating close");
}
write.close().await.ok();
break;
}
None => {
break;
}
}
}
}
}
// Wait for heartbeat handle
if let Some(handle) = heartbeat_handle {
handle.abort();
}
Ok(())
}
/// Disconnect from the server
pub async fn disconnect(&mut self) {
*self.is_running.lock().await = false;
if let Some(sender) = self.sender.lock().await.take() {
let _ = sender.send(OutgoingMessage::Close).await;
}
if let Some(handle) = self.task_handle.lock().await.take() {
handle.abort();
}
self.set_status(ConnectionStatus::Disconnected).await;
}
// ==================== Message Sending ====================
/// Send a text message
pub async fn send(&self, msg_type: &str, data: Value) -> Result<(), WebSocketError> {
let message = WebSocketMessage::new(msg_type, data);
2026-04-07 17:22:49 +08:00
let json = message.to_json_string()?;
2026-04-07 13:55:40 +08:00
self.send_raw_text(json).await
}
/// Send raw text
pub async fn send_raw_text(&self, text: String) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status == ConnectionStatus::Connected {
if let Some(ref sender) = *self.sender.lock().await {
sender.send(OutgoingMessage::Text(text))
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
} else {
// Queue the message
let mut queue = self.message_queue.lock().await;
if queue.len() < self.config.max_queue_size {
if self.config.debug_mode {
debug!("Queueing message (queue size: {})", queue.len() + 1);
}
queue.push(OutgoingMessage::Text(text));
} else {
if self.config.debug_mode {
warn!("Queue full, message dropped");
}
return Err(WebSocketError::QueueFull);
}
}
Ok(())
}
/// Send binary data
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status == ConnectionStatus::Connected {
if let Some(ref sender) = *self.sender.lock().await {
sender.send(OutgoingMessage::Binary(data))
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
} else {
let mut queue = self.message_queue.lock().await;
if queue.len() < self.config.max_queue_size {
if self.config.debug_mode {
debug!("Queueing binary message");
}
queue.push(OutgoingMessage::Binary(data));
} else {
return Err(WebSocketError::QueueFull);
}
}
Ok(())
}
// ==================== Status & Queue ====================
/// Get current connection status
pub async fn get_status(&self) -> ConnectionStatus {
*self.status.lock().await
}
/// Check if connected
pub async fn is_connected(&self) -> bool {
*self.status.lock().await == ConnectionStatus::Connected
}
/// Get queue size
pub async fn get_queue_size(&self) -> usize {
self.message_queue.lock().await.len()
}
/// Flush queued messages
pub async fn flush_queue(&self) -> Result<(), WebSocketError> {
let status = *self.status.lock().await;
if status != ConnectionStatus::Connected {
return Err(WebSocketError::NotConnected);
}
let mut queue = self.message_queue.lock().await;
if let Some(ref sender) = *self.sender.lock().await {
while let Some(msg) = queue.pop() {
sender.send(msg)
.await
.map_err(|_| WebSocketError::NotConnected)?;
}
}
Ok(())
}
// ==================== Private Helpers ====================
#[allow(clippy::large_enum_variant)]
async fn set_status(&mut self, new_status: ConnectionStatus) {
let old_status = *self.status.lock().await;
if old_status != new_status {
*self.status.lock().await = new_status.clone();
// Reset reconnect state on successful connection
if new_status == ConnectionStatus::Connected {
*self.reconnect_attempts.lock().await = 0;
*self.reconnect_delay_ms.lock().await = self.config.reconnect_delay_ms;
}
if let Some(ref callback) = self.on_status_changed {
callback(old_status, new_status);
}
}
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
// 不在这里调用异步 disconnect(),因为:
// 1. 无法在 Drop 中安全地使用 block_on (会导致 "Cannot start a runtime from within a runtime")
// 2. 调用者应该显式调用 disconnect() 方法
// 3. Tokio 的 JoinHandle 会在 task 结束时自动清理资源
// 如果 client 未被显式 disconnect() 就被 drop资源会通过 JoinHandle 的 abort 自然清理
2026-04-07 13:55:40 +08:00
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = WebSocketConfig::new("ws://localhost:8080")
.with_client_type("Client")
.with_debug(true)
.with_reconnect(false);
assert_eq!(config.ws_url, "ws://localhost:8080");
assert_eq!(config.client_type, Some("Client".to_string()));
assert!(config.debug_mode);
assert!(!config.reconnect);
}
#[tokio::test]
async fn test_client_creation() {
let config = WebSocketConfig::default();
let client = WebSocketClient::new(config);
let status = client.get_status().await;
assert_eq!(status, ConnectionStatus::Disconnected);
}
}