diff --git a/src/flowgate/config.rs b/src/flowgate/config.rs index a22cce4..c00b559 100755 --- a/src/flowgate/config.rs +++ b/src/flowgate/config.rs @@ -55,8 +55,7 @@ pub struct Config { pub https_host: String, pub threadpool_size: usize, pub connection_timeout: Duration, - pub incoming_ip_forwarding: IpForwarding, - pub websocket_host: Option + pub incoming_ip_forwarding: IpForwarding } impl Config { @@ -75,7 +74,6 @@ impl Config { .map(|o| o.as_str()).flatten() .map(|o| IpForwarding::from_name(o)).flatten() .unwrap_or(IpForwarding::None); - let websocket_host = doc.get("websocket_host").map(|o| o.as_str()).flatten().map(|o| o.to_string()); let mut sites: Vec = Vec::new(); @@ -121,8 +119,7 @@ impl Config { https_host, threadpool_size, connection_timeout, - incoming_ip_forwarding, - websocket_host + incoming_ip_forwarding }.clone()) } diff --git a/src/flowgate/server.rs b/src/flowgate/server.rs index e2c4f9d..579d1e3 100755 --- a/src/flowgate/server.rs +++ b/src/flowgate/server.rs @@ -5,8 +5,6 @@ use std::{ sync::Arc }; -use tokio::sync::RwLock; - use ignore_result::Ignore; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -26,7 +24,7 @@ use super::config::{ }; pub struct FlowgateServer { - config: Arc>, + config: Arc, } struct Connection { @@ -37,50 +35,45 @@ struct Connection { } impl FlowgateServer { - pub fn new(config: Arc>) -> Self { + pub fn new(config: Arc) -> Self { FlowgateServer { config } } - pub async fn start(&self) { - tokio::spawn({ - let config = self.config.clone(); + pub async fn start(self) { + let local_self = Arc::new(self); - async move { - Self::run_http(config).await.ignore(); - } + tokio::spawn({ + let local_self = local_self.clone(); + async move { local_self.run_http().await.ignore(); } }); tokio::spawn({ - let config = self.config.clone(); - - async move { - Self::run_https(config).await.ignore(); - } + let local_self = local_self.clone(); + async move { local_self.run_https().await.ignore(); } }); } pub async fn run_http( - config: Arc> + self: Arc ) -> Result<(), Box> { - let listener = TcpListener::bind(&config.read().await.http_host).await?; + let listener = TcpListener::bind(&self.config.http_host).await?; - info!("HTTP server runned on {}", &config.read().await.http_host); + info!("HTTP server runned on {}", &self.config.http_host); loop { let Ok((stream, addr)) = listener.accept().await else { break }; - let config = config.clone(); + let local_self = self.clone(); tokio::spawn(async move { let mut stream = TimeoutStream::new(stream); - stream.set_write_timeout(Some(config.read().await.connection_timeout)); - stream.set_read_timeout(Some(config.read().await.connection_timeout)); + stream.set_write_timeout(Some(local_self.config.connection_timeout)); + stream.set_read_timeout(Some(local_self.config.connection_timeout)); let mut stream = Box::pin(stream); - Self::accept_stream( - config, + local_self.accept_stream( &mut stream, addr, false @@ -92,29 +85,28 @@ impl FlowgateServer { } pub async fn run_https( - config: Arc> + self: Arc ) -> Result<(), Box> { - let listener = TcpListener::bind(&config.read().await.https_host).await?; - let acceptor = TlsAcceptor::from(Arc::new(create_server_config(config.clone()).await)); + let listener = TcpListener::bind(&self.config.https_host).await?; + let acceptor = TlsAcceptor::from(Arc::new(create_server_config(self.config.clone()).await)); - info!("HTTPS server runned on {}", &config.read().await.https_host); + info!("HTTPS server runned on {}", &self.config.https_host); loop { let Ok((stream, addr)) = listener.accept().await else { break }; - let config = config.clone(); + let local_self = self.clone(); let acceptor = acceptor.clone(); tokio::spawn(async move { let mut stream = TimeoutStream::new(stream); - stream.set_write_timeout(Some(config.read().await.connection_timeout)); - stream.set_read_timeout(Some(config.read().await.connection_timeout)); + stream.set_write_timeout(Some(local_self.config.connection_timeout)); + stream.set_read_timeout(Some(local_self.config.connection_timeout)); let Ok(mut stream) = acceptor.accept(Box::pin(stream)).await else { return }; - Self::accept_stream( - config, + local_self.accept_stream( &mut stream, addr, true @@ -125,13 +117,13 @@ impl FlowgateServer { Ok(()) } - pub async fn accept_stream( - config: Arc>, + async fn accept_stream( + self: Arc, stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin), addr: SocketAddr, https: bool ) -> Option<()> { - let mut conn = read_request(config.clone(), stream, addr, https, None).await?; + let mut conn = self.clone().read_request(stream, addr, https, None).await?; if conn.keep_alive && conn.config.enable_keep_alive { loop { @@ -139,7 +131,7 @@ impl FlowgateServer { conn.stream.shutdown().await.ignore(); conn.stream = conn.config.connect().await?; } - conn = read_request(config.clone(), stream, addr, https, Some(conn)).await?; + conn = self.clone().read_request(stream, addr, https, Some(conn)).await?; } } @@ -148,267 +140,66 @@ impl FlowgateServer { Some(()) } -} -async fn read_request( - config: Arc>, - stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin), - addr: SocketAddr, - https: bool, - conn: Option -) -> Option { - let mut addr = addr; + async fn read_request( + self: Arc, + stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin), + addr: SocketAddr, + https: bool, + conn: Option + ) -> Option { + let mut addr = addr; - match &config.read().await.incoming_ip_forwarding { - IpForwarding::Simple => { - let mut header = Vec::new(); + match &self.config.incoming_ip_forwarding { + IpForwarding::Simple => { + let mut header = Vec::new(); - { - let mut buf = [0; 1]; + { + let mut buf = [0; 1]; - while let Ok(1) = stream.read(&mut buf).await { - let byte = buf[0]; - if byte == b'\n' { break } - header.push(byte); + while let Ok(1) = stream.read(&mut buf).await { + let byte = buf[0]; + if byte == b'\n' { break } + header.push(byte); + } } - } - addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?; - }, - IpForwarding::Modern => { - let mut ipver = [0; 1]; - stream.read(&mut ipver).await.ok()?; - addr = match ipver[0] { - 0x01 => { - let mut octets = [0; 4]; - stream.read(&mut octets).await.ok()?; - let mut port = [0; 2]; - stream.read(&mut port).await.ok()?; - let port = u16::from_be_bytes(port); - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port)) - }, 0x02 => { - let mut octets = [0; 16]; - stream.read(&mut octets).await.ok()?; - let mut port = [0; 2]; - stream.read(&mut port).await.ok()?; - let port = u16::from_be_bytes(port); - SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0)) - }, _ => { return None }, - }; - }, - _ => { } - } - - let mut head = Vec::new(); - - { - let mut buf = [0; 1]; - let mut counter = 0; - - while let Ok(1) = stream.read(&mut buf).await { - let byte = buf[0]; - head.push(byte); - - counter = match (counter, byte) { - (0, b'\r') => 1, - (1, b'\n') => 2, - (2, b'\r') => 3, - (3, b'\n') => break, - _ => 0, - }; + addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?; + }, + IpForwarding::Modern => { + let mut ipver = [0; 1]; + stream.read(&mut ipver).await.ok()?; + addr = match ipver[0] { + 0x01 => { + let mut octets = [0; 4]; + stream.read(&mut octets).await.ok()?; + let mut port = [0; 2]; + stream.read(&mut port).await.ok()?; + let port = u16::from_be_bytes(port); + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port)) + }, 0x02 => { + let mut octets = [0; 16]; + stream.read(&mut octets).await.ok()?; + let mut port = [0; 2]; + stream.read(&mut port).await.ok()?; + let port = u16::from_be_bytes(port); + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0)) + }, _ => { return None }, + }; + }, + _ => { } } - head.truncate(head.len() - 4); - } - - if head.is_empty() { return None; } - - let head_str = String::from_utf8(head.clone()).ok()?; - let head_str = head_str.trim_matches(char::from(0)).to_string(); - - let mut head_lines = head_str.split("\r\n"); - - let status = head_lines.next()?; - let status_seq: Vec<&str> = status.split(" ").collect(); - - let headers: Vec<(&str, &str)> = head_lines - .filter(|l| l.contains(": ")) - .map(|l| l.split_once(": ").unwrap()) - .collect(); - - let is_chunked = headers.iter() - .find(|o| o.0.to_lowercase() == "transfer-encoding") - .map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::>()) - .map(|o| o.contains(&"chunked".to_string())) - .unwrap_or(false); - - if let IpForwarding::Header(header) = &config.read().await.incoming_ip_forwarding { - if let Some(ip) = headers.iter().find(|o| o.0 == header).map(|o| o.1) { - addr = SocketAddr::from_str(ip).ok()?; - } - } - - let mut conn: Connection = if conn.is_none() { - let mut host = String::new(); - let mut keep_alive = false; - - for (key, value) in &headers { - match key.to_lowercase().as_str() { - "host" => host = value.to_string(), - "connection" => keep_alive = *value == "keep-alive", - _ => {} - } - } - - let site = config.read().await.get_site(&host)?.clone(); - - Connection { - stream: site.connect().await?, - config: site, - keep_alive, - host - } - } else { - conn? - }; - - let content_length = headers - .iter() - .filter(|(k, _)| k.to_lowercase() == "content-length") - .next() - .map(|o| o.1.parse().ok()) - .flatten() - .unwrap_or(0usize); - - let mut reqbuf: Vec = Vec::new(); - - if let Some(replace_host) = conn.config.replace_host.clone() { - let mut new_head = Vec::new(); - let mut is_status = true; - - for line in head_str.split("\r\n") { - if is_status { - new_head.append(&mut line.as_bytes().to_vec()); - is_status = false; - } else { - new_head.append(&mut b"\r\n".to_vec()); - let (key, _) = line.split_once(": ")?; - if key.to_lowercase() == "host" { - new_head.append(&mut key.as_bytes().to_vec()); - new_head.append(&mut b": ".to_vec()); - new_head.append(&mut replace_host.as_bytes().to_vec()); - } else { - new_head.append(&mut line.as_bytes().to_vec()); - } - } - } - - head = new_head; - } - - match &conn.config.ip_forwarding { - IpForwarding::Header(header) => { - reqbuf.append(&mut status.to_string().as_bytes().to_vec()); - reqbuf.append(&mut b"\r\n".to_vec()); - for (key, value) in String::from_utf8(head.clone()).ok()? - .split("\r\n") - .skip(1) - .filter_map(|o| o.split_once(": ")) { - if *key.to_lowercase() == header.to_lowercase() { continue } - reqbuf.append(&mut key.to_string().as_bytes().to_vec()); - reqbuf.append(&mut b": ".to_vec()); - reqbuf.append(&mut value.to_string().as_bytes().to_vec()); - reqbuf.append(&mut b"\r\n".to_vec()); - } - reqbuf.append(&mut header.as_bytes().to_vec()); - reqbuf.append(&mut b": ".to_vec()); - reqbuf.append(&mut addr.to_string().as_bytes().to_vec()); - reqbuf.append(&mut b"\r\n\r\n".to_vec()); - }, - IpForwarding::Simple => { - reqbuf.append(&mut addr.to_string().as_bytes().to_vec()); - reqbuf.push(b'\n'); - reqbuf.append(&mut head.clone()); - reqbuf.append(&mut b"\r\n\r\n".to_vec()); - }, - IpForwarding::Modern => { - reqbuf.push(if addr.is_ipv4() { 0x01 } else { 0x02 }); - match addr.ip() { - IpAddr::V4(ip) => { - reqbuf.append(&mut ip.octets().to_vec()); - }, IpAddr::V6(ip) => { - reqbuf.append(&mut ip.octets().to_vec()); - } - } - reqbuf.append(&mut addr.port().to_be_bytes().to_vec()); - reqbuf.append(&mut head.clone()); - reqbuf.append(&mut b"\r\n\r\n".to_vec()); - }, - IpForwarding::None => { - reqbuf.append(&mut head.clone()); - reqbuf.append(&mut b"\r\n\r\n".to_vec()); - } - } - - conn.stream.write_all(&reqbuf).await.ok()?; - - if content_length > 0 { - let mut read = 0usize; - let mut buf = vec![0; 4096]; - while let Ok(size) = stream.read(&mut buf).await { - if size == 0 { break } - read += size; - buf.truncate(size); - conn.stream.write_all(&buf).await.ok()?; - buf = vec![0; 4096]; - if read >= content_length { break } - } - } else if is_chunked { - loop { - let mut length = Vec::new(); - { - let mut buf = [0; 1]; - let mut counter = 0; - - while let Ok(1) = stream.read(&mut buf).await { - let byte = buf[0]; - length.push(byte); - - counter = match (counter, byte) { - (0, b'\r') => 1, - (1, b'\n') => break, - _ => 0, - }; - } - conn.stream.write_all(&length).await.ok()?; - - length.truncate(length.len() - 2); - } - let length = String::from_utf8(length).ok()?; - let length = usize::from_str_radix(length.as_str(), 16).ok()?; - let mut data = vec![0u8; length+2]; - stream.read_exact(&mut data).await.ok()?; - - conn.stream.write_all(&data).await.ok()?; - if length == 0 { - break; - } - } - } - - if conn.config.support_keep_alive { let mut head = Vec::new(); { let mut buf = [0; 1]; let mut counter = 0; - while let Ok(1) = conn.stream.read(&mut buf).await { + while let Ok(1) = stream.read(&mut buf).await { let byte = buf[0]; head.push(byte); - stream.write_all(&buf).await.ok()?; - counter = match (counter, byte) { (0, b'\r') => 1, (1, b'\n') => 2, @@ -424,37 +215,144 @@ async fn read_request( if head.is_empty() { return None; } let head_str = String::from_utf8(head.clone()).ok()?; - let head_str = head_str.trim_matches(char::from(0)); + let head_str = head_str.trim_matches(char::from(0)).to_string(); - let headers = head_str.split("\r\n") - .skip(1) + let mut head_lines = head_str.split("\r\n"); + + let status = head_lines.next()?; + let status_seq: Vec<&str> = status.split(" ").collect(); + + let headers: Vec<(&str, &str)> = head_lines .filter(|l| l.contains(": ")) .map(|l| l.split_once(": ").unwrap()) - .map(|(k,v)| (k.to_lowercase(),v.to_string())) - .collect::>(); + .collect(); - let content_length = headers.iter() - .find(|(k, _)| k == "content-length") - .map(|o| o.1.parse().ok()) - .flatten() - .unwrap_or(0usize); - let is_chunked = headers.iter() .find(|o| o.0.to_lowercase() == "transfer-encoding") .map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::>()) .map(|o| o.contains(&"chunked".to_string())) .unwrap_or(false); + if let IpForwarding::Header(header) = &self.config.incoming_ip_forwarding { + if let Some(ip) = headers.iter().find(|o| o.0 == header).map(|o| o.1) { + addr = SocketAddr::from_str(ip).ok()?; + } + } + + let mut conn: Connection = if conn.is_none() { + let mut host = String::new(); + let mut keep_alive = false; + + for (key, value) in &headers { + match key.to_lowercase().as_str() { + "host" => host = value.to_string(), + "connection" => keep_alive = *value == "keep-alive", + _ => {} + } + } + + let site = self.config.get_site(&host)?.clone(); + + Connection { + stream: site.connect().await?, + config: site, + keep_alive, + host + } + } else { + conn? + }; + + let content_length = headers + .iter() + .filter(|(k, _)| k.to_lowercase() == "content-length") + .next() + .map(|o| o.1.parse().ok()) + .flatten() + .unwrap_or(0usize); + + let mut reqbuf: Vec = Vec::new(); + + if let Some(replace_host) = conn.config.replace_host.clone() { + let mut new_head = Vec::new(); + let mut is_status = true; + + for line in head_str.split("\r\n") { + if is_status { + new_head.append(&mut line.as_bytes().to_vec()); + is_status = false; + } else { + new_head.append(&mut b"\r\n".to_vec()); + let (key, _) = line.split_once(": ")?; + if key.to_lowercase() == "host" { + new_head.append(&mut key.as_bytes().to_vec()); + new_head.append(&mut b": ".to_vec()); + new_head.append(&mut replace_host.as_bytes().to_vec()); + } else { + new_head.append(&mut line.as_bytes().to_vec()); + } + } + } + + head = new_head; + } + + match &conn.config.ip_forwarding { + IpForwarding::Header(header) => { + reqbuf.append(&mut status.to_string().as_bytes().to_vec()); + reqbuf.append(&mut b"\r\n".to_vec()); + for (key, value) in String::from_utf8(head.clone()).ok()? + .split("\r\n") + .skip(1) + .filter_map(|o| o.split_once(": ")) { + if *key.to_lowercase() == header.to_lowercase() { continue } + reqbuf.append(&mut key.to_string().as_bytes().to_vec()); + reqbuf.append(&mut b": ".to_vec()); + reqbuf.append(&mut value.to_string().as_bytes().to_vec()); + reqbuf.append(&mut b"\r\n".to_vec()); + } + reqbuf.append(&mut header.as_bytes().to_vec()); + reqbuf.append(&mut b": ".to_vec()); + reqbuf.append(&mut addr.to_string().as_bytes().to_vec()); + reqbuf.append(&mut b"\r\n\r\n".to_vec()); + }, + IpForwarding::Simple => { + reqbuf.append(&mut addr.to_string().as_bytes().to_vec()); + reqbuf.push(b'\n'); + reqbuf.append(&mut head.clone()); + reqbuf.append(&mut b"\r\n\r\n".to_vec()); + }, + IpForwarding::Modern => { + reqbuf.push(if addr.is_ipv4() { 0x01 } else { 0x02 }); + match addr.ip() { + IpAddr::V4(ip) => { + reqbuf.append(&mut ip.octets().to_vec()); + }, IpAddr::V6(ip) => { + reqbuf.append(&mut ip.octets().to_vec()); + } + } + reqbuf.append(&mut addr.port().to_be_bytes().to_vec()); + reqbuf.append(&mut head.clone()); + reqbuf.append(&mut b"\r\n\r\n".to_vec()); + }, + IpForwarding::None => { + reqbuf.append(&mut head.clone()); + reqbuf.append(&mut b"\r\n\r\n".to_vec()); + } + } + + conn.stream.write_all(&reqbuf).await.ok()?; + if content_length > 0 { let mut read = 0usize; let mut buf = vec![0; 4096]; - while let Ok(size) = conn.stream.read(&mut buf).await { + while let Ok(size) = stream.read(&mut buf).await { if size == 0 { break } read += size; buf.truncate(size); - stream.write_all(&buf).await.ok()?; + conn.stream.write_all(&buf).await.ok()?; buf = vec![0; 4096]; - if read == content_length { break } + if read >= content_length { break } } } else if is_chunked { loop { @@ -463,7 +361,7 @@ async fn read_request( let mut buf = [0; 1]; let mut counter = 0; - while let Ok(1) = conn.stream.read(&mut buf).await { + while let Ok(1) = stream.read(&mut buf).await { let byte = buf[0]; length.push(byte); @@ -473,32 +371,126 @@ async fn read_request( _ => 0, }; } - stream.write_all(&length).await.ok()?; + conn.stream.write_all(&length).await.ok()?; length.truncate(length.len() - 2); } let length = String::from_utf8(length).ok()?; let length = usize::from_str_radix(length.as_str(), 16).ok()?; let mut data = vec![0u8; length+2]; - conn.stream.read_exact(&mut data).await.ok()?; + stream.read_exact(&mut data).await.ok()?; - stream.write_all(&data).await.ok()?; + conn.stream.write_all(&data).await.ok()?; if length == 0 { break; } } } - } else { - let mut buf = vec![0;1024]; - while let Ok(n) = conn.stream.read(&mut buf).await { - if n == 0 { break } - buf.truncate(n); - stream.write_all(&buf).await.ok()?; - buf = vec![0;1024]; + + if conn.config.support_keep_alive { + let mut head = Vec::new(); + + { + let mut buf = [0; 1]; + let mut counter = 0; + + while let Ok(1) = conn.stream.read(&mut buf).await { + let byte = buf[0]; + head.push(byte); + + stream.write_all(&buf).await.ok()?; + + counter = match (counter, byte) { + (0, b'\r') => 1, + (1, b'\n') => 2, + (2, b'\r') => 3, + (3, b'\n') => break, + _ => 0, + }; + } + + head.truncate(head.len() - 4); + } + + if head.is_empty() { return None; } + + let head_str = String::from_utf8(head.clone()).ok()?; + let head_str = head_str.trim_matches(char::from(0)); + + let headers = head_str.split("\r\n") + .skip(1) + .filter(|l| l.contains(": ")) + .map(|l| l.split_once(": ").unwrap()) + .map(|(k,v)| (k.to_lowercase(),v.to_string())) + .collect::>(); + + let content_length = headers.iter() + .find(|(k, _)| k == "content-length") + .map(|o| o.1.parse().ok()) + .flatten() + .unwrap_or(0usize); + + let is_chunked = headers.iter() + .find(|o| o.0.to_lowercase() == "transfer-encoding") + .map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::>()) + .map(|o| o.contains(&"chunked".to_string())) + .unwrap_or(false); + + if content_length > 0 { + let mut read = 0usize; + let mut buf = vec![0; 4096]; + while let Ok(size) = conn.stream.read(&mut buf).await { + if size == 0 { break } + read += size; + buf.truncate(size); + stream.write_all(&buf).await.ok()?; + buf = vec![0; 4096]; + if read == content_length { break } + } + } else if is_chunked { + loop { + let mut length = Vec::new(); + { + let mut buf = [0; 1]; + let mut counter = 0; + + while let Ok(1) = conn.stream.read(&mut buf).await { + let byte = buf[0]; + length.push(byte); + + counter = match (counter, byte) { + (0, b'\r') => 1, + (1, b'\n') => break, + _ => 0, + }; + } + stream.write_all(&length).await.ok()?; + + length.truncate(length.len() - 2); + } + let length = String::from_utf8(length).ok()?; + let length = usize::from_str_radix(length.as_str(), 16).ok()?; + let mut data = vec![0u8; length+2]; + conn.stream.read_exact(&mut data).await.ok()?; + + stream.write_all(&data).await.ok()?; + if length == 0 { + break; + } + } + } + } else { + let mut buf = vec![0;1024]; + while let Ok(n) = conn.stream.read(&mut buf).await { + if n == 0 { break } + buf.truncate(n); + stream.write_all(&buf).await.ok()?; + buf = vec![0;1024]; + } } + + info!("{addr} > {} {}://{}{}", status_seq[0], if https { "https" } else { "http" }, conn.host, status_seq[1]); + + Some(conn) } - - info!("{addr} > {} {}://{}{}", status_seq[0], if https { "https" } else { "http" }, conn.host, status_seq[1]); - - Some(conn) -} \ No newline at end of file +} diff --git a/src/flowgate/tls.rs b/src/flowgate/tls.rs index 7b51c7f..27d109e 100755 --- a/src/flowgate/tls.rs +++ b/src/flowgate/tls.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, thread}; +use std::sync::Arc; use rustls::{ crypto::aws_lc_rs::sign::any_supported_type, @@ -8,8 +8,6 @@ use rustls::{ ServerConfig }; -use tokio::{runtime::Handle, sync::RwLock}; - use super::config::Config; @@ -37,29 +35,19 @@ impl TlsCertificate { #[derive(Debug)] struct ResolvesServerCertWildcard { - config: Arc>, - handle: Handle + config: Arc } impl ResolvesServerCertWildcard { - pub async fn new(config: Arc>) -> Self { - Self { config, handle: Handle::current() } + pub async fn new(config: Arc) -> Self { + Self { config } } } impl ResolvesServerCert for ResolvesServerCertWildcard { fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { if let Some(cert) = client_hello.server_name() - .and_then(|name| { - thread::spawn({ - let handle = self.handle.clone(); - let config = self.config.clone(); - - move || { - handle.block_on(config.read()).clone() - } - }).join().unwrap().get_site(name).cloned() - }) + .and_then(|name| self.config.get_site(name).cloned()) .and_then(|site| site.ssl) { Some(Arc::new(cert.get_key())) } else { @@ -68,7 +56,7 @@ impl ResolvesServerCert for ResolvesServerCertWildcard { } } -pub async fn create_server_config(config: Arc>) -> ServerConfig { +pub async fn create_server_config(config: Arc) -> ServerConfig { ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new(ResolvesServerCertWildcard::new(config).await)) diff --git a/src/main.rs b/src/main.rs index 6581ee2..cce2935 100755 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ use std::{fs, path::Path, sync::Arc}; use flowgate::{config::Config, server::FlowgateServer}; use ignore_result::Ignore; -use tokio::sync::RwLock; #[tokio::main] async fn main() { @@ -12,7 +11,7 @@ async fn main() { fs::write("conf.yml", include_bytes!("../conf.yml")).ignore(); } - let config = Arc::new(RwLock::new(Config::parse("conf.yml").unwrap())); + let config = Arc::new(Config::parse("conf.yml").unwrap()); let server = FlowgateServer::new(config.clone()); server.start().await;