use std::{ error::Error, io::{BufRead, BufReader, Read, Write}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream}, str::FromStr, sync::Arc }; use ignore_result::Ignore; use log::{debug, info}; use rustls::{ServerConnection, StreamOwned}; use threadpool::ThreadPool; use super::{ tls::create_server_config, config::{ Config, SiteConfig, IpForwarding } }; pub struct FlowgateServer { config: Arc, } struct Connection { stream: BufReader, config: SiteConfig, keep_alive: bool, host: String, } impl FlowgateServer { pub fn new(config: Arc) -> Self { FlowgateServer { config } } pub fn run(self) { self.start().join(); } pub fn start(self) -> ThreadPool { let local_self = Arc::new(self); let threadpool = ThreadPool::new(local_self.config.threadpool_size); let mut handles = Vec::new(); handles.push(local_self.clone().start_http(&threadpool)); handles.push(local_self.clone().start_https(&threadpool)); threadpool } pub fn start_http(self: Arc, threadpool: &ThreadPool) { threadpool.execute({ let local_self = self.clone(); let threadpool = threadpool.clone(); move || { local_self.run_http(&threadpool).ignore(); } }) } pub fn start_https(self: Arc, threadpool: &ThreadPool) { threadpool.execute({ let local_self = self.clone(); let threadpool = threadpool.clone(); move || { local_self.run_https(&threadpool).ignore(); } }) } pub fn run_http(self: Arc, threadpool: &ThreadPool) -> Result<(), Box> { if let Some(host) = self.config.http_host.clone() { let listener = TcpListener::bind(&host)?; info!("HTTP server runned on {}", &host); for stream in listener.incoming() { if let Ok(mut stream) = stream { let local_self = self.clone(); let Ok(addr) = stream.peer_addr() else { continue }; let Ok(_) = stream.set_write_timeout(Some(local_self.config.connection_timeout)) else { continue }; let Ok(_) = stream.set_read_timeout(Some(local_self.config.connection_timeout)) else { continue }; threadpool.execute(move || { debug!("{} open connection", addr); local_self.accept_stream( &mut stream, addr, false ); debug!("{} close connection", addr); }); } } } Ok(()) } pub fn run_https(self: Arc, threadpool: &ThreadPool) -> Result<(), Box> { if let Some(host) = self.config.https_host.clone() { let listener = TcpListener::bind(&host)?; info!("HTTPS server runned on {}", &host); let config = Arc::new(create_server_config(self.config.clone())); for stream in listener.incoming() { if let Ok(stream) = stream { let local_self = self.clone(); let config = config.clone(); let Ok(addr) = stream.peer_addr() else { continue }; let Ok(_) = stream.set_write_timeout(Some(local_self.config.connection_timeout)) else { continue }; let Ok(_) = stream.set_read_timeout(Some(local_self.config.connection_timeout)) else { continue }; threadpool.execute(move || { let Ok(connection) = ServerConnection::new(config) else { return }; debug!("{} open connection", addr); let mut stream = StreamOwned::new(connection, stream); while stream.conn.is_handshaking() { let Ok(_) = stream.conn.complete_io(&mut stream.sock) else { debug!("{} close connection", addr);return }; } local_self.accept_stream( &mut stream, addr, true ); debug!("{} close connection", addr); }); } } } Ok(()) } fn accept_stream( self: Arc, stream: &mut (impl Read + Write + Shutdown), addr: SocketAddr, https: bool ) -> Option<()> { let mut conn = self.clone().read_request(stream, addr, https, None)?; if conn.keep_alive && conn.config.enable_keep_alive { loop { if !conn.config.support_keep_alive { conn.stream.shutdown(); conn.stream = BufReader::new(conn.config.connect()?); } conn = self.clone().read_request(stream, addr, https, Some(conn))?; } } conn.stream.shutdown(); stream.shutdown(); Some(()) } fn read_request( self: Arc, stream: &mut (impl Read + Write + Shutdown), addr: SocketAddr, https: bool, conn: Option ) -> Option { let mut addr = addr; let mut stream = BufReader::new(stream); match &self.config.incoming_ip_forwarding { IpForwarding::Simple => { let mut header = Vec::new(); stream.read_until(b'\n', &mut header).ok()?; header.truncate(header.len()-1); addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?; }, IpForwarding::Modern => { let mut header = [0]; stream.read(&mut header).ok()?; addr = match header[0] { 0x01 => { let mut octets = [0; 4]; stream.read(&mut octets).ok()?; let mut port = [0; 2]; stream.read(&mut port).ok()?; let port = u16::from_be_bytes(port); SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port)) }, 0x02 => { let mut octets = [0; 16]; stream.read(&mut octets).ok()?; let mut port = [0; 2]; stream.read(&mut port).ok()?; let port = u16::from_be_bytes(port); SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0)) }, _ => { return None }, }; }, _ => {} } let mut raw_status = read_until(&mut stream, b"\r\n")?; let mut request = Vec::new(); request.append(&mut raw_status.clone()); raw_status.truncate(raw_status.len() - 2); let status = String::from_utf8(raw_status.clone()).ok()?; let status = status.split(" ").collect::>(); debug!("{} {} read status", addr, status[1]); let mut content_length = 0; let mut host = None; let mut is_chunked = false; let mut keep_alive = false; let mut headers = Vec::new(); loop { let mut header = read_until(&mut stream, b"\r\n")?; header.truncate(header.len() - 2); if header.is_empty() { break; } let header = String::from_utf8(header).ok()?; let (key, value) = header.split_once(": ")?; headers.push((key.to_string(), value.to_string())); match key.to_lowercase().as_str() { "transfer-encoding" => { if value.contains("chunked") { is_chunked = true; } }, "host" => { host = Some(value.to_string()); }, "connection" => { keep_alive = value.to_lowercase().contains("keep-alive"); }, "content-length" => { content_length = value.parse::().ok()?; }, _ => { if let IpForwarding::Header(header) = &self.config.incoming_ip_forwarding { if key.to_lowercase() == header.to_lowercase() { addr = SocketAddr::from_str(value).ok()?; } } }, } } debug!("{} {} read headers", addr, status[1]); let mut conn: Connection = if conn.is_none() { let host = host?; let site = self.config.get_site(&host)?.clone(); Connection { stream: BufReader::new(site.connect()?), config: site, keep_alive, host } } else { conn? }; debug!("{} {} got connection", addr, status[1]); match &conn.config.ip_forwarding { IpForwarding::Simple => { request.append(&mut addr.to_string().as_bytes().to_vec()); request.push(b'\n'); }, IpForwarding::Modern => { match addr.ip() { IpAddr::V4(ip) => { request.push(0x01); request.append(&mut ip.octets().to_vec()); }, IpAddr::V6(ip) => { request.push(0x02); request.append(&mut ip.octets().to_vec()); } } request.append(&mut addr.port().to_be_bytes().to_vec()); }, _ => {} } for (key, value) in headers { let mut value = value.to_string(); match key.to_lowercase().as_str() { "host" => { if let Some(replace_host) = conn.config.replace_host.clone() { value = replace_host; } }, _ => {} } if let IpForwarding::Header(header) = &conn.config.ip_forwarding { if key.to_lowercase() == header.to_lowercase() { continue; } } request.append(&mut key.as_bytes().to_vec()); request.append(&mut b": ".to_vec()); request.append(&mut value.as_bytes().to_vec()); request.append(&mut b"\r\n".to_vec()); } if let IpForwarding::Header(header) = &conn.config.ip_forwarding { request.append(&mut header.as_bytes().to_vec()); request.append(&mut b": ".to_vec()); request.append(&mut addr.to_string().as_bytes().to_vec()); request.append(&mut b"\r\n".to_vec()); } request.append(&mut b"\r\n".to_vec()); debug!("{:?}", String::from_utf8_lossy(&request)); conn.stream.get_mut().write_all(&request).ok()?; debug!("{} {} sent request to server", addr, status[1]); if content_length > 0 { let buffer = stream.buffer().to_vec(); conn.stream.get_mut().write_all(&buffer).ok()?; stream.consume(buffer.len()); let mut read = buffer.len(); debug!("{} {} send part of body to server", addr, status[1]); while read < content_length { let mut buf = vec![0; 4096]; let size = conn.stream.get_mut().read(&mut buf).ok()?; if size == 0 { break } buf.truncate(size); read += size; debug!("{} {} send response body part {} to clientr", addr, status[1], size); stream.get_mut().write_all(&buf).ok()?; } } else if is_chunked { transfer_chunked(&mut stream, conn.stream.get_mut())?; } else { let buffer = stream.buffer().to_vec(); conn.stream.get_mut().write_all(&buffer).ok()?; stream.consume(buffer.len()); } debug!("{} {} send body to server", addr, status[1]); if conn.config.support_keep_alive { let mut response = Vec::new(); let raw_status = read_until(&mut conn.stream, b"\r\n")?; response.append(&mut raw_status.clone()); let mut content_length = 0; let mut is_chunked = false; loop { let mut header = read_until(&mut conn.stream, b"\r\n")?; response.append(&mut header.clone()); if header.len() == 2 { break; } header.truncate(header.len() - 2); let header = String::from_utf8(header).ok()?; let (key, value) = header.split_once(": ")?; match key.to_lowercase().as_str() { "transfer-encoding" => { if value.contains("chunked") { is_chunked = true; } }, "content-length" => { content_length = value.parse::().ok()?; }, _ => {} } } stream.get_mut().write_all(&response).ok()?; debug!("{} {} send response header to clientr", addr, status[1]); if content_length > 0 { let buffer = conn.stream.buffer().to_vec(); stream.get_mut().write_all(&buffer).ok()?; conn.stream.consume(buffer.len()); debug!("{} {} send response body part {} to clientr", addr, status[1], buffer.len()); let mut read = buffer.len(); while read < content_length { let mut buf = vec![0; 4096]; let size = conn.stream.get_mut().read(&mut buf).ok()?; if size == 0 { break } buf.truncate(size); read += size; debug!("{} {} send response body part {} to clientr", addr, status[1], size); stream.get_mut().write_all(&buf).ok()?; } } else if is_chunked { transfer_chunked(&mut conn.stream, stream.get_mut())?; } else { let buffer = conn.stream.buffer().to_vec(); stream.get_mut().write_all(&buffer).ok()?; conn.stream.consume(buffer.len()); } debug!("{} {} send response body to clientr", addr, status[1]); } else { let buffer = conn.stream.buffer(); stream.get_mut().write_all(buffer).ok()?; conn.stream.consume(buffer.len()); let mut buf = vec![0;4096]; while let Ok(n) = conn.stream.get_mut().read(&mut buf) { if n == 0 { break } buf.truncate(n); stream.get_mut().write_all(&buf).ok()?; buf = vec![0;4096]; } } info!("{addr} > {} {}://{}{}", status[0], if https { "https" } else { "http" }, conn.host, status[1]); Some(conn) } } fn read_until(stream: &mut impl BufRead, delimiter: &[u8]) -> Option> { let mut data = Vec::new(); let last_byte = *delimiter.last()?; loop { let mut buf = Vec::new(); let buf_len = stream.read_until(last_byte, &mut buf).ok()?; debug!("read buf len until {} {:?}", buf_len, String::from_utf8_lossy(&buf)); if buf_len == 0 { return None } data.append(&mut buf); if data.ends_with(delimiter) { break; } } Some(data) } fn transfer_chunked(src: &mut impl BufRead, dest: &mut impl Write) -> Option<()> { loop { let mut length = read_until(src, b"\r\n")?; dest.write_all(&length).ok()?; length.truncate(length.len()-2); let length = String::from_utf8(length).ok()?; let length = usize::from_str_radix(length.as_str(), 16).ok()?; let mut data = vec![0u8; length+2]; src.read_exact(&mut data).ok()?; dest.write_all(&data).ok()?; if length == 0 { break; } } Some(()) } pub trait Shutdown { fn shutdown(&self); } impl Shutdown for TcpStream { fn shutdown(&self) { TcpStream::shutdown(self, std::net::Shutdown::Both).ignore(); } } impl Shutdown for BufReader { fn shutdown(&self) { self.get_ref().shutdown(); } } impl Shutdown for StreamOwned { fn shutdown(&self) { self.sock.shutdown(); } }