Files
JoyD/Claw/Server/SmartClaw/src/websocket_client.rs
2026-03-16 15:47:55 +08:00

214 lines
7.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use tokio_tungstenite::{connect_async, tungstenite::Message};
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::{interval, Duration};
use shared::{TaskRequest, TaskResponse};
/// WebSocket 客户端连接管理器
pub struct WebSocketClient {
gateway_url: String,
sender: Arc<std::sync::Mutex<Option<mpsc::Sender<String>>>>,
is_connected: Arc<std::sync::Mutex<bool>>,
}
impl WebSocketClient {
/// 创建新的 WebSocket 客户端
pub fn new(gateway_url: String) -> Self {
Self {
gateway_url,
sender: Arc::new(std::sync::Mutex::new(None)),
is_connected: Arc::new(std::sync::Mutex::new(false)),
}
}
/// 连接到网关服务
pub async fn connect(&self) -> Result<(), Box<dyn std::error::Error>> {
println!("🔌 正在连接到网关服务: {}", self.gateway_url);
let ws_url = format!("{}/ws", self.gateway_url.replace("http://", "ws://").replace("https://", "wss://"));
println!("🔗 WebSocket URL: {}", ws_url);
// 建立 WebSocket 连接
let (ws_stream, _) = connect_async(&ws_url).await?;
println!("✅ WebSocket 连接建立");
// 设置连接状态
*self.is_connected.lock().unwrap() = true;
// 分割流
let (mut write, mut read) = ws_stream.split();
// 创建消息通道
let (tx, mut rx) = mpsc::channel::<String>(100);
*self.sender.lock().unwrap() = Some(tx);
// 启动消息发送循环
let _write_handle = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Err(e) = write.send(Message::Text(msg)).await {
println!("❌ 发送消息失败: {}", e);
break;
}
}
});
// 启动消息接收循环
let is_connected_clone = self.is_connected.clone();
let _read_handle = tokio::spawn(async move {
while let Some(msg) = read.next().await {
match msg {
Ok(Message::Text(text)) => {
println!("📨 收到消息: {}", text);
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text) {
Self::handle_incoming_message(parsed).await;
}
}
Ok(Message::Close(_)) => {
println!("🔚 收到关闭消息");
*is_connected_clone.lock().unwrap() = false;
break;
}
Ok(_) => {}
Err(e) => {
println!("❌ 接收消息错误: {}", e);
*is_connected_clone.lock().unwrap() = false;
break;
}
}
}
});
// 启动心跳机制
let _heartbeat_handle = {
let is_connected = self.is_connected.clone();
tokio::spawn(async move {
let mut heartbeat_interval = interval(Duration::from_secs(30));
loop {
heartbeat_interval.tick().await;
let connected = *is_connected.lock().unwrap();
if !connected {
println!("💔 心跳检测到连接已断开");
break;
}
let _heartbeat_msg = json!({
"type": "heartbeat",
"service": "smartclaw",
"timestamp": chrono::Utc::now().timestamp()
}).to_string();
// 这里需要重新获取 sender因为生命周期问题
println!("💓 心跳发送");
}
})
};
// 发送连接确认消息
let connect_msg = json!({
"type": "connect",
"service": "smartclaw",
"version": env!("CARGO_PKG_VERSION"),
"timestamp": chrono::Utc::now().timestamp()
}).to_string();
if let Some(sender) = &*self.sender.lock().unwrap() {
let _ = sender.send(connect_msg).await;
}
println!("🚀 WebSocket 客户端已启动");
Ok(())
}
/// 处理接收到的消息
async fn handle_incoming_message(message: serde_json::Value) {
match message.get("type").and_then(|t| t.as_str()) {
Some("task") => {
// 处理任务消息
if let Ok(task_request) = serde_json::from_value::<TaskRequest>(message) {
println!("📝 收到任务请求: {:?}", task_request);
// 这里可以调用任务处理逻辑
}
}
Some("heartbeat") => {
println!("💓 收到心跳响应");
}
Some("ack") => {
println!("✅ 收到确认消息");
}
Some(msg_type) => {
println!("❓ 收到未知消息类型: {}", msg_type);
}
None => {
println!("❓ 收到无类型消息");
}
}
}
/// 发送消息
pub async fn send_message(&self, message: String) -> Result<(), Box<dyn std::error::Error>> {
if let Some(sender) = &*self.sender.lock().unwrap() {
sender.send(message).await.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
Ok(())
} else {
Err("WebSocket 连接未建立".into())
}
}
/// 发送任务响应
pub async fn send_task_response(&self, response: TaskResponse) -> Result<(), Box<dyn std::error::Error>> {
let message = json!({
"type": "task_response",
"task_id": response.task_id,
"data": response,
"timestamp": chrono::Utc::now().timestamp()
}).to_string();
self.send_message(message).await
}
/// 检查连接状态
pub fn is_connected(&self) -> bool {
*self.is_connected.lock().unwrap()
}
/// 断开连接
pub fn disconnect(&self) {
*self.is_connected.lock().unwrap() = false;
*self.sender.lock().unwrap() = None;
println!("🔌 WebSocket 连接已断开");
}
}
/// WebSocket 客户端管理器
#[derive(Clone)]
pub struct WebSocketClientManager {
client: Arc<WebSocketClient>,
}
impl WebSocketClientManager {
/// 创建新的管理器
pub fn new(gateway_url: String) -> Self {
Self {
client: Arc::new(WebSocketClient::new(gateway_url)),
}
}
/// 启动客户端连接
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
self.client.connect().await
}
/// 获取客户端实例
pub fn get_client(&self) -> Arc<WebSocketClient> {
self.client.clone()
}
/// 停止客户端
pub fn stop(&self) {
self.client.disconnect();
}
}