diff --git a/Cargo.lock b/Cargo.lock index 5849b22..e62110a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -882,6 +882,15 @@ dependencies = [ "digest", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.6" @@ -965,6 +974,7 @@ dependencies = [ "num_cpus", "once_cell", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "winapi 0.3.9", diff --git a/Cargo.toml b/Cargo.toml index c068ede..3e46802 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["static", "file", "server", "http", "cli"] [dependencies] clap = { version = "3", default-features = false, features = ["std", "cargo"] } chrono = "0.4" -tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util"]} +tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"]} tokio-rustls = "0.23" tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "io-util"] } diff --git a/src/server.rs b/src/server.rs index de53d37..d1f38e7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -53,52 +53,65 @@ macro_rules! status { } pub async fn serve(args: Args) -> BoxResult<()> { + match args.tls.as_ref() { + Some(_) => serve_https(args).await, + None => serve_http(args).await, + } +} + +pub async fn serve_https(args: Args) -> BoxResult<()> { let args = Arc::new(args); let socket_addr = args.address()?; + let (certs, key) = args.tls.clone().unwrap(); let inner = Arc::new(InnerService::new(args.clone())); - if let Some((certs, key)) = args.tls.as_ref() { - let config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs.clone(), key.clone())?; - let tls_acceptor = TlsAcceptor::from(Arc::new(config)); - let arc_acceptor = Arc::new(tls_acceptor); - let listener = TcpListener::bind(&socket_addr).await?; - let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); - let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async { - match socket { - Ok(stream) => match arc_acceptor.clone().accept(stream).await { - Ok(val) => Some(Ok::<_, Infallible>(val)), - Err(_) => None, - }, + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key)?; + let tls_acceptor = TlsAcceptor::from(Arc::new(config)); + let arc_acceptor = Arc::new(tls_acceptor); + let listener = TcpListener::bind(&socket_addr).await?; + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async { + match socket { + Ok(stream) => match arc_acceptor.clone().accept(stream).await { + Ok(val) => Some(Ok::<_, Infallible>(val)), Err(_) => None, - } - })); - let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| { - let inner = inner.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let inner = inner.clone(); - inner.call(req) - })) - } - })); - print_listening(args.address.as_str(), args.port, true); - server.await?; - } else { - let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| { - let inner = inner.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let inner = inner.clone(); - inner.call(req) - })) - } - })); - print_listening(args.address.as_str(), args.port, false); - server.await?; - } + }, + Err(_) => None, + } + })); + let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| { + let inner = inner.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let inner = inner.clone(); + inner.call(req) + })) + } + })); + print_listening(args.address.as_str(), args.port, true); + let graceful = server.with_graceful_shutdown(shutdown_signal()); + graceful.await?; + Ok(()) +} +pub async fn serve_http(args: Args) -> BoxResult<()> { + let args = Arc::new(args); + let socket_addr = args.address()?; + let inner = Arc::new(InnerService::new(args.clone())); + let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| { + let inner = inner.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let inner = inner.clone(); + inner.call(req) + })) + } + })); + print_listening(args.address.as_str(), args.port, false); + let graceful = server.with_graceful_shutdown(shutdown_signal()); + graceful.await?; Ok(()) } @@ -751,3 +764,9 @@ fn retrive_listening_addrs(address: &str) -> Vec { } vec![address.to_owned()] } + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler") +}