新增 MessageSender 类型(on_message 回调的第三个参数),可在回调中同步发送消息

This commit is contained in:
zqm
2026-04-08 10:50:49 +08:00
parent 35ba667583
commit e439084286
3 changed files with 37 additions and 7 deletions

View File

@@ -30,4 +30,4 @@
pub mod websocket;
pub use websocket::{WebSocketClient, WebSocketConfig, WebSocketMessage, OutgoingMessage, ConnectionStatus, ReconnectedCallback};
pub use websocket::{WebSocketClient, WebSocketConfig, WebSocketMessage, OutgoingMessage, ConnectionStatus, ReconnectedCallback, MessageSender};

View File

@@ -51,7 +51,21 @@ impl std::fmt::Display for ConnectionStatus {
/// Event callback types (wrapped in Arc for cloneability)
pub type ConnectedCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type DisconnectedCallback = Arc<dyn Fn() + Send + Sync>;
pub type MessageCallback = Arc<dyn Fn(String, Value) + Send + Sync>;
/// Callback for text messages - args: (msg_type, parsed_data, sender)
/// sender can be used to send messages back synchronously
#[derive(Clone)]
pub struct MessageSender(std::sync::Arc<tokio::sync::Mutex<Option<mpsc::Sender<OutgoingMessage>>>>);
impl MessageSender {
/// Send a text message synchronously (from within the sync callback)
pub fn send(&self, text: String) {
if let Ok(guard) = self.0.try_lock() {
if let Some(ref tx) = *guard {
let _ = tx.try_send(OutgoingMessage::Text(text));
}
}
}
}
pub type MessageCallback = Arc<dyn Fn(String, Value, MessageSender) + Send + Sync>;
pub type BinaryCallback = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
pub type ErrorCallback = Arc<dyn Fn(String) + Send + Sync>;
pub type StatusCallback = Arc<dyn Fn(ConnectionStatus, ConnectionStatus) + Send + Sync>;
@@ -153,9 +167,11 @@ impl WebSocketClient {
}
/// Set callback for text message received
/// The callback receives: (msg_type, parsed_data, sender)
/// The sender can be used to send messages back synchronously
pub fn on_message<F>(&mut self, callback: F) -> &mut Self
where
F: Fn(String, Value) + Send + Sync + 'static,
F: Fn(String, Value, MessageSender) + Send + Sync + 'static,
{
self.on_message = Some(Arc::new(callback));
self
@@ -239,6 +255,17 @@ impl WebSocketClient {
*self.url.blocking_lock() = new_url;
}
/// Send a text message through the WebSocket (can be called from any context)
/// Returns Ok(()) if the message was sent, Err(msg) if not connected
pub async fn send_text(&self, text: String) -> Result<(), String> {
let sender = self.sender.lock().await;
if let Some(ref tx) = *sender {
tx.send(OutgoingMessage::Text(text)).await.map_err(|e| e.to_string())
} else {
Err("Not connected".to_string())
}
}
// ==================== Connection Management ====================
/// Connect to the WebSocket server (uses dynamic URL that can be updated)
@@ -354,8 +381,10 @@ impl WebSocketClient {
// Clone sender for first_connect callback
let sender_clone = Arc::clone(&sender);
// Create MessageSender for on_message callback (same underlying sender)
let message_sender = MessageSender(Arc::clone(&sender));
match Self::connect_and_handle(&current_url, &config, &mut receiver, &queue, &on_connected, &on_message, &on_binary, &on_error, is_first, &on_first_connect, &on_reconnected, sender_clone).await {
match Self::connect_and_handle(&current_url, &config, &mut receiver, &queue, &on_connected, &on_message, &on_binary, &on_error, is_first, &on_first_connect, &on_reconnected, sender_clone, message_sender).await {
Ok(_) => {
if config.debug_mode {
debug!("WebSocket connection closed normally");
@@ -445,6 +474,7 @@ impl WebSocketClient {
on_first_connect: &Option<FirstConnectCallback>,
on_reconnected: &Option<ReconnectedCallback>,
sender: Arc<Mutex<Option<mpsc::Sender<OutgoingMessage>>>>,
message_sender: MessageSender,
) -> Result<(), WebSocketError> {
use http::header::{HeaderName, HeaderValue};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
@@ -536,10 +566,10 @@ impl WebSocketClient {
.to_string();
if let Some(ref callback) = on_message {
callback(msg_type, parsed);
callback(msg_type, parsed, message_sender.clone());
}
} else if let Some(ref callback) = on_message {
callback("raw".to_string(), serde_json::json!(text));
callback("raw".to_string(), serde_json::json!(text), message_sender.clone());
}
}
Some(Ok(tokio_tungstenite::tungstenite::Message::Binary(data))) => {

View File

@@ -12,7 +12,7 @@ mod config;
mod message;
mod error;
pub use client::{WebSocketClient, ConnectionStatus, ReconnectingCallback, ReconnectedCallback, OutgoingMessage};
pub use client::{WebSocketClient, ConnectionStatus, ReconnectingCallback, ReconnectedCallback, OutgoingMessage, MessageSender};
pub use config::WebSocketConfig;
pub use message::{WebSocketMessage, ClientType};
pub use error::WebSocketError;