fix server deadlock and realize client

This commit is contained in:
MeexReay 2025-06-20 18:37:58 +03:00
parent d273b32eee
commit 364f63a75a
5 changed files with 257 additions and 74 deletions

View File

@ -166,6 +166,12 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]] [[package]]
name = "home" name = "home"
version = "0.5.11" version = "0.5.11"
@ -256,6 +262,16 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.21.3" version = "1.21.3"
@ -417,6 +433,15 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.18" version = "1.0.18"
@ -428,6 +453,7 @@ name = "unrknize-client"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"rustls", "rustls",
"threadpool",
] ]
[[package]] [[package]]

View File

@ -5,3 +5,4 @@ edition = "2024"
[dependencies] [dependencies]
rustls = "0.23.28" rustls = "0.23.28"
threadpool = "1.8.1"

View File

@ -1,6 +1,19 @@
use std::{env, error::Error, io::{Read, Write}, net::TcpStream, sync::Arc}; use std::{
env,
error::Error,
io::{ErrorKind, Read, Write},
net::{Shutdown, TcpListener, TcpStream},
sync::{Arc, Mutex},
thread,
};
use rustls::{client::{danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}}, pki_types::{CertificateDer, ServerName, UnixTime}, ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme}; use rustls::{
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme,
StreamOwned,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, ServerName, UnixTime},
};
use threadpool::ThreadPool;
#[derive(Debug)] #[derive(Debug)]
pub struct NoCertVerify; pub struct NoCertVerify;
@ -54,20 +67,102 @@ impl ServerCertVerifier for NoCertVerify {
} }
} }
fn accept_client(
stream: Arc<TcpStream>,
config: Arc<ClientConfig>,
sni_domain: String,
remote_stream: Arc<TcpStream>,
) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?;
let mut tls = StreamOwned::new(
ClientConnection::new(config, sni_domain.try_into()?)?,
remote_stream.try_clone()?,
);
while tls.conn.is_handshaking() {
match tls.conn.complete_io(&mut tls.sock) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Ok(());
}
}
}
}
let tls = Arc::new(Mutex::new(tls));
thread::spawn({
let tls = tls.clone();
let stream = stream.clone();
move || {
loop {
let mut buffer = vec![0; 4096];
match (&*stream).read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from remote {n}");
buffer.truncate(n);
match tls.lock().unwrap().write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
}
});
let tls = tls.clone();
let stream = stream.clone();
loop {
let mut buffer = vec![0; 4096];
match tls.lock().unwrap().read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from tls {n}");
buffer.truncate(n);
match (&*stream).write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let mut args = env::args(); let mut args = env::args();
args.next(); args.next();
let sni_domain = args let sni_domain = args.next().expect("missing sni domain argument");
.next() let host = args.next().expect("missing host argument");
.expect("missing sni domain argument"); let local_host = args.next().expect("missing local host argument");
let host = args
.next()
.expect("missing host argument");
let local_host = args
.next()
.expect("missing local host argument");
let mut config = ClientConfig::builder() let mut config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty()) .with_root_certificates(RootCertStore::empty())
@ -76,12 +171,36 @@ fn main() -> Result<(), Box<dyn Error>> {
let verifier = Arc::new(NoCertVerify); let verifier = Arc::new(NoCertVerify);
config.dangerous().set_certificate_verifier(verifier); config.dangerous().set_certificate_verifier(verifier);
let mut conn = ClientConnection::new(Arc::new(config), sni_domain.try_into()?)?; let config = Arc::new(config);
let mut sock = TcpStream::connect(host).unwrap();
let mut tls = rustls::Stream::new(&mut conn, &mut sock);
tls.write_all(b"hello")?; let threadpool = ThreadPool::new(10);
tls.read_to_end(&mut vec![])?;
let listener = TcpListener::bind(local_host).unwrap();
for stream in listener.incoming() {
let stream = stream.expect("listener got broken");
threadpool.execute({
let config = config.clone();
let sni_domain = sni_domain.clone();
let host = host.clone();
move || {
let Ok(remote_stream) = TcpStream::connect(host) else {
return;
};
let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
println!("error connection: {e:?}");
}
let _ = (&*stream).shutdown(Shutdown::Both);
}
});
}
Ok(()) Ok(())
} }

View File

@ -5,5 +5,5 @@ edition = "2024"
[dependencies] [dependencies]
rcgen = "0.13.2" rcgen = "0.13.2"
rustls = { version = "0.23.28", features = ["std"] } rustls = "0.23.28"
threadpool = "1.8.1" threadpool = "1.8.1"

View File

