fix server deadlock and realize client
This commit is contained in:
parent
d273b32eee
commit
364f63a75a
26
unrknize-client/Cargo.lock
generated
26
unrknize-client/Cargo.lock
generated
@ -166,6 +166,12 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.11"
|
||||
@ -256,6 +262,16 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
@ -417,6 +433,15 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "threadpool"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
|
||||
dependencies = [
|
||||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
@ -428,6 +453,7 @@ name = "unrknize-client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"threadpool",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5,3 +5,4 @@ edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rustls = "0.23.28"
|
||||
threadpool = "1.8.1"
|
||||
|
@ -1,6 +1,19 @@
|
||||
use std::{env, error::Error, io::{Read, Write}, net::TcpStream, sync::Arc};
|
||||
use std::{
|
||||
env,
|
||||
error::Error,
|
||||
io::{ErrorKind, Read, Write},
|
||||
net::{Shutdown, TcpListener, TcpStream},
|
||||
sync::{Arc, Mutex},
|
||||
thread,
|
||||
};
|
||||
|
||||
use rustls::{client::{danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}}, pki_types::{CertificateDer, ServerName, UnixTime}, ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme};
|
||||
use rustls::{
|
||||
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme,
|
||||
StreamOwned,
|
||||
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
|
||||
pki_types::{CertificateDer, ServerName, UnixTime},
|
||||
};
|
||||
use threadpool::ThreadPool;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoCertVerify;
|
||||
@ -54,20 +67,102 @@ impl ServerCertVerifier for NoCertVerify {
|
||||
}
|
||||
}
|
||||
|
||||
fn accept_client(
|
||||
stream: Arc<TcpStream>,
|
||||
config: Arc<ClientConfig>,
|
||||
sni_domain: String,
|
||||
remote_stream: Arc<TcpStream>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
stream.set_nonblocking(true)?;
|
||||
remote_stream.set_nonblocking(true)?;
|
||||
|
||||
let mut tls = StreamOwned::new(
|
||||
ClientConnection::new(config, sni_domain.try_into()?)?,
|
||||
remote_stream.try_clone()?,
|
||||
);
|
||||
|
||||
while tls.conn.is_handshaking() {
|
||||
match tls.conn.complete_io(&mut tls.sock) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tls = Arc::new(Mutex::new(tls));
|
||||
|
||||
thread::spawn({
|
||||
let tls = tls.clone();
|
||||
let stream = stream.clone();
|
||||
|
||||
move || {
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
match (&*stream).read(&mut buffer) {
|
||||
Ok(0) => { break },
|
||||
Ok(n) => {
|
||||
println!("from remote {n}");
|
||||
buffer.truncate(n);
|
||||
match tls.lock().unwrap().write(&buffer) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let tls = tls.clone();
|
||||
let stream = stream.clone();
|
||||
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
match tls.lock().unwrap().read(&mut buffer) {
|
||||
Ok(0) => { break },
|
||||
Ok(n) => {
|
||||
println!("from tls {n}");
|
||||
buffer.truncate(n);
|
||||
match (&*stream).write(&buffer) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let mut args = env::args();
|
||||
|
||||
args.next();
|
||||
|
||||
let sni_domain = args
|
||||
.next()
|
||||
.expect("missing sni domain argument");
|
||||
let host = args
|
||||
.next()
|
||||
.expect("missing host argument");
|
||||
let local_host = args
|
||||
.next()
|
||||
.expect("missing local host argument");
|
||||
let sni_domain = args.next().expect("missing sni domain argument");
|
||||
let host = args.next().expect("missing host argument");
|
||||
let local_host = args.next().expect("missing local host argument");
|
||||
|
||||
let mut config = ClientConfig::builder()
|
||||
.with_root_certificates(RootCertStore::empty())
|
||||
@ -76,12 +171,36 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
let verifier = Arc::new(NoCertVerify);
|
||||
config.dangerous().set_certificate_verifier(verifier);
|
||||
|
||||
let mut conn = ClientConnection::new(Arc::new(config), sni_domain.try_into()?)?;
|
||||
let mut sock = TcpStream::connect(host).unwrap();
|
||||
let mut tls = rustls::Stream::new(&mut conn, &mut sock);
|
||||
let config = Arc::new(config);
|
||||
|
||||
tls.write_all(b"hello")?;
|
||||
tls.read_to_end(&mut vec![])?;
|
||||
let threadpool = ThreadPool::new(10);
|
||||
|
||||
let listener = TcpListener::bind(local_host).unwrap();
|
||||
|
||||
for stream in listener.incoming() {
|
||||
let stream = stream.expect("listener got broken");
|
||||
|
||||
threadpool.execute({
|
||||
let config = config.clone();
|
||||
let sni_domain = sni_domain.clone();
|
||||
let host = host.clone();
|
||||
|
||||
move || {
|
||||
let Ok(remote_stream) = TcpStream::connect(host) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let stream = Arc::new(stream);
|
||||
let remote_stream = Arc::new(remote_stream);
|
||||
|
||||
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
|
||||
println!("error connection: {e:?}");
|
||||
}
|
||||
|
||||
let _ = (&*stream).shutdown(Shutdown::Both);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -5,5 +5,5 @@ edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rcgen = "0.13.2"
|
||||
rustls = { version = "0.23.28", features = ["std"] }
|
||||
rustls = "0.23.28"
|
||||
threadpool = "1.8.1"
|
||||
|
@ -1,71 +1,105 @@
|
||||
use std::io::{Read, Write};
|
||||
use std::{env, thread};
|
||||
use std::error::Error;
|
||||
use std::io::{ErrorKind, Read, Write};
|
||||
use std::net::{Shutdown, TcpListener, TcpStream};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{env, thread};
|
||||
|
||||
use rcgen::generate_simple_self_signed;
|
||||
use rustls::crypto::aws_lc_rs::sign::any_supported_type;
|
||||
use rustls::pki_types::pem::PemObject;
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use rustls::pki_types::pem::PemObject;
|
||||
use rustls::server::ResolvesServerCertUsingSni;
|
||||
use rustls::sign::CertifiedKey;
|
||||
use rustls::{ServerConfig, ServerConnection, Stream};
|
||||
use rustls::{ServerConfig, ServerConnection, StreamOwned};
|
||||
use threadpool::ThreadPool;
|
||||
|
||||
fn accept_client(stream: &mut TcpStream, config: Arc<ServerConfig>, sni_domain: String, host: String) -> Result<(), Box<dyn Error>> {
|
||||
let remote_stream = Arc::new(TcpStream::connect(host)?);
|
||||
fn accept_client(
|
||||
stream: Arc<TcpStream>,
|
||||
config: Arc<ServerConfig>,
|
||||
sni_domain: String,
|
||||
remote_stream: Arc<TcpStream>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
stream.set_nonblocking(true)?;
|
||||
remote_stream.set_nonblocking(true)?;
|
||||
|
||||
let connection = ServerConnection::new(config)?;
|
||||
let connection = Arc::new(Mutex::new(connection));
|
||||
|
||||
{
|
||||
let mut conn = connection.lock().unwrap();
|
||||
let mut tls_stream = Stream::new(&mut *conn, stream);
|
||||
while tls_stream.conn.is_handshaking() {
|
||||
tls_stream.conn.complete_io(&mut tls_stream.sock)?;
|
||||
}
|
||||
let mut tls = StreamOwned::new(ServerConnection::new(config)?, stream.try_clone()?);
|
||||
|
||||
if tls_stream.conn.server_name() != Some(&sni_domain) {
|
||||
return Err("unexpected sni domain".into());
|
||||
while tls.conn.is_handshaking() {
|
||||
match tls.conn.complete_io(&mut tls.sock) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tls.conn.server_name() != Some(&sni_domain) {
|
||||
return Err("unexpected sni domain".into());
|
||||
}
|
||||
|
||||
let tls = Arc::new(Mutex::new(tls));
|
||||
|
||||
thread::spawn({
|
||||
let mut stream = stream.try_clone()?;
|
||||
let connection = connection.clone();
|
||||
let tls = tls.clone();
|
||||
let remote_stream = remote_stream.clone();
|
||||
|
||||
move || {
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
let n = (&*remote_stream).read(&mut buffer).unwrap();
|
||||
if n != 0 {
|
||||
buffer.truncate(n);
|
||||
|
||||
let mut conn = connection.lock().unwrap();
|
||||
let mut tls_stream = Stream::new(&mut *conn, &mut stream);
|
||||
tls_stream.write_all(&buffer[..n]).unwrap();
|
||||
tls_stream.flush().unwrap();
|
||||
match (&*remote_stream).read(&mut buffer) {
|
||||
Ok(0) => { break },
|
||||
Ok(n) => {
|
||||
println!("from remote {n}");
|
||||
buffer.truncate(n);
|
||||
match tls.lock().unwrap().write(&buffer) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut stream = stream.try_clone()?;
|
||||
let connection = connection.clone();
|
||||
|
||||
loop {
|
||||
let mut conn = connection.lock().unwrap();
|
||||
let mut tls_stream = Stream::new(&mut *conn, &mut stream);
|
||||
let tls = tls.clone();
|
||||
let remote_stream = remote_stream.clone();
|
||||
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
let n = tls_stream.read(&mut buffer)?;
|
||||
if n != 0 {
|
||||
buffer.truncate(n);
|
||||
(&*remote_stream).write(&buffer)?;
|
||||
match tls.lock().unwrap().read(&mut buffer) {
|
||||
Ok(0) => { break },
|
||||
Ok(n) => {
|
||||
println!("from tls {n}");
|
||||
buffer.truncate(n);
|
||||
match (&*remote_stream).write(&buffer) {
|
||||
Ok(_) => {},
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() != ErrorKind::WouldBlock {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
@ -73,37 +107,33 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
args.next();
|
||||
|
||||
let sni_domain = args
|
||||
.next()
|
||||
.expect("missing sni domain argument");
|
||||
let host = args
|
||||
.next()
|
||||
.expect("missing host argument");
|
||||
let local_host = args
|
||||
.next()
|
||||
.expect("missing local host argument");
|
||||
let sni_domain = args.next().expect("missing sni domain argument");
|
||||
let host = args.next().expect("missing host argument");
|
||||
let local_host = args.next().expect("missing local host argument");
|
||||
|
||||
let certified_key = generate_simple_self_signed(vec![sni_domain.to_string()]).unwrap();
|
||||
let certified_key = CertifiedKey::new(
|
||||
vec![certified_key.cert.into()],
|
||||
any_supported_type(
|
||||
&PrivateKeyDer::from_pem_slice(certified_key.key_pair.serialize_pem().as_bytes())?
|
||||
)?
|
||||
vec![certified_key.cert.into()],
|
||||
any_supported_type(&PrivateKeyDer::from_pem_slice(
|
||||
certified_key.key_pair.serialize_pem().as_bytes(),
|
||||
)?)?,
|
||||
);
|
||||
|
||||
let mut resolver = ResolvesServerCertUsingSni::new();
|
||||
resolver.add(&sni_domain, certified_key)?;
|
||||
|
||||
let config = Arc::new(ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(resolver)));
|
||||
let config = Arc::new(
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_cert_resolver(Arc::new(resolver)),
|
||||
);
|
||||
|
||||
let threadpool = ThreadPool::new(10);
|
||||
|
||||
let listener = TcpListener::bind(local_host).unwrap();
|
||||
|
||||
for stream in listener.incoming() {
|
||||
let mut stream = stream.expect("listener got broken");
|
||||
let stream = stream.expect("listener got broken");
|
||||
|
||||
threadpool.execute({
|
||||
let config = config.clone();
|
||||
@ -111,14 +141,21 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
let host = host.clone();
|
||||
|
||||
move || {
|
||||
if let Err(e) = accept_client(&mut stream, config, sni_domain, host) {
|
||||
let Ok(remote_stream) = TcpStream::connect(host) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let stream = Arc::new(stream);
|
||||
let remote_stream = Arc::new(remote_stream);
|
||||
|
||||
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) {
|
||||
println!("error connection: {e:?}");
|
||||
}
|
||||
|
||||
let _ = stream.shutdown(Shutdown::Both);
|
||||
let _ = (&*stream).shutdown(Shutdown::Both);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user