1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
| use std::io::Cursor; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::rustls::{internal::pemfile, Certificate, ClientConfig, ServerConfig}; use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore}; use tokio_rustls::webpki::DNSNameRef; use tokio_rustls::TlsConnector; use tokio_rustls::{ client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor, };
use crate::KvError;
const ALPN_KV: &str = "kv";
#[derive(Clone)] pub struct TlsServerAcceptor { inner: Arc<ServerConfig>, }
#[derive(Clone)] pub struct TlsClientConnector { pub config: Arc<ClientConfig>, pub domain: Arc<String>, }
impl TlsClientConnector { pub fn new( domain: impl Into<String>, identity: Option<(&str, &str)>, server_ca: Option<&str>, ) -> Result<Self, KvError> { let mut config = ClientConfig::new();
if let Some((cert, key)) = identity { let certs = load_certs(cert)?; let key = load_key(key)?; config.set_single_client_cert(certs, key)?; }
config.root_store = match rustls_native_certs::load_native_certs() { Ok(store) | Err((Some(store), _)) => store, Err((None, error)) => return Err(error.into()), };
if let Some(cert) = server_ca { let mut buf = Cursor::new(cert); config.root_store.add_pem_file(&mut buf).unwrap(); }
Ok(Self { config: Arc::new(config), domain: Arc::new(domain.into()), }) }
pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError> where S: AsyncRead + AsyncWrite + Unpin + Send, { let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str()) .map_err(|_| KvError::Internal("Invalid DNS name".into()))?;
let stream = TlsConnector::from(self.config.clone()) .connect(dns, stream) .await?;
Ok(stream) } }
impl TlsServerAcceptor { pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result<Self, KvError> { let certs = load_certs(cert)?; let key = load_key(key)?;
let mut config = match client_ca { None => ServerConfig::new(NoClientAuth::new()), Some(cert) => { let mut cert = Cursor::new(cert); let mut client_root_cert_store = RootCertStore::empty(); client_root_cert_store .add_pem_file(&mut cert) .map_err(|_| KvError::CertifcateParseError("CA", "cert"))?;
let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store); ServerConfig::new(client_auth) } };
config .set_single_cert(certs, key) .map_err(|_| KvError::CertifcateParseError("server", "cert"))?; config.set_protocols(&[Vec::from(&ALPN_KV[..])]);
Ok(Self { inner: Arc::new(config), }) }
pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError> where S: AsyncRead + AsyncWrite + Unpin + Send, { let acceptor = TlsAcceptor::from(self.inner.clone()); Ok(acceptor.accept(stream).await?) } }
fn load_certs(cert: &str) -> Result<Vec<Certificate>, KvError> { let mut cert = Cursor::new(cert); pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert")) }
fn load_key(key: &str) -> Result<PrivateKey, KvError> { let mut cursor = Cursor::new(key);
if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) { if !keys.is_empty() { return Ok(keys.remove(0)); } }
cursor.set_position(0); if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) { if !keys.is_empty() { return Ok(keys.remove(0)); } }
Err(KvError::CertifcateParseError("private", "key")) }
|