From 4ebf859a9cd9e2e7586b8ee13ccb16e6db1edfbd Mon Sep 17 00:00:00 2001 From: MeexReay Date: Mon, 15 Jul 2024 20:13:41 +0300 Subject: [PATCH] packet builder, compression rewrite, (try) clone connection, getters and setters --- Cargo.toml | 2 +- examples/status_server.rs | 8 +- examples/test_compression.rs | 37 ------- src/lib.rs | 189 ++++++++++++++++++++++++++--------- src/tests.rs | 66 ++++++++++++ 5 files changed, 211 insertions(+), 91 deletions(-) delete mode 100644 examples/test_compression.rs create mode 100644 src/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 404dfd7..6ea3cc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,4 @@ edition = "2021" [dependencies] flate2 = "1.0.30" bytebuffer = "2.2.0" -uuid = "1.8.0" \ No newline at end of file +uuid = "1.10.0" \ No newline at end of file diff --git a/examples/status_server.rs b/examples/status_server.rs index 10b3927..39754b5 100644 --- a/examples/status_server.rs +++ b/examples/status_server.rs @@ -41,7 +41,7 @@ fn accept_client(mut conn: MCConnTcp, server: Arc>) -> Re }; if handshake { - if packet.id == 0x00 { + if packet.id() == 0x00 { let mut status = Packet::empty(0x00); let serv = server.lock().unwrap(); @@ -53,12 +53,12 @@ fn accept_client(mut conn: MCConnTcp, server: Arc>) -> Re status.write_string(&motd)?; conn.write_packet(&status)?; - } else if packet.id == 0x01 { + } else if packet.id() == 0x01 { let mut status = Packet::empty(0x01); status.write_long(packet.read_long()?)?; conn.write_packet(&status)?; } - } else if packet.id == 0x00 { + } else if packet.id() == 0x00 { let protocol_version = packet.read_i32_varint()?; let server_address = packet.read_string()?; let server_port = packet.read_unsigned_short()?; @@ -67,7 +67,7 @@ fn accept_client(mut conn: MCConnTcp, server: Arc>) -> Re if next_state != 1 { break; } println!("Client handshake info:"); - println!(" IP: {}", conn.stream.peer_addr().unwrap()); + println!(" IP: {}", conn.get_ref().peer_addr().unwrap()); println!(" Protocol version: {}", protocol_version); println!(" Server address: {}", server_address); println!(" Server port: {}", server_port); diff --git a/examples/test_compression.rs b/examples/test_compression.rs deleted file mode 100644 index 505e548..0000000 --- a/examples/test_compression.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::{net::TcpListener, thread, sync::mpsc::channel}; -use rust_mc_proto::{DataBufferReader, DataBufferWriter, MCConnTcp, Packet}; - -const LONG_TEXT: &str = "some_long_text_wow_123123123123123123123123"; - -fn main() { - let (tx, rx) = channel::<()>(); - - let server_tx = tx.clone(); - thread::spawn(move || { - let listener = TcpListener::bind("localhost:44447").unwrap(); - - server_tx.send(()).unwrap(); - - for stream in listener.incoming() { - let mut stream = MCConnTcp::new(stream.unwrap()); - stream.set_compression(Some(2)); - - let packet = stream.read_packet().unwrap(); - stream.write_packet(&packet).unwrap(); - } - }); - - rx.recv().unwrap(); - - let mut conn = MCConnTcp::connect("localhost:44447").unwrap(); - conn.set_compression(Some(2)); - - let mut packet = Packet::empty(0x12); - packet.write_string(LONG_TEXT).unwrap(); - conn.write_packet(&packet).unwrap(); - - let mut packet = conn.read_packet().unwrap(); - if packet.id == 0x12 && packet.read_string().unwrap() == LONG_TEXT { - println!("success"); - } -} diff --git a/src/lib.rs b/src/lib.rs index 22422a1..14ae155 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,11 @@ -use std::{error::Error, fmt, io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, sync::{Mutex, Arc, atomic::{AtomicUsize, Ordering}}}; +use std::{error::Error, fmt, io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, sync::{Arc, atomic::{AtomicUsize, Ordering}}}; use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression}; use bytebuffer::ByteBuffer; use uuid::Uuid; +#[cfg(test)] +mod tests; + pub trait Zigzag { fn zigzag(&self) -> T; } impl Zigzag for i8 { fn zigzag(&self) -> u8 { ((self << 1) ^ (self >> 7)) as u8 } } impl Zigzag for u8 { fn zigzag(&self) -> i8 { ((self >> 1) as i8) ^ (-((self & 1) as i8)) } } @@ -39,19 +42,11 @@ impl fmt::Display for ProtocolError { impl Error for ProtocolError {} -#[derive(Debug)] +/// Minecraft packet +#[derive(Debug, Clone)] pub struct Packet { - pub id: u8, - pub buffer: ByteBuffer -} - -macro_rules! return_error { - ($ex: expr, $error: expr) => { - match $ex { - Ok(i) => i, - Err(_) => { return Err($error) }, - } - }; + id: u8, + buffer: ByteBuffer } macro_rules! size_varint { @@ -61,7 +56,7 @@ macro_rules! size_varint { let mut size: $type = 0; loop { - let next = return_error!(DataBufferReader::read_byte($self), ProtocolError::VarIntError); + let next = DataBufferReader::read_byte($self).or(Err(ProtocolError::VarIntError))?; size += 1; if shift >= (std::mem::size_of::<$type>() * 8) as $type { @@ -85,7 +80,7 @@ macro_rules! read_varint { let mut decoded: $type = 0; loop { - let next = return_error!(DataBufferReader::read_byte($self), ProtocolError::VarIntError); + let next = DataBufferReader::read_byte($self).or(Err(ProtocolError::VarIntError))?; if shift >= (std::mem::size_of::<$type>() * 8) as $type { return Err(ProtocolError::VarIntError); @@ -107,16 +102,16 @@ macro_rules! write_varint { let mut value: $type = $value; if value == 0 { - Ok(return_error!(DataBufferWriter::write_byte($self, 0), ProtocolError::VarIntError)) + DataBufferWriter::write_byte($self, 0).or(Err(ProtocolError::VarIntError)) } else { while value >= 0b10000000 { let next: u8 = ((value & 0b01111111) as u8) | 0b10000000; value >>= 7; - return_error!(DataBufferWriter::write_byte($self, next), ProtocolError::VarIntError); + DataBufferWriter::write_byte($self, next).or(Err(ProtocolError::VarIntError))?; } - Ok(return_error!(DataBufferWriter::write_byte($self, (value & 0b01111111) as u8), ProtocolError::VarIntError)) + DataBufferWriter::write_byte($self, (value & 0b01111111) as u8).or(Err(ProtocolError::VarIntError)) } }}; } @@ -407,6 +402,47 @@ impl Packet { buffer: ByteBuffer::new() } } + + /// Build packet with lambda + pub fn build( + id: u8, + builder: F + ) -> Result + where F: FnOnce(&mut Packet) -> Result<(), ProtocolError> { + let mut packet = Self::empty(id); + builder(&mut packet)?; + Ok(packet) + } + + /// Get packet id + pub fn id(&self) -> u8 { + self.id + } + + /// Set packet id + pub fn set_id(&mut self, id: u8) { + self.id = id; + } + + /// Get mutable reference of buffer + pub fn buffer(&mut self) -> &mut ByteBuffer { + &mut self.buffer + } + + /// Set packet buffer + pub fn set_buffer(&mut self, buffer: ByteBuffer) { + self.buffer = buffer; + } + + /// Get buffer length + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Get buffer bytes + pub fn get_bytes(&self) -> Vec { + self.buffer.as_bytes().to_vec() + } } impl DataBufferReader for Packet { @@ -429,8 +465,8 @@ impl DataBufferWriter for Packet { } pub struct MinecraftConnection { - pub stream: T, - compression: Option + stream: T, + compression: Arc } impl MinecraftConnection { @@ -451,7 +487,7 @@ impl MinecraftConnection { Ok(MinecraftConnection { stream, - compression: None + compression: Arc::new(AtomicUsize::new(usize::MAX)) }) } @@ -459,6 +495,22 @@ impl MinecraftConnection { pub fn close(&mut self) { let _ = self.stream.shutdown(std::net::Shutdown::Both); } + + /// Try clone MinecraftConnection with compression and stream + pub fn try_clone(&mut self) -> Result, ProtocolError> { + match self.stream.try_clone() { + Ok(stream) => { + Ok(MinecraftConnection { + stream: stream, + compression: self.compression.clone() + }) + }, + _ => { + Err(ProtocolError::CloneError) + }, + } + + } } impl DataBufferReader for MinecraftConnection { @@ -485,37 +537,78 @@ impl MinecraftConnection { pub fn new(stream: T) -> MinecraftConnection { MinecraftConnection { stream, - compression: None + compression: Arc::new(AtomicUsize::new(usize::MAX)) } } - /// Set compression threashold - pub fn set_compression(&mut self, threashold: Option) { - self.compression = threashold; + /// Set compression threshold + pub fn set_compression(&mut self, threshold: Option) { + self.compression = Arc::new(AtomicUsize::new( + match threshold { + Some(t) => t, + None => usize::MAX, + } + )); + } + + /// Get compression threshold + pub fn get_compression(&self) -> Option { + let threshold = self.compression.load(Ordering::Relaxed); + if threshold == usize::MAX { + None + } else { + Some(threshold) + } + } + + /// Get mutable reference of stream + pub fn get_mut(&mut self) -> &mut T { + &mut self.stream + } + + /// Get immutable reference of stream + pub fn get_ref(&self) -> &T { + &self.stream } /// Read [`Packet`](Packet) from connection pub fn read_packet(&mut self) -> Result { - read_packet(&mut self.stream, Arc::new(AtomicUsize::new(if self.compression.is_none() {usize::MAX} else {self.compression.unwrap()}))) + read_packet_atomic( + &mut self.stream, + self.compression.clone(), + Ordering::Relaxed) } /// Write [`Packet`](Packet) to connection pub fn write_packet(&mut self, packet: &Packet) -> Result<(), ProtocolError> { - write_packet(&mut self.stream, Arc::new(AtomicUsize::new(if self.compression.is_none() {usize::MAX} else {self.compression.unwrap()})), packet) + write_packet_atomic( + &mut self.stream, + self.compression.clone(), + Ordering::Relaxed, + packet) + } +} + +impl MinecraftConnection { + /// Clone MinecraftConnection with compression and stream + pub fn clone(&mut self) -> MinecraftConnection { + MinecraftConnection { + stream: self.stream.clone(), + compression: self.compression.clone() + } } } fn compress_zlib(bytes: &[u8]) -> Result, ProtocolError> { - let mut encoder = ZlibEncoder::new(Vec::new(), Compression::fast()); - return_error!(encoder.write_all(bytes), ProtocolError::ZlibError); - let output = return_error!(encoder.finish(), ProtocolError::ZlibError); - Ok(output) + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(bytes).or(Err(ProtocolError::ZlibError))?; + encoder.finish().or(Err(ProtocolError::ZlibError)) } -fn decompress_zlib(bytes: &[u8], packet_length: usize) -> Result, ProtocolError> { +fn decompress_zlib(bytes: &[u8]) -> Result, ProtocolError> { let mut decoder = ZlibDecoder::new(bytes); - let mut output = vec![0;packet_length]; - return_error!(decoder.read_exact(&mut output), ProtocolError::ZlibError); + let mut output = Vec::new(); + decoder.read_to_end(&mut output).or(Err(ProtocolError::ZlibError))?; Ok(output) } @@ -526,12 +619,12 @@ pub type MCConn = MinecraftConnection; pub type MCConnTcp = MinecraftConnection; /// Read [`Packet`](Packet) from stream, if compression is usize::MAX, compression is disabled -pub fn read_packet(stream: &mut T, compression: Arc) -> Result { +pub fn read_packet_atomic(stream: &mut T, compression: Arc, ordering: Ordering) -> Result { let mut data: Vec; let packet_length = stream.read_usize_varint_size()?; - let compress_threashold = compression.load(Ordering::Relaxed); + let compress_threashold = compression.load(ordering); if compress_threashold != usize::MAX { let data_length = stream.read_usize_varint_size()?; @@ -539,7 +632,7 @@ pub fn read_packet(stream: &mut T, compression: Arc) -> Re data = stream.read_bytes(packet_length.0 - data_length.1)?; if data_length.0 != 0 { - data = decompress_zlib(&data, data_length.0)?; + data = decompress_zlib(&data)?; } } else { data = stream.read_bytes(packet_length.0)?; @@ -549,29 +642,27 @@ pub fn read_packet(stream: &mut T, compression: Arc) -> Re } /// Write [`Packet`](Packet) to stream, if compression is usize::MAX, compression is disabled -pub fn write_packet(stream: &mut T, compression: Arc, packet: &Packet) -> Result<(), ProtocolError> { +pub fn write_packet_atomic(stream: &mut T, compression: Arc, ordering: Ordering, packet: &Packet) -> Result<(), ProtocolError> { let mut buf = ByteBuffer::new(); let mut data_buf = ByteBuffer::new(); data_buf.write_u8_varint(packet.id)?; data_buf.write_buffer(&packet.buffer)?; - let compress_threashold = compression.load(Ordering::Relaxed); + let compress_threshold = compression.load(ordering); - if compress_threashold != usize::MAX { + if compress_threshold != usize::MAX { let mut packet_buf = ByteBuffer::new(); - let mut data = data_buf.as_bytes().to_vec(); - let mut data_length = 0; - - if data.len() >= compress_threashold { - data_length = data.len(); - data = compress_zlib(&data)?; + if data_buf.len() >= compress_threshold { + let compressed_data = compress_zlib(data_buf.as_bytes())?; + packet_buf.write_usize_varint(data_buf.len())?; + packet_buf.write_all(&compressed_data).or(Err(ProtocolError::WriteError))?; + } else { + packet_buf.write_usize_varint(0)?; + packet_buf.write_buffer(&data_buf)?; } - packet_buf.write_usize_varint(data_length)?; - DataBufferWriter::write_bytes(&mut packet_buf, &data)?; - buf.write_usize_varint(packet_buf.len())?; buf.write_buffer(&packet_buf)?; } else { diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..17f8faa --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,66 @@ +use super::*; +use std::{net::TcpListener, thread}; + +#[test] +fn test_compression_server_client() -> Result<(), ProtocolError> { + fn test(first_text: &str) -> Result { + let Ok(mut conn) = MCConnTcp::connect("localhost:44447") else { return test(first_text) }; + conn.set_compression(Some(5)); + + let mut packet = Packet::empty(0x12); + packet.write_string(first_text)?; + conn.write_packet(&packet)?; + + println!("[c -> s] sent packet with text \"{}\"", first_text); + + let mut packet = conn.read_packet()?; + let text = packet.read_string()?; + + println!("[c <- s] read packet with text \"{}\"", text); + + Ok(packet.id() == 0x12 && text == first_text) + } + + thread::spawn(move || -> Result<(), ProtocolError> { + let listener = TcpListener::bind("localhost:44447").or(Err(ProtocolError::StreamConnectError))?; + + for stream in listener.incoming() { + let mut stream = MCConnTcp::new(stream.or(Err(ProtocolError::StreamConnectError))?); + stream.set_compression(Some(5)); + + let mut packet = stream.read_packet()?; + let text = packet.read_string()?; + println!("[s <- c] read packet with text \"{}\"", text); + stream.write_packet(&packet)?; + println!("[s -> c] sent packet with text \"{}\"", text); + } + + Ok(()) + }); + + assert!(test("12bcvf756iuyu,.,.")? && test("a")?); + + Ok(()) +} + +#[test] +fn test_compression_atomic_bytebuffer() -> Result<(), ProtocolError> { + let packet_1 = Packet::build(0x12, |p| { + p.write_bytes(b"1234567890qwertyuiopasdfghjklzxcvbnm") + })?; + + let compression = Arc::new(AtomicUsize::new(5)); + + let mut buffer = ByteBuffer::new(); + + write_packet_atomic(&mut buffer, compression.clone(), Ordering::Acquire, &packet_1)?; + + buffer.set_rpos(0); + buffer.set_wpos(0); + + let mut packet_2 = read_packet_atomic(&mut buffer, compression.clone(), Ordering::Acquire)?; + + assert_eq!(packet_2.read_bytes(36)?, b"1234567890qwertyuiopasdfghjklzxcvbnm"); + + Ok(()) +} \ No newline at end of file