From 1ab27e531e14035286b0fe20e9fbf9b4f5c88108 Mon Sep 17 00:00:00 2001 From: mat Date: Sun, 12 May 2024 00:49:17 -0500 Subject: [PATCH] optimize compound tag allocations --- simdnbt/src/borrow/compound.rs | 44 ++++++++++++++++++++------------ simdnbt/src/borrow/list.rs | 14 +++++++---- simdnbt/src/borrow/mod.rs | 46 ++++++++++++++++++++++++++-------- 3 files changed, 73 insertions(+), 31 deletions(-) diff --git a/simdnbt/src/borrow/compound.rs b/simdnbt/src/borrow/compound.rs index 997c96d..1719cb2 100644 --- a/simdnbt/src/borrow/compound.rs +++ b/simdnbt/src/borrow/compound.rs @@ -1,4 +1,4 @@ -use std::io::Cursor; +use std::{cell::UnsafeCell, io::Cursor}; use byteorder::ReadBytesExt; @@ -10,28 +10,34 @@ use crate::{ Error, Mutf8Str, }; -use super::{list::NbtList, NbtTag}; +use super::{list::NbtList, tag_alloc::TagAllocator, NbtTag}; /// A list of named tags. The order of the tags is preserved. #[derive(Debug, Default, PartialEq, Clone)] pub struct NbtCompound<'a> { - values: Vec<(&'a Mutf8Str, NbtTag<'a>)>, + values: &'a [(&'a Mutf8Str, NbtTag<'a>)], } -impl<'a> NbtCompound<'a> { - pub fn from_values(values: Vec<(&'a Mutf8Str, NbtTag<'a>)>) -> Self { - Self { values } +impl<'a, 'b> NbtCompound<'a> { + pub fn read( + data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, + ) -> Result { + Self::read_with_depth(data, alloc, 0) } - pub fn read(data: &mut Cursor<&'a [u8]>) -> Result { - Self::read_with_depth(data, 0) - } - - pub fn read_with_depth(data: &mut Cursor<&'a [u8]>, depth: usize) -> Result { + pub fn read_with_depth( + data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, + depth: usize, + ) -> Result { if depth > MAX_DEPTH { return Err(Error::MaxDepthExceeded); } - let mut values = Vec::with_capacity(4); + + let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; + + let mut tags = alloc_mut.start_compound(depth); loop { let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; if tag_type == END_ID { @@ -39,13 +45,19 @@ impl<'a> NbtCompound<'a> { } let tag_name = read_string(data)?; - values.push((tag_name, NbtTag::read_with_type(data, tag_type, depth)?)); + tags.push( + tag_name, + NbtTag::read_with_type(data, alloc, tag_type, depth)?, + ); } + let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; + let values = alloc_mut.finish_compound(tags); + Ok(Self { values }) } pub fn write(&self, data: &mut Vec) { - for (name, tag) in &self.values { + for (name, tag) in self.values { // reserve 4 bytes extra so we can avoid reallocating for small tags data.reserve(1 + 2 + name.len() + 4); // SAFETY: We just reserved enough space for the tag ID, the name length, the name, and @@ -109,7 +121,7 @@ impl<'a> NbtCompound<'a> { pub fn get(&self, name: &str) -> Option<&NbtTag<'a>> { let name = Mutf8Str::from_str(name); let name = name.as_ref(); - for (key, value) in &self.values { + for (key, value) in self.values { if key == &name { return Some(value); } @@ -121,7 +133,7 @@ impl<'a> NbtCompound<'a> { pub fn contains(&self, name: &str) -> bool { let name = Mutf8Str::from_str(name); let name = name.as_ref(); - for (key, _) in &self.values { + for (key, _) in self.values { if key == &name { return true; } diff --git a/simdnbt/src/borrow/list.rs b/simdnbt/src/borrow/list.rs index 4cffd2b..2315186 100644 --- a/simdnbt/src/borrow/list.rs +++ b/simdnbt/src/borrow/list.rs @@ -1,4 +1,4 @@ -use std::io::Cursor; +use std::{cell::UnsafeCell, io::Cursor}; use byteorder::ReadBytesExt; @@ -13,7 +13,7 @@ use crate::{ Error, Mutf8Str, }; -use super::{read_u32, NbtCompound, MAX_DEPTH}; +use super::{read_u32, tag_alloc::TagAllocator, NbtCompound, MAX_DEPTH}; /// A list of NBT tags of a single type. #[repr(u8)] @@ -35,7 +35,11 @@ pub enum NbtList<'a> { LongArray(Vec>) = LONG_ARRAY_ID, } impl<'a> NbtList<'a> { - pub fn read(data: &mut Cursor<&'a [u8]>, depth: usize) -> Result { + pub fn read( + data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, + depth: usize, + ) -> Result { if depth > MAX_DEPTH { return Err(Error::MaxDepthExceeded); } @@ -74,7 +78,7 @@ impl<'a> NbtList<'a> { // arbitrary number to prevent big allocations let mut lists = Vec::with_capacity(length.min(128) as usize); for _ in 0..length { - lists.push(NbtList::read(data, depth + 1)?) + lists.push(NbtList::read(data, alloc, depth + 1)?) } lists }), @@ -83,7 +87,7 @@ impl<'a> NbtList<'a> { // arbitrary number to prevent big allocations let mut compounds = Vec::with_capacity(length.min(128) as usize); for _ in 0..length { - compounds.push(NbtCompound::read_with_depth(data, depth + 1)?) + compounds.push(NbtCompound::read_with_depth(data, alloc, depth + 1)?) } compounds }), diff --git a/simdnbt/src/borrow/mod.rs b/simdnbt/src/borrow/mod.rs index 52b0e4d..90513c6 100644 --- a/simdnbt/src/borrow/mod.rs +++ b/simdnbt/src/borrow/mod.rs @@ -2,8 +2,9 @@ mod compound; mod list; +mod tag_alloc; -use std::{io::Cursor, ops::Deref}; +use std::{cell::UnsafeCell, io::Cursor, ops::Deref}; use byteorder::{ReadBytesExt, BE}; @@ -17,13 +18,15 @@ use crate::{ Error, Mutf8Str, }; +use self::tag_alloc::TagAllocator; pub use self::{compound::NbtCompound, list::NbtList}; /// A complete NBT container. This contains a name and a compound tag. -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub struct BaseNbt<'a> { name: &'a Mutf8Str, tag: NbtCompound<'a>, + tag_alloc: TagAllocator<'a>, } #[derive(Debug, PartialEq, Default)] @@ -43,10 +46,16 @@ impl<'a> Nbt<'a> { if root_type != COMPOUND_ID { return Err(Error::InvalidRootType(root_type)); } - let name = read_string(data)?; - let tag = NbtCompound::read_with_depth(data, 0)?; + let tag_alloc = UnsafeCell::new(TagAllocator::new()); - Ok(Nbt::Some(BaseNbt { name, tag })) + let name = read_string(data)?; + let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?; + + Ok(Nbt::Some(BaseNbt { + name, + tag, + tag_alloc: tag_alloc.into_inner(), + })) } pub fn write(&self, data: &mut Vec) { @@ -83,6 +92,15 @@ impl<'a> BaseNbt<'a> { self.name } } + +impl PartialEq for BaseNbt<'_> { + fn eq(&self, other: &Self) -> bool { + // we don't need to compare the tag allocator since comparing `tag` will + // still compare the values of the tags + self.name == other.name && self.tag == other.tag + } +} + impl<'a> Deref for BaseNbt<'a> { type Target = NbtCompound<'a>; @@ -130,6 +148,7 @@ impl<'a> NbtTag<'a> { fn read_with_type( data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, tag_type: u8, depth: usize, ) -> Result { @@ -154,9 +173,10 @@ impl<'a> NbtTag<'a> { )), BYTE_ARRAY_ID => Ok(NbtTag::ByteArray(read_with_u32_length(data, 1)?)), STRING_ID => Ok(NbtTag::String(read_string(data)?)), - LIST_ID => Ok(NbtTag::List(NbtList::read(data, depth + 1)?)), + LIST_ID => Ok(NbtTag::List(NbtList::read(data, alloc, depth + 1)?)), COMPOUND_ID => Ok(NbtTag::Compound(NbtCompound::read_with_depth( data, + alloc, depth + 1, )?)), INT_ARRAY_ID => Ok(NbtTag::IntArray(read_int_array(data)?)), @@ -165,17 +185,23 @@ impl<'a> NbtTag<'a> { } } - pub fn read(data: &mut Cursor<&'a [u8]>) -> Result { + pub fn read( + data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, + ) -> Result { let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; - Self::read_with_type(data, tag_type, 0) + Self::read_with_type(data, alloc, tag_type, 0) } - pub fn read_optional(data: &mut Cursor<&'a [u8]>) -> Result, Error> { + pub fn read_optional( + data: &mut Cursor<&'a [u8]>, + alloc: &UnsafeCell>, + ) -> Result, Error> { let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; if tag_type == END_ID { return Ok(None); } - Ok(Some(Self::read_with_type(data, tag_type, 0)?)) + Ok(Some(Self::read_with_type(data, alloc, tag_type, 0)?)) } pub fn byte(&self) -> Option {