You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
158 lines
5.2 KiB
158 lines
5.2 KiB
use core::task::{Context, Poll}; |
|
use futures::ready; |
|
use hyper::server::accept::Accept; |
|
use hyper::server::conn::{AddrIncoming, AddrStream}; |
|
use rustls::{Certificate, PrivateKey}; |
|
use std::future::Future; |
|
use std::net::SocketAddr; |
|
use std::pin::Pin; |
|
use std::sync::Arc; |
|
use std::{fs, io}; |
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
|
use tokio_rustls::rustls::ServerConfig; |
|
|
|
enum State { |
|
Handshaking(tokio_rustls::Accept<AddrStream>), |
|
Streaming(tokio_rustls::server::TlsStream<AddrStream>), |
|
} |
|
|
|
// tokio_rustls::server::TlsStream doesn't expose constructor methods, |
|
// so we have to TlsAcceptor::accept and handshake to have access to it |
|
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first |
|
pub struct TlsStream { |
|
state: State, |
|
remote_addr: SocketAddr, |
|
} |
|
|
|
impl TlsStream { |
|
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { |
|
let remote_addr = stream.remote_addr(); |
|
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); |
|
TlsStream { |
|
state: State::Handshaking(accept), |
|
remote_addr, |
|
} |
|
} |
|
pub fn remote_addr(&self) -> SocketAddr { |
|
self.remote_addr |
|
} |
|
} |
|
|
|
impl AsyncRead for TlsStream { |
|
fn poll_read( |
|
self: Pin<&mut Self>, |
|
cx: &mut Context, |
|
buf: &mut ReadBuf, |
|
) -> Poll<io::Result<()>> { |
|
let pin = self.get_mut(); |
|
match pin.state { |
|
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { |
|
Ok(mut stream) => { |
|
let result = Pin::new(&mut stream).poll_read(cx, buf); |
|
pin.state = State::Streaming(stream); |
|
result |
|
} |
|
Err(err) => Poll::Ready(Err(err)), |
|
}, |
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), |
|
} |
|
} |
|
} |
|
|
|
impl AsyncWrite for TlsStream { |
|
fn poll_write( |
|
self: Pin<&mut Self>, |
|
cx: &mut Context<'_>, |
|
buf: &[u8], |
|
) -> Poll<io::Result<usize>> { |
|
let pin = self.get_mut(); |
|
match pin.state { |
|
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { |
|
Ok(mut stream) => { |
|
let result = Pin::new(&mut stream).poll_write(cx, buf); |
|
pin.state = State::Streaming(stream); |
|
result |
|
} |
|
Err(err) => Poll::Ready(Err(err)), |
|
}, |
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), |
|
} |
|
} |
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
|
match self.state { |
|
State::Handshaking(_) => Poll::Ready(Ok(())), |
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), |
|
} |
|
} |
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
|
match self.state { |
|
State::Handshaking(_) => Poll::Ready(Ok(())), |
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), |
|
} |
|
} |
|
} |
|
|
|
pub struct TlsAcceptor { |
|
config: Arc<ServerConfig>, |
|
incoming: AddrIncoming, |
|
} |
|
|
|
impl TlsAcceptor { |
|
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor { |
|
TlsAcceptor { config, incoming } |
|
} |
|
} |
|
|
|
impl Accept for TlsAcceptor { |
|
type Conn = TlsStream; |
|
type Error = io::Error; |
|
|
|
fn poll_accept( |
|
self: Pin<&mut Self>, |
|
cx: &mut Context<'_>, |
|
) -> Poll<Option<Result<Self::Conn, Self::Error>>> { |
|
let pin = self.get_mut(); |
|
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { |
|
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), |
|
Some(Err(e)) => Poll::Ready(Some(Err(e))), |
|
None => Poll::Ready(None), |
|
} |
|
} |
|
} |
|
|
|
// Load public certificate from file. |
|
pub fn load_certs(filename: &str) -> Result<Vec<Certificate>, Box<dyn std::error::Error>> { |
|
// Open certificate file. |
|
let certfile = fs::File::open(&filename) |
|
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?; |
|
let mut reader = io::BufReader::new(certfile); |
|
|
|
// Load and return certificate. |
|
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| "Failed to load certificate")?; |
|
if certs.is_empty() { |
|
return Err("No supported certificate in file".into()); |
|
} |
|
Ok(certs.into_iter().map(Certificate).collect()) |
|
} |
|
|
|
// Load private key from file. |
|
pub fn load_private_key(filename: &str) -> Result<PrivateKey, Box<dyn std::error::Error>> { |
|
// Open keyfile. |
|
let keyfile = fs::File::open(&filename) |
|
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?; |
|
let mut reader = io::BufReader::new(keyfile); |
|
|
|
// Load and return a single private key. |
|
let keys = rustls_pemfile::read_all(&mut reader) |
|
.map_err(|e| format!("There was a problem with reading private key: {:?}", e))? |
|
.into_iter() |
|
.find_map(|item| match item { |
|
rustls_pemfile::Item::RSAKey(key) | rustls_pemfile::Item::PKCS8Key(key) => Some(key), |
|
_ => None, |
|
}) |
|
.ok_or("No supported private key in file")?; |
|
|
|
Ok(PrivateKey(keys)) |
|
}
|
|
|