async rustls rewrite

This commit is contained in:
MeexReay 2025-04-07 00:15:10 +03:00
parent 4a0c00d421
commit 37c6122f87
9 changed files with 870 additions and 1274 deletions

1247
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,11 +4,13 @@ version = "0.1.2"
edition = "2021"
[dependencies]
openssl = "0.10.72"
tokio = { version = "1.44.2", features = ["full"] }
tokio-io-timeout = "1.2.0"
tokio-rustls = "0.26.2"
rustls = "0.23.25"
wildmatch = "2.4.0"
serde_yml = "0.0.12"
serde_json = "1.0.140"
log = "0.4.27"
colog = "1.3.0"
threadpool = "1.8.1"
wildcard_ex = "0.1.2"
websocket = "0.27.1"
serde_json = "1.0.140"
ignore-result = "0.2.0"

8
shell.nix Executable file
View File

@ -0,0 +1,8 @@
with import <nixpkgs> { };
mkShell {
nativeBuildInputs = [
openssl
pkg-config
];
}

View File

@ -1,3 +1,3 @@
pub mod config;
pub mod server;
pub mod ssl_cert;
pub mod tls;

View File

@ -1,15 +1,17 @@
use std::{fs, net::TcpStream, time::Duration};
use std::{fs, time::Duration};
use tokio::net::TcpStream;
use serde_yml::{Number, Value};
use wildcard_ex::is_match_simple;
use wildmatch::WildMatch;
use super::ssl_cert::SslCert;
use super::tls::TlsCertificate;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct SiteConfig {
pub domain: String,
pub domain: WildMatch,
pub host: String,
pub ssl: Option<SslCert>,
pub ssl: Option<TlsCertificate>,
pub enable_keep_alive: bool,
pub support_keep_alive: bool,
pub ip_forwarding: IpForwarding,
@ -17,12 +19,12 @@ pub struct SiteConfig {
}
impl SiteConfig {
pub fn connect(&self) -> Option<TcpStream> {
TcpStream::connect(self.host.clone()).ok()
pub async fn connect(&self) -> Option<TcpStream> {
TcpStream::connect(self.host.clone()).await.ok()
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum IpForwarding {
Simple,
Header(String),
@ -46,7 +48,7 @@ impl IpForwarding {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Config {
pub sites: Vec<SiteConfig>,
pub http_host: String,
@ -80,12 +82,12 @@ impl Config {
let sites_yaml = doc["sites"].as_sequence()?;
for s in sites_yaml {
let mut cert: Option<SslCert> = None;
let mut cert: Option<TlsCertificate> = None;
let s = s.as_mapping()?;
if s.contains_key("ssl_cert") {
cert = Some(
SslCert::new(
TlsCertificate::new(
s.get("ssl_cert")?.as_str()?,
s.get("ssl_key")?.as_str()?,
)?,
@ -93,7 +95,7 @@ impl Config {
}
let site = SiteConfig {
domain: s.get("domain")?.as_str()?.to_string(),
domain: WildMatch::new(&s.get("domain")?.as_str()?.to_string()),
host: s.get("host")?.as_str()?.to_string(),
ssl: cert,
enable_keep_alive: s.get("enable_keep_alive")
@ -126,7 +128,7 @@ impl Config {
pub fn get_site(&self, domain: &str) -> Option<&SiteConfig> {
for i in &self.sites {
if is_match_simple(&i.domain, domain) {
if i.domain.matches(domain) {
return Some(i);
}
}

View File

@ -1,30 +1,23 @@
use std::{
io::{
Read,
Write
},
net::{
IpAddr,
Ipv4Addr,
Ipv6Addr,
SocketAddr,
SocketAddrV4,
SocketAddrV6,
TcpListener,
TcpStream
},
error::Error,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
str::FromStr,
sync::{
Arc,
RwLock
},
thread,
time::Duration
sync::Arc
};
use tokio::sync::RwLock;
use ignore_result::Ignore;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream}
};
use log::info;
use openssl::ssl::SslStream;
use threadpool::ThreadPool;
use tokio_io_timeout::TimeoutStream;
use tokio_rustls::TlsAcceptor;
use crate::tls::create_server_config;
use super::config::{
Config,
@ -40,15 +33,9 @@ pub trait Closeable {
fn close(&mut self);
}
impl Closeable for SslStream<TcpStream> {
fn close(&mut self) {
let _ = self.shutdown();
}
}
impl Closeable for TcpStream {
fn close(&mut self) {
let _ = self.shutdown(std::net::Shutdown::Both);
let _ = self.shutdown();
}
}
@ -64,157 +51,132 @@ impl FlowgateServer {
FlowgateServer { config }
}
pub fn start(&self) {
let pool = ThreadPool::new(self.config.read().unwrap().threadpool_size);
let pool = Arc::new(pool);
pub async fn start(&self) {
tokio::spawn({
let config = self.config.clone();
thread::spawn({
let config = Arc::clone(&self.config);
let pool = Arc::clone(&pool);
move || {
Self::run_http(config, pool)
async move {
Self::run_http(config).await.ignore();
}
});
thread::spawn({
let config = Arc::clone(&self.config);
let pool = Arc::clone(&pool);
tokio::spawn({
let config = self.config.clone();
move || {
Self::run_https(config, pool)
async move {
Self::run_https(config).await.ignore();
}
});
}
pub fn run_http(
config: Arc<RwLock<Config>>,
pool: Arc<ThreadPool>
) -> Option<()> {
let listener = TcpListener::bind(&config.read().ok()?.http_host).ok()?;
pub async fn run_http(
config: Arc<RwLock<Config>>
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(&config.read().await.http_host).await?;
info!("HTTP server runned on {}", &config.read().ok()?.http_host);
info!("HTTP server runned on {}", &config.read().await.http_host);
loop {
let Ok((stream, addr)) = listener.accept().await else { break };
for stream in listener.incoming() {
pool.execute({
let config = config.clone();
move || {
let Ok(mut stream) = stream else { return };
tokio::spawn(async move {
let mut stream = TimeoutStream::new(stream);
let Ok(_) = stream.set_write_timeout(Some(Duration::from_secs(10))) else { return };
let Ok(_) = stream.set_read_timeout(Some(Duration::from_secs(10))) else { return };
stream.set_write_timeout(Some(config.read().await.connection_timeout));
stream.set_read_timeout(Some(config.read().await.connection_timeout));
let Ok(addr) = stream.peer_addr() else { return };
let mut stream = Box::pin(stream);
Self::accept_stream(
config,
&mut stream,
addr,
false
);
}
).await;
});
}
Some(())
Ok(())
}
pub fn run_https(
config: Arc<RwLock<Config>>,
pool: Arc<ThreadPool>
) -> Option<()> {
use openssl::ssl::{NameType, SniError, SslAcceptor, SslAlert, SslMethod, SslRef};
pub async fn run_https(
config: Arc<RwLock<Config>>
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(&config.read().await.https_host).await?;
let acceptor = TlsAcceptor::from(Arc::new(create_server_config(config.clone()).await));
let listener = TcpListener::bind(&config.read().ok()?.https_host).ok()?;
info!("HTTPS server runned on {}", &config.read().await.http_host);
let mut cert = SslAcceptor::mozilla_intermediate(SslMethod::tls()).ok()?;
loop {
let Ok((stream, addr)) = listener.accept().await else { break };
cert.set_servername_callback(Box::new({
let config = config.clone();
let acceptor = acceptor.clone();
move |ssl: &mut SslRef, _: &mut SslAlert| -> Result<(), SniError> {
let servname = ssl.servername(NameType::HOST_NAME).ok_or(SniError::NOACK)?;
let c = config.read().unwrap();
let cert = c.get_site(servname).ok_or(SniError::NOACK)?;
ssl.set_ssl_context(&cert.ssl.as_ref().ok_or(SniError::NOACK)?.get_context()).ok().ok_or(SniError::NOACK)
}
}
));
tokio::spawn(async move {
let mut stream = TimeoutStream::new(stream);
let cert = cert.build();
stream.set_write_timeout(Some(config.read().await.connection_timeout));
stream.set_read_timeout(Some(config.read().await.connection_timeout));
info!("HTTPS server runned on {}", &config.read().ok()?.https_host);
for stream in listener.incoming() {
pool.execute({
let config = config.clone();
let cert = cert.clone();
move || {
let Ok(stream) = stream else { return };
let Ok(_) = stream.set_write_timeout(Some(config.read().unwrap().connection_timeout)) else { return };
let Ok(_) = stream.set_read_timeout(Some(config.read().unwrap().connection_timeout)) else { return };
let Ok(addr) = stream.peer_addr() else { return };
let Ok(mut stream) = cert.accept(stream) else { return };
let Ok(mut stream) = acceptor.accept(Box::pin(stream)).await else { return };
Self::accept_stream(
config,
&mut stream,
addr,
true
);
}
false
).await;
});
}
Some(())
Ok(())
}
pub fn accept_stream(
pub async fn accept_stream(
config: Arc<RwLock<Config>>,
stream: &mut (impl Read + Write + Closeable),
stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin),
addr: SocketAddr,
https: bool
) -> Option<()> {
let mut conn = Self::read_request(config.clone(), stream, addr, https, None)?;
let mut conn = read_request(config.clone(), stream, addr, https, None).await?;
if conn.keep_alive && conn.config.enable_keep_alive {
loop {
if !conn.config.support_keep_alive {
conn.stream.close();
conn.stream = conn.config.connect()?;
conn.stream = conn.config.connect().await?;
}
conn = Self::read_request(config.clone(), stream, addr, https, Some(conn))?;
conn = read_request(config.clone(), stream, addr, https, Some(conn)).await?;
}
}
conn.stream.close();
stream.close();
stream.shutdown().await.ok()?;
Some(())
}
}
fn read_request(
async fn read_request(
config: Arc<RwLock<Config>>,
stream: &mut (impl Read + Write + Closeable),
stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin),
addr: SocketAddr,
https: bool,
conn: Option<Connection>
) -> Option<Connection> {
) -> Option<Connection> {
let mut addr = addr;
match &config.read().ok()?.incoming_ip_forwarding {
match &config.read().await.incoming_ip_forwarding {
IpForwarding::Simple => {
let mut header = Vec::new();
{
let mut buf = [0; 1];
while let Ok(1) = stream.read(&mut buf) {
while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0];
if byte == b'\n' { break }
header.push(byte);
@ -225,20 +187,20 @@ impl FlowgateServer {
},
IpForwarding::Modern => {
let mut ipver = [0; 1];
stream.read(&mut ipver).ok()?;
stream.read(&mut ipver).await.ok()?;
addr = match ipver[0] {
0x01 => {
let mut octets = [0; 4];
stream.read(&mut octets).ok()?;
stream.read(&mut octets).await.ok()?;
let mut port = [0; 2];
stream.read(&mut port).ok()?;
stream.read(&mut port).await.ok()?;
let port = u16::from_be_bytes(port);
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
}, 0x02 => {
let mut octets = [0; 16];
stream.read(&mut octets).ok()?;
stream.read(&mut octets).await.ok()?;
let mut port = [0; 2];
stream.read(&mut port).ok()?;
stream.read(&mut port).await.ok()?;
let port = u16::from_be_bytes(port);
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
}, _ => { return None },
@ -253,7 +215,7 @@ impl FlowgateServer {
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = stream.read(&mut buf) {
while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0];
head.push(byte);
@ -290,7 +252,7 @@ impl FlowgateServer {
.map(|o| o.contains(&"chunked".to_string()))
.unwrap_or(false);
if let IpForwarding::Header(header) = &config.read().ok()?.incoming_ip_forwarding {
if let IpForwarding::Header(header) = &config.read().await.incoming_ip_forwarding {
if let Some(ip) = headers.iter().find(|o| o.0 == header).map(|o| o.1) {
addr = SocketAddr::from_str(ip).ok()?;
}
@ -308,10 +270,10 @@ impl FlowgateServer {
}
}
let site = config.read().ok()?.get_site(&host)?.clone();
let site = config.read().await.get_site(&host)?.clone();
Connection {
stream: site.connect()?,
stream: site.connect().await?,
config: site,
keep_alive,
host
@ -398,28 +360,27 @@ impl FlowgateServer {
}
}
conn.stream.write_all(&reqbuf).ok()?;
conn.stream.write_all(&reqbuf).await.ok()?;
if content_length > 0 {
let mut read = 0usize;
let mut buf = vec![0; 4096];
while let Ok(size) = stream.read(&mut buf) {
while let Ok(size) = stream.read(&mut buf).await {
if size == 0 { break }
read += size;
buf.truncate(size);
conn.stream.write_all(&buf).ok()?;
conn.stream.write_all(&buf).await.ok()?;
buf = vec![0; 4096];
if read >= content_length { break }
}
} else if is_chunked {
loop {
let mut length = Vec::new();
{
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = stream.read(&mut buf) {
while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0];
length.push(byte);
@ -429,16 +390,16 @@ impl FlowgateServer {
_ => 0,
};
}
conn.stream.write_all(&length).ok()?;
conn.stream.write_all(&length).await.ok()?;
length.truncate(length.len() - 2);
}
let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2];
stream.read_exact(&mut data).ok()?;
stream.read_exact(&mut data).await.ok()?;
conn.stream.write_all(&data).ok()?;
conn.stream.write_all(&data).await.ok()?;
if length == 0 {
break;
}
@ -452,11 +413,11 @@ impl FlowgateServer {
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf) {
while let Ok(1) = conn.stream.read(&mut buf).await {
let byte = buf[0];
head.push(byte);
stream.write_all(&buf).ok()?;
stream.write_all(&buf).await.ok()?;
counter = match (counter, byte) {
(0, b'\r') => 1,
@ -497,11 +458,11 @@ impl FlowgateServer {
if content_length > 0 {
let mut read = 0usize;
let mut buf = vec![0; 4096];
while let Ok(size) = conn.stream.read(&mut buf) {
while let Ok(size) = conn.stream.read(&mut buf).await {
if size == 0 { break }
read += size;
buf.truncate(size);
stream.write_all(&buf).ok()?;
stream.write_all(&buf).await.ok()?;
buf = vec![0; 4096];
if read == content_length { break }
}
@ -512,7 +473,7 @@ impl FlowgateServer {
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf) {
while let Ok(1) = conn.stream.read(&mut buf).await {
let byte = buf[0];
length.push(byte);
@ -522,16 +483,16 @@ impl FlowgateServer {
_ => 0,
};
}
stream.write_all(&length).ok()?;
stream.write_all(&length).await.ok()?;
length.truncate(length.len() - 2);
}
let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2];
conn.stream.read_exact(&mut data).ok()?;
conn.stream.read_exact(&mut data).await.ok()?;
stream.write_all(&data).ok()?;
stream.write_all(&data).await.ok()?;
if length == 0 {
break;
}
@ -539,10 +500,10 @@ impl FlowgateServer {
}
} else {
let mut buf = vec![0;1024];
while let Ok(n) = conn.stream.read(&mut buf) {
while let Ok(n) = conn.stream.read(&mut buf).await {
if n == 0 { break }
buf.truncate(n);
stream.write_all(&buf).ok()?;
stream.write_all(&buf).await.ok()?;
buf = vec![0;1024];
}
}
@ -550,5 +511,4 @@ impl FlowgateServer {
info!("{addr} > {} {}://{}{}", status_seq[0], if https { "https" } else { "http" }, conn.host, status_seq[1]);
Some(conn)
}
}

View File

@ -1,28 +0,0 @@
use openssl::ssl::SslContext;
#[derive(Clone)]
pub struct SslCert {
context: SslContext,
}
fn generate_ctx(cert_file: &str, key_file: &str) -> Option<SslContext> {
use openssl::ssl::{SslFiletype, SslMethod};
let mut ctx = SslContext::builder(SslMethod::tls()).ok().unwrap();
ctx.set_private_key_file(&key_file, SslFiletype::PEM).ok().unwrap();
ctx.set_certificate_file(&cert_file, SslFiletype::PEM).ok().unwrap();
ctx.check_private_key().ok()?;
Some(ctx.build())
}
impl SslCert {
pub fn new(cert_file: &str, key_file: &str) -> Option<SslCert> {
Some(SslCert {
context: generate_ctx(cert_file, key_file)?
})
}
pub fn get_context(&self) -> SslContext {
self.context.clone()
}
}

66
src/flowgate/tls.rs Executable file
View File

@ -0,0 +1,66 @@
use std::sync::Arc;
use rustls::{
crypto::aws_lc_rs::sign::any_supported_type,
pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer},
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
ServerConfig
};
use tokio::{runtime::Handle, sync::RwLock};
use super::config::Config;
#[derive(Clone, Debug)]
pub struct TlsCertificate {
key: CertifiedKey
}
impl TlsCertificate {
pub fn new(cert_file: &str, key_file: &str) -> Option<TlsCertificate> {
let certs = CertificateDer::pem_file_iter(cert_file)
.unwrap()
.map(|cert| cert.unwrap())
.collect();
let private_key = PrivateKeyDer::from_pem_file(key_file).unwrap();
let key = CertifiedKey::new(certs, any_supported_type(&private_key).ok()?);
Some(Self { key })
}
pub fn get_key(&self) -> CertifiedKey {
self.key.clone()
}
}
#[derive(Debug)]
struct ResolvesServerCertWildcard {
config: Arc<RwLock<Config>>,
handle: Handle
}
impl ResolvesServerCertWildcard {
pub async fn new(config: Arc<RwLock<Config>>) -> Self {
Self { config, handle: Handle::current() }
}
}
impl ResolvesServerCert for ResolvesServerCertWildcard {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(cert) = client_hello.server_name()
.and_then(|name| self.handle.block_on(self.config.read()).get_site(name).cloned())
.and_then(|site| site.ssl) {
Some(Arc::new(cert.get_key()))
} else {
None
}
}
}
pub async fn create_server_config(config: Arc<RwLock<Config>>) -> ServerConfig {
ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(ResolvesServerCertWildcard::new(config).await))
}

View File

@ -1,16 +1,19 @@
use std::{fs, path::Path, sync::{Arc, RwLock}};
use std::{fs, path::Path, sync::Arc};
use flowgate::{config::Config, server::FlowgateServer};
use ignore_result::Ignore;
use tokio::sync::RwLock;
fn main() {
#[tokio::main]
async fn main() {
colog::init();
if !Path::new("conf.yml").exists() {
let _ = fs::write("conf.yml", include_bytes!("../conf.yml"));
fs::write("conf.yml", include_bytes!("../conf.yml")).ignore();
}
let config = Arc::new(RwLock::new(Config::parse("conf.yml").unwrap()));
let server = FlowgateServer::new(config.clone());
server.start();
server.start().await;
}