diff --git a/examples/parallel_sites.rs b/examples/parallel_sites.rs index 794e9fc..28036fe 100644 --- a/examples/parallel_sites.rs +++ b/examples/parallel_sites.rs @@ -17,7 +17,7 @@ impl EzSite { } impl HttpServer for EzSite { - async fn on_request(&mut self, req: &HttpRequest) -> Option { + async fn on_request(&self, req: &HttpRequest) -> Option { // println!("{} > {} {}", req.addr, req.method, req.page); if req.page == "/" { diff --git a/examples/simple_site.rs b/examples/simple_site.rs index 0e8d345..3714487 100644 --- a/examples/simple_site.rs +++ b/examples/simple_site.rs @@ -12,7 +12,7 @@ impl EzSite { } } - fn ok_response(&mut self, content: String) -> HttpResponse { + fn ok_response(&self, content: String) -> HttpResponse { HttpResponse::from_string( Headers::from(vec![("Content-Type", "text/html")]), "200 OK".to_string(), @@ -20,7 +20,7 @@ impl EzSite { ) } - fn not_found_response(&mut self, content: String) -> HttpResponse { + fn not_found_response(&self, content: String) -> HttpResponse { HttpResponse::from_string( Headers::from(vec![("Content-Type", "text/html")]), "404 Not Found".to_string(), @@ -28,7 +28,7 @@ impl EzSite { ) } - async fn get_main_page(&mut self, req: &HttpRequest) -> Option { + async fn get_main_page(&self, req: &HttpRequest) -> Option { if req.page == "/" { Some(self.ok_response(self.main_page.clone())) } else { @@ -36,13 +36,13 @@ impl EzSite { } } - async fn get_unknown_page(&mut self, req: &HttpRequest) -> Option { + async fn get_unknown_page(&self, req: &HttpRequest) -> Option { Some(self.not_found_response(format!("

404 Error

Not Found {}", &req.page))) } } impl HttpServer for EzSite { - async fn on_request(&mut self, req: &HttpRequest) -> Option { + async fn on_request(&self, req: &HttpRequest) -> Option { println!("{} > {} {}", req.addr, req.method, req.page); if let Some(resp) = self.get_main_page(req).await { diff --git a/src/ezhttp/handler.rs b/src/ezhttp/handler.rs index d6462a2..d012b8b 100644 --- a/src/ezhttp/handler.rs +++ b/src/ezhttp/handler.rs @@ -1,18 +1,18 @@ -use super::{HttpError, HttpRequest, HttpServer, Stream}; +use super::{HttpRequest, HttpServer, Stream}; use std::{future::Future, pin::Pin, sync::Arc}; -use tokio::{net::TcpStream, sync::Mutex}; +use tokio::{net::TcpStream, sync::RwLock}; use tokio_io_timeout::TimeoutStream; #[cfg(feature = "http_rrs")] use {super::read_line_lf, std::net::{ToSocketAddrs, SocketAddr}}; -pub type Handler = Box>, TimeoutStream) -> Pin + Send>> + Send + Sync>; +pub type Handler = Box>, TimeoutStream) -> Pin + Send>> + Send + Sync>; /// Default connection handler /// Turns input to request and response to output -pub async fn handler_connection( - server: Arc>, +pub async fn handler_connection( + server: Arc>, mut sock: Stream ) { let Ok(addr) = sock.get_ref().peer_addr() else { return; }; @@ -20,23 +20,27 @@ pub async fn handler_connection( let req = match HttpRequest::read(sock.get_mut(), &addr).await { Ok(i) => i, Err(e) => { - server.lock().await.on_error(e).await; + server.write().await.on_error(e).await; return; } }; - let resp = match server.lock().await.on_request(&req).await { + let resp = match server.read().await.on_request(&req).await { Some(i) => i, None => { - server.lock().await.on_error(HttpError::RequstError).await; - return; + match server.write().await.on_request_mut(&req).await { + Some(i) => i, + None => { + return; + } + } } }; match resp.write(sock.get_mut()).await { Ok(_) => {}, Err(e) => { - server.lock().await.on_error(e).await; + server.write().await.on_error(e).await; return; }, } @@ -51,14 +55,14 @@ macro_rules! pin_handler { #[cfg(feature = "http_rrs")] /// HTTP_RRS handler -pub async fn handler_http_rrs( - server: Arc>, +pub async fn handler_http_rrs( + server: Arc>, mut sock: Stream, ) { let addr = match read_line_lf(sock.get_mut()).await { Ok(i) => i, Err(e) => { - server.lock().await.on_error(e).await; + server.write().await.on_error(e).await; return; } } @@ -69,23 +73,27 @@ pub async fn handler_http_rrs( let req = match HttpRequest::read(sock.get_mut(), &addr).await { Ok(i) => i, Err(e) => { - server.lock().await.on_error(e).await; + server.write().await.on_error(e).await; return; } }; - let resp = match server.lock().await.on_request(&req).await { + let resp = match server.read().await.on_request(&req).await { Some(i) => i, None => { - server.lock().await.on_error(HttpError::RequstError).await; - return; + match server.write().await.on_request_mut(&req).await { + Some(i) => i, + None => { + return; + } + } } }; match resp.write(sock.get_mut()).await { Ok(_) => {}, Err(e) => { - server.lock().await.on_error(e).await; + server.write().await.on_error(e).await; return; }, } diff --git a/src/ezhttp/mod.rs b/src/ezhttp/mod.rs index 9a17643..9d5e268 100644 --- a/src/ezhttp/mod.rs +++ b/src/ezhttp/mod.rs @@ -10,7 +10,7 @@ use std::{ use tokio::io::AsyncReadExt; use rusty_pool::ThreadPool; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Mutex; +use tokio::sync::RwLock; use tokio_io_timeout::TimeoutStream; pub mod error; @@ -27,6 +27,8 @@ pub use response::*; pub use starter::*; pub use handler::*; +use crate::pin_handler; + async fn read_line(data: &mut (impl AsyncReadExt + Unpin)) -> Result { let mut line = Vec::new(); @@ -64,9 +66,15 @@ pub trait HttpServer { fn on_start(&mut self, host: &str) -> impl Future + Send; fn on_close(&mut self) -> impl Future + Send; fn on_request( - &mut self, + &self, req: &HttpRequest, ) -> impl Future> + Send; + fn on_request_mut( + &mut self, + _: &HttpRequest, + ) -> impl Future> + Send { + async { None } + } fn on_error( &mut self, _: HttpError @@ -87,12 +95,12 @@ where T: HttpServer + Send + 'static, { let threadpool = ThreadPool::new(threads, threads * 10, Duration::from_secs(60)); - let server = Arc::new(Mutex::new(server)); + let server = Arc::new(RwLock::new(server)); let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); - server_clone.lock().await.on_start(&host_clone).await; + server_clone.write().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { let Ok((sock, _)) = listener.accept().await else { continue; }; @@ -108,7 +116,7 @@ where threadpool.join(); - server.lock().await.on_close().await; + server.write().await.on_close().await; Ok(()) } @@ -123,12 +131,12 @@ async fn start_server_new_thread( where T: HttpServer + Send + 'static, { - let server = Arc::new(Mutex::new(server)); + let server = Arc::new(RwLock::new(server)); let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); - server_clone.lock().await.on_start(&host_clone).await; + server_clone.write().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { let Ok((sock, _)) = listener.accept().await else { continue; }; @@ -142,7 +150,7 @@ where tokio::spawn((&handler)(now_server, sock)); } - server.lock().await.on_close().await; + server.write().await.on_close().await; Ok(()) } @@ -157,12 +165,12 @@ async fn start_server_sync( where T: HttpServer + Send + 'static, { - let server = Arc::new(Mutex::new(server)); + let server = Arc::new(RwLock::new(server)); let listener = TcpListener::bind(host).await?; let host_clone = String::from(host).clone(); let server_clone = server.clone(); - server_clone.lock().await.on_start(&host_clone).await; + server_clone.write().await.on_start(&host_clone).await; while running.load(Ordering::Acquire) { let Ok((sock, _)) = listener.accept().await else { continue; }; @@ -176,7 +184,7 @@ where handler(now_server, sock).await; } - server.lock().await.on_close().await; + server.write().await.on_close().await; Ok(()) } @@ -184,7 +192,7 @@ where /// Start [`HttpServer`](HttpServer) on some host /// /// Use [`HttpServerStarter`](HttpServerStarter) to set more options -pub async fn start_server( +pub async fn start_server( server: T, host: &str ) -> Result<(), Box> { @@ -192,7 +200,7 @@ pub async fn start_server( server, host, None, - Box::new(move |a, b| Box::pin(handler_connection(a, b))), + pin_handler!(handler_connection), Arc::new(AtomicBool::new(true)), ).await } diff --git a/src/ezhttp/starter.rs b/src/ezhttp/starter.rs index decc285..120e2e7 100644 --- a/src/ezhttp/starter.rs +++ b/src/ezhttp/starter.rs @@ -39,7 +39,7 @@ pub struct HttpServerStarter { threads: usize, } -impl HttpServerStarter { +impl HttpServerStarter { /// Create new HttpServerStarter pub fn new(http_server: T, host: &str) -> Self { HttpServerStarter { diff --git a/src/main.rs b/src/main.rs index 3e4de53..538ceaf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ impl EzSite { } impl HttpServer for EzSite { - async fn on_request(&mut self, req: &HttpRequest) -> Option { + async fn on_request(&self, req: &HttpRequest) -> Option { println!("{} > {} {}", req.addr, req.method, req.page); if req.page == "/" {