@ -1,71 +1,105 @@
use std::io::{Read, Write};
use std::{env, thread};
use std::error::Error; use std::error::Error;
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream}; use std::net::{Shutdown, TcpListener, TcpStream};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{env, thread};
use rcgen::generate_simple_self_signed; use rcgen::generate_simple_self_signed;
use rustls::crypto::aws_lc_rs::sign::any_supported_type; use rustls::crypto::aws_lc_rs::sign::any_supported_type;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::PrivateKeyDer; use rustls::pki_types::PrivateKeyDer;
use rustls::pki_types::pem::PemObject;
use rustls::server::ResolvesServerCertUsingSni; use rustls::server::ResolvesServerCertUsingSni;
use rustls::sign::CertifiedKey; use rustls::sign::CertifiedKey;
use rustls::{ServerConfig, ServerConnection, Stream}; use rustls::{ServerConfig, ServerConnection, StreamOwned};
use threadpool::ThreadPool; use threadpool::ThreadPool;
fn accept_client(stream: &mut TcpStream, config: Arc<ServerConfig>, sni_domain: String, host: String) -> Result<(), Box<dyn Error>> { fn accept_client(
let remote_stream = Arc::new(TcpStream::connect(host)?); stream: Arc<TcpStream>,
config: Arc<ServerConfig>,
sni_domain: String,
remote_stream: Arc<TcpStream>,
) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?;
let connection = ServerConnection::new(config)?; let mut tls = StreamOwned::new(ServerConnection::new(config)?, stream.try_clone()?);
let connection = Arc::new(Mutex::new(connection));
{ while tls.conn.is_handshaking() {
let mut conn = connection.lock().unwrap(); match tls.conn.complete_io(&mut tls.sock) {
let mut tls_stream = Stream::new(&mut *conn, stream); Ok(_) => {},
while tls_stream.conn.is_handshaking() { Err(e) => {
tls_stream.conn.complete_io(&mut tls_stream.sock)?; if e.kind() != ErrorKind::WouldBlock {
} return Err(e.into());
}
if tls_stream.conn.server_name() != Some(&sni_domain) { }
return Err("unexpected sni domain".into());
} }
} }
if tls.conn.server_name() != Some(&sni_domain) {
return Err("unexpected sni domain".into());
}
let tls = Arc::new(Mutex::new(tls));
thread::spawn({ thread::spawn({
let mut stream = stream.try_clone()?; let tls = tls.clone();
let connection = connection.clone();
let remote_stream = remote_stream.clone(); let remote_stream = remote_stream.clone();
move || { move || {
loop { loop {
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
let n = (&*remote_stream).read(&mut buffer).unwrap(); match (&*remote_stream).read(&mut buffer) {
if n != 0 { Ok(0) => { break },
buffer.truncate(n); Ok(n) => {
println!("from remote {n}");
let mut conn = connection.lock().unwrap(); buffer.truncate(n);
let mut tls_stream = Stream::new(&mut *conn, &mut stream); match tls.lock().unwrap().write(&buffer) {
tls_stream.write_all(&buffer[..n]).unwrap(); Ok(_) => {},
tls_stream.flush().unwrap(); Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
} }
} }
} }
}); });
let mut stream = stream.try_clone()?; let tls = tls.clone();
let connection = connection.clone(); let remote_stream = remote_stream.clone();
loop { loop {
let mut conn = connection.lock().unwrap();
let mut tls_stream = Stream::new(&mut *conn, &mut stream);
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
let n = tls_stream.read(&mut buffer)?; match tls.lock().unwrap().read(&mut buffer) {
if n != 0 { Ok(0) => { break },
buffer.truncate(n); Ok(n) => {
(&*remote_stream).write(&buffer)?; println!("from tls {n}");
buffer.truncate(n);
match (&*remote_stream).write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
} }
} }
Ok(())
} }
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
@ -73,37 +107,33 @@ fn main() -> Result<(), Box<dyn Error>> {
args.next(); args.next();
let sni_domain = args let sni_domain = args.next().expect("missing sni domain argument");
.next() let host = args.next().expect("missing host argument");
.expect("missing sni domain argument"); let local_host = args.next().expect("missing local host argument");
let host = args
.next()
.expect("missing host argument");
let local_host = args
.next()
.expect("missing local host argument");
let certified_key = generate_simple_self_signed(vec![sni_domain.to_string()]).unwrap(); let certified_key = generate_simple_self_signed(vec![sni_domain.to_string()]).unwrap();
let certified_key = CertifiedKey::new( let certified_key = CertifiedKey::new(
vec![certified_key.cert.into()], vec![certified_key.cert.into()],
any_supported_type( any_supported_type(&PrivateKeyDer::from_pem_slice(
&PrivateKeyDer::from_pem_slice(certified_key.key_pair.serialize_pem().as_bytes())? certified_key.key_pair.serialize_pem().as_bytes(),
)? )?)?,
); );
let mut resolver = ResolvesServerCertUsingSni::new(); let mut resolver = ResolvesServerCertUsingSni::new();
resolver.add(&sni_domain, certified_key)?; resolver.add(&sni_domain, certified_key)?;
let config = Arc::new(ServerConfig::builder() let config = Arc::new(
.with_no_client_auth() ServerConfig::builder()
.with_cert_resolver(Arc::new(resolver))); .with_no_client_auth()
.with_cert_resolver(Arc::new(resolver)),
);
let threadpool = ThreadPool::new(10); let threadpool = ThreadPool::new(10);
let listener = TcpListener::bind(local_host).unwrap(); let listener = TcpListener::bind(local_host).unwrap();
for stream in listener.incoming() { for stream in listener.incoming() {
let mut stream = stream.expect("listener got broken"); let stream = stream.expect("listener got broken");
threadpool.execute({ threadpool.execute({
let config = config.clone(); let config = config.clone();
@ -111,11 +141,18 @@ fn main() -> Result<(), Box<dyn Error>> {
let host = host.clone(); let host = host.clone();
move || { move || {
if let Err(e) = accept_client(&mut stream, config, sni_domain, host) { let Ok(remote_stream) = TcpStream::connect(host) else {
return;
};
let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
println!("error connection: {e:?}"); println!("error connection: {e:?}");
} }
let _ = stream.shutdown(Shutdown::Both); let _ = (&*stream).shutdown(Shutdown::Both);
} }
}); });
} }