From d8f7335053d151c3db90d7f07a752688c3c7bdfd Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 16 Jun 2022 10:24:32 +0800 Subject: [PATCH] fix: range request (#44) close #43 --- Cargo.lock | 22 +++++++ Cargo.toml | 3 +- src/main.rs | 1 + src/server.rs | 151 ++++++++++++++++++++++++++++-------------------- src/streamer.rs | 68 ++++++++++++++++++++++ tests/range.rs | 45 +++++++++++++++ 6 files changed, 226 insertions(+), 64 deletions(-) create mode 100644 src/streamer.rs create mode 100644 tests/range.rs diff --git a/Cargo.lock b/Cargo.lock index 7738f23..b4acd97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,6 +188,27 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-task" version = "4.2.0" @@ -554,6 +575,7 @@ version = "0.17.0" dependencies = [ "assert_cmd", "assert_fs", + "async-stream", "async-walkdir", "async_zip", "base64", diff --git a/Cargo.toml b/Cargo.toml index 6d16651..8670894 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ clap = { version = "3", default-features = false, features = ["std"] } chrono = "0.4" tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"]} tokio-rustls = "0.23" -tokio-util = { version = "0.7", features = ["codec", "io-util"] } +tokio-util = { version = "0.7", features = ["io-util"] } hyper = { version = "0.14", features = ["http1", "server", "tcp", "stream"] } percent-encoding = "2.1" serde = { version = "1", features = ["derive"] } @@ -37,6 +37,7 @@ xml-rs = "0.8" env_logger = { version = "0.9", default-features = false, features = ["humantime"] } log = "0.4" socket2 = "0.4" +async-stream = "0.3" [dev-dependencies] assert_cmd = "2" diff --git a/src/main.rs b/src/main.rs index a44acd9..30d4ac8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod args; mod auth; mod server; +mod streamer; mod tls; #[macro_use] diff --git a/src/server.rs b/src/server.rs index 2aee843..23b899b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use crate::auth::{generate_www_auth, valid_digest}; +use crate::streamer::Streamer; use crate::{Args, BoxResult}; use xml::escape::escape_str_pcdata; @@ -10,26 +11,26 @@ use futures::stream::StreamExt; use futures::TryStreamExt; use headers::{ AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders, - AccessControlAllowOrigin, Connection, ContentLength, ContentRange, ContentType, ETag, - HeaderMap, HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range, + AccessControlAllowOrigin, Connection, ContentLength, ContentType, ETag, HeaderMap, + HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range, }; use hyper::header::{ - HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, ORIGIN, RANGE, - WWW_AUTHENTICATE, + HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, ORIGIN, RANGE, WWW_AUTHENTICATE, }; use hyper::{Body, Method, StatusCode, Uri}; use percent_encoding::percent_decode; use serde::Serialize; use std::fs::Metadata; +use std::io::SeekFrom; use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::SystemTime; use tokio::fs::File; -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite}; +use tokio::io::{AsyncSeekExt, AsyncWrite}; use tokio::{fs, io}; -use tokio_util::codec::{BytesCodec, FramedRead}; -use tokio_util::io::{ReaderStream, StreamReader}; +use tokio_util::io::StreamReader; use uuid::Uuid; pub type Request = hyper::Request; @@ -40,7 +41,7 @@ const INDEX_CSS: &str = include_str!("../assets/index.css"); const INDEX_JS: &str = include_str!("../assets/index.js"); const FAVICON_ICO: &[u8] = include_bytes!("../assets/favicon.ico"); const INDEX_NAME: &str = "index.html"; -const BUF_SIZE: usize = 1024 * 16; +const BUF_SIZE: usize = 65536; pub struct Server { args: Arc, @@ -353,8 +354,8 @@ impl Server { error!("Failed to zip {}, {}", path.display(), e); } }); - let stream = ReaderStream::new(reader); - *res.body_mut() = Body::wrap_stream(stream); + let reader = Streamer::new(reader, BUF_SIZE); + *res.body_mut() = Body::wrap_stream(reader.into_stream()); Ok(()) } @@ -425,7 +426,7 @@ impl Server { ) -> BoxResult<()> { let (file, meta) = tokio::join!(fs::File::open(path), fs::metadata(path),); let (mut file, meta) = (file?, meta?); - let mut maybe_range = true; + let mut use_range = true; if let Some((etag, last_modified)) = extract_cache_headers(&meta) { let cached = { if let Some(if_none_match) = headers.typed_get::() { @@ -436,55 +437,77 @@ impl Server { false } }; - res.headers_mut().typed_insert(last_modified); - res.headers_mut().typed_insert(etag.clone()); if cached { *res.status_mut() = StatusCode::NOT_MODIFIED; return Ok(()); } + + res.headers_mut().typed_insert(last_modified); + res.headers_mut().typed_insert(etag.clone()); + if headers.typed_get::().is_some() { - maybe_range = headers + use_range = headers .typed_get::() .map(|if_range| !if_range.is_modified(Some(&etag), Some(&last_modified))) // Always be fresh if there is no validators .unwrap_or(true); } else { - maybe_range = false; + use_range = false; } } - let file_range = if maybe_range { - if let Some(content_range) = headers - .typed_get::() - .and_then(|range| to_content_range(&range, meta.len())) - { - res.headers_mut().typed_insert(content_range.clone()); - *res.status_mut() = StatusCode::PARTIAL_CONTENT; - content_range.bytes_range() - } else { - None - } + + let range = if use_range { + parse_range(headers) } else { None }; + if let Some(mime) = mime_guess::from_path(&path).first() { res.headers_mut().typed_insert(ContentType::from(mime)); + } else { + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); } + res.headers_mut().typed_insert(AcceptRanges::bytes()); - res.headers_mut() - .typed_insert(ContentLength(meta.len() as u64)); - if head_only { - return Ok(()); - } - let body = if let Some((begin, end)) = file_range { - file.seek(io::SeekFrom::Start(begin)).await?; - let stream = FramedRead::new(file.take(end - begin + 1), BytesCodec::new()); - Body::wrap_stream(stream) + let size = meta.len(); + + if let Some(range) = range { + if range + .end + .map_or_else(|| range.start < size, |v| v >= range.start) + && file.seek(SeekFrom::Start(range.start)).await.is_ok() + { + let end = range.end.unwrap_or(size - 1).min(size - 1); + let part_size = end - range.start + 1; + let reader = Streamer::new(file, BUF_SIZE); + *res.status_mut() = StatusCode::PARTIAL_CONTENT; + let content_range = format!("bytes {}-{}/{}", range.start, end, size); + res.headers_mut() + .insert(CONTENT_RANGE, content_range.parse().unwrap()); + res.headers_mut() + .insert(CONTENT_LENGTH, format!("{}", part_size).parse().unwrap()); + if head_only { + return Ok(()); + } + *res.body_mut() = Body::wrap_stream(reader.into_stream_sized(part_size)); + } else { + *res.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE; + res.headers_mut() + .insert(CONTENT_RANGE, format!("bytes */{}", size).parse().unwrap()); + } } else { - let stream = FramedRead::new(file, BytesCodec::new()); - Body::wrap_stream(stream) - }; - *res.body_mut() = body; + res.headers_mut() + .insert(CONTENT_LENGTH, format!("{}", size).parse().unwrap()); + if head_only { + return Ok(()); + } + let reader = Streamer::new(file, BUF_SIZE); + *res.body_mut() = Body::wrap_stream(reader.into_stream()); + } Ok(()) } @@ -965,32 +988,34 @@ fn extract_cache_headers(meta: &Metadata) -> Option<(ETag, LastModified)> { Some((etag, last_modified)) } -fn to_content_range(range: &Range, complete_length: u64) -> Option { - use core::ops::Bound::{Included, Unbounded}; - let mut iter = range.iter(); - let bounds = iter.next(); +#[derive(Debug)] +struct RangeValue { + start: u64, + end: Option, +} - if iter.next().is_some() { - // Found multiple byte-range-spec. Drop. - return None; +fn parse_range(headers: &HeaderMap) -> Option { + let range_hdr = headers.get(RANGE)?; + let hdr = range_hdr.to_str().ok()?; + let mut sp = hdr.splitn(2, '='); + let units = sp.next().unwrap(); + if units == "bytes" { + let range = sp.next()?; + let mut sp_range = range.splitn(2, '-'); + let start: u64 = sp_range.next().unwrap().parse().ok()?; + let end: Option = if let Some(end) = sp_range.next() { + if end.is_empty() { + None + } else { + Some(end.parse().ok()?) + } + } else { + None + }; + Some(RangeValue { start, end }) + } else { + None } - - bounds.and_then(|b| match b { - (Included(start), Included(end)) if start <= end && start < complete_length => { - ContentRange::bytes( - start..=end.min(complete_length.saturating_sub(1)), - complete_length, - ) - .ok() - } - (Included(start), Unbounded) if start < complete_length => { - ContentRange::bytes(start.., complete_length).ok() - } - (Unbounded, Included(end)) if end > 0 => { - ContentRange::bytes(complete_length.saturating_sub(end).., complete_length).ok() - } - _ => None, - }) } fn encode_uri(v: &str) -> String { diff --git a/src/streamer.rs b/src/streamer.rs new file mode 100644 index 0000000..163b36f --- /dev/null +++ b/src/streamer.rs @@ -0,0 +1,68 @@ +use async_stream::stream; +use futures::{Stream, StreamExt}; +use std::io::Error; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncReadExt}; + +pub struct Streamer +where + R: AsyncRead + Unpin + Send + 'static, +{ + reader: R, + buf_size: usize, +} + +impl Streamer +where + R: AsyncRead + Unpin + Send + 'static, +{ + #[inline] + pub fn new(reader: R, buf_size: usize) -> Self { + Self { reader, buf_size } + } + pub fn into_stream( + mut self, + ) -> Pin, Error>> + 'static>> { + let stream = stream! { + loop { + let mut buf = vec![0; self.buf_size]; + let r = self.reader.read(&mut buf).await?; + if r == 0 { + break + } + buf.truncate(r); + yield Ok(buf); + } + }; + stream.boxed() + } + // allow truncation as truncated remaining is always less than buf_size: usize + pub fn into_stream_sized( + mut self, + max_length: u64, + ) -> Pin, Error>> + 'static>> { + let stream = stream! { + let mut remaining = max_length; + loop { + if remaining == 0 { + break; + } + let bs = if remaining >= self.buf_size as u64 { + self.buf_size + } else { + remaining as usize + }; + let mut buf = vec![0; bs]; + let r = self.reader.read(&mut buf).await?; + if r == 0 { + break; + } else { + buf.truncate(r); + yield Ok(buf); + } + remaining -= r as u64; + } + }; + stream.boxed() + } +} diff --git a/tests/range.rs b/tests/range.rs new file mode 100644 index 0000000..a2c9c50 --- /dev/null +++ b/tests/range.rs @@ -0,0 +1,45 @@ +mod fixtures; +mod utils; + +use fixtures::{server, Error, TestServer}; +use headers::HeaderValue; +use rstest::rstest; + +#[rstest] +fn get_file_range(server: TestServer) -> Result<(), Error> { + let resp = fetch!(b"GET", format!("{}index.html", server.url())) + .header("range", HeaderValue::from_static("bytes=0-6")) + .send()?; + assert_eq!(resp.status(), 206); + assert_eq!(resp.headers().get("content-range").unwrap(), "bytes 0-6/18"); + assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes"); + assert_eq!(resp.headers().get("content-length").unwrap(), "7"); + assert_eq!(resp.text()?, "This is"); + Ok(()) +} + +#[rstest] +fn get_file_range_beyond(server: TestServer) -> Result<(), Error> { + let resp = fetch!(b"GET", format!("{}index.html", server.url())) + .header("range", HeaderValue::from_static("bytes=12-20")) + .send()?; + assert_eq!(resp.status(), 206); + assert_eq!( + resp.headers().get("content-range").unwrap(), + "bytes 12-17/18" + ); + assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes"); + assert_eq!(resp.headers().get("content-length").unwrap(), "6"); + assert_eq!(resp.text()?, "x.html"); + Ok(()) +} + +#[rstest] +fn get_file_range_invalid(server: TestServer) -> Result<(), Error> { + let resp = fetch!(b"GET", format!("{}index.html", server.url())) + .header("range", HeaderValue::from_static("bytes=20-")) + .send()?; + assert_eq!(resp.status(), 416); + assert_eq!(resp.headers().get("content-range").unwrap(), "bytes */18"); + Ok(()) +}