diff --git a/Cargo.lock b/Cargo.lock index 0f04e3c..dfce6bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,6 +87,7 @@ dependencies = [ "rusty_pool", "serde_json", "tokio", + "tokio-io-timeout", "urlencoding", ] @@ -374,9 +375,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.125" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", @@ -453,6 +454,16 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.4.0" diff --git a/Cargo.toml b/Cargo.toml index cf2186c..4a7d1d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,10 @@ keywords = ["http", "server", "site", "async"] [dependencies] urlencoding = "2.1.3" -serde_json = "1.0.125" +serde_json = "1.0.127" tokio = { version = "1.39.3", features = ["full"] } -rusty_pool = "0.7.0" \ No newline at end of file +rusty_pool = "0.7.0" +tokio-io-timeout = "1.2.0" + +[features] +http_rrs = [] \ No newline at end of file diff --git a/src/ezhttp/error.rs b/src/ezhttp/error.rs index e2086c6..9392f45 100644 --- a/src/ezhttp/error.rs +++ b/src/ezhttp/error.rs @@ -13,6 +13,7 @@ pub enum HttpError { WriteHeadError, WriteBodyError, InvalidStatus, + RequstError } impl std::fmt::Display for HttpError { diff --git a/src/ezhttp/handler.rs b/src/ezhttp/handler.rs new file mode 100644 index 0000000..f8dd978 --- /dev/null +++ b/src/ezhttp/handler.rs @@ -0,0 +1,86 @@ +use super::{HttpError, HttpRequest, HttpServer, Stream}; + +use std::{future::Future, pin::Pin, sync::Arc}; +use tokio::{net::TcpStream, sync::Mutex}; +use tokio_io_timeout::TimeoutStream; + +#[cfg(feature = "http_rrs")] +use {super::read_line_lf, std::net::{ToSocketAddrs, SocketAddr}}; + +pub type Handler = Box>, TimeoutStream) -> Pin + Send>> + Send + Sync>; + + +/// Default connection handler +/// Turns input to request and response to output +pub async fn handler_connection( + server: Arc>, + mut sock: Stream +) { + let Ok(addr) = sock.get_ref().peer_addr() else { return; }; + + let req = match HttpRequest::read(sock.get_mut(), &addr).await { + Ok(i) => i, + Err(e) => { + server.lock().await.on_error(e).await; + return; + } + }; + + let resp = match server.lock().await.on_request(&req).await { + Some(i) => i, + None => { + server.lock().await.on_error(HttpError::RequstError).await; + return; + } + }; + + match resp.write(sock.get_mut()).await { + Ok(_) => {}, + Err(e) => { + server.lock().await.on_error(e).await; + return; + }, + } +} + +#[cfg(feature = "http_rrs")] +/// HTTP_RRS handler +pub async fn handler_http_rrs( + server: Arc>, + mut sock: Stream, +) { + let addr = match read_line_lf(sock.get_mut()).await { + Ok(i) => i, + Err(e) => { + server.lock().await.on_error(e).await; + return; + } + } + .to_socket_addrs() + .unwrap() + .collect::>()[0]; + + let req = match HttpRequest::read(sock.get_mut(), &addr).await { + Ok(i) => i, + Err(e) => { + server.lock().await.on_error(e).await; + return; + } + }; + + let resp = match server.lock().await.on_request(&req).await { + Some(i) => i, + None => { + server.lock().await.on_error(HttpError::RequstError).await; + return; + } + }; + + match resp.write(sock.get_mut()).await { + Ok(_) => {}, + Err(e) => { + server.lock().await.on_error(e).await; + return; + }, + } +} \ No newline at end of file diff --git a/src/ezhttp/mod.rs b/src/ezhttp/mod.rs index 15c8d89..397a9dd 100644 --- a/src/ezhttp/mod.rs +++ b/src/ezhttp/mod.rs @@ -1,94 +1,64 @@ +use std::sync::atomic::{AtomicBool, Ordering}; use std::{ boxed::Box, error::Error, future::Future, - io::Read, - net::{TcpListener, TcpStream}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::Arc, time::Duration, }; +use tokio::io::AsyncReadExt; +use rusty_pool::ThreadPool; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::Mutex; +use tokio_io_timeout::TimeoutStream; + pub mod error; pub mod headers; pub mod request; pub mod response; pub mod starter; +pub mod handler; pub use error::*; pub use headers::*; pub use request::*; pub use response::*; -use rusty_pool::ThreadPool; pub use starter::*; -use tokio::sync::Mutex; +pub use handler::*; -fn read_line(data: &mut impl Read) -> Result { - let mut bytes = Vec::new(); - for byte in data.bytes() { - let byte = match byte { - Ok(i) => i, - Err(_) => return Err(HttpError::ReadLineEof), - }; - bytes.push(byte); - - if byte == 0x0A { +async fn read_line(data: &mut (impl AsyncReadExt + Unpin)) -> Result { + let mut line = Vec::new(); + loop { + let mut buffer = vec![0;1]; + data.read_exact(&mut buffer).await.or(Err(HttpError::ReadLineEof))?; + let char = buffer[0]; + line.push(char); + if char == 0x0a { break; } } - - match String::from_utf8(bytes) { - Ok(i) => Ok(i), - Err(_) => Err(HttpError::ReadLineUnknown), - } + String::from_utf8(line).or(Err(HttpError::ReadLineUnknown)) } -fn read_line_crlf(data: &mut impl Read) -> Result { - match read_line(data) { +async fn read_line_crlf(data: &mut (impl AsyncReadExt + Unpin)) -> Result { + match read_line(data).await { Ok(i) => Ok(i[..i.len() - 2].to_string()), Err(e) => Err(e), } } -fn read_line_lf(data: &mut impl Read) -> Result { - match read_line(data) { +#[cfg(feature = "http_rrs")] +async fn read_line_lf(data: &mut (impl AsyncReadExt + Unpin)) -> Result { + match read_line(data).await { Ok(i) => Ok(i[..i.len() - 1].to_string()), Err(e) => Err(e), } } -fn rem_first(value: &str) -> &str { - let mut chars = value.chars(); - chars.next(); - chars.as_str() -} - -fn split(text: String, delimiter: &str, times: usize) -> Vec { - match times { - 0 => text.split(delimiter).map(|v| v.to_string()).collect(), - 1 => { - let mut v: Vec = Vec::new(); - match text.split_once(delimiter) { - Some(i) => { - v.push(i.0.to_string()); - v.push(i.1.to_string()); - } - None => { - v.push(text); - } - } - v - } - _ => text - .splitn(times, delimiter) - .map(|v| v.to_string()) - .collect(), - } -} +pub type Stream = TimeoutStream; /// Async http server trait pub trait HttpServer { @@ -98,45 +68,43 @@ pub trait HttpServer { &mut self, req: &HttpRequest, ) -> impl Future> + Send; + fn on_error( + &mut self, + _: HttpError + ) -> impl Future + Send { + async {} + } } -async fn start_server_with_threadpool( - server: S, +async fn start_server_with_threadpool( + server: T, host: &str, timeout: Option, threads: usize, - rrs: bool, + handler: Handler, running: Arc, ) -> Result<(), Box> where - S: HttpServer + Send + 'static, + T: HttpServer + Send + 'static, { let threadpool = ThreadPool::new(threads, threads * 10, Duration::from_secs(60)); let server = Arc::new(Mutex::new(server)); - let listener = TcpListener::bind(host)?; + let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); server_clone.lock().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { - let (sock, _) = match listener.accept() { - Ok(i) => i, - Err(_) => { - continue; - } - }; + let Ok((sock, _)) = listener.accept().await else { continue; }; + let mut sock = TimeoutStream::new(sock); - sock.set_read_timeout(timeout).unwrap(); - sock.set_write_timeout(timeout).unwrap(); + sock.set_read_timeout(timeout); + sock.set_write_timeout(timeout); let now_server = Arc::clone(&server); - if !rrs { - threadpool.spawn(handle_connection(now_server, sock)); - } else { - threadpool.spawn(handle_connection_rrs(now_server, sock)); - } + threadpool.spawn((&handler)(now_server, sock)); } threadpool.join(); @@ -146,41 +114,33 @@ where Ok(()) } -async fn start_server_new_thread( - server: S, +async fn start_server_new_thread( + server: T, host: &str, timeout: Option, - rrs: bool, + handler: Handler, running: Arc, ) -> Result<(), Box> where - S: HttpServer + Send + 'static, + T: HttpServer + Send + 'static, { let server = Arc::new(Mutex::new(server)); - let listener = TcpListener::bind(host)?; + let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); server_clone.lock().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { - let (sock, _) = match listener.accept() { - Ok(i) => i, - Err(_) => { - continue; - } - }; + let Ok((sock, _)) = listener.accept().await else { continue; }; + let mut sock = TimeoutStream::new(sock); - sock.set_read_timeout(timeout).unwrap(); - sock.set_write_timeout(timeout).unwrap(); + sock.set_read_timeout(timeout); + sock.set_write_timeout(timeout); let now_server = Arc::clone(&server); - if !rrs { - tokio::spawn(handle_connection(now_server, sock)); - } else { - tokio::spawn(handle_connection_rrs(now_server, sock)); - } + tokio::spawn((&handler)(now_server, sock)); } server.lock().await.on_close().await; @@ -188,41 +148,33 @@ where Ok(()) } -async fn start_server_sync( - server: S, +async fn start_server_sync( + server: T, host: &str, timeout: Option, - rrs: bool, + handler: Handler, running: Arc, ) -> Result<(), Box> where - S: HttpServer + Send + 'static, + T: HttpServer + Send + 'static, { let server = Arc::new(Mutex::new(server)); - let listener = TcpListener::bind(host)?; + let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); server_clone.lock().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { - let (sock, _) = match listener.accept() { - Ok(i) => i, - Err(_) => { - continue; - } - }; + let Ok((sock, _)) = listener.accept().await else { continue; }; + let mut sock = TimeoutStream::new(sock); - sock.set_read_timeout(timeout).unwrap(); - sock.set_write_timeout(timeout).unwrap(); + sock.set_read_timeout(timeout); + sock.set_write_timeout(timeout); let now_server = Arc::clone(&server); - if !rrs { - handle_connection(now_server, sock).await; - } else { - handle_connection_rrs(now_server, sock).await; - } + handler(now_server, sock).await; } server.lock().await.on_close().await; @@ -230,60 +182,18 @@ where Ok(()) } -async fn handle_connection( - server: Arc>, - mut sock: TcpStream -) { - let Ok(addr) = sock.peer_addr() else { return; }; - - let req = match HttpRequest::read(&mut sock, &addr) { - Ok(i) => i, - Err(_) => { - return; - } - }; - - let resp = match server.lock().await.on_request(&req).await { - Some(i) => i, - None => { - return; - } - }; - - let _ = resp.write(&mut sock); -} - -async fn handle_connection_rrs( - server: Arc>, - mut sock: TcpStream, -) { - let req = match HttpRequest::read_with_rrs(&mut sock) { - Ok(i) => i, - Err(_) => { - return; - } - }; - let resp = match server.lock().await.on_request(&req).await { - Some(i) => i, - None => { - return; - } - }; - let _ = resp.write(&mut sock); -} - /// Start [`HttpServer`](HttpServer) on some host /// /// Use [`HttpServerStarter`](HttpServerStarter) to set more options -pub async fn start_server( - server: S, +pub async fn start_server( + server: T, host: &str ) -> Result<(), Box> { start_server_new_thread( server, host, None, - false, + Box::new(move |a, b| Box::pin(handler_connection(a, b))), Arc::new(AtomicBool::new(true)), ).await } diff --git a/src/ezhttp/request.rs b/src/ezhttp/request.rs index 5590905..97ae2f4 100644 --- a/src/ezhttp/request.rs +++ b/src/ezhttp/request.rs @@ -1,11 +1,11 @@ -use super::{read_line_crlf, read_line_lf, rem_first, split, Headers, HttpError}; +use super::{read_line_crlf, Headers, HttpError}; use serde_json::Value; use std::{ fmt::{Debug, Display}, - io::{Read, Write}, - net::{IpAddr, SocketAddr, ToSocketAddrs}, + net::SocketAddr, }; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; /// Http request #[derive(Debug, Clone)] @@ -38,39 +38,28 @@ impl HttpRequest { } /// Read http request from stream - pub fn read(data: &mut impl Read, addr: &SocketAddr) -> Result { - let octets = match addr.ip() { - IpAddr::V4(ip) => ip.octets(), - _ => [127, 0, 0, 1], - }; + pub async fn read(data: &mut (impl AsyncReadExt + Unpin), addr: &SocketAddr) -> Result { + let ip_str = addr.to_string(); - let ip_str = octets[0].to_string() - + "." - + &octets[1].to_string() - + "." - + &octets[2].to_string() - + "." - + &octets[3].to_string(); - - let status = split( - match read_line_crlf(data) { - Ok(i) => i, - Err(e) => return Err(e), + let status: Vec = match read_line_crlf(data).await { + Ok(i) => { + i.splitn(3, " ") + .map(|s| s.to_string()) + .collect() }, - " ", - 3, - ); + Err(e) => return Err(e), + }; let method = status[0].clone(); let (page, query) = match status[1].split_once("?") { Some(i) => (i.0.to_string(), Some(i.1)), - None => (status[1].clone(), None), + None => (status[1].to_string(), None), }; let mut headers = Headers::new(); loop { - let text = match read_line_crlf(data) { + let text = match read_line_crlf(data).await { Ok(i) => i, Err(_) => return Err(HttpError::InvalidHeaders), }; @@ -121,7 +110,7 @@ impl HttpRequest { let mut buf: Vec = Vec::new(); buf.resize(content_size - reqdata.len(), 0); - match data.read_exact(&mut buf) { + match data.read_exact(&mut buf).await { Ok(i) => i, Err(_) => return Err(HttpError::InvalidContent), }; @@ -166,7 +155,7 @@ impl HttpRequest { } "application/x-www-form-urlencoded" => { if body.starts_with("?") { - body = rem_first(body.as_str()).to_string() + body = body.as_str()[1..].to_string() } for ele in body.split("&") { @@ -201,20 +190,6 @@ impl HttpRequest { }) } - /// Read http request with http_rrs support - pub fn read_with_rrs(data: &mut impl Read) -> Result { - let addr = match read_line_lf(data) { - Ok(i) => i, - Err(e) => { - return Err(e); - } - } - .to_socket_addrs() - .unwrap() - .collect::>()[0]; - HttpRequest::read(data, &addr) - } - /// Set params to query in url pub fn params_to_page(&mut self) { let mut query = String::new(); @@ -246,7 +221,7 @@ impl HttpRequest { /// Write http request to stream /// /// [`params`](Self::params) is not written to the stream, you need to use [`params_to_json`](Self::params_to_json) or [`params_to_page`](Self::params_to_page) - pub fn write(self, data: &mut impl Write) -> Result<(), HttpError> { + pub async fn write(self, data: &mut (impl AsyncWriteExt + Unpin)) -> Result<(), HttpError> { let mut head: String = String::new(); head.push_str(&self.method); head.push_str(" "); @@ -263,13 +238,13 @@ impl HttpRequest { head.push_str("\r\n"); - match data.write_all(head.as_bytes()) { + match data.write_all(head.as_bytes()).await { Ok(i) => i, Err(_) => return Err(HttpError::WriteHeadError), }; if !self.data.is_empty() { - match data.write_all(&self.data) { + match data.write_all(&self.data).await { Ok(i) => i, Err(_) => return Err(HttpError::WriteBodyError), }; diff --git a/src/ezhttp/response.rs b/src/ezhttp/response.rs index 8dcb0c6..6beb5ef 100644 --- a/src/ezhttp/response.rs +++ b/src/ezhttp/response.rs @@ -1,10 +1,8 @@ use super::{read_line_crlf, Headers, HttpError}; use serde_json::Value; -use std::{ - fmt::{Debug, Display}, - io::{Read, Write}, -}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use std::fmt::{Debug, Display}; /// Http response #[derive(Debug, Clone)] @@ -61,8 +59,8 @@ impl HttpResponse { } /// Read http response from stream - pub fn read(data: &mut impl Read) -> Result { - let status = match read_line_crlf(data) { + pub async fn read(data: &mut (impl AsyncReadExt + Unpin)) -> Result { + let status = match read_line_crlf(data).await { Ok(i) => i, Err(e) => { return Err(e); @@ -77,7 +75,7 @@ impl HttpResponse { let mut headers = Headers::new(); loop { - let text = match read_line_crlf(data) { + let text = match read_line_crlf(data).await { Ok(i) => i, Err(_) => return Err(HttpError::InvalidHeaders), }; @@ -106,7 +104,7 @@ impl HttpResponse { let mut buf: Vec = Vec::new(); buf.resize(content_size - reqdata.len(), 0); - match data.read_exact(&mut buf) { + match data.read_exact(&mut buf).await { Ok(i) => i, Err(_) => return Err(HttpError::InvalidContent), }; @@ -117,7 +115,7 @@ impl HttpResponse { loop { let mut buf: Vec = vec![0; 1024 * 32]; - let buf_len = match data.read(&mut buf) { + let buf_len = match data.read(&mut buf).await { Ok(i) => i, Err(_) => { break; @@ -138,7 +136,7 @@ impl HttpResponse { } /// Write http response to stream - pub fn write(self, data: &mut impl Write) -> Result<(), &str> { + pub async fn write(self, data: &mut (impl AsyncWriteExt + Unpin)) -> Result<(), HttpError> { let mut head: String = String::new(); head.push_str("HTTP/1.1 "); head.push_str(&self.status_code); @@ -153,14 +151,14 @@ impl HttpResponse { head.push_str("\r\n"); - match data.write_all(head.as_bytes()) { + match data.write_all(head.as_bytes()).await { Ok(i) => i, - Err(_) => return Err("write head error"), + Err(_) => return Err(HttpError::WriteHeadError), }; - match data.write_all(&self.data) { + match data.write_all(&self.data).await { Ok(i) => i, - Err(_) => return Err("write body error"), + Err(_) => return Err(HttpError::WriteHeadError), }; Ok(()) diff --git a/src/ezhttp/starter.rs b/src/ezhttp/starter.rs index ce7d0b0..30c1f4e 100644 --- a/src/ezhttp/starter.rs +++ b/src/ezhttp/starter.rs @@ -1,28 +1,25 @@ use tokio::task::JoinHandle; +use tokio::sync::Mutex; +use tokio::net::TcpStream; +use tokio_io_timeout::TimeoutStream; use super::{ - start_server_new_thread, start_server_sync, - start_server_with_threadpool, HttpServer, + start_server_new_thread, + start_server_sync, + start_server_with_threadpool, + handler_connection, + Handler, + HttpServer, }; +use std::pin::Pin; use std::{ - error::Error, - sync::{ + error::Error, future::Future, sync::{ atomic::{AtomicBool, Ordering}, Arc, - }, - time::Duration, + }, time::Duration }; -/// Http server start builder -pub struct HttpServerStarter { - http_server: T, - support_http_rrs: bool, - timeout: Option, - host: String, - threads: usize, -} - /// Running http server pub struct RunningHttpServer { thread: JoinHandle<()>, @@ -41,12 +38,21 @@ impl RunningHttpServer { } } +/// Http server start builder +pub struct HttpServerStarter { + http_server: T, + handler: Handler, + timeout: Option, + host: String, + threads: usize, +} + impl HttpServerStarter { /// Create new HttpServerStarter pub fn new(http_server: T, host: &str) -> Self { HttpServerStarter { http_server, - support_http_rrs: false, + handler: Box::new(move |a, b| Box::pin(handler_connection(a, b))), timeout: None, host: host.to_string(), threads: 0, @@ -56,25 +62,25 @@ impl HttpServerStarter { /// Set http server pub fn http_server(mut self, http_server: T) -> Self { self.http_server = http_server; - return self; + self } /// Set if http_rrs is supported - pub fn support_http_rrs(mut self, support_http_rrs: bool) -> Self { - self.support_http_rrs = support_http_rrs; - return self; + pub fn handler(mut self, handler: impl Fn(Arc>, TimeoutStream) -> Pin + Send>> + Send + Sync + 'static) -> Self { + self.handler = Box::new(move |a, b| Box::pin(handler(a, b))); + self } /// Set timeout for read & write pub fn timeout(mut self, timeout: Option) -> Self { self.timeout = timeout; - return self; + self } /// Set host pub fn host(mut self, host: String) -> Self { self.host = host; - return self; + self } /// Set threads in threadpool and return builder @@ -83,17 +89,17 @@ impl HttpServerStarter { /// 1 thread means that all connections are processed in the main thread pub fn threads(mut self, threads: usize) -> Self { self.threads = threads; - return self; + self } /// Get http server - pub fn get_http_server(self) -> T { - self.http_server + pub fn get_http_server(&self) -> &T { + &self.http_server } /// Get if http_rrs is supported - pub fn get_support_http_rrs(&self) -> bool { - self.support_http_rrs + pub fn get_handler(&self) -> &Handler { + &self.handler } /// Get timeout for read & write @@ -109,7 +115,7 @@ impl HttpServerStarter { /// Get threads in threadpool /// /// 0 threads means that a new thread is created for each connection \ - /// 1 thread means that all connections are processed in the main thread + /// 1 thread means that all connections are processed in the one thread pub fn get_threads(&self) -> usize { self.threads } @@ -119,16 +125,16 @@ impl HttpServerStarter { let running = Arc::new(AtomicBool::new(true)); if self.threads == 0 { - start_server_new_thread(self.http_server, &self.host, self.timeout, self.support_http_rrs, running).await + start_server_new_thread(self.http_server, &self.host, self.timeout, self.handler, running).await } else if self.threads == 1 { - start_server_sync(self.http_server, &self.host, self.timeout, self.support_http_rrs, running).await + start_server_sync(self.http_server, &self.host, self.timeout, self.handler, running).await } else { start_server_with_threadpool( self.http_server, &self.host, self.timeout, self.threads, - self.support_http_rrs, + self.handler, running, ).await } @@ -145,7 +151,7 @@ impl HttpServerStarter { self.http_server, &self.host, self.timeout, - self.support_http_rrs, + self.handler, running_clone, ).await .expect("http server error"); @@ -156,7 +162,7 @@ impl HttpServerStarter { self.http_server, &self.host, self.timeout, - self.support_http_rrs, + self.handler, running_clone, ).await .expect("http server error"); @@ -168,7 +174,7 @@ impl HttpServerStarter { &self.host, self.timeout, self.threads, - self.support_http_rrs, + self.handler, running_clone, ).await .expect("http server error")