diff --git a/Cargo.lock b/Cargo.lock
index 7738f23..b4acd97 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -188,6 +188,27 @@ dependencies = [
"wasm-bindgen-futures",
]
+[[package]]
+name = "async-stream"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
+dependencies = [
+ "async-stream-impl",
+ "futures-core",
+]
+
+[[package]]
+name = "async-stream-impl"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
[[package]]
name = "async-task"
version = "4.2.0"
@@ -554,6 +575,7 @@ version = "0.17.0"
dependencies = [
"assert_cmd",
"assert_fs",
+ "async-stream",
"async-walkdir",
"async_zip",
"base64",
diff --git a/Cargo.toml b/Cargo.toml
index 6d16651..8670894 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,7 +15,7 @@ clap = { version = "3", default-features = false, features = ["std"] }
chrono = "0.4"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"]}
tokio-rustls = "0.23"
-tokio-util = { version = "0.7", features = ["codec", "io-util"] }
+tokio-util = { version = "0.7", features = ["io-util"] }
hyper = { version = "0.14", features = ["http1", "server", "tcp", "stream"] }
percent-encoding = "2.1"
serde = { version = "1", features = ["derive"] }
@@ -37,6 +37,7 @@ xml-rs = "0.8"
env_logger = { version = "0.9", default-features = false, features = ["humantime"] }
log = "0.4"
socket2 = "0.4"
+async-stream = "0.3"
[dev-dependencies]
assert_cmd = "2"
diff --git a/src/main.rs b/src/main.rs
index a44acd9..30d4ac8 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,7 @@
mod args;
mod auth;
mod server;
+mod streamer;
mod tls;
#[macro_use]
diff --git a/src/server.rs b/src/server.rs
index 2aee843..23b899b 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,4 +1,5 @@
use crate::auth::{generate_www_auth, valid_digest};
+use crate::streamer::Streamer;
use crate::{Args, BoxResult};
use xml::escape::escape_str_pcdata;
@@ -10,26 +11,26 @@ use futures::stream::StreamExt;
use futures::TryStreamExt;
use headers::{
AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders,
- AccessControlAllowOrigin, Connection, ContentLength, ContentRange, ContentType, ETag,
- HeaderMap, HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range,
+ AccessControlAllowOrigin, Connection, ContentLength, ContentType, ETag, HeaderMap,
+ HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range,
};
use hyper::header::{
- HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, ORIGIN, RANGE,
- WWW_AUTHENTICATE,
+ HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE,
+ CONTENT_TYPE, ORIGIN, RANGE, WWW_AUTHENTICATE,
};
use hyper::{Body, Method, StatusCode, Uri};
use percent_encoding::percent_decode;
use serde::Serialize;
use std::fs::Metadata;
+use std::io::SeekFrom;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::SystemTime;
use tokio::fs::File;
-use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite};
+use tokio::io::{AsyncSeekExt, AsyncWrite};
use tokio::{fs, io};
-use tokio_util::codec::{BytesCodec, FramedRead};
-use tokio_util::io::{ReaderStream, StreamReader};
+use tokio_util::io::StreamReader;
use uuid::Uuid;
pub type Request = hyper::Request
;
@@ -40,7 +41,7 @@ const INDEX_CSS: &str = include_str!("../assets/index.css");
const INDEX_JS: &str = include_str!("../assets/index.js");
const FAVICON_ICO: &[u8] = include_bytes!("../assets/favicon.ico");
const INDEX_NAME: &str = "index.html";
-const BUF_SIZE: usize = 1024 * 16;
+const BUF_SIZE: usize = 65536;
pub struct Server {
args: Arc,
@@ -353,8 +354,8 @@ impl Server {
error!("Failed to zip {}, {}", path.display(), e);
}
});
- let stream = ReaderStream::new(reader);
- *res.body_mut() = Body::wrap_stream(stream);
+ let reader = Streamer::new(reader, BUF_SIZE);
+ *res.body_mut() = Body::wrap_stream(reader.into_stream());
Ok(())
}
@@ -425,7 +426,7 @@ impl Server {
) -> BoxResult<()> {
let (file, meta) = tokio::join!(fs::File::open(path), fs::metadata(path),);
let (mut file, meta) = (file?, meta?);
- let mut maybe_range = true;
+ let mut use_range = true;
if let Some((etag, last_modified)) = extract_cache_headers(&meta) {
let cached = {
if let Some(if_none_match) = headers.typed_get::() {
@@ -436,55 +437,77 @@ impl Server {
false
}
};
- res.headers_mut().typed_insert(last_modified);
- res.headers_mut().typed_insert(etag.clone());
if cached {
*res.status_mut() = StatusCode::NOT_MODIFIED;
return Ok(());
}
+
+ res.headers_mut().typed_insert(last_modified);
+ res.headers_mut().typed_insert(etag.clone());
+
if headers.typed_get::().is_some() {
- maybe_range = headers
+ use_range = headers
.typed_get::()
.map(|if_range| !if_range.is_modified(Some(&etag), Some(&last_modified)))
// Always be fresh if there is no validators
.unwrap_or(true);
} else {
- maybe_range = false;
+ use_range = false;
}
}
- let file_range = if maybe_range {
- if let Some(content_range) = headers
- .typed_get::()
- .and_then(|range| to_content_range(&range, meta.len()))
- {
- res.headers_mut().typed_insert(content_range.clone());
- *res.status_mut() = StatusCode::PARTIAL_CONTENT;
- content_range.bytes_range()
- } else {
- None
- }
+
+ let range = if use_range {
+ parse_range(headers)
} else {
None
};
+
if let Some(mime) = mime_guess::from_path(&path).first() {
res.headers_mut().typed_insert(ContentType::from(mime));
+ } else {
+ res.headers_mut().insert(
+ CONTENT_TYPE,
+ HeaderValue::from_static("application/octet-stream"),
+ );
}
+
res.headers_mut().typed_insert(AcceptRanges::bytes());
- res.headers_mut()
- .typed_insert(ContentLength(meta.len() as u64));
- if head_only {
- return Ok(());
- }
- let body = if let Some((begin, end)) = file_range {
- file.seek(io::SeekFrom::Start(begin)).await?;
- let stream = FramedRead::new(file.take(end - begin + 1), BytesCodec::new());
- Body::wrap_stream(stream)
+ let size = meta.len();
+
+ if let Some(range) = range {
+ if range
+ .end
+ .map_or_else(|| range.start < size, |v| v >= range.start)
+ && file.seek(SeekFrom::Start(range.start)).await.is_ok()
+ {
+ let end = range.end.unwrap_or(size - 1).min(size - 1);
+ let part_size = end - range.start + 1;
+ let reader = Streamer::new(file, BUF_SIZE);
+ *res.status_mut() = StatusCode::PARTIAL_CONTENT;
+ let content_range = format!("bytes {}-{}/{}", range.start, end, size);
+ res.headers_mut()
+ .insert(CONTENT_RANGE, content_range.parse().unwrap());
+ res.headers_mut()
+ .insert(CONTENT_LENGTH, format!("{}", part_size).parse().unwrap());
+ if head_only {
+ return Ok(());
+ }
+ *res.body_mut() = Body::wrap_stream(reader.into_stream_sized(part_size));
+ } else {
+ *res.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
+ res.headers_mut()
+ .insert(CONTENT_RANGE, format!("bytes */{}", size).parse().unwrap());
+ }
} else {
- let stream = FramedRead::new(file, BytesCodec::new());
- Body::wrap_stream(stream)
- };
- *res.body_mut() = body;
+ res.headers_mut()
+ .insert(CONTENT_LENGTH, format!("{}", size).parse().unwrap());
+ if head_only {
+ return Ok(());
+ }
+ let reader = Streamer::new(file, BUF_SIZE);
+ *res.body_mut() = Body::wrap_stream(reader.into_stream());
+ }
Ok(())
}
@@ -965,32 +988,34 @@ fn extract_cache_headers(meta: &Metadata) -> Option<(ETag, LastModified)> {
Some((etag, last_modified))
}
-fn to_content_range(range: &Range, complete_length: u64) -> Option {
- use core::ops::Bound::{Included, Unbounded};
- let mut iter = range.iter();
- let bounds = iter.next();
+#[derive(Debug)]
+struct RangeValue {
+ start: u64,
+ end: Option,
+}
- if iter.next().is_some() {
- // Found multiple byte-range-spec. Drop.
- return None;
+fn parse_range(headers: &HeaderMap) -> Option {
+ let range_hdr = headers.get(RANGE)?;
+ let hdr = range_hdr.to_str().ok()?;
+ let mut sp = hdr.splitn(2, '=');
+ let units = sp.next().unwrap();
+ if units == "bytes" {
+ let range = sp.next()?;
+ let mut sp_range = range.splitn(2, '-');
+ let start: u64 = sp_range.next().unwrap().parse().ok()?;
+ let end: Option = if let Some(end) = sp_range.next() {
+ if end.is_empty() {
+ None
+ } else {
+ Some(end.parse().ok()?)
+ }
+ } else {
+ None
+ };
+ Some(RangeValue { start, end })
+ } else {
+ None
}
-
- bounds.and_then(|b| match b {
- (Included(start), Included(end)) if start <= end && start < complete_length => {
- ContentRange::bytes(
- start..=end.min(complete_length.saturating_sub(1)),
- complete_length,
- )
- .ok()
- }
- (Included(start), Unbounded) if start < complete_length => {
- ContentRange::bytes(start.., complete_length).ok()
- }
- (Unbounded, Included(end)) if end > 0 => {
- ContentRange::bytes(complete_length.saturating_sub(end).., complete_length).ok()
- }
- _ => None,
- })
}
fn encode_uri(v: &str) -> String {
diff --git a/src/streamer.rs b/src/streamer.rs
new file mode 100644
index 0000000..163b36f
--- /dev/null
+++ b/src/streamer.rs
@@ -0,0 +1,68 @@
+use async_stream::stream;
+use futures::{Stream, StreamExt};
+use std::io::Error;
+use std::pin::Pin;
+use tokio::io::{AsyncRead, AsyncReadExt};
+
+pub struct Streamer
+where
+ R: AsyncRead + Unpin + Send + 'static,
+{
+ reader: R,
+ buf_size: usize,
+}
+
+impl Streamer
+where
+ R: AsyncRead + Unpin + Send + 'static,
+{
+ #[inline]
+ pub fn new(reader: R, buf_size: usize) -> Self {
+ Self { reader, buf_size }
+ }
+ pub fn into_stream(
+ mut self,
+ ) -> Pin, Error>> + 'static>> {
+ let stream = stream! {
+ loop {
+ let mut buf = vec![0; self.buf_size];
+ let r = self.reader.read(&mut buf).await?;
+ if r == 0 {
+ break
+ }
+ buf.truncate(r);
+ yield Ok(buf);
+ }
+ };
+ stream.boxed()
+ }
+ // allow truncation as truncated remaining is always less than buf_size: usize
+ pub fn into_stream_sized(
+ mut self,
+ max_length: u64,
+ ) -> Pin, Error>> + 'static>> {
+ let stream = stream! {
+ let mut remaining = max_length;
+ loop {
+ if remaining == 0 {
+ break;
+ }
+ let bs = if remaining >= self.buf_size as u64 {
+ self.buf_size
+ } else {
+ remaining as usize
+ };
+ let mut buf = vec![0; bs];
+ let r = self.reader.read(&mut buf).await?;
+ if r == 0 {
+ break;
+ } else {
+ buf.truncate(r);
+ yield Ok(buf);
+ }
+ remaining -= r as u64;
+ }
+ };
+ stream.boxed()
+ }
+}
diff --git a/tests/range.rs b/tests/range.rs
new file mode 100644
index 0000000..a2c9c50
--- /dev/null
+++ b/tests/range.rs
@@ -0,0 +1,45 @@
+mod fixtures;
+mod utils;
+
+use fixtures::{server, Error, TestServer};
+use headers::HeaderValue;
+use rstest::rstest;
+
+#[rstest]
+fn get_file_range(server: TestServer) -> Result<(), Error> {
+ let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+ .header("range", HeaderValue::from_static("bytes=0-6"))
+ .send()?;
+ assert_eq!(resp.status(), 206);
+ assert_eq!(resp.headers().get("content-range").unwrap(), "bytes 0-6/18");
+ assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes");
+ assert_eq!(resp.headers().get("content-length").unwrap(), "7");
+ assert_eq!(resp.text()?, "This is");
+ Ok(())
+}
+
+#[rstest]
+fn get_file_range_beyond(server: TestServer) -> Result<(), Error> {
+ let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+ .header("range", HeaderValue::from_static("bytes=12-20"))
+ .send()?;
+ assert_eq!(resp.status(), 206);
+ assert_eq!(
+ resp.headers().get("content-range").unwrap(),
+ "bytes 12-17/18"
+ );
+ assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes");
+ assert_eq!(resp.headers().get("content-length").unwrap(), "6");
+ assert_eq!(resp.text()?, "x.html");
+ Ok(())
+}
+
+#[rstest]
+fn get_file_range_invalid(server: TestServer) -> Result<(), Error> {
+ let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+ .header("range", HeaderValue::from_static("bytes=20-"))
+ .send()?;
+ assert_eq!(resp.status(), 416);
+ assert_eq!(resp.headers().get("content-range").unwrap(), "bytes */18");
+ Ok(())
+}