diff --git a/src/main.rs b/src/main.rs index 97cfdba..56837d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use rand::{distr::Alphanumeric, Rng}; use clap::Parser; use rustls::{pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, ServerConfig, ServerConnection, StreamOwned}; +use tungstenite::{accept, Bytes, Message, WebSocket}; #[derive(Clone)] @@ -186,123 +187,244 @@ fn accept_stream( messages: Arc>>, accounts: Arc>> ) -> Result<(), Box> { - let mut buf = vec![0]; - stream.read_exact(&mut buf)?; + if args.enable_wrac { + let mut websocket = accept(stream).unwrap(); - if buf[0] == 0x00 { - let mut messages = messages.read().unwrap().clone(); + while let Ok(msg) = websocket.read() { + if let Some(data) = if msg.is_binary() { + Some(msg.into_data().to_vec()) + } else if msg.is_text() { + msg.into_text().ok().map(|o| o.as_bytes().to_vec()) + } else { + None + } { + let mut data = data; + let Some(id) = data.drain(..1).next() else { return Ok(()) }; - if let Some(splash) = &args.splash { - stream.write_all((splash.len() + messages.len()).to_string().as_bytes())?; + if id == 0x00 { + let mut messages = messages.read().unwrap().clone(); - let mut id = vec![0]; - stream.read_exact(&mut id)?; - - if id[0] == 0x01 { - messages.append(&mut splash.clone().as_bytes().to_vec()); - stream.write_all(&messages)?; - } else if id[0] == 0x02 { - let mut buf = vec![0; 10]; - let size = stream.read(&mut buf)?; - buf.truncate(size); - - let len: usize = String::from_utf8(buf)?.parse()?; - stream.write_all(&messages[(len - splash.len())..])?; - } - } else { - stream.write_all(messages.len().to_string().as_bytes())?; + if data.is_empty() { + if let Some(splash) = &args.splash { + websocket.write(Message::Binary(Bytes::from((messages.len() + splash.len()).to_string().as_bytes().to_vec())))?; + } else { + websocket.write(Message::Binary(Bytes::from(messages.len().to_string().as_bytes().to_vec())))?; + } + } else { + let Some(id) = data.drain(..1).next() else { return Ok(()) }; - let mut id = vec![0]; - stream.read_exact(&mut id)?; + if id == 0x01 { + if let Some(splash) = &args.splash { + messages.append(&mut splash.clone().as_bytes().to_vec()); + } + websocket.write(Message::Binary(Bytes::from(messages)))?; + } else if id == 0x02 { + let last_size: usize = String::from_utf8(data)?.parse()?; + if let Some(splash) = &args.splash { + websocket.write(Message::Binary(Bytes::from(messages[(last_size - splash.len())..].to_vec())))?; + } else { + websocket.write(Message::Binary(Bytes::from(messages[last_size..].to_vec())))?; + } + } + } + } else if id == 0x01 { + if !args.auth_only { + add_message(&mut data, messages.clone(), Some(addr.ip()), args.sanitize, args.messages_file.clone())?; + } + } else if id == 0x02 { + let msg = String::from_utf8_lossy(&data).to_string(); + + let mut segments = msg.split("\n"); + + let Some(name) = segments.next() else { return Ok(()) }; + let Some(password) = segments.next() else { return Ok(()) }; + let Some(text) = segments.next() else { return Ok(()) }; + + let mut sent = false; - if id[0] == 0x01 { - stream.write_all(&messages)?; - } else if id[0] == 0x02 { - let mut buf = vec![0; 10]; - let size = stream.read(&mut buf)?; - buf.truncate(size); + for user in accounts.read().unwrap().iter() { + if user.name() == name { + if user.check_password(password) { + add_message(&mut text.as_bytes().to_vec(), messages.clone(), None, args.sanitize, args.messages_file.clone())?; + } else { + websocket.write(Message::Binary(Bytes::from(vec![0x02])))?; + } + sent = true; + break; + } + } + + if !sent { + websocket.write(Message::Binary(Bytes::from(vec![0x01])))?; + } + } else if id == 0x03 { + let msg = String::from_utf8_lossy(&data).to_string(); + + let mut segments = msg.split("\n"); + + let Some(name) = segments.next() else { return Ok(()) }; + let Some(password) = segments.next() else { return Ok(()) }; + + let addr = addr.ip().to_string(); + + let now: i64 = Local::now().timestamp_millis(); - let len: usize = String::from_utf8(buf)?.parse()?; - stream.write_all(&messages[len..])?; + let mut continue_send = false; + + for user in accounts.read().unwrap().iter() { + if user.name() == name { + websocket.write(Message::Binary(Bytes::from(vec![0x01])))?; + continue_send = true; + break; + } + if user.addr() == addr && ((now - user.date()) as usize) < 1000 * args.register_timeout { + websocket.write(Message::Binary(Bytes::from(vec![0x01])))?; + continue_send = true; + break; + } + } + + if continue_send { + continue; + } + + let account = Account::new(name.to_string(), password.to_string(), addr, now); + + if let Some(accounts_file) = args.accounts_file.clone() { + let mut file = OpenOptions::new() + .write(true) + .append(true) + .create(true) + .open(accounts_file)?; + + file.write_all(&account.to_bytes())?; + file.write_all(b"\n")?; + file.flush()?; + } + + accounts.write().unwrap().push(account); + } } } - } else if buf[0] == 0x01 { - if !args.auth_only { + } else { + let mut buf = vec![0]; + stream.read_exact(&mut buf)?; + + if buf[0] == 0x00 { + let mut messages = messages.read().unwrap().clone(); + + if let Some(splash) = &args.splash { + stream.write_all((splash.len() + messages.len()).to_string().as_bytes())?; + + let mut id = vec![0]; + stream.read_exact(&mut id)?; + + if id[0] == 0x01 { + messages.append(&mut splash.clone().as_bytes().to_vec()); + stream.write_all(&messages)?; + } else if id[0] == 0x02 { + let mut buf = vec![0; 10]; + let size = stream.read(&mut buf)?; + buf.truncate(size); + + let len: usize = String::from_utf8(buf)?.parse()?; + stream.write_all(&messages[(len - splash.len())..])?; + } + } else { + stream.write_all(messages.len().to_string().as_bytes())?; + + let mut id = vec![0]; + stream.read_exact(&mut id)?; + + if id[0] == 0x01 { + stream.write_all(&messages)?; + } else if id[0] == 0x02 { + let mut buf = vec![0; 10]; + let size = stream.read(&mut buf)?; + buf.truncate(size); + + let len: usize = String::from_utf8(buf)?.parse()?; + stream.write_all(&messages[len..])?; + } + } + } else if buf[0] == 0x01 { + if !args.auth_only { + let mut buf = vec![0; 1024]; + let size = stream.read(&mut buf)?; + buf.truncate(size); + + add_message(&mut buf, messages.clone(), Some(addr.ip()), args.sanitize, args.messages_file.clone())?; + } + } else if buf[0] == 0x02 { + let mut buf = vec![0; 8192]; + let size = stream.read(&mut buf)?; + buf.truncate(size); + + let msg = String::from_utf8_lossy(&buf).to_string(); + + let mut segments = msg.split("\n"); + + let Some(name) = segments.next() else { return Ok(()) }; + let Some(password) = segments.next() else { return Ok(()) }; + let Some(text) = segments.next() else { return Ok(()) }; + + for user in accounts.read().unwrap().iter() { + if user.name() == name { + if user.check_password(password) { + add_message(&mut text.as_bytes().to_vec(), messages.clone(), None, args.sanitize, args.messages_file.clone())?; + } else { + stream.write_all(&[0x02])?; + } + return Ok(()); + } + } + + stream.write_all(&[0x01])?; + } else if buf[0] == 0x03 { let mut buf = vec![0; 1024]; let size = stream.read(&mut buf)?; buf.truncate(size); - - add_message(&mut buf, messages.clone(), Some(addr.ip()), args.sanitize, args.messages_file.clone())?; - } - } else if buf[0] == 0x02 { - let mut buf = vec![0; 8192]; - let size = stream.read(&mut buf)?; - buf.truncate(size); - let msg = String::from_utf8_lossy(&buf).to_string(); + let msg = String::from_utf8_lossy(&buf).to_string(); - let mut segments = msg.split("\n"); + let mut segments = msg.split("\n"); - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; - let Some(text) = segments.next() else { return Ok(()) }; + let Some(name) = segments.next() else { return Ok(()) }; + let Some(password) = segments.next() else { return Ok(()) }; - for user in accounts.read().unwrap().iter() { - if user.name() == name { - if user.check_password(password) { - add_message(&mut text.as_bytes().to_vec(), messages.clone(), None, args.sanitize, args.messages_file.clone())?; - } else { - stream.write_all(&[0x02])?; + let addr = addr.ip().to_string(); + + let now: i64 = Local::now().timestamp_millis(); + + for user in accounts.read().unwrap().iter() { + if user.name() == name { + stream.write_all(&[0x01])?; + return Ok(()); + } + if user.addr() == addr && ((now - user.date()) as usize) < 1000 * args.register_timeout { + stream.write_all(&[0x01])?; + return Ok(()); } - return Ok(()); } - } - stream.write_all(&[0x01])?; - } else if buf[0] == 0x03 { - let mut buf = vec![0; 1024]; - let size = stream.read(&mut buf)?; - buf.truncate(size); + let account = Account::new(name.to_string(), password.to_string(), addr, now); - let msg = String::from_utf8_lossy(&buf).to_string(); + if let Some(accounts_file) = args.accounts_file.clone() { + let mut file = OpenOptions::new() + .write(true) + .append(true) + .create(true) + .open(accounts_file)?; - let mut segments = msg.split("\n"); - - let Some(name) = segments.next() else { return Ok(()) }; - let Some(password) = segments.next() else { return Ok(()) }; - - let addr = addr.ip().to_string(); - - let now: i64 = Local::now().timestamp_millis(); - - for user in accounts.read().unwrap().iter() { - if user.name() == name { - stream.write_all(&[0x01])?; - return Ok(()); - } - if user.addr() == addr && ((now - user.date()) as usize) < 1000 * args.register_timeout { - stream.write_all(&[0x01])?; - return Ok(()); + file.write_all(&account.to_bytes())?; + file.write_all(b"\n")?; + file.flush()?; } + + println!("user registered: {name}"); + + accounts.write().unwrap().push(account); } - - let account = Account::new(name.to_string(), password.to_string(), addr, now); - - if let Some(accounts_file) = args.accounts_file.clone() { - let mut file = OpenOptions::new() - .write(true) - .append(true) - .create(true) - .open(accounts_file)?; - - file.write_all(&account.to_bytes())?; - file.write_all(b"\n")?; - file.flush()?; - } - - println!("user registered: {name}"); - - accounts.write().unwrap().push(account); } Ok(())