diff --git a/unrknize-client/src/main.rs b/unrknize-client/src/main.rs index 2e5a244..e591e69 100644 --- a/unrknize-client/src/main.rs +++ b/unrknize-client/src/main.rs @@ -3,8 +3,8 @@ use std::{ error::Error, io::{ErrorKind, Read, Write}, net::{Shutdown, TcpListener, TcpStream}, - sync::{Arc, Mutex}, - thread, + sync::{atomic::{AtomicBool, Ordering}, Arc, Mutex}, + thread, time::Duration, }; use rustls::{ @@ -72,6 +72,7 @@ fn accept_client( config: Arc, sni_domain: String, remote_stream: Arc, + threadpool: ThreadPool ) -> Result<(), Box> { stream.set_nonblocking(true)?; remote_stream.set_nonblocking(true)?; @@ -92,21 +93,24 @@ fn accept_client( } } + let break_it = Arc::new(AtomicBool::new(false)); + let tls = Arc::new(Mutex::new(tls)); - thread::spawn({ + threadpool.execute({ let tls = tls.clone(); let stream = stream.clone(); + let break_it = break_it.clone(); move || { - loop { + while !break_it.load(Ordering::SeqCst) { let mut buffer = vec![0; 4096]; - match (&*stream).read(&mut buffer) { + match tls.lock().unwrap().read(&mut buffer) { Ok(0) => { break }, Ok(n) => { // println!("from remote {n}"); buffer.truncate(n); - match tls.lock().unwrap().write(&buffer) { + match (&*stream).write(&buffer) { Ok(_) => {}, Err(e) => { if e.kind() != ErrorKind::WouldBlock { @@ -121,24 +125,31 @@ fn accept_client( } } } + thread::sleep(Duration::from_millis(2)); } + + break_it.store(true, Ordering::SeqCst); } }); let tls = tls.clone(); let stream = stream.clone(); - loop { + while !break_it.load(Ordering::SeqCst) { let mut buffer = vec![0; 4096]; - match tls.lock().unwrap().read(&mut buffer) { - Ok(0) => { break }, + match (&*stream).read(&mut buffer) { + Ok(0) => { + break_it.store(true, Ordering::SeqCst); + break + }, Ok(n) => { // println!("from tls {n}"); buffer.truncate(n); - match (&*stream).write(&buffer) { + match tls.lock().unwrap().write(&buffer) { Ok(_) => {}, Err(e) => { if e.kind() != ErrorKind::WouldBlock { + break_it.store(true, Ordering::SeqCst); return Err(e.into()); } } @@ -146,10 +157,12 @@ fn accept_client( } Err(e) => { if e.kind() != ErrorKind::WouldBlock { + break_it.store(true, Ordering::SeqCst); return Err(e.into()); } } } + thread::sleep(Duration::from_millis(2)); } Ok(()) @@ -173,7 +186,7 @@ fn main() -> Result<(), Box> { let config = Arc::new(config); - let threadpool = ThreadPool::new(10); + let threadpool = ThreadPool::new(100); let listener = TcpListener::bind(local_host).unwrap(); @@ -184,20 +197,27 @@ fn main() -> Result<(), Box> { let config = config.clone(); let sni_domain = sni_domain.clone(); let host = host.clone(); + let threadpool = threadpool.clone(); move || { let Ok(remote_stream) = TcpStream::connect(host) else { return; }; + let addr = stream.peer_addr().unwrap(); + + println!("connected {}", addr); + let stream = Arc::new(stream); let remote_stream = Arc::new(remote_stream); - if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { + if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream, threadpool) { println!("error connection: {e:?}"); } let _ = (&*stream).shutdown(Shutdown::Both); + + println!("disconnected {}", addr); } }); } diff --git a/unrknize-server/src/main.rs b/unrknize-server/src/main.rs index 7a9d63e..84ec2c5 100644 --- a/unrknize-server/src/main.rs +++ b/unrknize-server/src/main.rs @@ -1,7 +1,9 @@ use std::error::Error; use std::io::{ErrorKind, Read, Write}; use std::net::{Shutdown, TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; +use std::time::Duration; use std::{env, thread}; use rcgen::generate_simple_self_signed; @@ -18,6 +20,7 @@ fn accept_client( config: Arc, sni_domain: String, remote_stream: Arc, + threadpool: ThreadPool ) -> Result<(), Box> { stream.set_nonblocking(true)?; remote_stream.set_nonblocking(true)?; @@ -39,14 +42,17 @@ fn accept_client( return Err("unexpected sni domain".into()); } + let break_it = Arc::new(AtomicBool::new(false)); + let tls = Arc::new(Mutex::new(tls)); - thread::spawn({ + threadpool.execute({ let tls = tls.clone(); let remote_stream = remote_stream.clone(); + let break_it = break_it.clone(); move || { - loop { + while !break_it.load(Ordering::SeqCst) { let mut buffer = vec![0; 4096]; match (&*remote_stream).read(&mut buffer) { Ok(0) => { break }, @@ -68,17 +74,23 @@ fn accept_client( } } } + thread::sleep(Duration::from_millis(2)); } + + break_it.store(true, Ordering::SeqCst); } }); let tls = tls.clone(); let remote_stream = remote_stream.clone(); - loop { + while !break_it.load(Ordering::SeqCst) { let mut buffer = vec![0; 4096]; match tls.lock().unwrap().read(&mut buffer) { - Ok(0) => { break }, + Ok(0) => { + break_it.store(true, Ordering::SeqCst); + break; + }, Ok(n) => { // println!("from tls {n}"); buffer.truncate(n); @@ -86,6 +98,7 @@ fn accept_client( Ok(_) => {}, Err(e) => { if e.kind() != ErrorKind::WouldBlock { + break_it.store(true, Ordering::SeqCst); return Err(e.into()); } } @@ -93,10 +106,12 @@ fn accept_client( } Err(e) => { if e.kind() != ErrorKind::WouldBlock { + break_it.store(true, Ordering::SeqCst); return Err(e.into()); } } } + thread::sleep(Duration::from_millis(2)); } Ok(()) @@ -128,7 +143,7 @@ fn main() -> Result<(), Box> { .with_cert_resolver(Arc::new(resolver)), ); - let threadpool = ThreadPool::new(10); + let threadpool = ThreadPool::new(100); let listener = TcpListener::bind(local_host).unwrap(); @@ -139,20 +154,27 @@ fn main() -> Result<(), Box> { let config = config.clone(); let sni_domain = sni_domain.clone(); let host = host.clone(); + let threadpool = threadpool.clone(); move || { let Ok(remote_stream) = TcpStream::connect(host) else { return; }; + let addr = stream.peer_addr().unwrap(); + + println!("connected {}", addr); + let stream = Arc::new(stream); let remote_stream = Arc::new(remote_stream); - if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { + if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream, threadpool) { println!("error connection: {e:?}"); } let _ = (&*stream).shutdown(Shutdown::Both); + + println!("disconnected {}", addr); } }); }