[Rust] Tăng tốc download sử dụng nhiều threads trong Rust
Bài toán download một file từ mạng cũng là bài toán quen thuộc. Trong bài này chúng ta sẽ tìm hiểu cách để download một files lớn bằng cách sử dụng nhiều threads để download từng phần và ghép lại. Điều này sẽ giúp giảm thời gian download.
Bài này phân tích từ https://github.com/aochagavia/toy-download-accelerator. Các kỹ năng có thể học được trong bài download này là:
- Chia một buffer thành các
buffernhỏ - Sử dụng
semaphoređể giới hạn số threads thực hiện đồng thời - Cách gửi một
http requestđể download một phần nhỏ của một files
Lấy thông tin kích thước của file cần download
- Sử dụng
reqwestlibrary đểlet http_client = reqwest::Client::new(); - Gửi bản tin
HEADđể lấy thông tin của file size bằng cách tìm kiếm Header làACCEPT_RANGESandCONTENT_LENGTH
use reqwest::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
let http_client = reqwest::Client::new();
let head_response = http_client
.head(url)
.send()
.await
.context("HEAD request failed")?;
let Some(content_length) = head_response.headers().get(CONTENT_LENGTH) else {
bail!("HEAD response did not contain a Content-Length header");
};
Tạo các tasks để download từng phần của files
Từ một kích thước file ban đầu chúng ta sẽ cần phải phần
- Giả sử chúng ta cần tạo
CONCURRENT_REQUEST_LIMIT = 20tasks đồng thời - Với kích thước file
10.000.000 MB = 10GB. Do đó mỗi task sẽ download500MB
Các bước như sau:
- Tạo một buffer kiểu dữ liệu
BytesMutvới kích thước là10GB(không biết có tạo được ko???) - Thự viện này có hàm
split_to. Sử dụng thư viện này để tách thành buffer cho mỗi chunk.
pub fn split_to(&mut self, at: usize) -> BytesMut
Hàm này sẽ chia một buffer thành 2 buffer, với điểm chia taị vị trí at. Ta sẽ có 2 mảng
[0, at) và [at, end)
Ví dụ:
let mut a = BytesMut::from(&b"Hello world"[..]);
let mut b = a.split_to(5)
/// a[0] = b'!';
/// b[0] = b'j';
///
/// assert_eq!(&a[..], b"!world");
/// assert_eq!(&b[..], b"jello");
Chúng ta sẽ có &a[..] = " world" còn &b[..] = "Hello"
- Một điểm chú ý nữa là đối với chunk cuối cùng, có thể
chunk_size(kích thước mỗi chunk) nhỏ hơn kich
let this_chunk_size = chunk_size.min(buffer.len());
- Tạo một task
let task = tokio::spawn(async move {
let _permit = chunk_semaphore.acquire().await?;
let start = Instant::now();
let range_start = chunk_number * chunk_size;
//TODO:
download_chunk();
let duration = start.elapsed();
let chunk_size_mb = buffer_slice_for_chunk.len() as f64 / 1024.0 / 1024.0;
println!("* Chunk {chunk_number} downloaded in {} ms (size = {:.2} MiB; throughput = {:.2} MiB/s)", duration.as_millis(), chunk_size_mb, chunk_size_mb / duration.as_secs_f64());
Ok::<_, anyhow::Error>(buffer_slice_for_chunk)
});
for task in download_tasks {
let buffer_slice = task
.await
.context("tokio task unexpectedly crashed")?
.context("chunk download failed")?;
buffer.unsplit(buffer_slice);
}
Trong mỗi task chúng ta sử dụng một semaphore. Đặc tính của semaphore là, khi khởi tạo cần khởi cần truyền một số thể hiện số max threads có thể cùng đồng thời xử lý
Mỗi khi chúng ta gọi hàm acquire(), giá trị biến max_threads trong sempahore này sẽ giảm đi một đơn vị. Đến khi giá trị này bằng 0, hàm acquire() sẽ bị block.
Chỉ khi có một thread thoát, biến max_threads sẽ được đếm lên một đơn vị. Đấy là cách chúng ta có thể sử dụng để giới hạn số task download cùng lúc.
use tokio::sync::Semaphore;
// Khởi tạo sempahore
let semaphore = Arc::new(Semaphore::new(CONCURRENT_REQUEST_LIMIT));
// Sử dụng semaphore
let _permit = chunk_semaphore.acquire().await?;
- Chúng ta tạo các
Taskvà đẩy vào một mảng. Chỉ khi chúng ta gọi hàm.awaitnó mới thực sự thực thi đoạn code trong khối lệnhasync {}. Kết quả của mỗi future này sẽ trả về mộtbuffercủa mỗi chunk. Chúng ta sử dụng hàmunsplitđể có thể gộp lại các phần nhỏ sau khi đã chia tách nó.
Cho ví dụ:
let mut buf = BytesMut::with_capacity(64);
buf.extend_from_slice(b"aaabbbcccddd");
let splitted = buf.split_off(6);
assert_eq!(b"aaabbb", &buf[..]);
assert_eq!(b"cccddd", &splitted[..]);
buf.unsplit(splitted);
assert_eq!(b"aaabbbcccddd", &buf[..]);
Thực hiện hàm download_chunk
- Để có thể dowload mỗi chunk với chính xác vị trí và range bytes download chúng ta cần một
httprequest như sau: - Cần có header
RANGEvới giá trị như saubytes=0-1000. Ví dụ như trên chúng ta sẽ download bytes từ0tới1000
async fn download_chunk(
http_client: &reqwest::Client,
url: &str,
buffer: &mut BytesMut,
range_start: u64,
) -> anyhow::Result<()>
// How to calculate
let range_end_inclusive = range_start + buffer.len()-1;
let range_header = format!("bytes={range_start}-{range_end_inclusive}");
Tóm lại chúng ta cần gửi một bản tin HTTP với nội dụng dạng như sau:
GET https://<api-need-to-download>
RANGE: bytes="0-122330"
Tính toán checksum
- Hàm checksum là hàm consume
buffervà trả về một chuỗi duy nhất. Chúng ta sử dụngsha256để tạo chuỗi này. Chỉ cần một bytes sai khác chúng ta cũng sẽ nhận chuỗi checksum khác nha
fn sha256(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
let digest = hasher.finalize();
format!("{digest:x}")
}
// How to use
let checksum = sha256(&buffer);