fix traffic issues

This commit is contained in:
MeexReay 2025-06-21 01:44:06 +03:00
parent 7d6f3bb72d
commit 48e35c00e6
2 changed files with 60 additions and 18 deletions

View File

@ -3,8 +3,8 @@ use std::{
error::Error, error::Error,
io::{ErrorKind, Read, Write}, io::{ErrorKind, Read, Write},
net::{Shutdown, TcpListener, TcpStream}, net::{Shutdown, TcpListener, TcpStream},
sync::{Arc, Mutex}, sync::{atomic::{AtomicBool, Ordering}, Arc, Mutex},
thread, thread, time::Duration,
}; };
use rustls::{ use rustls::{
@ -72,6 +72,7 @@ fn accept_client(
config: Arc<ClientConfig>, config: Arc<ClientConfig>,
sni_domain: String, sni_domain: String,
remote_stream: Arc<TcpStream>, remote_stream: Arc<TcpStream>,
threadpool: ThreadPool
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?; stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?; remote_stream.set_nonblocking(true)?;
@ -92,21 +93,24 @@ fn accept_client(
} }
} }
let break_it = Arc::new(AtomicBool::new(false));
let tls = Arc::new(Mutex::new(tls)); let tls = Arc::new(Mutex::new(tls));
thread::spawn({ threadpool.execute({
let tls = tls.clone(); let tls = tls.clone();
let stream = stream.clone(); let stream = stream.clone();
let break_it = break_it.clone();
move || { move || {
loop { while !break_it.load(Ordering::SeqCst) {
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
match (&*stream).read(&mut buffer) { match tls.lock().unwrap().read(&mut buffer) {
Ok(0) => { break }, Ok(0) => { break },
Ok(n) => { Ok(n) => {
// println!("from remote {n}"); // println!("from remote {n}");
buffer.truncate(n); buffer.truncate(n);
match tls.lock().unwrap().write(&buffer) { match (&*stream).write(&buffer) {
Ok(_) => {}, Ok(_) => {},
Err(e) => { Err(e) => {
if e.kind() != ErrorKind::WouldBlock { if e.kind() != ErrorKind::WouldBlock {
@ -121,24 +125,31 @@ fn accept_client(
} }
} }
} }
thread::sleep(Duration::from_millis(2));
} }
break_it.store(true, Ordering::SeqCst);
} }
}); });
let tls = tls.clone(); let tls = tls.clone();
let stream = stream.clone(); let stream = stream.clone();
loop { while !break_it.load(Ordering::SeqCst) {
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
match tls.lock().unwrap().read(&mut buffer) { match (&*stream).read(&mut buffer) {
Ok(0) => { break }, Ok(0) => {
break_it.store(true, Ordering::SeqCst);
break
},
Ok(n) => { Ok(n) => {
// println!("from tls {n}"); // println!("from tls {n}");
buffer.truncate(n); buffer.truncate(n);
match (&*stream).write(&buffer) { match tls.lock().unwrap().write(&buffer) {
Ok(_) => {}, Ok(_) => {},
Err(e) => { Err(e) => {
if e.kind() != ErrorKind::WouldBlock { if e.kind() != ErrorKind::WouldBlock {
break_it.store(true, Ordering::SeqCst);
return Err(e.into()); return Err(e.into());
} }
} }
@ -146,10 +157,12 @@ fn accept_client(
} }
Err(e) => { Err(e) => {
if e.kind() != ErrorKind::WouldBlock { if e.kind() != ErrorKind::WouldBlock {
break_it.store(true, Ordering::SeqCst);
return Err(e.into()); return Err(e.into());
} }
} }
} }
thread::sleep(Duration::from_millis(2));
} }
Ok(()) Ok(())
@ -173,7 +186,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let config = Arc::new(config); let config = Arc::new(config);
let threadpool = ThreadPool::new(10); let threadpool = ThreadPool::new(100);
let listener = TcpListener::bind(local_host).unwrap(); let listener = TcpListener::bind(local_host).unwrap();
@ -184,20 +197,27 @@ fn main() -> Result<(), Box<dyn Error>> {
let config = config.clone(); let config = config.clone();
let sni_domain = sni_domain.clone(); let sni_domain = sni_domain.clone();
let host = host.clone(); let host = host.clone();
let threadpool = threadpool.clone();
move || { move || {
let Ok(remote_stream) = TcpStream::connect(host) else { let Ok(remote_stream) = TcpStream::connect(host) else {
return; return;
}; };
let addr = stream.peer_addr().unwrap();
println!("connected {}", addr);
let stream = Arc::new(stream); let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream); let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream, threadpool) {
println!("error connection: {e:?}"); println!("error connection: {e:?}");
} }
let _ = (&*stream).shutdown(Shutdown::Both); let _ = (&*stream).shutdown(Shutdown::Both);
println!("disconnected {}", addr);
} }
}); });
} }

