diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index d3617b3f..3d910d3a 100644 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -81,7 +81,7 @@ impl HandshakeConnection { /// Write a packet to the server pub async fn write(&mut self, packet: HandshakePacket) { - write_packet(packet, &mut self.stream).await; + write_packet(packet, &mut self.stream, None).await; } } @@ -92,7 +92,7 @@ impl GameConnection { /// Write a packet to the server pub async fn write(&mut self, packet: GamePacket) { - write_packet(packet, &mut self.stream).await; + write_packet(packet, &mut self.stream, self.compression_threshold).await; } } @@ -103,7 +103,7 @@ impl StatusConnection { /// Write a packet to the server pub async fn write(&mut self, packet: StatusPacket) { - write_packet(packet, &mut self.stream).await; + write_packet(packet, &mut self.stream, None).await; } } @@ -115,7 +115,7 @@ impl LoginConnection { /// Write a packet to the server pub async fn write(&mut self, packet: LoginPacket) { - write_packet(packet, &mut self.stream).await; + write_packet(packet, &mut self.stream, self.compression_threshold).await; } pub fn set_compression_threshold(&mut self, threshold: i32) { diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 96946f29..7b94135a 100644 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -40,7 +40,7 @@ where // this is always true in multiplayer, false in singleplayer static VALIDATE_DECOMPRESSED: bool = true; -static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 8388608; +pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 8388608; async fn compression_decoder( stream: &mut R, @@ -102,7 +102,6 @@ where { let mut buf = frame_splitter(stream).await?; if let Some(compression_threshold) = compression_threshold { - println!("compression_decoder"); buf = compression_decoder(&mut buf.as_slice(), compression_threshold).await?; } let packet = packet_decoder(&mut buf.as_slice(), flow).await?; diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index bf9fd0aa..4ae9f1c1 100644 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -1,31 +1,65 @@ -use tokio::{io::AsyncWriteExt, net::TcpStream}; +use std::io::Read; -use crate::{mc_buf::Writable, packets::ProtocolPacket}; +use crate::{mc_buf::Writable, packets::ProtocolPacket, read::MAXIMUM_UNCOMPRESSED_LENGTH}; +use async_compression::tokio::bufread::ZlibEncoder; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; -pub async fn write_packet(packet: impl ProtocolPacket, stream: &mut TcpStream) { - // TODO: implement compression - - // packet structure: - // length (varint) + id (varint) + data - - // write the packet id - let mut id_and_data_buf = vec![]; - id_and_data_buf - .write_varint(packet.id() as i32) - .expect("Writing packet id failed"); - packet.write(&mut id_and_data_buf); - - // write the packet data - - // make a new buffer that has the length at the beginning - // and id+data at the end - let mut complete_buf: Vec = Vec::new(); - complete_buf - .write_varint(id_and_data_buf.len() as i32) - .expect("Writing packet length failed"); - complete_buf.append(&mut id_and_data_buf); - - // finally, write and flush to the stream - stream.write_all(&complete_buf).await.unwrap(); - stream.flush().await.unwrap(); +fn frame_prepender(data: &mut Vec) -> Result, String> { + let mut buf = Vec::new(); + buf.write_varint(data.len() as i32) + .map_err(|e| e.to_string())?; + buf.append(data); + Ok(buf) +} + +fn packet_encoder(packet: &P) -> Result, String> { + let mut buf = Vec::new(); + buf.write_varint(packet.id() as i32) + .map_err(|e| e.to_string())?; + packet.write(&mut buf); + if buf.len() > MAXIMUM_UNCOMPRESSED_LENGTH as usize { + return Err(format!( + "Packet too big (is {} bytes, should be less than {}): {:?}", + buf.len(), + MAXIMUM_UNCOMPRESSED_LENGTH, + packet + )); + } + Ok(buf) +} + +async fn compression_encoder(data: &[u8], compression_threshold: u32) -> Result, String> { + let n = data.len(); + // if it's less than the compression threshold, don't compress + if n < compression_threshold as usize { + let mut buf = Vec::new(); + buf.write_varint(0).map_err(|e| e.to_string())?; + buf.write_all(data).await.map_err(|e| e.to_string())?; + Ok(buf) + } else { + // otherwise, compress + let mut deflater = ZlibEncoder::new(data); + // write deflated data to buf + let mut buf = Vec::new(); + deflater + .read_to_end(&mut buf) + .await + .map_err(|e| e.to_string())?; + Ok(buf) + } +} + +pub async fn write_packet

(packet: P, stream: &mut TcpStream, compression_threshold: Option) +where + P: ProtocolPacket + std::fmt::Debug, +{ + let mut buf = packet_encoder(&packet).unwrap(); + if let Some(threshold) = compression_threshold { + buf = compression_encoder(&buf, threshold).await.unwrap(); + } + buf = frame_prepender(&mut buf).unwrap(); + stream.write_all(&buf).await.unwrap(); }