diff --git a/src/main.rs b/src/main.rs index bd29bdd..cf94a77 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,28 @@ -use std::{collections::HashMap, error::Error, fs::{self, OpenOptions}, io::{Cursor, Read, Write}, net::{IpAddr, SocketAddr, TcpListener}, sync::{Arc, RwLock}, thread, time::Duration}; +use std::{ + collections::HashMap, + error::Error, + fs::{self, OpenOptions}, + io::{Cursor, Read, Write}, + net::{IpAddr, SocketAddr, TcpListener}, + sync::{ + Arc, RwLock, + atomic::{AtomicUsize, Ordering}, + }, + thread, + time::Duration, +}; use bRAC::{chat::format_message, util::sanitize_text}; use chrono::{DateTime, Local, TimeZone}; use md5::{Digest, Md5}; -use rand::{distr::Alphanumeric, Rng}; +use rand::{Rng, distr::Alphanumeric}; use clap::Parser; -use rustls::{pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, ServerConfig, ServerConnection, StreamOwned}; -use tungstenite::{accept, Bytes, Message}; - +use rustls::{ + ServerConfig, ServerConnection, StreamOwned, + pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject}, +}; +use tungstenite::{Bytes, Message, accept}; fn load_accounts(accounts_file: Option) -> Vec { if let Some(accounts_file) = accounts_file.clone() { @@ -40,27 +54,26 @@ fn load_messages(messages_file: Option) -> Vec { } pub struct Context { + args: Arc, messages_file: Option, accounts_file: Option, - messages: RwLock>, + messages: RwLock>, accounts: RwLock>, + messages_offset: AtomicUsize, + notifications: RwLock>>, timeouts: RwLock>, - messages_offset: RwLock, - notifications: RwLock>> } impl Context { - fn new( - messages_file: Option, - accounts_file: Option - ) -> Self { + fn new(args: Arc, messages_file: Option, accounts_file: Option) -> Self { Self { + args, messages_file: messages_file.clone(), accounts_file: accounts_file.clone(), messages: RwLock::new(load_messages(messages_file.clone())), accounts: RwLock::new(load_accounts(accounts_file.clone())), timeouts: RwLock::new(HashMap::new()), - messages_offset: RwLock::new(0), + messages_offset: AtomicUsize::default(), notifications: RwLock::new(HashMap::new()), } } @@ -71,19 +84,28 @@ impl Context { .write(true) .append(true) .create(true) - .open(messages_file).expect("error messages file open"); + .open(messages_file) + .expect("error messages file open"); file.write_all(&msg).expect("error messages file write"); file.flush().expect("error messages file flush"); } self.messages.write().unwrap().append(&mut msg.clone()); + + let content = self.messages.read().unwrap().clone(); + + if content.len() > self.args.messages_total_limit { + let offset = content.len() - self.args.messages_total_limit; + *self.messages.write().unwrap() = content[offset..].to_vec(); + self.messages_offset.store(offset, Ordering::SeqCst); + } } fn get_account_by_addr(&self, addr: &str) -> Option { for acc in self.accounts.read().unwrap().iter().rev() { if acc.addr() == addr { - return Some(acc.clone()) + return Some(acc.clone()); } } None @@ -92,7 +114,7 @@ impl Context { fn get_account(&self, name: &str) -> Option { for acc in self.accounts.read().unwrap().iter() { if acc.name() == name { - return Some(acc.clone()) + return Some(acc.clone()); } } None @@ -104,9 +126,11 @@ impl Context { .write(true) .append(true) .create(true) - .open(accounts_file).expect("error accounts file open"); + .open(accounts_file) + .expect("error accounts file open"); - file.write_all(&acc.to_bytes()).expect("error accounts file write"); + file.write_all(&acc.to_bytes()) + .expect("error accounts file write"); file.write_all(b"\n").expect("error accounts file write"); file.flush().expect("error accounts file flush"); } @@ -121,7 +145,7 @@ pub struct Account { pass: Vec, salt: String, addr: String, - date: i64 + date: i64, } fn password_hash(name: &str, pass: &str, salt: &str) -> Vec { @@ -147,7 +171,7 @@ impl Account { name: name.clone(), salt: salt.clone(), addr, - date + date, } } @@ -224,14 +248,14 @@ impl Account { salt, pass, addr, - date + date, } } } fn message_prefix(time_millis: i64, address: Option) -> String { let datetime: DateTime = Local.timestamp_millis_opt(time_millis).unwrap(); - + format!( "[{}]{} ", datetime.format("%d.%m.%Y %H:%M"), @@ -244,20 +268,25 @@ fn message_prefix(time_millis: i64, address: Option) -> String { } fn add_message( - buf: &mut Vec, - context: Arc, + buf: &mut Vec, + context: Arc, addr: Option, - sanitize: bool + sanitize: bool, ) -> Result<(), Box> { let mut msg = Vec::new(); - msg.append(&mut message_prefix( - Local::now().timestamp_millis(), - addr.map(|o| o.to_string()) - ).as_bytes().to_vec()); + msg.append( + &mut message_prefix(Local::now().timestamp_millis(), addr.map(|o| o.to_string())) + .as_bytes() + .to_vec(), + ); if sanitize { - msg.append(&mut sanitize_text(&String::from_utf8_lossy(&buf.clone())).as_bytes().to_vec()); + msg.append( + &mut sanitize_text(&String::from_utf8_lossy(&buf.clone())) + .as_bytes() + .to_vec(), + ); } else { msg.append(buf); } @@ -274,10 +303,9 @@ fn add_message( } fn accept_wrac_stream( - stream: impl Read + Write, + stream: impl Read + Write, addr: SocketAddr, - context: Arc, - args: Arc + ctx: Arc, ) -> Result<(), Box> { let mut websocket = match accept(stream) { Ok(i) => i, @@ -289,56 +317,92 @@ fn accept_wrac_stream( Message::Binary(o) => Some(o.to_vec()), Message::Text(o) => Some(o.as_bytes().to_vec()), Message::Close(_) => return Ok(()), - _ => None + _ => None, } { let mut data = data; - let Some(id) = data.drain(..1).next() else { return Ok(()) }; + let Some(id) = data.drain(..1).next() else { + return Ok(()); + }; if id == 0x00 { - let mut messages = context.messages.read().unwrap().clone(); + let messages = ctx.messages.read().unwrap().clone(); + + let offset = ctx.messages_offset.load(Ordering::SeqCst); + + let mut messages = if offset > 0 { + let mut buf = vec![0; offset]; + buf.append(&mut messages.clone()); + buf + } else { + messages + }; if data.is_empty() { - if let Some(splash) = &args.splash { - websocket.write(Message::Binary(Bytes::from((messages.len() + splash.len()).to_string().as_bytes().to_vec())))?; + if let Some(splash) = &ctx.args.splash { + websocket.write(Message::Binary(Bytes::from( + (messages.len() + splash.len() + offset) + .to_string() + .as_bytes() + .to_vec(), + )))?; } else { - websocket.write(Message::Binary(Bytes::from(messages.len().to_string().as_bytes().to_vec())))?; + websocket.write(Message::Binary(Bytes::from( + (messages.len() + offset).to_string().as_bytes().to_vec(), + )))?; } websocket.flush()?; } else { - let Some(id) = data.drain(..1).next() else { return Ok(()) }; + let Some(id) = data.drain(..1).next() else { + return Ok(()); + }; if id == 0x01 { - if let Some(splash) = &args.splash { + if let Some(splash) = &ctx.args.splash { messages.append(&mut splash.clone().as_bytes().to_vec()); } websocket.write(Message::Binary(Bytes::from(messages)))?; websocket.flush()?; } else if id == 0x02 { let last_size: usize = String::from_utf8(data)?.parse()?; - if let Some(splash) = &args.splash { - websocket.write(Message::Binary(Bytes::from(messages[(last_size - splash.len())..].to_vec())))?; + if let Some(splash) = &ctx.args.splash { + websocket.write(Message::Binary(Bytes::from( + messages[(last_size - splash.len())..].to_vec(), + )))?; } else { - websocket.write(Message::Binary(Bytes::from(messages[last_size..].to_vec())))?; + websocket.write(Message::Binary(Bytes::from( + messages[last_size..].to_vec(), + )))?; } websocket.flush()?; } } } else if id == 0x01 { - if !args.auth_only { - add_message(&mut data, context.clone(), Some(addr.ip()), args.sanitize)?; + if !ctx.args.auth_only { + add_message(&mut data, ctx.clone(), Some(addr.ip()), ctx.args.sanitize)?; } } else if id == 0x02 { let msg = String::from_utf8_lossy(&data).to_string(); - + let mut segments = msg.split("\n"); - - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; - let Some(text) = segments.next() else { return Ok(()) }; - - if let Some(acc) = context.get_account(name) { + + let Some(name) = segments.next() else { + return Ok(()); + }; + let Some(password) = segments.next() else { + return Ok(()); + }; + let Some(text) = segments.next() else { + return Ok(()); + }; + + if let Some(acc) = ctx.get_account(name) { if acc.check_password(password) { - add_message(&mut text.as_bytes().to_vec(), context.clone(), None, args.sanitize)?; + add_message( + &mut text.as_bytes().to_vec(), + ctx.clone(), + None, + ctx.args.sanitize, + )?; } else { websocket.write(Message::Binary(Bytes::from(vec![0x02])))?; websocket.flush()?; @@ -349,59 +413,75 @@ fn accept_wrac_stream( } } else if id == 0x03 { let msg = String::from_utf8_lossy(&data).to_string(); - + let mut segments = msg.split("\n"); - - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; - + + let Some(name) = segments.next() else { + return Ok(()); + }; + let Some(password) = segments.next() else { + return Ok(()); + }; + let addr = addr.ip().to_string(); - + let now: i64 = Local::now().timestamp_millis(); - if context.get_account(name).is_some() || ( - if let Some(acc) = context.get_account_by_addr(&addr) { - ((now - acc.date()) as usize) < 1000 * args.register_timeout + if ctx.get_account(name).is_some() + || (if let Some(acc) = ctx.get_account_by_addr(&addr) { + ((now - acc.date()) as usize) < 1000 * ctx.args.register_timeout } else { false - } - ) { + }) + { websocket.write(Message::Binary(Bytes::from(vec![0x01])))?; websocket.flush()?; continue; } - + let account = Account::new(name.to_string(), password.to_string(), addr, now); println!("user registered: {name}"); - - context.push_account(account); + + ctx.push_account(account); } } } - Ok(()) } fn accept_rac_stream( - mut stream: impl Read + Write, + mut stream: impl Read + Write, addr: SocketAddr, - context: Arc, - args: Arc + ctx: Arc, ) -> Result<(), Box> { let mut buf = vec![0]; stream.read_exact(&mut buf)?; if buf[0] == 0x00 { - let mut messages = context.messages.read().unwrap().clone(); + let messages = ctx.messages.read().unwrap().clone(); - if let Some(splash) = &args.splash { - stream.write_all((splash.len() + messages.len()).to_string().as_bytes())?; + let offset = ctx.messages_offset.load(Ordering::SeqCst); + + let mut messages = if offset > 0 { + let mut buf = vec![0; offset]; + buf.append(&mut messages.clone()); + buf + } else { + messages + }; + + if let Some(splash) = &ctx.args.splash { + stream.write_all( + (splash.len() + messages.len() + offset) + .to_string() + .as_bytes(), + )?; let mut id = vec![0]; stream.read_exact(&mut id)?; - + if id[0] == 0x01 { messages.append(&mut splash.clone().as_bytes().to_vec()); stream.write_all(&messages)?; @@ -409,12 +489,12 @@ fn accept_rac_stream( let mut buf = vec![0; 10]; let size = stream.read(&mut buf)?; buf.truncate(size); - + let len: usize = String::from_utf8(buf)?.parse()?; stream.write_all(&messages[(len - splash.len())..])?; } } else { - stream.write_all(messages.len().to_string().as_bytes())?; + stream.write_all((messages.len() + offset).to_string().as_bytes())?; let mut id = vec![0]; stream.read_exact(&mut id)?; @@ -431,12 +511,12 @@ fn accept_rac_stream( } } } else if buf[0] == 0x01 { - if !args.auth_only { + if !ctx.args.auth_only { let mut buf = vec![0; 1024]; let size = stream.read(&mut buf)?; buf.truncate(size); - - add_message(&mut buf, context.clone(), Some(addr.ip()), args.sanitize)?; + + add_message(&mut buf, ctx.clone(), Some(addr.ip()), ctx.args.sanitize)?; } } else if buf[0] == 0x02 { let mut buf = vec![0; 8192]; @@ -447,13 +527,24 @@ fn accept_rac_stream( let mut segments = msg.split("\n"); - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; - let Some(text) = segments.next() else { return Ok(()) }; + let Some(name) = segments.next() else { + return Ok(()); + }; + let Some(password) = segments.next() else { + return Ok(()); + }; + let Some(text) = segments.next() else { + return Ok(()); + }; - if let Some(acc) = context.get_account(name) { + if let Some(acc) = ctx.get_account(name) { if acc.check_password(password) { - add_message(&mut text.as_bytes().to_vec(), context.clone(), None, args.sanitize)?; + add_message( + &mut text.as_bytes().to_vec(), + ctx.clone(), + None, + ctx.args.sanitize, + )?; } else { stream.write_all(&[0x02])?; } @@ -469,20 +560,24 @@ fn accept_rac_stream( let mut segments = msg.split("\n"); - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; + let Some(name) = segments.next() else { + return Ok(()); + }; + let Some(password) = segments.next() else { + return Ok(()); + }; let addr = addr.ip().to_string(); let now: i64 = Local::now().timestamp_millis(); - if context.get_account(name).is_some() || ( - if let Some(acc) = context.get_account_by_addr(&addr) { - ((now - acc.date()) as usize) < 1000 * args.register_timeout + if ctx.get_account(name).is_some() + || (if let Some(acc) = ctx.get_account_by_addr(&addr) { + ((now - acc.date()) as usize) < 1000 * ctx.args.register_timeout } else { false - } - ) { + }) + { stream.write_all(&[0x01])?; return Ok(()); } @@ -491,86 +586,98 @@ fn accept_rac_stream( println!("user registered: {name}"); - context.push_account(account); + ctx.push_account(account); } Ok(()) } fn accept_stream( - stream: impl Read + Write, + stream: impl Read + Write, addr: SocketAddr, - context: Arc, - args: Arc + ctx: Arc, ) -> Result<(), Box> { - if args.enable_wrac { - accept_wrac_stream(stream, addr, context, args)?; + if ctx.args.enable_wrac { + accept_wrac_stream(stream, addr, ctx)?; } else { - accept_rac_stream(stream, addr, context, args)?; + accept_rac_stream(stream, addr, ctx)?; } Ok(()) } -fn run_normal_listener( - context: Arc, - args: Arc -) { - let listener = TcpListener::bind(&args.host).expect("error trying bind to the provided addr"); +fn run_normal_listener(ctx: Arc) { + let listener = + TcpListener::bind(&ctx.args.host).expect("error trying bind to the provided addr"); for stream in listener.incoming() { let Ok(stream) = stream else { continue }; - let context = context.clone(); - let args = args.clone(); + let ctx = ctx.clone(); thread::spawn(move || { - let Ok(addr) = stream.peer_addr() else { return; }; - match accept_stream(stream, addr, context, args) { - Ok(_) => {}, - Err(e) => { println!("{}", e) }, + let Ok(addr) = stream.peer_addr() else { + return; + }; + match accept_stream(stream, addr, ctx) { + Ok(_) => {} + Err(e) => { + println!("{}", e) + } } }); } } -fn run_secure_listener( - context: Arc, - args: Arc -) { - let listener = TcpListener::bind(&args.host).expect("error trying bind to the provided addr"); +fn run_secure_listener(ctx: Arc) { + let listener = + TcpListener::bind(&ctx.args.host).expect("error trying bind to the provided addr"); - let server_config = Arc::new(ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(CertificateDer::pem_file_iter( - args.ssl_cert.clone().expect("--ssl-cert is required")) - .unwrap() - .map(|cert| cert.unwrap()) - .collect(), - PrivateKeyDer::from_pem_file( - args.ssl_key.clone().expect("--ssl-key is required")).unwrap() - ).unwrap()); + let server_config = Arc::new( + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert( + CertificateDer::pem_file_iter( + ctx.args.ssl_cert.clone().expect("--ssl-cert is required"), + ) + .unwrap() + .map(|cert| cert.unwrap()) + .collect(), + PrivateKeyDer::from_pem_file( + ctx.args.ssl_key.clone().expect("--ssl-key is required"), + ) + .unwrap(), + ) + .unwrap(), + ); for stream in listener.incoming() { let Ok(stream) = stream else { continue }; - let context = context.clone(); - let args = args.clone(); + let ctx = ctx.clone(); let server_config = server_config.clone(); thread::spawn(move || { - let Ok(addr) = stream.peer_addr() else { return; }; + let Ok(addr) = stream.peer_addr() else { + return; + }; - let Ok(connection) = ServerConnection::new(server_config) else { return }; + let Ok(connection) = ServerConnection::new(server_config) else { + return; + }; let mut stream = StreamOwned::new(connection, stream); while stream.conn.is_handshaking() { - let Ok(_) = stream.conn.complete_io(&mut stream.sock) else { return }; + let Ok(_) = stream.conn.complete_io(&mut stream.sock) else { + return; + }; } - match accept_stream(stream, addr, context, args) { - Ok(_) => {}, - Err(e) => { println!("{}", e) }, + match accept_stream(stream, addr, ctx) { + Ok(_) => {} + Err(e) => { + println!("{}", e) + } } }); } @@ -580,7 +687,7 @@ fn run_secure_listener( #[command(version)] struct Args { /// Server host - #[arg(short='H', long)] + #[arg(short = 'H', long)] host: String, /// Sanitize messages @@ -592,23 +699,23 @@ struct Args { auth_only: bool, /// Splash message - #[arg(short='S', long)] + #[arg(short = 'S', long)] splash: Option, /// Save messages to file - #[arg(short='M', long)] + #[arg(short = 'M', long)] messages_file: Option, /// Save accounts to file - #[arg(short='A', long)] + #[arg(short = 'A', long)] accounts_file: Option, /// Register timeout in seconds - #[arg(short='r', long, default_value_t = 600)] + #[arg(short = 'r', long, default_value_t = 600)] register_timeout: usize, /// Message timeout in seconds - #[arg(short='m', long, default_value_t = 5)] + #[arg(short = 'm', long, default_value_t = 5)] message_timeout: usize, /// Message limit in bytes @@ -620,7 +727,7 @@ struct Args { messages_total_limit: usize, /// Enable SSL (RACS) - #[arg(short='l', long)] + #[arg(short = 'l', long)] enable_ssl: bool, /// Set ssl certificate path (x509) @@ -632,21 +739,24 @@ struct Args { ssl_cert: Option, /// Enable WRAC - #[arg(short='w', long)] + #[arg(short = 'w', long)] enable_wrac: bool, } - fn main() { let args = Arc::new(Args::parse()); - let context = Arc::new(Context::new(args.messages_file.clone(), args.accounts_file.clone())); + let context = Arc::new(Context::new( + args.clone(), + args.messages_file.clone(), + args.accounts_file.clone(), + )); println!("Server started on {}", &args.host); if args.enable_ssl { - run_secure_listener(context, args); + run_secure_listener(context); } else { - run_normal_listener(context, args); + run_normal_listener(context); } }