config remove rwlock

This commit is contained in:
MeexReay 2025-04-07 01:41:34 +03:00
parent 903bf23a4c
commit ef9f1b2847
4 changed files with 311 additions and 335 deletions

View File

@ -55,8 +55,7 @@ pub struct Config {
pub https_host: String, pub https_host: String,
pub threadpool_size: usize, pub threadpool_size: usize,
pub connection_timeout: Duration, pub connection_timeout: Duration,
pub incoming_ip_forwarding: IpForwarding, pub incoming_ip_forwarding: IpForwarding
pub websocket_host: Option<String>
} }
impl Config { impl Config {
@ -75,7 +74,6 @@ impl Config {
.map(|o| o.as_str()).flatten() .map(|o| o.as_str()).flatten()
.map(|o| IpForwarding::from_name(o)).flatten() .map(|o| IpForwarding::from_name(o)).flatten()
.unwrap_or(IpForwarding::None); .unwrap_or(IpForwarding::None);
let websocket_host = doc.get("websocket_host").map(|o| o.as_str()).flatten().map(|o| o.to_string());
let mut sites: Vec<SiteConfig> = Vec::new(); let mut sites: Vec<SiteConfig> = Vec::new();
@ -121,8 +119,7 @@ impl Config {
https_host, https_host,
threadpool_size, threadpool_size,
connection_timeout, connection_timeout,
incoming_ip_forwarding, incoming_ip_forwarding
websocket_host
}.clone()) }.clone())
} }

View File

