Rust-TLS

一般 Web 开发中,类似 MySQL,HTTP 等基于 TCP 的协议,会有 TLS 加密的需要。一个网络应用,即便是在内网使用,如果没有安全协议来保护,都是很危险的。

TLS 是传输层安全协议,全称 Transport Layer Security。TLS 的前身是 SSL。TLS 详细内容可以参考 TLS RFC 文档

TLS 握手过程:

  1. 客户端向服务器发送一个 ClientHello 消息,消息中包含客户端支持的协议版本、加密套件、压缩算法等。
  2. 服务器收到 ClientHello 消息后,会返回一个 ServerHello 消息,消息中包含服务器选择的协议版本、加密套件、压缩算法等。
  3. 服务器发送 Certificate 消息,消息中包含服务器的数字证书。
  4. 服务器发送 ServerKeyExchange 消息,消息中包含服务器生成的密钥交换数据。
  5. 服务器发送 ServerHelloDone 消息,表示服务器已经发送完毕。

参考 Wiki TLS 以及 Full TLS 1.3 Handshake

在数据帧中,TCP payload 中会包含 TLS 数据帧。。

Rust TLS

Rust 对 openssl 有很不错的封装,也有不依赖 openssl 用 Rust 撰写的 rustls。tokio 进一步提供了符合 tokio 生态圈的 tls 支持,有 openssl 版本和 rustls 版本可选。

Rust tokio-tls 示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use std::io::Cursor;
use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::{internal::pemfile, Certificate, ClientConfig, ServerConfig};
use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore};
use tokio_rustls::webpki::DNSNameRef;
use tokio_rustls::TlsConnector;
use tokio_rustls::{
client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor,
};

use crate::KvError;

/// Server ALPN (Application-Layer Protocol Negotiation)
const ALPN_KV: &str = "kv";

/// 存放 TLS ServerConfig 并提供方法 accept 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsServerAcceptor {
inner: Arc<ServerConfig>,
}

/// 存放 TLS Client 并提供方法 connect 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsClientConnector {
pub config: Arc<ClientConfig>,
pub domain: Arc<String>,
}

impl TlsClientConnector {
/// 加载 client cert / CA cert,生成 ClientConfig
pub fn new(
domain: impl Into<String>,
identity: Option<(&str, &str)>,
server_ca: Option<&str>,
) -> Result<Self, KvError> {
let mut config = ClientConfig::new();

// 如果有客户端证书,加载之
if let Some((cert, key)) = identity {
let certs = load_certs(cert)?;
let key = load_key(key)?;
config.set_single_client_cert(certs, key)?;
}

// 加载本地信任的根证书链
config.root_store = match rustls_native_certs::load_native_certs() {
Ok(store) | Err((Some(store), _)) => store,
Err((None, error)) => return Err(error.into()),
};

// 如果有签署服务器的 CA 证书,则加载它
if let Some(cert) = server_ca {
let mut buf = Cursor::new(cert);
config.root_store.add_pem_file(&mut buf).unwrap();
}

Ok(Self {
config: Arc::new(config),
domain: Arc::new(domain.into()),
})
}

/// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())
.map_err(|_| KvError::Internal("Invalid DNS name".into()))?;

let stream = TlsConnector::from(self.config.clone())
.connect(dns, stream)
.await?;

Ok(stream)
}
}

impl TlsServerAcceptor {
/// 加载 server cert / CA cert,生成 ServerConfig
pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result<Self, KvError> {
let certs = load_certs(cert)?;
let key = load_key(key)?;

let mut config = match client_ca {
None => ServerConfig::new(NoClientAuth::new()),
Some(cert) => {
// 如果客户端证书是某个 CA 证书签发的,则把这个 CA 证书加载到信任链中
let mut cert = Cursor::new(cert);
let mut client_root_cert_store = RootCertStore::empty();
client_root_cert_store
.add_pem_file(&mut cert)
.map_err(|_| KvError::CertifcateParseError("CA", "cert"))?;

let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store);
ServerConfig::new(client_auth)
}
};

// 加载服务器证书
config
.set_single_cert(certs, key)
.map_err(|_| KvError::CertifcateParseError("server", "cert"))?;
config.set_protocols(&[Vec::from(&ALPN_KV[..])]);

Ok(Self {
inner: Arc::new(config),
})
}

/// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
let acceptor = TlsAcceptor::from(self.inner.clone());
Ok(acceptor.accept(stream).await?)
}
}

fn load_certs(cert: &str) -> Result<Vec<Certificate>, KvError> {
let mut cert = Cursor::new(cert);
pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert"))
}

fn load_key(key: &str) -> Result<PrivateKey, KvError> {
let mut cursor = Cursor::new(key);

// 先尝试用 PKCS8 加载私钥
if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
if !keys.is_empty() {
return Ok(keys.remove(0));
}
}

// 再尝试加载 RSA key
cursor.set_position(0);
if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
if !keys.is_empty() {
return Ok(keys.remove(0));
}
}

// 不支持的私钥类型
Err(KvError::CertifcateParseError("private", "key"))
}

根据提供的证书,来生成 tokio-tls 需要的 ServerConfig / ClientConfig。

因为 TLS 需要验证证书的 CA,所以还需要加载 CA 证书。虽然平时在做 Web 开发时,经常只使用服务器证书,但其实 TLS 支持双向验证,服务器也可以验证客户端的证书是否是它认识的 CA 签发的。

测试三种情况:

  1. 标准的 TLS 连接
  2. 带有客户端证书的 TLS 连接
  3. 客户端提供了错的域名
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#[cfg(test)]
mod tests {

use std::net::SocketAddr;

use super::*;
use anyhow::Result;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};

// 在编译期把文件加载成字符串放在 RODATA 段
const CA_CERT: &str = include_str!("ca.cert");
const CLIENT_CERT: &str = include_str!("client.cert");
const CLIENT_KEY: &str = include_str!("client.key");
const SERVER_CERT: &str = include_str!("server.cert");
const SERVER_KEY: &str = include_str!("server.key");

#[tokio::test]
async fn tls_should_work() -> Result<()> {
let ca = Some(CA_CERT);

let addr = start_server(None).await?;

let connector = TlsClientConnector::new("right.com", None, ca)?;
let stream = TcpStream::connect(addr).await?;
let mut stream = connector.connect(stream).await?;
stream.write_all(b"hello world!").await?;
let mut buf = [0; 12];
stream.read_exact(&mut buf).await?;
assert_eq!(&buf, b"hello world!");

Ok(())
}

#[tokio::test]
async fn tls_with_client_cert_should_work() -> Result<()> {
let client_identity = Some((CLIENT_CERT, CLIENT_KEY));
let ca = Some(CA_CERT);

let addr = start_server(ca.clone()).await?;

let connector = TlsClientConnector::new("right.com", client_identity, ca)?;
let stream = TcpStream::connect(addr).await?;
let mut stream = connector.connect(stream).await?;
stream.write_all(b"hello world!").await?;
let mut buf = [0; 12];
stream.read_exact(&mut buf).await?;
assert_eq!(&buf, b"hello world!");

Ok(())
}

#[tokio::test]
async fn tls_with_bad_domain_should_not_work() -> Result<()> {
let addr = start_server(None).await?;

let connector = TlsClientConnector::new("wrong.com", None, Some(CA_CERT))?;
let stream = TcpStream::connect(addr).await?;
let result = connector.connect(stream).await;

assert!(result.is_err());

Ok(())
}

async fn start_server(ca: Option<&str>) -> Result<SocketAddr> {
let acceptor = TlsServerAcceptor::new(SERVER_CERT, SERVER_KEY, ca)?;

let echo = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = echo.local_addr().unwrap();

tokio::spawn(async move {
let (stream, _) = echo.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let mut buf = [0; 12];
stream.read_exact(&mut buf).await.unwrap();
stream.write_all(&buf).await.unwrap();
});

Ok(addr)
}
}

其它选择

Noise Protocol Framework,可使用的工具 snow


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!