diff --git a/unrknize-client/Cargo.lock b/unrknize-client/Cargo.lock index e42501a..4e78af2 100644 --- a/unrknize-client/Cargo.lock +++ b/unrknize-client/Cargo.lock @@ -166,6 +166,12 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "home" version = "0.5.11" @@ -256,6 +262,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -417,6 +433,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "threadpool" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" +dependencies = [ + "num_cpus", +] + [[package]] name = "unicode-ident" version = "1.0.18" @@ -428,6 +453,7 @@ name = "unrknize-client" version = "0.1.0" dependencies = [ "rustls", + "threadpool", ] [[package]] diff --git a/unrknize-client/Cargo.toml b/unrknize-client/Cargo.toml index 9649659..3f32607 100644 --- a/unrknize-client/Cargo.toml +++ b/unrknize-client/Cargo.toml @@ -5,3 +5,4 @@ edition = "2024" [dependencies] rustls = "0.23.28" +threadpool = "1.8.1" diff --git a/unrknize-client/src/main.rs b/unrknize-client/src/main.rs index fd9bae6..5c57b49 100644 --- a/unrknize-client/src/main.rs +++ b/unrknize-client/src/main.rs @@ -1,6 +1,19 @@ -use std::{env, error::Error, io::{Read, Write}, net::TcpStream, sync::Arc}; +use std::{ + env, + error::Error, + io::{ErrorKind, Read, Write}, + net::{Shutdown, TcpListener, TcpStream}, + sync::{Arc, Mutex}, + thread, +}; -use rustls::{client::{danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}}, pki_types::{CertificateDer, ServerName, UnixTime}, ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme}; +use rustls::{ + ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme, + StreamOwned, + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; +use threadpool::ThreadPool; #[derive(Debug)] pub struct NoCertVerify; @@ -54,20 +67,102 @@ impl ServerCertVerifier for NoCertVerify { } } +fn accept_client( + stream: Arc, + config: Arc, + sni_domain: String, + remote_stream: Arc, +) -> Result<(), Box> { + stream.set_nonblocking(true)?; + remote_stream.set_nonblocking(true)?; + + let mut tls = StreamOwned::new( + ClientConnection::new(config, sni_domain.try_into()?)?, + remote_stream.try_clone()?, + ); + + while tls.conn.is_handshaking() { + match tls.conn.complete_io(&mut tls.sock) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Ok(()); + } + } + } + } + + let tls = Arc::new(Mutex::new(tls)); + + thread::spawn({ + let tls = tls.clone(); + let stream = stream.clone(); + + move || { + loop { + let mut buffer = vec![0; 4096]; + match (&*stream).read(&mut buffer) { + Ok(0) => { break }, + Ok(n) => { + println!("from remote {n}"); + buffer.truncate(n); + match tls.lock().unwrap().write(&buffer) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + break; + } + } + } + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + break; + } + } + } + } + } + }); + + let tls = tls.clone(); + let stream = stream.clone(); + + loop { + let mut buffer = vec![0; 4096]; + match tls.lock().unwrap().read(&mut buffer) { + Ok(0) => { break }, + Ok(n) => { + println!("from tls {n}"); + buffer.truncate(n); + match (&*stream).write(&buffer) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Err(e.into()); + } + } + } + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Err(e.into()); + } + } + } + } + + Ok(()) +} + fn main() -> Result<(), Box> { let mut args = env::args(); args.next(); - let sni_domain = args - .next() - .expect("missing sni domain argument"); - let host = args - .next() - .expect("missing host argument"); - let local_host = args - .next() - .expect("missing local host argument"); + let sni_domain = args.next().expect("missing sni domain argument"); + let host = args.next().expect("missing host argument"); + let local_host = args.next().expect("missing local host argument"); let mut config = ClientConfig::builder() .with_root_certificates(RootCertStore::empty()) @@ -76,12 +171,36 @@ fn main() -> Result<(), Box> { let verifier = Arc::new(NoCertVerify); config.dangerous().set_certificate_verifier(verifier); - let mut conn = ClientConnection::new(Arc::new(config), sni_domain.try_into()?)?; - let mut sock = TcpStream::connect(host).unwrap(); - let mut tls = rustls::Stream::new(&mut conn, &mut sock); + let config = Arc::new(config); - tls.write_all(b"hello")?; - tls.read_to_end(&mut vec![])?; + let threadpool = ThreadPool::new(10); + + let listener = TcpListener::bind(local_host).unwrap(); + + for stream in listener.incoming() { + let stream = stream.expect("listener got broken"); + + threadpool.execute({ + let config = config.clone(); + let sni_domain = sni_domain.clone(); + let host = host.clone(); + + move || { + let Ok(remote_stream) = TcpStream::connect(host) else { + return; + }; + + let stream = Arc::new(stream); + let remote_stream = Arc::new(remote_stream); + + if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { + println!("error connection: {e:?}"); + } + + let _ = (&*stream).shutdown(Shutdown::Both); + } + }); + } Ok(()) } diff --git a/unrknize-server/Cargo.toml b/unrknize-server/Cargo.toml index 7b1bd18..5a02f8f 100644 --- a/unrknize-server/Cargo.toml +++ b/unrknize-server/Cargo.toml @@ -5,5 +5,5 @@ edition = "2024" [dependencies] rcgen = "0.13.2" -rustls = { version = "0.23.28", features = ["std"] } +rustls = "0.23.28" threadpool = "1.8.1" diff --git a/unrknize-server/src/main.rs b/unrknize-server/src/main.rs index 855491a..2238bc7 100644 --- a/unrknize-server/src/main.rs +++ b/unrknize-server/src/main.rs @@ -1,71 +1,105 @@ -use std::io::{Read, Write}; -use std::{env, thread}; use std::error::Error; +use std::io::{ErrorKind, Read, Write}; use std::net::{Shutdown, TcpListener, TcpStream}; use std::sync::{Arc, Mutex}; +use std::{env, thread}; use rcgen::generate_simple_self_signed; use rustls::crypto::aws_lc_rs::sign::any_supported_type; -use rustls::pki_types::pem::PemObject; use rustls::pki_types::PrivateKeyDer; +use rustls::pki_types::pem::PemObject; use rustls::server::ResolvesServerCertUsingSni; use rustls::sign::CertifiedKey; -use rustls::{ServerConfig, ServerConnection, Stream}; +use rustls::{ServerConfig, ServerConnection, StreamOwned}; use threadpool::ThreadPool; -fn accept_client(stream: &mut TcpStream, config: Arc, sni_domain: String, host: String) -> Result<(), Box> { - let remote_stream = Arc::new(TcpStream::connect(host)?); +fn accept_client( + stream: Arc, + config: Arc, + sni_domain: String, + remote_stream: Arc, +) -> Result<(), Box> { + stream.set_nonblocking(true)?; + remote_stream.set_nonblocking(true)?; - let connection = ServerConnection::new(config)?; - let connection = Arc::new(Mutex::new(connection)); - - { - let mut conn = connection.lock().unwrap(); - let mut tls_stream = Stream::new(&mut *conn, stream); - while tls_stream.conn.is_handshaking() { - tls_stream.conn.complete_io(&mut tls_stream.sock)?; - } + let mut tls = StreamOwned::new(ServerConnection::new(config)?, stream.try_clone()?); - if tls_stream.conn.server_name() != Some(&sni_domain) { - return Err("unexpected sni domain".into()); + while tls.conn.is_handshaking() { + match tls.conn.complete_io(&mut tls.sock) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Err(e.into()); + } + } } } + if tls.conn.server_name() != Some(&sni_domain) { + return Err("unexpected sni domain".into()); + } + + let tls = Arc::new(Mutex::new(tls)); + thread::spawn({ - let mut stream = stream.try_clone()?; - let connection = connection.clone(); + let tls = tls.clone(); let remote_stream = remote_stream.clone(); move || { loop { let mut buffer = vec![0; 4096]; - let n = (&*remote_stream).read(&mut buffer).unwrap(); - if n != 0 { - buffer.truncate(n); - - let mut conn = connection.lock().unwrap(); - let mut tls_stream = Stream::new(&mut *conn, &mut stream); - tls_stream.write_all(&buffer[..n]).unwrap(); - tls_stream.flush().unwrap(); + match (&*remote_stream).read(&mut buffer) { + Ok(0) => { break }, + Ok(n) => { + println!("from remote {n}"); + buffer.truncate(n); + match tls.lock().unwrap().write(&buffer) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + break; + } + } + } + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + break; + } + } } } } }); - let mut stream = stream.try_clone()?; - let connection = connection.clone(); - - loop { - let mut conn = connection.lock().unwrap(); - let mut tls_stream = Stream::new(&mut *conn, &mut stream); + let tls = tls.clone(); + let remote_stream = remote_stream.clone(); + loop { let mut buffer = vec![0; 4096]; - let n = tls_stream.read(&mut buffer)?; - if n != 0 { - buffer.truncate(n); - (&*remote_stream).write(&buffer)?; + match tls.lock().unwrap().read(&mut buffer) { + Ok(0) => { break }, + Ok(n) => { + println!("from tls {n}"); + buffer.truncate(n); + match (&*remote_stream).write(&buffer) { + Ok(_) => {}, + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Err(e.into()); + } + } + } + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + return Err(e.into()); + } + } } } + + Ok(()) } fn main() -> Result<(), Box> { @@ -73,37 +107,33 @@ fn main() -> Result<(), Box> { args.next(); - let sni_domain = args - .next() - .expect("missing sni domain argument"); - let host = args - .next() - .expect("missing host argument"); - let local_host = args - .next() - .expect("missing local host argument"); + let sni_domain = args.next().expect("missing sni domain argument"); + let host = args.next().expect("missing host argument"); + let local_host = args.next().expect("missing local host argument"); let certified_key = generate_simple_self_signed(vec![sni_domain.to_string()]).unwrap(); let certified_key = CertifiedKey::new( - vec![certified_key.cert.into()], - any_supported_type( - &PrivateKeyDer::from_pem_slice(certified_key.key_pair.serialize_pem().as_bytes())? - )? + vec![certified_key.cert.into()], + any_supported_type(&PrivateKeyDer::from_pem_slice( + certified_key.key_pair.serialize_pem().as_bytes(), + )?)?, ); let mut resolver = ResolvesServerCertUsingSni::new(); resolver.add(&sni_domain, certified_key)?; - let config = Arc::new(ServerConfig::builder() - .with_no_client_auth() - .with_cert_resolver(Arc::new(resolver))); + let config = Arc::new( + ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)), + ); let threadpool = ThreadPool::new(10); let listener = TcpListener::bind(local_host).unwrap(); for stream in listener.incoming() { - let mut stream = stream.expect("listener got broken"); + let stream = stream.expect("listener got broken"); threadpool.execute({ let config = config.clone(); @@ -111,14 +141,21 @@ fn main() -> Result<(), Box> { let host = host.clone(); move || { - if let Err(e) = accept_client(&mut stream, config, sni_domain, host) { + let Ok(remote_stream) = TcpStream::connect(host) else { + return; + }; + + let stream = Arc::new(stream); + let remote_stream = Arc::new(remote_stream); + + if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { println!("error connection: {e:?}"); } - let _ = stream.shutdown(Shutdown::Both); + let _ = (&*stream).shutdown(Shutdown::Both); } }); } Ok(()) -} \ No newline at end of file +}