View File

@ -1,7 +1,9 @@
use std::error::Error; use std::error::Error;
use std::io::{ErrorKind, Read, Write}; use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream}; use std::net::{Shutdown, TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::{env, thread}; use std::{env, thread};
use rcgen::generate_simple_self_signed; use rcgen::generate_simple_self_signed;
@ -18,6 +20,7 @@ fn accept_client(
config: Arc<ServerConfig>, config: Arc<ServerConfig>,
sni_domain: String, sni_domain: String,
remote_stream: Arc<TcpStream>, remote_stream: Arc<TcpStream>,
threadpool: ThreadPool
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
stream.set_nonblocking(true)?; stream.set_nonblocking(true)?;
remote_stream.set_nonblocking(true)?; remote_stream.set_nonblocking(true)?;
@ -39,14 +42,17 @@ fn accept_client(
return Err("unexpected sni domain".into()); return Err("unexpected sni domain".into());
} }
let break_it = Arc::new(AtomicBool::new(false));
let tls = Arc::new(Mutex::new(tls)); let tls = Arc::new(Mutex::new(tls));
thread::spawn({ threadpool.execute({
let tls = tls.clone(); let tls = tls.clone();
let remote_stream = remote_stream.clone(); let remote_stream = remote_stream.clone();
let break_it = break_it.clone();
move || { move || {
loop { while !break_it.load(Ordering::SeqCst) {
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
match (&*remote_stream).read(&mut buffer) { match (&*remote_stream).read(&mut buffer) {
Ok(0) => { break }, Ok(0) => { break },
@ -68,17 +74,23 @@ fn accept_client(
} }
} }
} }
thread::sleep(Duration::from_millis(2));
} }
break_it.store(true, Ordering::SeqCst);
} }
}); });
let tls = tls.clone(); let tls = tls.clone();
let remote_stream = remote_stream.clone(); let remote_stream = remote_stream.clone();
loop { while !break_it.load(Ordering::SeqCst) {
let mut buffer = vec![0; 4096]; let mut buffer = vec![0; 4096];
match tls.lock().unwrap().read(&mut buffer) { match tls.lock().unwrap().read(&mut buffer) {
Ok(0) => { break }, Ok(0) => {
break_it.store(true, Ordering::SeqCst);
break;
},
Ok(n) => { Ok(n) => {
// println!("from tls {n}"); // println!("from tls {n}");
buffer.truncate(n); buffer.truncate(n);
@ -86,6 +98,7 @@ fn accept_client(
Ok(_) => {}, Ok(_) => {},
Err(e) => { Err(e) => {
if e.kind() != ErrorKind::WouldBlock { if e.kind() != ErrorKind::WouldBlock {
break_it.store(true, Ordering::SeqCst);
return Err(e.into()); return Err(e.into());
} }
} }
@ -93,10 +106,12 @@ fn accept_client(
} }
Err(e) => { Err(e) => {
if e.kind() != ErrorKind::WouldBlock { if e.kind() != ErrorKind::WouldBlock {
break_it.store(true, Ordering::SeqCst);
return Err(e.into()); return Err(e.into());
} }
} }
} }
thread::sleep(Duration::from_millis(2));
} }
Ok(()) Ok(())
@ -128,7 +143,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.with_cert_resolver(Arc::new(resolver)), .with_cert_resolver(Arc::new(resolver)),
); );
let threadpool = ThreadPool::new(10); let threadpool = ThreadPool::new(100);
let listener = TcpListener::bind(local_host).unwrap(); let listener = TcpListener::bind(local_host).unwrap();
@ -139,20 +154,27 @@ fn main() -> Result<(), Box<dyn Error>> {
let config = config.clone(); let config = config.clone();
let sni_domain = sni_domain.clone(); let sni_domain = sni_domain.clone();
let host = host.clone(); let host = host.clone();
let threadpool = threadpool.clone();
move || { move || {
let Ok(remote_stream) = TcpStream::connect(host) else { let Ok(remote_stream) = TcpStream::connect(host) else {
return; return;
}; };
let addr = stream.peer_addr().unwrap();
println!("connected {}", addr);
let stream = Arc::new(stream); let stream = Arc::new(stream);
let remote_stream = Arc::new(remote_stream); let remote_stream = Arc::new(remote_stream);
if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream) { if let Err(e) = accept_client(stream.clone(), config, sni_domain, remote_stream, threadpool) {
println!("error connection: {e:?}"); println!("error connection: {e:?}");
} }
let _ = (&*stream).shutdown(Shutdown::Both); let _ = (&*stream).shutdown(Shutdown::Both);
println!("disconnected {}", addr);
} }
}); });
} }