@ -5,8 +5,6 @@ use std::{
sync::Arc sync::Arc
}; };
use tokio::sync::RwLock;
use ignore_result::Ignore; use ignore_result::Ignore;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
@ -26,7 +24,7 @@ use super::config::{
}; };
pub struct FlowgateServer { pub struct FlowgateServer {
config: Arc<RwLock<Config>>, config: Arc<Config>,
} }
struct Connection { struct Connection {
@ -37,50 +35,45 @@ struct Connection {
} }
impl FlowgateServer { impl FlowgateServer {
pub fn new(config: Arc<RwLock<Config>>) -> Self { pub fn new(config: Arc<Config>) -> Self {
FlowgateServer { config } FlowgateServer { config }
} }
pub async fn start(&self) { pub async fn start(self) {
tokio::spawn({ let local_self = Arc::new(self);
let config = self.config.clone();
async move { tokio::spawn({
Self::run_http(config).await.ignore(); let local_self = local_self.clone();
} async move { local_self.run_http().await.ignore(); }
}); });
tokio::spawn({ tokio::spawn({
let config = self.config.clone(); let local_self = local_self.clone();
async move { local_self.run_https().await.ignore(); }
async move {
Self::run_https(config).await.ignore();
}
}); });
} }
pub async fn run_http( pub async fn run_http(
config: Arc<RwLock<Config>> self: Arc<Self>
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(&config.read().await.http_host).await?; let listener = TcpListener::bind(&self.config.http_host).await?;
info!("HTTP server runned on {}", &config.read().await.http_host); info!("HTTP server runned on {}", &self.config.http_host);
loop { loop {
let Ok((stream, addr)) = listener.accept().await else { break }; let Ok((stream, addr)) = listener.accept().await else { break };
let config = config.clone(); let local_self = self.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut stream = TimeoutStream::new(stream); let mut stream = TimeoutStream::new(stream);
stream.set_write_timeout(Some(config.read().await.connection_timeout)); stream.set_write_timeout(Some(local_self.config.connection_timeout));
stream.set_read_timeout(Some(config.read().await.connection_timeout)); stream.set_read_timeout(Some(local_self.config.connection_timeout));
let mut stream = Box::pin(stream); let mut stream = Box::pin(stream);
Self::accept_stream( local_self.accept_stream(
config,
&mut stream, &mut stream,
addr, addr,
false false
@ -92,29 +85,28 @@ impl FlowgateServer {
} }
pub async fn run_https( pub async fn run_https(
config: Arc<RwLock<Config>> self: Arc<Self>
) -> Result<(), Box<dyn Error>> { ) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(&config.read().await.https_host).await?; let listener = TcpListener::bind(&self.config.https_host).await?;
let acceptor = TlsAcceptor::from(Arc::new(create_server_config(config.clone()).await)); let acceptor = TlsAcceptor::from(Arc::new(create_server_config(self.config.clone()).await));
info!("HTTPS server runned on {}", &config.read().await.https_host); info!("HTTPS server runned on {}", &self.config.https_host);
loop { loop {
let Ok((stream, addr)) = listener.accept().await else { break }; let Ok((stream, addr)) = listener.accept().await else { break };
let config = config.clone(); let local_self = self.clone();
let acceptor = acceptor.clone(); let acceptor = acceptor.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut stream = TimeoutStream::new(stream); let mut stream = TimeoutStream::new(stream);
stream.set_write_timeout(Some(config.read().await.connection_timeout)); stream.set_write_timeout(Some(local_self.config.connection_timeout));
stream.set_read_timeout(Some(config.read().await.connection_timeout)); stream.set_read_timeout(Some(local_self.config.connection_timeout));
let Ok(mut stream) = acceptor.accept(Box::pin(stream)).await else { return }; let Ok(mut stream) = acceptor.accept(Box::pin(stream)).await else { return };
Self::accept_stream( local_self.accept_stream(
config,
&mut stream, &mut stream,
addr, addr,
true true
@ -125,13 +117,13 @@ impl FlowgateServer {
Ok(()) Ok(())
} }
pub async fn accept_stream( async fn accept_stream(
config: Arc<RwLock<Config>>, self: Arc<Self>,
stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin), stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin),
addr: SocketAddr, addr: SocketAddr,
https: bool https: bool
) -> Option<()> { ) -> Option<()> {
let mut conn = read_request(config.clone(), stream, addr, https, None).await?; let mut conn = self.clone().read_request(stream, addr, https, None).await?;
if conn.keep_alive && conn.config.enable_keep_alive { if conn.keep_alive && conn.config.enable_keep_alive {
loop { loop {
@ -139,7 +131,7 @@ impl FlowgateServer {
conn.stream.shutdown().await.ignore(); conn.stream.shutdown().await.ignore();
conn.stream = conn.config.connect().await?; conn.stream = conn.config.connect().await?;
} }
conn = read_request(config.clone(), stream, addr, https, Some(conn)).await?; conn = self.clone().read_request(stream, addr, https, Some(conn)).await?;
} }
} }
@ -148,267 +140,66 @@ impl FlowgateServer {
Some(()) Some(())
} }
}
async fn read_request( async fn read_request(
config: Arc<RwLock<Config>>, self: Arc<Self>,
stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin), stream: &mut (impl AsyncReadExt + AsyncWriteExt + Unpin),
addr: SocketAddr, addr: SocketAddr,
https: bool, https: bool,
conn: Option<Connection> conn: Option<Connection>
) -> Option<Connection> { ) -> Option<Connection> {
let mut addr = addr; let mut addr = addr;
match &config.read().await.incoming_ip_forwarding { match &self.config.incoming_ip_forwarding {
IpForwarding::Simple => { IpForwarding::Simple => {
let mut header = Vec::new(); let mut header = Vec::new();
{ {
let mut buf = [0; 1]; let mut buf = [0; 1];
while let Ok(1) = stream.read(&mut buf).await { while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0]; let byte = buf[0];
if byte == b'\n' { break } if byte == b'\n' { break }
header.push(byte); header.push(byte);
}
} }
}
addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?; addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?;
}, },
IpForwarding::Modern => { IpForwarding::Modern => {
let mut ipver = [0; 1]; let mut ipver = [0; 1];
stream.read(&mut ipver).await.ok()?; stream.read(&mut ipver).await.ok()?;
addr = match ipver[0] { addr = match ipver[0] {
0x01 => { 0x01 => {
let mut octets = [0; 4]; let mut octets = [0; 4];
stream.read(&mut octets).await.ok()?; stream.read(&mut octets).await.ok()?;
let mut port = [0; 2]; let mut port = [0; 2];
stream.read(&mut port).await.ok()?; stream.read(&mut port).await.ok()?;
let port = u16::from_be_bytes(port); let port = u16::from_be_bytes(port);
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port)) SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
}, 0x02 => { }, 0x02 => {
let mut octets = [0; 16]; let mut octets = [0; 16];
stream.read(&mut octets).await.ok()?; stream.read(&mut octets).await.ok()?;
let mut port = [0; 2]; let mut port = [0; 2];
stream.read(&mut port).await.ok()?; stream.read(&mut port).await.ok()?;
let port = u16::from_be_bytes(port); let port = u16::from_be_bytes(port);
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0)) SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
}, _ => { return None }, }, _ => { return None },
}; };
}, },
_ => { } _ => { }
}
let mut head = Vec::new();
{
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0];
head.push(byte);
counter = match (counter, byte) {
(0, b'\r') => 1,
(1, b'\n') => 2,
(2, b'\r') => 3,
(3, b'\n') => break,
_ => 0,
};
} }
head.truncate(head.len() - 4);
}
if head.is_empty() { return None; }
let head_str = String::from_utf8(head.clone()).ok()?;
let head_str = head_str.trim_matches(char::from(0)).to_string();
let mut head_lines = head_str.split("\r\n");
let status = head_lines.next()?;
let status_seq: Vec<&str> = status.split(" ").collect();
let headers: Vec<(&str, &str)> = head_lines
.filter(|l| l.contains(": "))
.map(|l| l.split_once(": ").unwrap())
.collect();
let is_chunked = headers.iter()
.find(|o| o.0.to_lowercase() == "transfer-encoding")
.map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::<Vec<String>>())
.map(|o| o.contains(&"chunked".to_string()))
.unwrap_or(false);
if let IpForwarding::Header(header) = &config.read().await.incoming_ip_forwarding {
if let Some(ip) = headers.iter().find(|o| o.0 == header).map(|o| o.1) {
addr = SocketAddr::from_str(ip).ok()?;
}
}
let mut conn: Connection = if conn.is_none() {
let mut host = String::new();
let mut keep_alive = false;
for (key, value) in &headers {
match key.to_lowercase().as_str() {
"host" => host = value.to_string(),
"connection" => keep_alive = *value == "keep-alive",
_ => {}
}
}
let site = config.read().await.get_site(&host)?.clone();
Connection {
stream: site.connect().await?,
config: site,
keep_alive,
host
}
} else {
conn?
};
let content_length = headers
.iter()
.filter(|(k, _)| k.to_lowercase() == "content-length")
.next()
.map(|o| o.1.parse().ok())
.flatten()
.unwrap_or(0usize);
let mut reqbuf: Vec<u8> = Vec::new();
if let Some(replace_host) = conn.config.replace_host.clone() {
let mut new_head = Vec::new();
let mut is_status = true;
for line in head_str.split("\r\n") {
if is_status {
new_head.append(&mut line.as_bytes().to_vec());
is_status = false;
} else {
new_head.append(&mut b"\r\n".to_vec());
let (key, _) = line.split_once(": ")?;
if key.to_lowercase() == "host" {
new_head.append(&mut key.as_bytes().to_vec());
new_head.append(&mut b": ".to_vec());
new_head.append(&mut replace_host.as_bytes().to_vec());
} else {
new_head.append(&mut line.as_bytes().to_vec());
}
}
}
head = new_head;
}
match &conn.config.ip_forwarding {
IpForwarding::Header(header) => {
reqbuf.append(&mut status.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n".to_vec());
for (key, value) in String::from_utf8(head.clone()).ok()?
.split("\r\n")
.skip(1)
.filter_map(|o| o.split_once(": ")) {
if *key.to_lowercase() == header.to_lowercase() { continue }
reqbuf.append(&mut key.to_string().as_bytes().to_vec());
reqbuf.append(&mut b": ".to_vec());
reqbuf.append(&mut value.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n".to_vec());
}
reqbuf.append(&mut header.as_bytes().to_vec());
reqbuf.append(&mut b": ".to_vec());
reqbuf.append(&mut addr.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::Simple => {
reqbuf.append(&mut addr.to_string().as_bytes().to_vec());
reqbuf.push(b'\n');
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::Modern => {
reqbuf.push(if addr.is_ipv4() { 0x01 } else { 0x02 });
match addr.ip() {
IpAddr::V4(ip) => {
reqbuf.append(&mut ip.octets().to_vec());
}, IpAddr::V6(ip) => {
reqbuf.append(&mut ip.octets().to_vec());
}
}
reqbuf.append(&mut addr.port().to_be_bytes().to_vec());
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::None => {
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
}
}
conn.stream.write_all(&reqbuf).await.ok()?;
if content_length > 0 {
let mut read = 0usize;
let mut buf = vec![0; 4096];
while let Ok(size) = stream.read(&mut buf).await {
if size == 0 { break }
read += size;
buf.truncate(size);
conn.stream.write_all(&buf).await.ok()?;
buf = vec![0; 4096];
if read >= content_length { break }
}
} else if is_chunked {
loop {
let mut length = Vec::new();
{
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0];
length.push(byte);
counter = match (counter, byte) {
(0, b'\r') => 1,
(1, b'\n') => break,
_ => 0,
};
}
conn.stream.write_all(&length).await.ok()?;
length.truncate(length.len() - 2);
}
let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2];
stream.read_exact(&mut data).await.ok()?;
conn.stream.write_all(&data).await.ok()?;
if length == 0 {
break;
}
}
}
if conn.config.support_keep_alive {
let mut head = Vec::new(); let mut head = Vec::new();
{ {
let mut buf = [0; 1]; let mut buf = [0; 1];
let mut counter = 0; let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf).await { while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0]; let byte = buf[0];
head.push(byte); head.push(byte);
stream.write_all(&buf).await.ok()?;
counter = match (counter, byte) { counter = match (counter, byte) {
(0, b'\r') => 1, (0, b'\r') => 1,
(1, b'\n') => 2, (1, b'\n') => 2,
@ -424,37 +215,144 @@ async fn read_request(
if head.is_empty() { return None; } if head.is_empty() { return None; }
let head_str = String::from_utf8(head.clone()).ok()?; let head_str = String::from_utf8(head.clone()).ok()?;
let head_str = head_str.trim_matches(char::from(0)); let head_str = head_str.trim_matches(char::from(0)).to_string();
let headers = head_str.split("\r\n") let mut head_lines = head_str.split("\r\n");
.skip(1)
let status = head_lines.next()?;
let status_seq: Vec<&str> = status.split(" ").collect();
let headers: Vec<(&str, &str)> = head_lines
.filter(|l| l.contains(": ")) .filter(|l| l.contains(": "))
.map(|l| l.split_once(": ").unwrap()) .map(|l| l.split_once(": ").unwrap())
.map(|(k,v)| (k.to_lowercase(),v.to_string())) .collect();
.collect::<Vec<(String,String)>>();
let content_length = headers.iter()
.find(|(k, _)| k == "content-length")
.map(|o| o.1.parse().ok())
.flatten()
.unwrap_or(0usize);
let is_chunked = headers.iter() let is_chunked = headers.iter()
.find(|o| o.0.to_lowercase() == "transfer-encoding") .find(|o| o.0.to_lowercase() == "transfer-encoding")
.map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::<Vec<String>>()) .map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::<Vec<String>>())
.map(|o| o.contains(&"chunked".to_string())) .map(|o| o.contains(&"chunked".to_string()))
.unwrap_or(false); .unwrap_or(false);
if let IpForwarding::Header(header) = &self.config.incoming_ip_forwarding {
if let Some(ip) = headers.iter().find(|o| o.0 == header).map(|o| o.1) {
addr = SocketAddr::from_str(ip).ok()?;
}
}
let mut conn: Connection = if conn.is_none() {
let mut host = String::new();
let mut keep_alive = false;
for (key, value) in &headers {
match key.to_lowercase().as_str() {
"host" => host = value.to_string(),
"connection" => keep_alive = *value == "keep-alive",
_ => {}
}
}
let site = self.config.get_site(&host)?.clone();
Connection {
stream: site.connect().await?,
config: site,
keep_alive,
host
}
} else {
conn?
};
let content_length = headers
.iter()
.filter(|(k, _)| k.to_lowercase() == "content-length")
.next()
.map(|o| o.1.parse().ok())
.flatten()
.unwrap_or(0usize);
let mut reqbuf: Vec<u8> = Vec::new();
if let Some(replace_host) = conn.config.replace_host.clone() {
let mut new_head = Vec::new();
let mut is_status = true;
for line in head_str.split("\r\n") {
if is_status {
new_head.append(&mut line.as_bytes().to_vec());
is_status = false;
} else {
new_head.append(&mut b"\r\n".to_vec());
let (key, _) = line.split_once(": ")?;
if key.to_lowercase() == "host" {
new_head.append(&mut key.as_bytes().to_vec());
new_head.append(&mut b": ".to_vec());
new_head.append(&mut replace_host.as_bytes().to_vec());
} else {
new_head.append(&mut line.as_bytes().to_vec());
}
}
}
head = new_head;
}
match &conn.config.ip_forwarding {
IpForwarding::Header(header) => {
reqbuf.append(&mut status.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n".to_vec());
for (key, value) in String::from_utf8(head.clone()).ok()?
.split("\r\n")
.skip(1)
.filter_map(|o| o.split_once(": ")) {
if *key.to_lowercase() == header.to_lowercase() { continue }
reqbuf.append(&mut key.to_string().as_bytes().to_vec());
reqbuf.append(&mut b": ".to_vec());
reqbuf.append(&mut value.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n".to_vec());
}
reqbuf.append(&mut header.as_bytes().to_vec());
reqbuf.append(&mut b": ".to_vec());
reqbuf.append(&mut addr.to_string().as_bytes().to_vec());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::Simple => {
reqbuf.append(&mut addr.to_string().as_bytes().to_vec());
reqbuf.push(b'\n');
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::Modern => {
reqbuf.push(if addr.is_ipv4() { 0x01 } else { 0x02 });
match addr.ip() {
IpAddr::V4(ip) => {
reqbuf.append(&mut ip.octets().to_vec());
}, IpAddr::V6(ip) => {
reqbuf.append(&mut ip.octets().to_vec());
}
}
reqbuf.append(&mut addr.port().to_be_bytes().to_vec());
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
},
IpForwarding::None => {
reqbuf.append(&mut head.clone());
reqbuf.append(&mut b"\r\n\r\n".to_vec());
}
}
conn.stream.write_all(&reqbuf).await.ok()?;
if content_length > 0 { if content_length > 0 {
let mut read = 0usize; let mut read = 0usize;
let mut buf = vec![0; 4096]; let mut buf = vec![0; 4096];
while let Ok(size) = conn.stream.read(&mut buf).await { while let Ok(size) = stream.read(&mut buf).await {
if size == 0 { break } if size == 0 { break }
read += size; read += size;
buf.truncate(size); buf.truncate(size);
stream.write_all(&buf).await.ok()?; conn.stream.write_all(&buf).await.ok()?;
buf = vec![0; 4096]; buf = vec![0; 4096];
if read == content_length { break } if read >= content_length { break }
} }
} else if is_chunked { } else if is_chunked {
loop { loop {
@ -463,7 +361,7 @@ async fn read_request(
let mut buf = [0; 1]; let mut buf = [0; 1];
let mut counter = 0; let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf).await { while let Ok(1) = stream.read(&mut buf).await {
let byte = buf[0]; let byte = buf[0];
length.push(byte); length.push(byte);
@ -473,32 +371,126 @@ async fn read_request(
_ => 0, _ => 0,
}; };
} }
stream.write_all(&length).await.ok()?; conn.stream.write_all(&length).await.ok()?;
length.truncate(length.len() - 2); length.truncate(length.len() - 2);
} }
let length = String::from_utf8(length).ok()?; let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?; let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2]; let mut data = vec![0u8; length+2];
conn.stream.read_exact(&mut data).await.ok()?; stream.read_exact(&mut data).await.ok()?;
stream.write_all(&data).await.ok()?; conn.stream.write_all(&data).await.ok()?;
if length == 0 { if length == 0 {
break; break;
} }
} }
} }
} else {
let mut buf = vec![0;1024]; if conn.config.support_keep_alive {
while let Ok(n) = conn.stream.read(&mut buf).await { let mut head = Vec::new();
if n == 0 { break }
buf.truncate(n); {
stream.write_all(&buf).await.ok()?; let mut buf = [0; 1];
buf = vec![0;1024]; let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf).await {
let byte = buf[0];
head.push(byte);
stream.write_all(&buf).await.ok()?;
counter = match (counter, byte) {
(0, b'\r') => 1,
(1, b'\n') => 2,
(2, b'\r') => 3,
(3, b'\n') => break,
_ => 0,
};
}
head.truncate(head.len() - 4);
}
if head.is_empty() { return None; }
let head_str = String::from_utf8(head.clone()).ok()?;
let head_str = head_str.trim_matches(char::from(0));
let headers = head_str.split("\r\n")
.skip(1)
.filter(|l| l.contains(": "))
.map(|l| l.split_once(": ").unwrap())
.map(|(k,v)| (k.to_lowercase(),v.to_string()))
.collect::<Vec<(String,String)>>();
let content_length = headers.iter()
.find(|(k, _)| k == "content-length")
.map(|o| o.1.parse().ok())
.flatten()
.unwrap_or(0usize);
let is_chunked = headers.iter()
.find(|o| o.0.to_lowercase() == "transfer-encoding")
.map(|o| o.1.split(",").map(|x| x.trim_matches(' ').to_string()).collect::<Vec<String>>())
.map(|o| o.contains(&"chunked".to_string()))
.unwrap_or(false);
if content_length > 0 {
let mut read = 0usize;
let mut buf = vec![0; 4096];
while let Ok(size) = conn.stream.read(&mut buf).await {
if size == 0 { break }
read += size;
buf.truncate(size);
stream.write_all(&buf).await.ok()?;
buf = vec![0; 4096];
if read == content_length { break }
}
} else if is_chunked {
loop {
let mut length = Vec::new();
{
let mut buf = [0; 1];
let mut counter = 0;
while let Ok(1) = conn.stream.read(&mut buf).await {
let byte = buf[0];
length.push(byte);
counter = match (counter, byte) {
(0, b'\r') => 1,
(1, b'\n') => break,
_ => 0,
};
}
stream.write_all(&length).await.ok()?;
length.truncate(length.len() - 2);
}
let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2];
conn.stream.read_exact(&mut data).await.ok()?;
stream.write_all(&data).await.ok()?;
if length == 0 {
break;
}
}
}
} else {
let mut buf = vec![0;1024];
while let Ok(n) = conn.stream.read(&mut buf).await {
if n == 0 { break }
buf.truncate(n);
stream.write_all(&buf).await.ok()?;
buf = vec![0;1024];
}
} }
info!("{addr} > {} {}://{}{}", status_seq[0], if https { "https" } else { "http" }, conn.host, status_seq[1]);
Some(conn)
} }
}
info!("{addr} > {} {}://{}{}", status_seq[0], if https { "https" } else { "http" }, conn.host, status_seq[1]);
Some(conn)
}

View File

@ -1,4 +1,4 @@
use std::{sync::Arc, thread}; use std::sync::Arc;
use rustls::{ use rustls::{
crypto::aws_lc_rs::sign::any_supported_type, crypto::aws_lc_rs::sign::any_supported_type,
@ -8,8 +8,6 @@ use rustls::{
ServerConfig ServerConfig
}; };
use tokio::{runtime::Handle, sync::RwLock};
use super::config::Config; use super::config::Config;
@ -37,29 +35,19 @@ impl TlsCertificate {
#[derive(Debug)] #[derive(Debug)]
struct ResolvesServerCertWildcard { struct ResolvesServerCertWildcard {
config: Arc<RwLock<Config>>, config: Arc<Config>
handle: Handle
} }
impl ResolvesServerCertWildcard { impl ResolvesServerCertWildcard {
pub async fn new(config: Arc<RwLock<Config>>) -> Self { pub async fn new(config: Arc<Config>) -> Self {
Self { config, handle: Handle::current() } Self { config }
} }
} }
impl ResolvesServerCert for ResolvesServerCertWildcard { impl ResolvesServerCert for ResolvesServerCertWildcard {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> { fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(cert) = client_hello.server_name() if let Some(cert) = client_hello.server_name()
.and_then(|name| { .and_then(|name| self.config.get_site(name).cloned())
thread::spawn({
let handle = self.handle.clone();
let config = self.config.clone();
move || {
handle.block_on(config.read()).clone()
}
}).join().unwrap().get_site(name).cloned()
})
.and_then(|site| site.ssl) { .and_then(|site| site.ssl) {
Some(Arc::new(cert.get_key())) Some(Arc::new(cert.get_key()))
} else { } else {
@ -68,7 +56,7 @@ impl ResolvesServerCert for ResolvesServerCertWildcard {
} }
} }
pub async fn create_server_config(config: Arc<RwLock<Config>>) -> ServerConfig { pub async fn create_server_config(config: Arc<Config>) -> ServerConfig {
ServerConfig::builder() ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
.with_cert_resolver(Arc::new(ResolvesServerCertWildcard::new(config).await)) .with_cert_resolver(Arc::new(ResolvesServerCertWildcard::new(config).await))

View File

@ -2,7 +2,6 @@ use std::{fs, path::Path, sync::Arc};
use flowgate::{config::Config, server::FlowgateServer}; use flowgate::{config::Config, server::FlowgateServer};
use ignore_result::Ignore; use ignore_result::Ignore;
use tokio::sync::RwLock;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -12,7 +11,7 @@ async fn main() {
fs::write("conf.yml", include_bytes!("../conf.yml")).ignore(); fs::write("conf.yml", include_bytes!("../conf.yml")).ignore();
} }
let config = Arc::new(RwLock::new(Config::parse("conf.yml").unwrap())); let config = Arc::new(Config::parse("conf.yml").unwrap());
let server = FlowgateServer::new(config.clone()); let server = FlowgateServer::new(config.clone());
server.start().await; server.start().await;