flowgate/src/server.rs

539 lines
17 KiB
Rust
Executable File

use std::{
error::Error, io::{BufRead, BufReader, Read, Write},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream},
str::FromStr, sync::Arc
};
use ignore_result::Ignore;
use log::{debug, info};
use rustls::{ServerConnection, StreamOwned};
use threadpool::ThreadPool;
use super::{
tls::create_server_config,
config::{
Config,
SiteConfig,
IpForwarding
}
};
pub struct FlowgateServer {
config: Arc<Config>,
}
struct Connection {
stream: BufReader<TcpStream>,
config: SiteConfig,
keep_alive: bool,
host: String,
}
impl FlowgateServer {
pub fn new(config: Arc<Config>) -> Self {
FlowgateServer { config }
}
pub fn run(self) {
self.start().join();
}
pub fn start(self) -> ThreadPool {
let local_self = Arc::new(self);
let threadpool = ThreadPool::new(local_self.config.threadpool_size);
let mut handles = Vec::new();
handles.push(local_self.clone().start_http(&threadpool));
handles.push(local_self.clone().start_https(&threadpool));
threadpool
}
pub fn start_http(self: Arc<Self>, threadpool: &ThreadPool) {
threadpool.execute({
let local_self = self.clone();
let threadpool = threadpool.clone();
move || { local_self.run_http(&threadpool).ignore(); }
})
}
pub fn start_https(self: Arc<Self>, threadpool: &ThreadPool) {
threadpool.execute({
let local_self = self.clone();
let threadpool = threadpool.clone();
move || { local_self.run_https(&threadpool).ignore(); }
})
}
pub fn run_http(self: Arc<Self>, threadpool: &ThreadPool) -> Result<(), Box<dyn Error>> {
if let Some(host) = self.config.http_host.clone() {
let listener = TcpListener::bind(&host)?;
info!("HTTP server runned on {}", &host);
for stream in listener.incoming() {
if let Ok(mut stream) = stream {
let local_self = self.clone();
let Ok(addr) = stream.peer_addr() else { continue };
let Ok(_) = stream.set_write_timeout(Some(local_self.config.connection_timeout)) else { continue };
let Ok(_) = stream.set_read_timeout(Some(local_self.config.connection_timeout)) else { continue };
threadpool.execute(move || {
debug!("{} open connection", addr);
local_self.accept_stream(
&mut stream,
addr,
false
);
debug!("{} close connection", addr);
});
}
}
}
Ok(())
}
pub fn run_https(self: Arc<Self>, threadpool: &ThreadPool) -> Result<(), Box<dyn Error>> {
if let Some(host) = self.config.https_host.clone() {
let listener = TcpListener::bind(&host)?;
info!("HTTPS server runned on {}", &host);
let config = Arc::new(create_server_config(self.config.clone()));
for stream in listener.incoming() {
if let Ok(stream) = stream {
let local_self = self.clone();
let config = config.clone();
let Ok(addr) = stream.peer_addr() else { continue };
let Ok(_) = stream.set_write_timeout(Some(local_self.config.connection_timeout)) else { continue };
let Ok(_) = stream.set_read_timeout(Some(local_self.config.connection_timeout)) else { continue };
threadpool.execute(move || {
let Ok(connection) = ServerConnection::new(config) else { return };
debug!("{} open connection", addr);
let mut stream = StreamOwned::new(connection, stream);
while stream.conn.is_handshaking() {
let Ok(_) = stream.conn.complete_io(&mut stream.sock) else { debug!("{} close connection", addr);return };
}
local_self.accept_stream(
&mut stream,
addr,
true
);
debug!("{} close connection", addr);
});
}
}
}
Ok(())
}
fn accept_stream(
self: Arc<Self>,
stream: &mut (impl Read + Write + Shutdown),
addr: SocketAddr,
https: bool
) -> Option<()> {
let mut conn = self.clone().read_request(stream, addr, https, None)?;
if conn.keep_alive && conn.config.enable_keep_alive {
loop {
if !conn.config.support_keep_alive {
conn.stream.shutdown();
conn.stream = BufReader::new(conn.config.connect()?);
}
conn = self.clone().read_request(stream, addr, https, Some(conn))?;
}
}
conn.stream.shutdown();
stream.shutdown();
Some(())
}
fn read_request(
self: Arc<Self>,
stream: &mut (impl Read + Write + Shutdown),
addr: SocketAddr,
https: bool,
conn: Option<Connection>
) -> Option<Connection> {
let mut addr = addr;
let mut stream = BufReader::new(stream);
match &self.config.incoming_ip_forwarding {
IpForwarding::Simple => {
let mut header = Vec::new();
stream.read_until(b'\n', &mut header).ok()?;
header.truncate(header.len()-1);
addr = SocketAddr::from_str(&String::from_utf8(header).ok()?).ok()?;
},
IpForwarding::Modern => {
let mut header = [0];
stream.read(&mut header).ok()?;
addr = match header[0] {
0x01 => {
let mut octets = [0; 4];
stream.read(&mut octets).ok()?;
let mut port = [0; 2];
stream.read(&mut port).ok()?;
let port = u16::from_be_bytes(port);
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(octets), port))
}, 0x02 => {
let mut octets = [0; 16];
stream.read(&mut octets).ok()?;
let mut port = [0; 2];
stream.read(&mut port).ok()?;
let port = u16::from_be_bytes(port);
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(octets), port, 0, 0))
}, _ => { return None },
};
},
_ => {}
}
let mut raw_status = read_until(&mut stream, b"\r\n")?;
let mut request = Vec::new();
request.append(&mut raw_status.clone());
raw_status.truncate(raw_status.len() - 2);
let status = String::from_utf8(raw_status.clone()).ok()?;
let status = status.split(" ").collect::<Vec<&str>>();
debug!("{} {} read status", addr, status[1]);
let mut content_length = 0;
let mut host = None;
let mut is_chunked = false;
let mut keep_alive = false;
let mut headers = Vec::new();
loop {
let mut header = read_until(&mut stream, b"\r\n")?;
header.truncate(header.len() - 2);
if header.is_empty() {
break;
}
let header = String::from_utf8(header).ok()?;
let (key, value) = header.split_once(": ")?;
headers.push((key.to_string(), value.to_string()));
match key.to_lowercase().as_str() {
"transfer-encoding" => {
if value.contains("chunked") {
is_chunked = true;
}
},
"host" => {
host = Some(value.to_string());
},
"connection" => {
keep_alive = value.to_lowercase().contains("keep-alive");
},
"content-length" => {
content_length = value.parse::<usize>().ok()?;
},
_ => {
if let IpForwarding::Header(header) = &self.config.incoming_ip_forwarding {
if key.to_lowercase() == header.to_lowercase() {
addr = SocketAddr::from_str(value).ok()?;
}
}
},
}
}
debug!("{} {} read headers", addr, status[1]);
let mut conn: Connection = if conn.is_none() {
let host = host?;
let site = self.config.get_site(&host)?.clone();
Connection {
stream: BufReader::new(site.connect()?),
config: site,
keep_alive,
host
}
} else {
conn?
};
debug!("{} {} got connection", addr, status[1]);
match &conn.config.ip_forwarding {
IpForwarding::Simple => {
request.append(&mut addr.to_string().as_bytes().to_vec());
request.push(b'\n');
},
IpForwarding::Modern => {
match addr.ip() {
IpAddr::V4(ip) => {
request.push(0x01);
request.append(&mut ip.octets().to_vec());
}, IpAddr::V6(ip) => {
request.push(0x02);
request.append(&mut ip.octets().to_vec());
}
}
request.append(&mut addr.port().to_be_bytes().to_vec());
},
_ => {}
}
for (key, value) in headers {
let mut value = value.to_string();
match key.to_lowercase().as_str() {
"host" => {
if let Some(replace_host) = conn.config.replace_host.clone() {
value = replace_host;
}
},
_ => {}
}
if let IpForwarding::Header(header) = &conn.config.ip_forwarding {
if key.to_lowercase() == header.to_lowercase() {
continue;
}
}
request.append(&mut key.as_bytes().to_vec());
request.append(&mut b": ".to_vec());
request.append(&mut value.as_bytes().to_vec());
request.append(&mut b"\r\n".to_vec());
}
if let IpForwarding::Header(header) = &conn.config.ip_forwarding {
request.append(&mut header.as_bytes().to_vec());
request.append(&mut b": ".to_vec());
request.append(&mut addr.to_string().as_bytes().to_vec());
request.append(&mut b"\r\n".to_vec());
}
request.append(&mut b"\r\n".to_vec());
debug!("{:?}", String::from_utf8_lossy(&request));
conn.stream.get_mut().write_all(&request).ok()?;
debug!("{} {} sent request to server", addr, status[1]);
if content_length > 0 {
let buffer = stream.buffer().to_vec();
conn.stream.get_mut().write_all(&buffer).ok()?;
stream.consume(buffer.len());
let mut read = buffer.len();
debug!("{} {} send part of body to server", addr, status[1]);
while read < content_length {
let mut buf = vec![0; 4096];
let size = conn.stream.get_mut().read(&mut buf).ok()?;
if size == 0 { break }
buf.truncate(size);
read += size;
debug!("{} {} send response body part {} to clientr", addr, status[1], size);
stream.get_mut().write_all(&buf).ok()?;
}
} else if is_chunked {
transfer_chunked(&mut stream, conn.stream.get_mut())?;
} else {
let buffer = stream.buffer().to_vec();
conn.stream.get_mut().write_all(&buffer).ok()?;
stream.consume(buffer.len());
}
debug!("{} {} send body to server", addr, status[1]);
if conn.config.support_keep_alive {
let mut response = Vec::new();
let raw_status = read_until(&mut conn.stream, b"\r\n")?;
response.append(&mut raw_status.clone());
let mut content_length = 0;
let mut is_chunked = false;
loop {
let mut header = read_until(&mut conn.stream, b"\r\n")?;
response.append(&mut header.clone());
if header.len() == 2 {
break;
}
header.truncate(header.len() - 2);
let header = String::from_utf8(header).ok()?;
let (key, value) = header.split_once(": ")?;
match key.to_lowercase().as_str() {
"transfer-encoding" => {
if value.contains("chunked") {
is_chunked = true;
}
},
"content-length" => {
content_length = value.parse::<usize>().ok()?;
},
_ => {}
}
}
stream.get_mut().write_all(&response).ok()?;
debug!("{} {} send response header to clientr", addr, status[1]);
if content_length > 0 {
let buffer = conn.stream.buffer().to_vec();
stream.get_mut().write_all(&buffer).ok()?;
conn.stream.consume(buffer.len());
debug!("{} {} send response body part {} to clientr", addr, status[1], buffer.len());
let mut read = buffer.len();
while read < content_length {
let mut buf = vec![0; 4096];
let size = conn.stream.get_mut().read(&mut buf).ok()?;
if size == 0 { break }
buf.truncate(size);
read += size;
debug!("{} {} send response body part {} to clientr", addr, status[1], size);
stream.get_mut().write_all(&buf).ok()?;
}
} else if is_chunked {
transfer_chunked(&mut conn.stream, stream.get_mut())?;
} else {
let buffer = conn.stream.buffer().to_vec();
stream.get_mut().write_all(&buffer).ok()?;
conn.stream.consume(buffer.len());
}
debug!("{} {} send response body to clientr", addr, status[1]);
} else {
let buffer = conn.stream.buffer();
stream.get_mut().write_all(buffer).ok()?;
conn.stream.consume(buffer.len());
let mut buf = vec![0;4096];
while let Ok(n) = conn.stream.get_mut().read(&mut buf) {
if n == 0 { break }
buf.truncate(n);
stream.get_mut().write_all(&buf).ok()?;
buf = vec![0;4096];
}
}
info!("{addr} > {} {}://{}{}", status[0], if https { "https" } else { "http" }, conn.host, status[1]);
Some(conn)
}
}
fn read_until(stream: &mut impl BufRead, delimiter: &[u8]) -> Option<Vec<u8>> {
let mut data = Vec::new();
let last_byte = *delimiter.last()?;
loop {
let mut buf = Vec::new();
let buf_len = stream.read_until(last_byte, &mut buf).ok()?;
debug!("read buf len until {} {:?}", buf_len, String::from_utf8_lossy(&buf));
if buf_len == 0 {
return None
}
data.append(&mut buf);
if data.ends_with(delimiter) {
break;
}
}
Some(data)
}
fn transfer_chunked(src: &mut impl BufRead, dest: &mut impl Write) -> Option<()> {
loop {
let mut length = read_until(src, b"\r\n")?;
dest.write_all(&length).ok()?;
length.truncate(length.len()-2);
let length = String::from_utf8(length).ok()?;
let length = usize::from_str_radix(length.as_str(), 16).ok()?;
let mut data = vec![0u8; length+2];
src.read_exact(&mut data).ok()?;
dest.write_all(&data).ok()?;
if length == 0 {
break;
}
}
Some(())
}
pub trait Shutdown {
fn shutdown(&self);
}
impl Shutdown for TcpStream {
fn shutdown(&self) {
TcpStream::shutdown(self, std::net::Shutdown::Both).ignore();
}
}
impl <T: Shutdown> Shutdown for BufReader<T> {
fn shutdown(&self) {
self.get_ref().shutdown();
}
}
impl <C, T: Read + Write + Shutdown> Shutdown for StreamOwned<C, T> {
fn shutdown(&self) {
self.sock.shutdown();
}
}