sigoden
3 years ago
committed by
GitHub
8 changed files with 405 additions and 235 deletions
@ -0,0 +1,158 @@
@@ -0,0 +1,158 @@
|
||||
use core::task::{Context, Poll}; |
||||
use futures::ready; |
||||
use hyper::server::accept::Accept; |
||||
use hyper::server::conn::{AddrIncoming, AddrStream}; |
||||
use rustls::{Certificate, PrivateKey}; |
||||
use std::future::Future; |
||||
use std::net::SocketAddr; |
||||
use std::pin::Pin; |
||||
use std::sync::Arc; |
||||
use std::{fs, io}; |
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
||||
use tokio_rustls::rustls::ServerConfig; |
||||
|
||||
enum State { |
||||
Handshaking(tokio_rustls::Accept<AddrStream>), |
||||
Streaming(tokio_rustls::server::TlsStream<AddrStream>), |
||||
} |
||||
|
||||
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
|
||||
// so we have to TlsAcceptor::accept and handshake to have access to it
|
||||
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
|
||||
pub struct TlsStream { |
||||
state: State, |
||||
remote_addr: SocketAddr, |
||||
} |
||||
|
||||
impl TlsStream { |
||||
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { |
||||
let remote_addr = stream.remote_addr(); |
||||
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); |
||||
TlsStream { |
||||
state: State::Handshaking(accept), |
||||
remote_addr, |
||||
} |
||||
} |
||||
pub fn remote_addr(&self) -> SocketAddr { |
||||
self.remote_addr |
||||
} |
||||
} |
||||
|
||||
impl AsyncRead for TlsStream { |
||||
fn poll_read( |
||||
self: Pin<&mut Self>, |
||||
cx: &mut Context, |
||||
buf: &mut ReadBuf, |
||||
) -> Poll<io::Result<()>> { |
||||
let pin = self.get_mut(); |
||||
match pin.state { |
||||
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { |
||||
Ok(mut stream) => { |
||||
let result = Pin::new(&mut stream).poll_read(cx, buf); |
||||
pin.state = State::Streaming(stream); |
||||
result |
||||
} |
||||
Err(err) => Poll::Ready(Err(err)), |
||||
}, |
||||
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl AsyncWrite for TlsStream { |
||||
fn poll_write( |
||||
self: Pin<&mut Self>, |
||||
cx: &mut Context<'_>, |
||||
buf: &[u8], |
||||
) -> Poll<io::Result<usize>> { |
||||
let pin = self.get_mut(); |
||||
match pin.state { |
||||
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { |
||||
Ok(mut stream) => { |
||||
let result = Pin::new(&mut stream).poll_write(cx, buf); |
||||
pin.state = State::Streaming(stream); |
||||
result |
||||
} |
||||
Err(err) => Poll::Ready(Err(err)), |
||||
}, |
||||
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), |
||||
} |
||||
} |
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
||||
match self.state { |
||||
State::Handshaking(_) => Poll::Ready(Ok(())), |
||||
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), |
||||
} |
||||
} |
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
||||
match self.state { |
||||
State::Handshaking(_) => Poll::Ready(Ok(())), |
||||
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub struct TlsAcceptor { |
||||
config: Arc<ServerConfig>, |
||||
incoming: AddrIncoming, |
||||
} |
||||
|
||||
impl TlsAcceptor { |
||||
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor { |
||||
TlsAcceptor { config, incoming } |
||||
} |
||||
} |
||||
|
||||
impl Accept for TlsAcceptor { |
||||
type Conn = TlsStream; |
||||
type Error = io::Error; |
||||
|
||||
fn poll_accept( |
||||
self: Pin<&mut Self>, |
||||
cx: &mut Context<'_>, |
||||
) -> Poll<Option<Result<Self::Conn, Self::Error>>> { |
||||
let pin = self.get_mut(); |
||||
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { |
||||
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), |
||||
Some(Err(e)) => Poll::Ready(Some(Err(e))), |
||||
None => Poll::Ready(None), |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Load public certificate from file.
|
||||
pub fn load_certs(filename: &str) -> Result<Vec<Certificate>, Box<dyn std::error::Error>> { |
||||
// Open certificate file.
|
||||
let certfile = fs::File::open(&filename) |
||||
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?; |
||||
let mut reader = io::BufReader::new(certfile); |
||||
|
||||
// Load and return certificate.
|
||||
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| "Failed to load certificate")?; |
||||
if certs.is_empty() { |
||||
return Err("No supported certificate in file".into()); |
||||
} |
||||
Ok(certs.into_iter().map(Certificate).collect()) |
||||
} |
||||
|
||||
// Load private key from file.
|
||||
pub fn load_private_key(filename: &str) -> Result<PrivateKey, Box<dyn std::error::Error>> { |
||||
// Open keyfile.
|
||||
let keyfile = fs::File::open(&filename) |
||||
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?; |
||||
let mut reader = io::BufReader::new(keyfile); |
||||
|
||||
// Load and return a single private key.
|
||||
let keys = rustls_pemfile::read_all(&mut reader) |
||||
.map_err(|e| format!("There was a problem with reading private key: {:?}", e))? |
||||
.into_iter() |
||||
.find_map(|item| match item { |
||||
rustls_pemfile::Item::RSAKey(key) | rustls_pemfile::Item::PKCS8Key(key) => Some(key), |
||||
_ => None, |
||||
}) |
||||
.ok_or("No supported private key in file")?; |
||||
|
||||
Ok(PrivateKey(keys)) |
||||
} |
Loading…
Reference in new issue