支持断点续传

This commit is contained in:
zqm
2026-04-08 14:39:36 +08:00
parent acfc82f04b
commit 8df4f6c473
3 changed files with 112 additions and 6 deletions

View File

@@ -10,6 +10,7 @@ dependencies = [
"chrono",
"cube_lib",
"dirs",
"md5",
"serde",
"serde_json",
"tokio",
@@ -609,6 +610,12 @@ dependencies = [
"regex-automata",
]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "memchr"
version = "2.8.0"

View File

@@ -11,6 +11,7 @@ tokio = { version = "1.37", features = ["full"] }
windows = { version = "0.56", features = ["Win32_System_Console"] }
chrono = "0.4"
base64 = "0.22"
md5 = "0.7"
# Local CubeLib for WebSocket
cube_lib = { path = "../../../Rust/CubeLib" }

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use std::io::Read;
use std::fs::{self, File};
use std::io::Write as IoWrite;
use std::path::PathBuf;
@@ -119,6 +120,51 @@ fn version_less_than(a: &str, b: &str) -> bool {
false
}
/// 计算本地文件前 N 字节的 MD5 hash字节数为 0 表示全部)
fn compute_file_hash(filename: &str, bytes: u64, _debug: bool) -> Option<String> {
let file_path = get_updater_data_dir().join(filename);
if !file_path.exists() {
return None;
}
let mut file = File::open(&file_path).ok()?;
let file_size = file.metadata().ok()?.len();
// 如果 bytes=0 或超过文件大小,则计算整个文件
let read_bytes = if bytes == 0 || bytes > file_size {
file_size
} else {
bytes
};
// 读取前 read_bytes 字节
let mut buffer = vec![0u8; read_bytes as usize];
file.read_exact(&mut buffer).ok()?;
let hash = md5::compute(&buffer);
Some(format!("{:x}", hash))
}
/// 获取临时文件的当前大小(字节数)
fn get_tmp_file_size(filename: &str) -> u64 {
let tmp_path = get_updater_data_dir().join(format!("{}.tmp", filename));
tmp_path.metadata().map(|m| m.len()).unwrap_or(0)
}
/// 发送 GetFileMd5 请求
fn request_file_md5(sender: &cube_lib::websocket::MessageSender, filename: &str, bytes: u64, debug: bool) {
let msg_str = format!(
r#"{{"Type":"GetFileMd5","Data":{{"filename":"{}","bytes":{}}}}}"#,
filename, bytes
);
if debug {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} 发送消息:{}", ts, msg_str);
}
sender.send(msg_str);
}
/// 发送文件下载请求(断点续传)
fn request_download(sender: &cube_lib::websocket::MessageSender, filename: &str, offset: u64, debug: bool) {
let msg_str = format!(
@@ -440,26 +486,78 @@ async fn run_updater(debug_mode: bool) {
for (filename, server_ver) in file_versions {
let server_version = server_ver.as_str().unwrap_or("0.0.0");
let local_version = get_local_file_version(filename);
let tmp_size = get_tmp_file_size(filename);
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [版本] {}: 服务端={}, 本地={}", ts, filename, server_version, local_version);
println!("{} [版本] {}: 服务端={}, 本地={}, tmp大小={}",
ts, filename, server_version, local_version, tmp_size);
}
// 比较版本:如果本地不存在或比服务端旧,则下载
let need_update = local_version == "0.0.0" || version_less_than(&local_version, server_version);
if need_update {
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [升级] {} 需要更新,开始下载...", ts, filename);
if tmp_size > 0 {
// 有临时文件,请求 hash 校验
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [续传] {} 发现未完成下载,请求 hash 校验...", ts, filename);
}
request_file_md5(&sender, filename, tmp_size, debug_msg);
} else {
// 无临时文件,从头下载
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [升级] {} 需要更新,开始下载...", ts, filename);
}
request_download(&sender, filename, 0, debug_msg);
}
// 从 offset 0 开始下载(断点续传逻辑在 handle_file_chunk 中处理)
request_download(&sender, filename, 0, debug_msg);
}
}
}
}
// 处理 Md5 响应
if msg_type == "Md5" {
if let Some(md5_data) = data.get("Data").and_then(|v| v.as_object()) {
let filename = md5_data.get("filename").and_then(|v| v.as_str()).unwrap_or("");
let server_md5 = md5_data.get("md5").and_then(|v| v.as_str()).unwrap_or("");
let bytes = md5_data.get("bytes").and_then(|v| v.as_u64()).unwrap_or(0);
let tmp_filename = format!("{}.tmp", filename);
// 计算本地 tmp 文件前 bytes 字节的 md5
let local_md5 = compute_file_hash(&tmp_filename, bytes, debug_msg);
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
if let Some(ref lm) = local_md5 {
println!("{} [续传] {} md5对比: 本地={}, 服务端={}", ts, filename, lm, server_md5);
} else {
println!("{} [续传] {} 无法计算本地md5重新下载", ts, filename);
}
}
// 比较 md5
if local_md5.as_deref() == Some(server_md5) {
// md5 相同,从 offset bytes 处续传
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [续传] {} md5匹配{} 字节处续传", ts, filename, bytes);
}
request_download(&sender, filename, bytes, debug_msg);
} else {
// md5 不同,删除临时文件,重新下载
if debug_msg {
let ts = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
println!("{} [续传] {} md5不匹配重新下载", ts, filename);
}
let tmp_path = get_updater_data_dir().join(&tmp_filename);
let _ = fs::remove_file(&tmp_path);
request_download(&sender, filename, 0, debug_msg);
}
}
}
// 处理文件块
if msg_type == "FileChunk" {
if let Some(data_obj) = data.get("Data").and_then(|v| v.as_object()) {