You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
5.7 KiB
188 lines
5.7 KiB
mod args; |
|
mod auth; |
|
mod server; |
|
mod streamer; |
|
mod tls; |
|
|
|
#[macro_use] |
|
extern crate log; |
|
|
|
use crate::args::{matches, Args}; |
|
use crate::server::{Request, Server}; |
|
use crate::tls::{TlsAcceptor, TlsStream}; |
|
|
|
use std::io::Write; |
|
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener}; |
|
use std::{env, sync::Arc}; |
|
|
|
use futures::future::join_all; |
|
use tokio::net::TcpListener; |
|
use tokio::task::JoinHandle; |
|
|
|
use hyper::server::conn::{AddrIncoming, AddrStream}; |
|
use hyper::service::{make_service_fn, service_fn}; |
|
use rustls::ServerConfig; |
|
|
|
pub type BoxResult<T> = Result<T, Box<dyn std::error::Error>>; |
|
|
|
#[tokio::main] |
|
async fn main() { |
|
run().await.unwrap_or_else(handle_err) |
|
} |
|
|
|
async fn run() -> BoxResult<()> { |
|
if env::var("RUST_LOG").is_err() { |
|
env::set_var("RUST_LOG", "info") |
|
} |
|
env_logger::builder() |
|
.format(|buf, record| { |
|
let timestamp = buf.timestamp_millis(); |
|
writeln!(buf, "[{} {}] {}", timestamp, record.level(), record.args()) |
|
}) |
|
.init(); |
|
|
|
let args = Args::parse(matches())?; |
|
let args = Arc::new(args); |
|
let handles = serve(args.clone())?; |
|
print_listening(args)?; |
|
|
|
tokio::select! { |
|
ret = join_all(handles) => { |
|
for r in ret { |
|
if let Err(e) = r { |
|
error!("{}", e); |
|
} |
|
} |
|
Ok(()) |
|
}, |
|
_ = shutdown_signal() => { |
|
Ok(()) |
|
}, |
|
} |
|
} |
|
|
|
fn serve(args: Arc<Args>) -> BoxResult<Vec<JoinHandle<Result<(), hyper::Error>>>> { |
|
let inner = Arc::new(Server::new(args.clone())); |
|
let mut handles = vec![]; |
|
let port = args.port; |
|
for ip in args.addrs.iter() { |
|
let inner = inner.clone(); |
|
let incoming = create_addr_incoming(SocketAddr::new(*ip, port)) |
|
.map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?; |
|
let serv_func = move |remote_addr: SocketAddr| { |
|
let inner = inner.clone(); |
|
async move { |
|
Ok::<_, hyper::Error>(service_fn(move |req: Request| { |
|
let inner = inner.clone(); |
|
inner.call(req, remote_addr) |
|
})) |
|
} |
|
}; |
|
match args.tls.clone() { |
|
Some((certs, key)) => { |
|
let config = ServerConfig::builder() |
|
.with_safe_defaults() |
|
.with_no_client_auth() |
|
.with_single_cert(certs, key)?; |
|
let config = Arc::new(config); |
|
let accepter = TlsAcceptor::new(config.clone(), incoming); |
|
let new_service = make_service_fn(move |socket: &TlsStream| { |
|
let remote_addr = socket.remote_addr(); |
|
serv_func(remote_addr) |
|
}); |
|
let server = tokio::spawn(hyper::Server::builder(accepter).serve(new_service)); |
|
handles.push(server); |
|
} |
|
None => { |
|
let new_service = make_service_fn(move |socket: &AddrStream| { |
|
let remote_addr = socket.remote_addr(); |
|
serv_func(remote_addr) |
|
}); |
|
let server = tokio::spawn(hyper::Server::builder(incoming).serve(new_service)); |
|
handles.push(server); |
|
} |
|
}; |
|
} |
|
Ok(handles) |
|
} |
|
|
|
fn create_addr_incoming(addr: SocketAddr) -> BoxResult<AddrIncoming> { |
|
use socket2::{Domain, Protocol, Socket, Type}; |
|
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?; |
|
if addr.is_ipv6() { |
|
socket.set_only_v6(true)?; |
|
} |
|
socket.set_reuse_address(true)?; |
|
socket.bind(&addr.into())?; |
|
socket.listen(1024 /* Default backlog */)?; |
|
let std_listener = StdTcpListener::from(socket); |
|
std_listener.set_nonblocking(true)?; |
|
let incoming = AddrIncoming::from_listener(TcpListener::from_std(std_listener)?)?; |
|
Ok(incoming) |
|
} |
|
|
|
fn print_listening(args: Arc<Args>) -> BoxResult<()> { |
|
let mut addrs = vec![]; |
|
let (mut ipv4, mut ipv6) = (false, false); |
|
for ip in args.addrs.iter() { |
|
if ip.is_unspecified() { |
|
if ip.is_ipv6() { |
|
ipv6 = true; |
|
} else { |
|
ipv4 = true; |
|
} |
|
} else { |
|
addrs.push(*ip); |
|
} |
|
} |
|
if ipv4 || ipv6 { |
|
let ifaces = get_if_addrs::get_if_addrs() |
|
.map_err(|e| format!("Failed to get local interface addresses: {}", e))?; |
|
for iface in ifaces.into_iter() { |
|
let local_ip = iface.ip(); |
|
if ipv4 && local_ip.is_ipv4() { |
|
addrs.push(local_ip) |
|
} |
|
if ipv6 && local_ip.is_ipv6() { |
|
addrs.push(local_ip) |
|
} |
|
} |
|
} |
|
addrs.sort_unstable(); |
|
let urls = addrs |
|
.into_iter() |
|
.map(|addr| match addr { |
|
IpAddr::V4(_) => format!("{}:{}", addr, args.port), |
|
IpAddr::V6(_) => format!("[{}]:{}", addr, args.port), |
|
}) |
|
.map(|addr| match &args.tls { |
|
Some(_) => format!("https://{}", addr), |
|
None => format!("http://{}", addr), |
|
}) |
|
.map(|url| format!("{}{}", url, args.uri_prefix)) |
|
.collect::<Vec<_>>(); |
|
|
|
if urls.len() == 1 { |
|
println!("Listening on {}", urls[0]); |
|
} else { |
|
let info = urls |
|
.iter() |
|
.map(|v| format!(" {}", v)) |
|
.collect::<Vec<String>>() |
|
.join("\n"); |
|
println!("Listening on:\n{}\n", info); |
|
} |
|
|
|
Ok(()) |
|
} |
|
|
|
fn handle_err<T>(err: Box<dyn std::error::Error>) -> T { |
|
eprintln!("error: {}", err); |
|
std::process::exit(1); |
|
} |
|
|
|
async fn shutdown_signal() { |
|
tokio::signal::ctrl_c() |
|
.await |
|
.expect("Failed to install CTRL+C signal handler") |
|
}
|
|
|