fix server deadlock and realize client

This commit is contained in:
MeexReay 2025-06-20 18:37:58 +03:00
parent d273b32eee
commit 364f63a75a
5 changed files with 257 additions and 74 deletions

View File

@ -166,6 +166,12 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "home"
version = "0.5.11"
@ -256,6 +262,16 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "once_cell"
version = "1.21.3"
@ -417,6 +433,15 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
@ -428,6 +453,7 @@ name = "unrknize-client"
version = "0.1.0"
dependencies = [
"rustls",
"threadpool",
]
[[package]]

View File

@ -5,3 +5,4 @@ edition = "2024"
[dependencies]
rustls = "0.23.28"
threadpool = "1.8.1"

View File

@ -1,6 +1,19 @@
use std::{env, error::Error, io::{Read, Write}, net::TcpStream, sync::Arc};
use std::{
env,
error::Error,
io::{ErrorKind, Read, Write},
net::{Shutdown, TcpListener, TcpStream},
sync::{Arc, Mutex},
thread,
};
use rustls::{client::{danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}}, pki_types::{CertificateDer, ServerName, UnixTime}, ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme};
use rustls::{
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme,
StreamOwned,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, ServerName, UnixTime},
};
use threadpool::ThreadPool;
#[derive(Debug)]
pub struct NoCertVerify;
@ -54,20 +67,102 @@ impl ServerCertVerifier for NoCertVerify {
}
}
fn accept_client(
stream: Arc<TcpStream>,
config: Arc<ClientConfig>,
sni_domain: String,
remote_stream: Arc<TcpStream>,
) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?;
let mut tls = StreamOwned::new(
ClientConnection::new(config, sni_domain.try_into()?)?,
remote_stream.try_clone()?,
);
while tls.conn.is_handshaking() {
match tls.conn.complete_io(&mut tls.sock) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Ok(());
}
}
}
}
let tls = Arc::new(Mutex::new(tls));
thread::spawn({
let tls = tls.clone();
let stream = stream.clone();
move || {
loop {
let mut buffer = vec![0; 4096];
match (&*stream).read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from remote {n}");
buffer.truncate(n);
match tls.lock().unwrap().write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
}
});
let tls = tls.clone();
let stream = stream.clone();
loop {
let mut buffer = vec![0; 4096];
match tls.lock().unwrap().read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from tls {n}");
buffer.truncate(n);
match (&*stream).write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
let mut args = env::args();
args.next();
let sni_domain = args
.next()
.expect("missing sni domain argument");
let host = args
.next()
.expect("missing host argument");
let local_host = args
.next()
.expect("missing local host argument");
let sni_domain = args.next().expect("missing sni domain argument");
let host = args.next().expect("missing host argument");
let local_host = args.next().expect("missing local host argument");
let mut config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
@ -76,12 +171,36 @@ fn main() -> Result<(), Box<dyn Error>> {
let verifier = Arc::new(NoCertVerify);
config.dangerous().set_certificate_verifier(verifier);
let mut conn = ClientConnection::new(Arc::new(config), sni_domain.try_into()?)?;
let mut sock = TcpStream::connect(host).unwrap();
let mut tls = rustls::Stream::new(&mut conn, &mut sock);
let config = Arc::new(config);
tls.write_all(b"hello")?;
tls.read_to_end(&mut vec![])?;
let threadpool = ThreadPool::new(10);
let listener = TcpListener::bind(local_host).unwrap();
for stream in listener.incoming() {
let stream = stream.expect("listener got broken");
threadpool.execute({
let config = config.clone();
let sni_domain = sni_domain.clone();
let host = host.clone();
move || {
let Ok(remote_stream) = TcpStream::connect(host) else {
return;
};
let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
println!("error connection: {e:?}");
}
let _ = (&*stream).shutdown(Shutdown::Both);
}
});
}
Ok(())
}

View File

@ -5,5 +5,5 @@ edition = "2024"
[dependencies]
rcgen = "0.13.2"
rustls = { version = "0.23.28", features = ["std"] }
rustls = "0.23.28"
threadpool = "1.8.1"

View File

@ -1,71 +1,105 @@
use std::io::{Read, Write};
use std::{env, thread};
use std::error::Error;
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream};
use std::sync::{Arc, Mutex};
use std::{env, thread};
use rcgen::generate_simple_self_signed;
use rustls::crypto::aws_lc_rs::sign::any_supported_type;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::PrivateKeyDer;
use rustls::pki_types::pem::PemObject;
use rustls::server::ResolvesServerCertUsingSni;
use rustls::sign::CertifiedKey;
use rustls::{ServerConfig, ServerConnection, Stream};
use rustls::{ServerConfig, ServerConnection, StreamOwned};
use threadpool::ThreadPool;
fn accept_client(stream: &mut TcpStream, config: Arc<ServerConfig>, sni_domain: String, host: String) -> Result<(), Box<dyn Error>> {
let remote_stream = Arc::new(TcpStream::connect(host)?);
fn accept_client(
stream: Arc<TcpStream>,
config: Arc<ServerConfig>,
sni_domain: String,
remote_stream: Arc<TcpStream>,
) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?;
let connection = ServerConnection::new(config)?;
let connection = Arc::new(Mutex::new(connection));
let mut tls = StreamOwned::new(ServerConnection::new(config)?, stream.try_clone()?);
{
let mut conn = connection.lock().unwrap();
let mut tls_stream = Stream::new(&mut *conn, stream);
while tls_stream.conn.is_handshaking() {
tls_stream.conn.complete_io(&mut tls_stream.sock)?;
}
if tls_stream.conn.server_name() != Some(&sni_domain) {
return Err("unexpected sni domain".into());
while tls.conn.is_handshaking() {
match tls.conn.complete_io(&mut tls.sock) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
if tls.conn.server_name() != Some(&sni_domain) {
return Err("unexpected sni domain".into());
}
let tls = Arc::new(Mutex::new(tls));
thread::spawn({
let mut stream = stream.try_clone()?;
let connection = connection.clone();
let tls = tls.clone();
let remote_stream = remote_stream.clone();
move || {
loop {
let mut buffer = vec![0; 4096];
let n = (&*remote_stream).read(&mut buffer).unwrap();
if n != 0 {
buffer.truncate(n);
let mut conn = connection.lock().unwrap();
let mut tls_stream = Stream::new(&mut *conn, &mut stream);
tls_stream.write_all(&buffer[..n]).unwrap();
tls_stream.flush().unwrap();
match (&*remote_stream).read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from remote {n}");
buffer.truncate(n);
match tls.lock().unwrap().write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
break;
}
}
}
}
}
});
let mut stream = stream.try_clone()?;
let connection = connection.clone();
let tls = tls.clone();
let remote_stream = remote_stream.clone();
loop {
let mut conn = connection.lock().unwrap();
let mut tls_stream = Stream::new(&mut *conn, &mut stream);
let mut buffer = vec![0; 4096];
let n = tls_stream.read(&mut buffer)?;
if n != 0 {
buffer.truncate(n);
(&*remote_stream).write(&buffer)?;
match tls.lock().unwrap().read(&mut buffer) {
Ok(0) => { break },
Ok(n) => {
println!("from tls {n}");
buffer.truncate(n);
match (&*remote_stream).write(&buffer) {
Ok(_) => {},
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Err(e) => {
if e.kind() != ErrorKind::WouldBlock {
return Err(e.into());
}
}
}
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
@ -73,37 +107,33 @@ fn main() -> Result<(), Box<dyn Error>> {
args.next();
let sni_domain = args
.next()
.expect("missing sni domain argument");
let host = args
.next()
.expect("missing host argument");
let local_host = args
.next()
.expect("missing local host argument");
let sni_domain = args.next().expect("missing sni domain argument");
let host = args.next().expect("missing host argument");
let local_host = args.next().expect("missing local host argument");
let certified_key = generate_simple_self_signed(vec![sni_domain.to_string()]).unwrap();
let certified_key = CertifiedKey::new(
vec![certified_key.cert.into()],
any_supported_type(
&PrivateKeyDer::from_pem_slice(certified_key.key_pair.serialize_pem().as_bytes())?
)?
any_supported_type(&PrivateKeyDer::from_pem_slice(
certified_key.key_pair.serialize_pem().as_bytes(),
)?)?,
);
let mut resolver = ResolvesServerCertUsingSni::new();
resolver.add(&sni_domain, certified_key)?;
let config = Arc::new(ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver)));
let config = Arc::new(
ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver)),
);
let threadpool = ThreadPool::new(10);
let listener = TcpListener::bind(local_host).unwrap();
for stream in listener.incoming() {
let mut stream = stream.expect("listener got broken");
let stream = stream.expect("listener got broken");
threadpool.execute({
let config = config.clone();
@ -111,11 +141,18 @@ fn main() -> Result<(), Box<dyn Error>> {
let host = host.clone();
move || {
if let Err(e) = accept_client(&mut stream, config, sni_domain, host) {
let Ok(remote_stream) = TcpStream::connect(host) else {
return;
};
let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
println!("error connection: {e:?}");
}
let _ = stream.shutdown(Shutdown::Both);
let _ = (&*stream).shutdown(Shutdown::Both);
}
});
}