packet builder, compression rewrite, (try) clone connection, getters and setters

This commit is contained in:
MeexReay 2024-07-15 20:13:41 +03:00
parent 1f72b280d5
commit 4ebf859a9c
5 changed files with 211 additions and 91 deletions

View File

@ -13,4 +13,4 @@ edition = "2021"
[dependencies]
flate2 = "1.0.30"
bytebuffer = "2.2.0"
uuid = "1.8.0"
uuid = "1.10.0"

View File

@ -41,7 +41,7 @@ fn accept_client(mut conn: MCConnTcp, server: Arc<Mutex<MinecraftServer>>) -> Re
};
if handshake {
if packet.id == 0x00 {
if packet.id() == 0x00 {
let mut status = Packet::empty(0x00);
let serv = server.lock().unwrap();
@ -53,12 +53,12 @@ fn accept_client(mut conn: MCConnTcp, server: Arc<Mutex<MinecraftServer>>) -> Re
status.write_string(&motd)?;
conn.write_packet(&status)?;
} else if packet.id == 0x01 {
} else if packet.id() == 0x01 {
let mut status = Packet::empty(0x01);
status.write_long(packet.read_long()?)?;
conn.write_packet(&status)?;
}
} else if packet.id == 0x00 {
} else if packet.id() == 0x00 {
let protocol_version = packet.read_i32_varint()?;
let server_address = packet.read_string()?;
let server_port = packet.read_unsigned_short()?;
@ -67,7 +67,7 @@ fn accept_client(mut conn: MCConnTcp, server: Arc<Mutex<MinecraftServer>>) -> Re
if next_state != 1 { break; }
println!("Client handshake info:");
println!(" IP: {}", conn.stream.peer_addr().unwrap());
println!(" IP: {}", conn.get_ref().peer_addr().unwrap());
println!(" Protocol version: {}", protocol_version);
println!(" Server address: {}", server_address);
println!(" Server port: {}", server_port);

View File

@ -1,37 +0,0 @@
use std::{net::TcpListener, thread, sync::mpsc::channel};
use rust_mc_proto::{DataBufferReader, DataBufferWriter, MCConnTcp, Packet};
const LONG_TEXT: &str = "some_long_text_wow_123123123123123123123123";
fn main() {
let (tx, rx) = channel::<()>();
let server_tx = tx.clone();
thread::spawn(move || {
let listener = TcpListener::bind("localhost:44447").unwrap();
server_tx.send(()).unwrap();
for stream in listener.incoming() {
let mut stream = MCConnTcp::new(stream.unwrap());
stream.set_compression(Some(2));
let packet = stream.read_packet().unwrap();
stream.write_packet(&packet).unwrap();
}
});
rx.recv().unwrap();
let mut conn = MCConnTcp::connect("localhost:44447").unwrap();
conn.set_compression(Some(2));
let mut packet = Packet::empty(0x12);
packet.write_string(LONG_TEXT).unwrap();
conn.write_packet(&packet).unwrap();
let mut packet = conn.read_packet().unwrap();
if packet.id == 0x12 && packet.read_string().unwrap() == LONG_TEXT {
println!("success");
}
}

View File

@ -1,8 +1,11 @@
use std::{error::Error, fmt, io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, sync::{Mutex, Arc, atomic::{AtomicUsize, Ordering}}};
use std::{error::Error, fmt, io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, sync::{Arc, atomic::{AtomicUsize, Ordering}}};
use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression};
use bytebuffer::ByteBuffer;
use uuid::Uuid;
#[cfg(test)]
mod tests;
pub trait Zigzag<T> { fn zigzag(&self) -> T; }
impl Zigzag<u8> for i8 { fn zigzag(&self) -> u8 { ((self << 1) ^ (self >> 7)) as u8 } }
impl Zigzag<i8> for u8 { fn zigzag(&self) -> i8 { ((self >> 1) as i8) ^ (-((self & 1) as i8)) } }
@ -39,19 +42,11 @@ impl fmt::Display for ProtocolError {
impl Error for ProtocolError {}
#[derive(Debug)]
/// Minecraft packet
#[derive(Debug, Clone)]
pub struct Packet {
pub id: u8,
pub buffer: ByteBuffer
}
macro_rules! return_error {
($ex: expr, $error: expr) => {
match $ex {
Ok(i) => i,
Err(_) => { return Err($error) },
}
};
id: u8,
buffer: ByteBuffer
}
macro_rules! size_varint {
@ -61,7 +56,7 @@ macro_rules! size_varint {
let mut size: $type = 0;
loop {
let next = return_error!(DataBufferReader::read_byte($self), ProtocolError::VarIntError);
let next = DataBufferReader::read_byte($self).or(Err(ProtocolError::VarIntError))?;
size += 1;
if shift >= (std::mem::size_of::<$type>() * 8) as $type {
@ -85,7 +80,7 @@ macro_rules! read_varint {
let mut decoded: $type = 0;
loop {
let next = return_error!(DataBufferReader::read_byte($self), ProtocolError::VarIntError);
let next = DataBufferReader::read_byte($self).or(Err(ProtocolError::VarIntError))?;
if shift >= (std::mem::size_of::<$type>() * 8) as $type {
return Err(ProtocolError::VarIntError);
@ -107,16 +102,16 @@ macro_rules! write_varint {
let mut value: $type = $value;
if value == 0 {
Ok(return_error!(DataBufferWriter::write_byte($self, 0), ProtocolError::VarIntError))
DataBufferWriter::write_byte($self, 0).or(Err(ProtocolError::VarIntError))
} else {
while value >= 0b10000000 {
let next: u8 = ((value & 0b01111111) as u8) | 0b10000000;
value >>= 7;
return_error!(DataBufferWriter::write_byte($self, next), ProtocolError::VarIntError);
DataBufferWriter::write_byte($self, next).or(Err(ProtocolError::VarIntError))?;
}
Ok(return_error!(DataBufferWriter::write_byte($self, (value & 0b01111111) as u8), ProtocolError::VarIntError))
DataBufferWriter::write_byte($self, (value & 0b01111111) as u8).or(Err(ProtocolError::VarIntError))
}
}};
}
@ -407,6 +402,47 @@ impl Packet {
buffer: ByteBuffer::new()
}
}
/// Build packet with lambda
pub fn build<F>(
id: u8,
builder: F
) -> Result<Packet, ProtocolError>
where F: FnOnce(&mut Packet) -> Result<(), ProtocolError> {
let mut packet = Self::empty(id);
builder(&mut packet)?;
Ok(packet)
}
/// Get packet id
pub fn id(&self) -> u8 {
self.id
}
/// Set packet id
pub fn set_id(&mut self, id: u8) {
self.id = id;
}
/// Get mutable reference of buffer
pub fn buffer(&mut self) -> &mut ByteBuffer {
&mut self.buffer
}
/// Set packet buffer
pub fn set_buffer(&mut self, buffer: ByteBuffer) {
self.buffer = buffer;
}
/// Get buffer length
pub fn len(&self) -> usize {
self.buffer.len()
}
/// Get buffer bytes
pub fn get_bytes(&self) -> Vec<u8> {
self.buffer.as_bytes().to_vec()
}
}
impl DataBufferReader for Packet {
@ -429,8 +465,8 @@ impl DataBufferWriter for Packet {
}
pub struct MinecraftConnection<T: Read + Write> {
pub stream: T,
compression: Option<usize>
stream: T,
compression: Arc<AtomicUsize>
}
impl MinecraftConnection<TcpStream> {
@ -451,7 +487,7 @@ impl MinecraftConnection<TcpStream> {
Ok(MinecraftConnection {
stream,
compression: None
compression: Arc::new(AtomicUsize::new(usize::MAX))
})
}
@ -459,6 +495,22 @@ impl MinecraftConnection<TcpStream> {
pub fn close(&mut self) {
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
/// Try clone MinecraftConnection with compression and stream
pub fn try_clone(&mut self) -> Result<MinecraftConnection<TcpStream>, ProtocolError> {
match self.stream.try_clone() {
Ok(stream) => {
Ok(MinecraftConnection {
stream: stream,
compression: self.compression.clone()
})
},
_ => {
Err(ProtocolError::CloneError)
},
}
}
}
impl<T: Read + Write> DataBufferReader for MinecraftConnection<T> {
@ -485,37 +537,78 @@ impl<T: Read + Write> MinecraftConnection<T> {
pub fn new(stream: T) -> MinecraftConnection<T> {
MinecraftConnection {
stream,
compression: None
compression: Arc::new(AtomicUsize::new(usize::MAX))
}
}
/// Set compression threashold
pub fn set_compression(&mut self, threashold: Option<usize>) {
self.compression = threashold;
/// Set compression threshold
pub fn set_compression(&mut self, threshold: Option<usize>) {
self.compression = Arc::new(AtomicUsize::new(
match threshold {
Some(t) => t,
None => usize::MAX,
}
));
}
/// Get compression threshold
pub fn get_compression(&self) -> Option<usize> {
let threshold = self.compression.load(Ordering::Relaxed);
if threshold == usize::MAX {
None
} else {
Some(threshold)
}
}
/// Get mutable reference of stream
pub fn get_mut(&mut self) -> &mut T {
&mut self.stream
}
/// Get immutable reference of stream
pub fn get_ref(&self) -> &T {
&self.stream
}
/// Read [`Packet`](Packet) from connection
pub fn read_packet(&mut self) -> Result<Packet, ProtocolError> {
read_packet(&mut self.stream, Arc::new(AtomicUsize::new(if self.compression.is_none() {usize::MAX} else {self.compression.unwrap()})))
read_packet_atomic(
&mut self.stream,
self.compression.clone(),
Ordering::Relaxed)
}
/// Write [`Packet`](Packet) to connection
pub fn write_packet(&mut self, packet: &Packet) -> Result<(), ProtocolError> {
write_packet(&mut self.stream, Arc::new(AtomicUsize::new(if self.compression.is_none() {usize::MAX} else {self.compression.unwrap()})), packet)
write_packet_atomic(
&mut self.stream,
self.compression.clone(),
Ordering::Relaxed,
packet)
}
}
impl<T: Read + Write + Clone> MinecraftConnection<T> {
/// Clone MinecraftConnection with compression and stream
pub fn clone(&mut self) -> MinecraftConnection<T> {
MinecraftConnection {
stream: self.stream.clone(),
compression: self.compression.clone()
}
}
}
fn compress_zlib(bytes: &[u8]) -> Result<Vec<u8>, ProtocolError> {
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::fast());
return_error!(encoder.write_all(bytes), ProtocolError::ZlibError);
let output = return_error!(encoder.finish(), ProtocolError::ZlibError);
Ok(output)
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder.write_all(bytes).or(Err(ProtocolError::ZlibError))?;
encoder.finish().or(Err(ProtocolError::ZlibError))
}
fn decompress_zlib(bytes: &[u8], packet_length: usize) -> Result<Vec<u8>, ProtocolError> {
fn decompress_zlib(bytes: &[u8]) -> Result<Vec<u8>, ProtocolError> {
let mut decoder = ZlibDecoder::new(bytes);
let mut output = vec![0;packet_length];
return_error!(decoder.read_exact(&mut output), ProtocolError::ZlibError);
let mut output = Vec::new();
decoder.read_to_end(&mut output).or(Err(ProtocolError::ZlibError))?;
Ok(output)
}
@ -526,12 +619,12 @@ pub type MCConn<T> = MinecraftConnection<T>;
pub type MCConnTcp = MinecraftConnection<TcpStream>;
/// Read [`Packet`](Packet) from stream, if compression is usize::MAX, compression is disabled
pub fn read_packet<T: Read>(stream: &mut T, compression: Arc<AtomicUsize>) -> Result<Packet, ProtocolError> {
pub fn read_packet_atomic<T: Read>(stream: &mut T, compression: Arc<AtomicUsize>, ordering: Ordering) -> Result<Packet, ProtocolError> {
let mut data: Vec<u8>;
let packet_length = stream.read_usize_varint_size()?;
let compress_threashold = compression.load(Ordering::Relaxed);
let compress_threashold = compression.load(ordering);
if compress_threashold != usize::MAX {
let data_length = stream.read_usize_varint_size()?;
@ -539,7 +632,7 @@ pub fn read_packet<T: Read>(stream: &mut T, compression: Arc<AtomicUsize>) -> Re
data = stream.read_bytes(packet_length.0 - data_length.1)?;
if data_length.0 != 0 {
data = decompress_zlib(&data, data_length.0)?;
data = decompress_zlib(&data)?;
}
} else {
data = stream.read_bytes(packet_length.0)?;
@ -549,29 +642,27 @@ pub fn read_packet<T: Read>(stream: &mut T, compression: Arc<AtomicUsize>) -> Re
}
/// Write [`Packet`](Packet) to stream, if compression is usize::MAX, compression is disabled
pub fn write_packet<T: Write>(stream: &mut T, compression: Arc<AtomicUsize>, packet: &Packet) -> Result<(), ProtocolError> {
pub fn write_packet_atomic<T: Write>(stream: &mut T, compression: Arc<AtomicUsize>, ordering: Ordering, packet: &Packet) -> Result<(), ProtocolError> {
let mut buf = ByteBuffer::new();
let mut data_buf = ByteBuffer::new();
data_buf.write_u8_varint(packet.id)?;
data_buf.write_buffer(&packet.buffer)?;
let compress_threashold = compression.load(Ordering::Relaxed);
let compress_threshold = compression.load(ordering);
if compress_threashold != usize::MAX {
if compress_threshold != usize::MAX {
let mut packet_buf = ByteBuffer::new();
let mut data = data_buf.as_bytes().to_vec();
let mut data_length = 0;
if data.len() >= compress_threashold {
data_length = data.len();
data = compress_zlib(&data)?;
if data_buf.len() >= compress_threshold {
let compressed_data = compress_zlib(data_buf.as_bytes())?;
packet_buf.write_usize_varint(data_buf.len())?;
packet_buf.write_all(&compressed_data).or(Err(ProtocolError::WriteError))?;
} else {
packet_buf.write_usize_varint(0)?;
packet_buf.write_buffer(&data_buf)?;
}
packet_buf.write_usize_varint(data_length)?;
DataBufferWriter::write_bytes(&mut packet_buf, &data)?;
buf.write_usize_varint(packet_buf.len())?;
buf.write_buffer(&packet_buf)?;
} else {

66
src/tests.rs Normal file
View File

@ -0,0 +1,66 @@
use super::*;
use std::{net::TcpListener, thread};
#[test]
fn test_compression_server_client() -> Result<(), ProtocolError> {
fn test(first_text: &str) -> Result<bool, ProtocolError> {
let Ok(mut conn) = MCConnTcp::connect("localhost:44447") else { return test(first_text) };
conn.set_compression(Some(5));
let mut packet = Packet::empty(0x12);
packet.write_string(first_text)?;
conn.write_packet(&packet)?;
println!("[c -> s] sent packet with text \"{}\"", first_text);
let mut packet = conn.read_packet()?;
let text = packet.read_string()?;
println!("[c <- s] read packet with text \"{}\"", text);
Ok(packet.id() == 0x12 && text == first_text)
}
thread::spawn(move || -> Result<(), ProtocolError> {
let listener = TcpListener::bind("localhost:44447").or(Err(ProtocolError::StreamConnectError))?;
for stream in listener.incoming() {
let mut stream = MCConnTcp::new(stream.or(Err(ProtocolError::StreamConnectError))?);
stream.set_compression(Some(5));
let mut packet = stream.read_packet()?;
let text = packet.read_string()?;
println!("[s <- c] read packet with text \"{}\"", text);
stream.write_packet(&packet)?;
println!("[s -> c] sent packet with text \"{}\"", text);
}
Ok(())
});
assert!(test("12bcvf756iuyu,.,.")? && test("a")?);
Ok(())
}
#[test]
fn test_compression_atomic_bytebuffer() -> Result<(), ProtocolError> {
let packet_1 = Packet::build(0x12, |p| {
p.write_bytes(b"1234567890qwertyuiopasdfghjklzxcvbnm")
})?;
let compression = Arc::new(AtomicUsize::new(5));
let mut buffer = ByteBuffer::new();
write_packet_atomic(&mut buffer, compression.clone(), Ordering::Acquire, &packet_1)?;
buffer.set_rpos(0);
buffer.set_wpos(0);
let mut packet_2 = read_packet_atomic(&mut buffer, compression.clone(), Ordering::Acquire)?;
assert_eq!(packet_2.read_bytes(36)?, b"1234567890qwertyuiopasdfghjklzxcvbnm");
Ok(())
}