1
0
Fork 0
mirror of https://github.com/azalea-rs/simdnbt.git synced 2025-08-02 07:26:04 +00:00

Merge pull request #4 from azalea-rs/reuse-allocs

Optimize simdnbt::borrow and probably introduce UB
This commit is contained in:
mat 2024-05-12 21:39:51 -05:00 committed by GitHub
commit 4d660133d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 497 additions and 93 deletions

View file

@ -60,10 +60,12 @@ The most significant and simple optimization you can do is switching to an alloc
## Implementation details
Simdnbt currently makes use of SIMD instructions for two things:
- swapping the endianness of int arrays
- checking if a string is plain ascii for faster mutf8 to utf8 conversion
- swapping the endianness of int arrays
- checking if a string is plain ascii for faster mutf8 to utf8 conversion
Simdnbt ~~cheats~~ takes some shortcuts to be this fast:
1. it requires a reference to the original data (to avoid cloning)
2. it doesn't validate/decode the mutf-8 strings at decode-time
@ -75,7 +77,7 @@ Here's a benchmark comparing Simdnbt against a few of the other fastest NBT crat
| Library | Throughput |
| --------------------------------------------------------------------------- | ------------ |
| [simdnbt::borrow](https://docs.rs/simdnbt/latest/simdnbt/borrow/index.html) | 717.45 MiB/s |
| [simdnbt::borrow](https://docs.rs/simdnbt/latest/simdnbt/borrow/index.html) | 1.7619 GiB/s |
| [simdnbt::owned](https://docs.rs/simdnbt/latest/simdnbt/owned/index.html) | 329.10 MiB/s |
| [shen_nbt5](https://docs.rs/shen-nbt5/latest/shen_nbt5/) | 306.58 MiB/s |
| [azalea_nbt](https://docs.rs/azalea-nbt/latest/azalea_nbt/) | 297.28 MiB/s |
@ -85,14 +87,14 @@ Here's a benchmark comparing Simdnbt against a few of the other fastest NBT crat
| [hematite_nbt](https://docs.rs/hematite-nbt/latest/nbt/) | 108.91 MiB/s |
And for writing `complex_player.dat`:
| Library | Throughput |
| ----------------| ------------ |
| --------------- | ------------ |
| simdnbt::borrow | 2.5914 GiB/s |
| azalea_nbt | 2.1096 GiB/s |
| simdnbt::owned | 1.9508 GiB/s |
| graphite_binary | 1.7745 GiB/s |
The tables above were made from the [compare benchmark](https://github.com/azalea-rs/simdnbt/tree/master/simdnbt/benches) in this repo.
Note that the benchmark is somewhat unfair, since `simdnbt::borrow` doesn't fully decode some things like strings and integer arrays until they're used.
Also keep in mind that if you run your own benchmark you'll get different numbers, but the speeds should be about the same relative to each other.

View file

@ -138,9 +138,9 @@ fn bench(c: &mut Criterion) {
// bench_read_file("hello_world.nbt", c);
// bench_read_file("bigtest.nbt", c);
// bench_read_file("simple_player.dat", c);
// bench_read_file("complex_player.dat", c);
bench_read_file("complex_player.dat", c);
// bench_read_file("level.dat", c);
bench_read_file("inttest1023.nbt", c);
// bench_read_file("inttest1023.nbt", c);
}
criterion_group!(compare, bench);

View file

@ -50,10 +50,11 @@ fn bench_file(filename: &str, c: &mut Criterion) {
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
fn bench(c: &mut Criterion) {
// bench_file("bigtest.nbt", c);
// bench_file("simple_player.dat", c);
bench_file("bigtest.nbt", c);
bench_file("simple_player.dat", c);
bench_file("complex_player.dat", c);
// bench_file("level.dat", c);
bench_file("level.dat", c);
// bench_file("stringtest.nbt", c);
// bench_file("inttest16.nbt", c);

View file

@ -1,4 +1,4 @@
use std::io::Cursor;
use std::{io::Cursor, mem::MaybeUninit};
use byteorder::ReadBytesExt;
@ -10,42 +10,87 @@ use crate::{
Error, Mutf8Str,
};
use super::{list::NbtList, NbtTag};
use super::{list::NbtList, tag_alloc::TagAllocator, NbtTag};
/// A list of named tags. The order of the tags is preserved.
#[derive(Debug, Default, PartialEq, Clone)]
pub struct NbtCompound<'a> {
values: Vec<(&'a Mutf8Str, NbtTag<'a>)>,
values: &'a [(&'a Mutf8Str, NbtTag<'a>)],
}
impl<'a> NbtCompound<'a> {
pub fn from_values(values: Vec<(&'a Mutf8Str, NbtTag<'a>)>) -> Self {
Self { values }
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
Self::read_with_depth(data, alloc, 0)
}
pub fn read(data: &mut Cursor<&'a [u8]>) -> Result<Self, Error> {
Self::read_with_depth(data, 0)
}
pub fn read_with_depth(data: &mut Cursor<&'a [u8]>, depth: usize) -> Result<Self, Error> {
pub fn read_with_depth(
data: &mut Cursor<&'a [u8]>,
alloc: &TagAllocator<'a>,
depth: usize,
) -> Result<Self, Error> {
if depth > MAX_DEPTH {
return Err(Error::MaxDepthExceeded);
}
let mut values = Vec::with_capacity(4);
let mut tags = alloc.get().named.start(depth);
let mut tags_buffer = unsafe {
MaybeUninit::<[MaybeUninit<(&Mutf8Str, NbtTag<'a>)>; 4]>::uninit().assume_init()
};
let mut tags_buffer_len: usize = 0;
loop {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
let tag_type = match data.read_u8() {
Ok(tag_type) => tag_type,
Err(_) => {
alloc.get().named.finish(tags, depth);
return Err(Error::UnexpectedEof);
}
};
if tag_type == END_ID {
break;
}
let tag_name = read_string(data)?;
values.push((tag_name, NbtTag::read_with_type(data, tag_type, depth)?));
let tag_name = match read_string(data) {
Ok(name) => name,
Err(_) => {
alloc.get().named.finish(tags, depth);
// the only error read_string can return is UnexpectedEof, so this makes it
// slightly faster
return Err(Error::UnexpectedEof);
}
};
let tag = match NbtTag::read_with_type(data, alloc, tag_type, depth) {
Ok(tag) => tag,
Err(e) => {
alloc.get().named.finish(tags, depth);
return Err(e);
}
};
tags_buffer[tags_buffer_len] = MaybeUninit::new((tag_name, tag));
tags_buffer_len += 1;
if tags_buffer_len == tags_buffer.len() {
// writing the tags in groups like this is slightly faster
for i in 0..tags_buffer_len {
tags.push(unsafe { tags_buffer.get_unchecked(i).assume_init_read() });
}
tags_buffer_len = 0;
}
}
for i in 0..tags_buffer_len {
tags.push(unsafe { tags_buffer.get_unchecked(i).assume_init_read() });
}
let values = alloc.get().named.finish(tags, depth);
Ok(Self { values })
}
pub fn write(&self, data: &mut Vec<u8>) {
for (name, tag) in &self.values {
for (name, tag) in self.values {
// reserve 4 bytes extra so we can avoid reallocating for small tags
data.reserve(1 + 2 + name.len() + 4);
// SAFETY: We just reserved enough space for the tag ID, the name length, the name, and
@ -109,7 +154,7 @@ impl<'a> NbtCompound<'a> {
pub fn get(&self, name: &str) -> Option<&NbtTag<'a>> {
let name = Mutf8Str::from_str(name);
let name = name.as_ref();
for (key, value) in &self.values {
for (key, value) in self.values {
if key == &name {
return Some(value);
}
@ -121,7 +166,7 @@ impl<'a> NbtCompound<'a> {
pub fn contains(&self, name: &str) -> bool {
let name = Mutf8Str::from_str(name);
let name = name.as_ref();
for (key, _) in &self.values {
for (key, _) in self.values {
if key == &name {
return true;
}

View file

@ -13,7 +13,7 @@ use crate::{
Error, Mutf8Str,
};
use super::{read_u32, NbtCompound, MAX_DEPTH};
use super::{read_u32, tag_alloc::TagAllocator, NbtCompound, MAX_DEPTH};
/// A list of NBT tags of a single type.
#[repr(u8)]
@ -27,15 +27,19 @@ pub enum NbtList<'a> {
Long(RawList<'a, i64>) = LONG_ID,
Float(RawList<'a, f32>) = FLOAT_ID,
Double(RawList<'a, f64>) = DOUBLE_ID,
ByteArray(Vec<&'a [u8]>) = BYTE_ARRAY_ID,
String(Vec<&'a Mutf8Str>) = STRING_ID,
List(Vec<NbtList<'a>>) = LIST_ID,
Compound(Vec<NbtCompound<'a>>) = COMPOUND_ID,
IntArray(Vec<RawList<'a, i32>>) = INT_ARRAY_ID,
LongArray(Vec<RawList<'a, i64>>) = LONG_ARRAY_ID,
ByteArray(&'a [&'a [u8]]) = BYTE_ARRAY_ID,
String(&'a [&'a Mutf8Str]) = STRING_ID,
List(&'a [NbtList<'a>]) = LIST_ID,
Compound(&'a [NbtCompound<'a>]) = COMPOUND_ID,
IntArray(&'a [RawList<'a, i32>]) = INT_ARRAY_ID,
LongArray(&'a [RawList<'a, i64>]) = LONG_ARRAY_ID,
}
impl<'a> NbtList<'a> {
pub fn read(data: &mut Cursor<&'a [u8]>, depth: usize) -> Result<Self, Error> {
pub fn read(
data: &mut Cursor<&'a [u8]>,
alloc: &TagAllocator<'a>,
depth: usize,
) -> Result<Self, Error> {
if depth > MAX_DEPTH {
return Err(Error::MaxDepthExceeded);
}
@ -53,57 +57,93 @@ impl<'a> NbtList<'a> {
DOUBLE_ID => NbtList::Double(RawList::new(read_with_u32_length(data, 8)?)),
BYTE_ARRAY_ID => NbtList::ByteArray({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut arrays = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_bytearray.start(depth);
for _ in 0..length {
arrays.push(read_u8_array(data)?)
let tag = match read_u8_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_bytearray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
arrays
alloc.get().unnamed_bytearray.finish(tags, depth)
}),
STRING_ID => NbtList::String({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut strings = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_string.start(depth);
for _ in 0..length {
strings.push(read_string(data)?)
let tag = match read_string(data) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_string.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
strings
alloc.get().unnamed_string.finish(tags, depth)
}),
LIST_ID => NbtList::List({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut lists = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_list.start(depth);
for _ in 0..length {
lists.push(NbtList::read(data, depth + 1)?)
let tag = match NbtList::read(data, alloc, depth + 1) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_list.finish(tags, depth);
return Err(e);
}
};
tags.push(tag)
}
lists
alloc.get().unnamed_list.finish(tags, depth)
}),
COMPOUND_ID => NbtList::Compound({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut compounds = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_compound.start(depth);
for _ in 0..length {
compounds.push(NbtCompound::read_with_depth(data, depth + 1)?)
let tag = match NbtCompound::read_with_depth(data, alloc, depth + 1) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_compound.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
compounds
alloc.get().unnamed_compound.finish(tags, depth)
}),
INT_ARRAY_ID => NbtList::IntArray({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut arrays = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_intarray.start(depth);
for _ in 0..length {
arrays.push(read_int_array(data)?)
let tag = match read_int_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_intarray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
arrays
alloc.get().unnamed_intarray.finish(tags, depth)
}),
LONG_ARRAY_ID => NbtList::LongArray({
let length = read_u32(data)?;
// arbitrary number to prevent big allocations
let mut arrays = Vec::with_capacity(length.min(128) as usize);
let mut tags = alloc.get().unnamed_longarray.start(depth);
for _ in 0..length {
arrays.push(read_long_array(data)?)
let tag = match read_long_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc.get().unnamed_longarray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
arrays
alloc.get().unnamed_longarray.finish(tags, depth)
}),
_ => return Err(Error::UnknownTagId(tag_type)),
})
@ -118,7 +158,7 @@ impl<'a> NbtList<'a> {
unchecked_push(data, COMPOUND_ID);
unchecked_extend(data, &(compounds.len() as u32).to_be_bytes());
}
for compound in compounds {
for compound in *compounds {
compound.write(data);
}
return;
@ -155,13 +195,13 @@ impl<'a> NbtList<'a> {
}
NbtList::String(strings) => {
write_u32(data, strings.len() as u32);
for string in strings {
for string in *strings {
write_string(data, string);
}
}
NbtList::List(lists) => {
write_u32(data, lists.len() as u32);
for list in lists {
for list in *lists {
list.write(data);
}
}
@ -170,13 +210,13 @@ impl<'a> NbtList<'a> {
}
NbtList::IntArray(int_arrays) => {
write_u32(data, int_arrays.len() as u32);
for array in int_arrays {
for array in *int_arrays {
write_with_u32_length(data, 4, array.as_big_endian());
}
}
NbtList::LongArray(long_arrays) => {
write_u32(data, long_arrays.len() as u32);
for array in long_arrays {
for array in *long_arrays {
write_with_u32_length(data, 8, array.as_big_endian());
}
}
@ -229,7 +269,7 @@ impl<'a> NbtList<'a> {
_ => None,
}
}
pub fn byte_arrays(&self) -> Option<&Vec<&[u8]>> {
pub fn byte_arrays(&self) -> Option<&[&[u8]]> {
match self {
NbtList::ByteArray(byte_arrays) => Some(byte_arrays),
_ => None,

View file

@ -2,6 +2,7 @@
mod compound;
mod list;
mod tag_alloc;
use std::{io::Cursor, ops::Deref};
@ -17,13 +18,16 @@ use crate::{
Error, Mutf8Str,
};
use self::tag_alloc::TagAllocator;
pub use self::{compound::NbtCompound, list::NbtList};
/// A complete NBT container. This contains a name and a compound tag.
#[derive(Debug, PartialEq)]
#[derive(Debug)]
pub struct BaseNbt<'a> {
name: &'a Mutf8Str,
tag: NbtCompound<'a>,
// we need to keep this around so it's not deallocated
_tag_alloc: TagAllocator<'a>,
}
#[derive(Debug, PartialEq, Default)]
@ -43,10 +47,16 @@ impl<'a> Nbt<'a> {
if root_type != COMPOUND_ID {
return Err(Error::InvalidRootType(root_type));
}
let name = read_string(data)?;
let tag = NbtCompound::read_with_depth(data, 0)?;
let tag_alloc = TagAllocator::new();
Ok(Nbt::Some(BaseNbt { name, tag }))
let name = read_string(data)?;
let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?;
Ok(Nbt::Some(BaseNbt {
name,
tag,
_tag_alloc: tag_alloc,
}))
}
pub fn write(&self, data: &mut Vec<u8>) {
@ -83,6 +93,15 @@ impl<'a> BaseNbt<'a> {
self.name
}
}
impl PartialEq for BaseNbt<'_> {
fn eq(&self, other: &Self) -> bool {
// we don't need to compare the tag allocator since comparing `tag` will
// still compare the values of the tags
self.name == other.name && self.tag == other.tag
}
}
impl<'a> Deref for BaseNbt<'a> {
type Target = NbtCompound<'a>;
@ -101,35 +120,45 @@ impl<'a> BaseNbt<'a> {
}
/// A single NBT tag.
#[repr(u8)]
#[derive(Debug, PartialEq, Clone)]
pub enum NbtTag<'a> {
Byte(i8) = BYTE_ID,
Short(i16) = SHORT_ID,
Int(i32) = INT_ID,
Long(i64) = LONG_ID,
Float(f32) = FLOAT_ID,
Double(f64) = DOUBLE_ID,
ByteArray(&'a [u8]) = BYTE_ARRAY_ID,
String(&'a Mutf8Str) = STRING_ID,
List(NbtList<'a>) = LIST_ID,
Compound(NbtCompound<'a>) = COMPOUND_ID,
IntArray(RawList<'a, i32>) = INT_ARRAY_ID,
LongArray(RawList<'a, i64>) = LONG_ARRAY_ID,
Byte(i8),
Short(i16),
Int(i32),
Long(i64),
Float(f32),
Double(f64),
ByteArray(&'a [u8]),
String(&'a Mutf8Str),
List(NbtList<'a>),
Compound(NbtCompound<'a>),
IntArray(RawList<'a, i32>),
LongArray(RawList<'a, i64>),
}
impl<'a> NbtTag<'a> {
/// Get the numerical ID of the tag type.
#[inline]
pub fn id(&self) -> u8 {
// SAFETY: Because `Self` is marked `repr(u8)`, its layout is a `repr(C)`
// `union` between `repr(C)` structs, each of which has the `u8`
// discriminant as its first field, so we can read the discriminant
// without offsetting the pointer.
unsafe { *<*const _>::from(self).cast::<u8>() }
match self {
NbtTag::Byte(_) => BYTE_ID,
NbtTag::Short(_) => SHORT_ID,
NbtTag::Int(_) => INT_ID,
NbtTag::Long(_) => LONG_ID,
NbtTag::Float(_) => FLOAT_ID,
NbtTag::Double(_) => DOUBLE_ID,
NbtTag::ByteArray(_) => BYTE_ARRAY_ID,
NbtTag::String(_) => STRING_ID,
NbtTag::List(_) => LIST_ID,
NbtTag::Compound(_) => COMPOUND_ID,
NbtTag::IntArray(_) => INT_ARRAY_ID,
NbtTag::LongArray(_) => LONG_ARRAY_ID,
}
}
#[inline(always)]
fn read_with_type(
data: &mut Cursor<&'a [u8]>,
alloc: &TagAllocator<'a>,
tag_type: u8,
depth: usize,
) -> Result<Self, Error> {
@ -154,9 +183,10 @@ impl<'a> NbtTag<'a> {
)),
BYTE_ARRAY_ID => Ok(NbtTag::ByteArray(read_with_u32_length(data, 1)?)),
STRING_ID => Ok(NbtTag::String(read_string(data)?)),
LIST_ID => Ok(NbtTag::List(NbtList::read(data, depth + 1)?)),
LIST_ID => Ok(NbtTag::List(NbtList::read(data, alloc, depth + 1)?)),
COMPOUND_ID => Ok(NbtTag::Compound(NbtCompound::read_with_depth(
data,
alloc,
depth + 1,
)?)),
INT_ARRAY_ID => Ok(NbtTag::IntArray(read_int_array(data)?)),
@ -165,17 +195,20 @@ impl<'a> NbtTag<'a> {
}
}
pub fn read(data: &mut Cursor<&'a [u8]>) -> Result<Self, Error> {
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
Self::read_with_type(data, tag_type, 0)
Self::read_with_type(data, alloc, tag_type, 0)
}
pub fn read_optional(data: &mut Cursor<&'a [u8]>) -> Result<Option<Self>, Error> {
pub fn read_optional(
data: &mut Cursor<&'a [u8]>,
alloc: &TagAllocator<'a>,
) -> Result<Option<Self>, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
if tag_type == END_ID {
return Ok(None);
}
Ok(Some(Self::read_with_type(data, tag_type, 0)?))
Ok(Some(Self::read_with_type(data, alloc, tag_type, 0)?))
}
pub fn byte(&self) -> Option<i8> {
@ -417,4 +450,17 @@ mod tests {
}
assert_eq!(ints.len(), 1023);
}
#[test]
fn compound_eof() {
let mut data = Vec::new();
data.write_u8(COMPOUND_ID).unwrap(); // root type
data.write_u16::<BE>(0).unwrap(); // root name length
data.write_u8(COMPOUND_ID).unwrap(); // first element type
data.write_u16::<BE>(0).unwrap(); // first element name length
// eof
let res = Nbt::read(&mut Cursor::new(&data));
assert_eq!(res, Err(Error::UnexpectedEof));
}
}

View file

@ -0,0 +1,270 @@
//! Some tags, like compounds and arrays, contain other tags. The naive approach would be to just
//! use `Vec`s or `HashMap`s, but this is inefficient and leads to many small allocations.
//!
//! Instead, the idea for this is essentially that we'd have two big Vec for every tag (one for
//! named tags and one for unnamed tags), and then compounds/arrays simply contain a slice of this
//! vec.
//!
//! This almost works. but there's two main issues:
//! - compounds aren't length-prefixed, so we can't pre-allocate at the beginning of compounds for
//! the rest of that compound
//! - resizing a vec might move it in memory, invalidating all of our slices to it
//!
//! solving the first problem isn't that hard, since we can have a separate vec for every "depth"
//! (so compounds in compounds don't share the same vec).
//! to solve the second problem, i chose to implement a special data structure
//! that relies on low-level allocations so we can guarantee that our allocations don't move in memory.
use std::{
alloc::{self, Layout},
cell::UnsafeCell,
fmt,
ptr::NonNull,
};
use crate::{raw_list::RawList, Mutf8Str};
use super::{NbtCompound, NbtList, NbtTag};
// this value appears to have the best results on my pc when testing with complex_player.dat
const MIN_ALLOC_SIZE: usize = 1024;
#[derive(Default)]
pub struct TagAllocator<'a>(UnsafeCell<TagAllocatorImpl<'a>>);
impl<'a> TagAllocator<'a> {
pub fn new() -> Self {
Self(UnsafeCell::new(TagAllocatorImpl::new()))
}
// shhhhh
#[allow(clippy::mut_from_ref)]
pub fn get(&self) -> &mut TagAllocatorImpl<'a> {
unsafe { self.0.get().as_mut().unwrap_unchecked() }
}
}
#[derive(Default)]
pub struct TagAllocatorImpl<'a> {
pub named: IndividualTagAllocator<(&'a Mutf8Str, NbtTag<'a>)>,
// so remember earlier when i said the depth thing is only necessary because compounds aren't
// length prefixed? ... well soooo i decided to make arrays store per-depth separately too to
// avoid exploits where an array with a big length is sent to force it to immediately allocate
// a lot
pub unnamed_list: IndividualTagAllocator<NbtList<'a>>,
pub unnamed_compound: IndividualTagAllocator<NbtCompound<'a>>,
pub unnamed_bytearray: IndividualTagAllocator<&'a [u8]>,
pub unnamed_string: IndividualTagAllocator<&'a Mutf8Str>,
pub unnamed_intarray: IndividualTagAllocator<RawList<'a, i32>>,
pub unnamed_longarray: IndividualTagAllocator<RawList<'a, i64>>,
}
impl<'a> TagAllocatorImpl<'a> {
pub fn new() -> Self {
Self::default()
}
}
pub struct IndividualTagAllocator<T> {
// it's a vec because of the depth thing mentioned earlier, index in the vec = depth
current: Vec<TagsAllocation<T>>,
// we also have to keep track of old allocations so we can deallocate them later
previous: Vec<Vec<TagsAllocation<T>>>,
}
impl<T> IndividualTagAllocator<T>
where
T: Clone,
{
pub fn start(&mut self, depth: usize) -> ContiguousTagsAllocator<T> {
// make sure we have enough space for this depth
// (also note that depth is reused for compounds and arrays so we might have to push
// more than once)
for _ in self.current.len()..=depth {
self.current.push(Default::default());
self.previous.push(Default::default());
}
let alloc = self.current[depth].clone();
start_allocating_tags(alloc)
}
pub fn finish<'a>(&mut self, alloc: ContiguousTagsAllocator<T>, depth: usize) -> &'a [T] {
finish_allocating_tags(alloc, &mut self.current[depth], &mut self.previous[depth])
}
}
impl<T> Default for IndividualTagAllocator<T> {
fn default() -> Self {
Self {
current: Default::default(),
previous: Default::default(),
}
}
}
impl<T> Drop for IndividualTagAllocator<T> {
fn drop(&mut self) {
self.current.iter_mut().for_each(TagsAllocation::deallocate);
self.previous
.iter_mut()
.flatten()
.for_each(TagsAllocation::deallocate);
}
}
fn start_allocating_tags<T>(alloc: TagsAllocation<T>) -> ContiguousTagsAllocator<T> {
let is_new_allocation = alloc.cap == 0;
ContiguousTagsAllocator {
alloc,
is_new_allocation,
size: 0,
}
}
fn finish_allocating_tags<'a, T>(
alloc: ContiguousTagsAllocator<T>,
current_alloc: &mut TagsAllocation<T>,
previous_allocs: &mut Vec<TagsAllocation<T>>,
) -> &'a [T] {
let slice = unsafe {
std::slice::from_raw_parts(
alloc
.alloc
.ptr
.as_ptr()
.add(alloc.alloc.len)
.sub(alloc.size),
alloc.size,
)
};
let previous_allocation_at_that_depth = std::mem::replace(current_alloc, alloc.alloc);
if alloc.is_new_allocation {
previous_allocs.push(previous_allocation_at_that_depth);
}
slice
}
#[derive(Clone)]
pub struct TagsAllocation<T> {
ptr: NonNull<T>,
cap: usize,
len: usize,
}
impl<T> Default for TagsAllocation<T> {
fn default() -> Self {
Self {
ptr: NonNull::dangling(),
cap: 0,
len: 0,
}
}
}
impl<T> TagsAllocation<T> {
fn deallocate(&mut self) {
if self.cap == 0 {
return;
}
// call drop on the tags too
unsafe {
std::ptr::drop_in_place(std::slice::from_raw_parts_mut(
self.ptr.as_ptr().cast::<T>(),
self.len,
));
}
unsafe {
alloc::dealloc(
self.ptr.as_ptr().cast(),
Layout::array::<T>(self.cap).unwrap(),
)
}
}
}
// this is created when we start allocating a compound tag
pub struct ContiguousTagsAllocator<T> {
alloc: TagsAllocation<T>,
/// whether we created a new allocation for this compound (as opposed to reusing an existing
/// one).
/// this is used to determine whether we're allowed to deallocate it when growing, and whether
/// we should add this allocation to `all_allocations`
is_new_allocation: bool,
/// the size of this individual compound allocation. the size of the full allocation is in
/// `alloc.len`.
size: usize,
}
impl<T> ContiguousTagsAllocator<T> {
#[inline(never)]
fn grow(&mut self) {
let new_cap = if self.is_new_allocation {
// this makes sure we don't allocate 0 bytes
std::cmp::max(self.alloc.cap * 2, MIN_ALLOC_SIZE)
} else {
// reuse the previous cap, since it's not unlikely that we'll have another compound
// with a similar
self.alloc.cap
};
let new_layout = Layout::array::<T>(new_cap).unwrap();
let new_ptr = if self.is_new_allocation && self.alloc.ptr != NonNull::dangling() {
let old_ptr = self.alloc.ptr.as_ptr();
let old_cap = self.alloc.cap;
let old_layout = Layout::array::<T>(old_cap).unwrap();
unsafe { alloc::realloc(old_ptr as *mut u8, old_layout, new_cap) }
} else {
self.is_new_allocation = true;
unsafe { alloc::alloc(new_layout) }
} as *mut T;
// copy the last `size` elements from the old allocation to the new one
unsafe {
std::ptr::copy_nonoverlapping(
self.alloc.ptr.as_ptr().sub(self.size),
new_ptr,
self.size,
)
};
self.alloc.ptr = NonNull::new(new_ptr).unwrap();
self.alloc.cap = new_cap;
self.alloc.len = self.size;
}
#[inline]
pub fn extend_from_slice(&mut self, slice: &[T]) {
while self.alloc.len + slice.len() > self.alloc.cap {
self.grow();
}
// copy the slice
unsafe {
let end = self.alloc.ptr.as_ptr().add(self.alloc.len);
std::ptr::copy_nonoverlapping(slice.as_ptr(), end, slice.len());
}
self.alloc.len += slice.len();
self.size += slice.len();
}
#[inline]
pub fn push(&mut self, value: T) {
// check if we need to reallocate
if self.alloc.len == self.alloc.cap {
self.grow();
}
// push the new tag
unsafe {
let end = self.alloc.ptr.as_ptr().add(self.alloc.len);
std::ptr::write(end, value);
}
self.alloc.len += 1;
self.size += 1;
}
}
impl<'a> fmt::Debug for TagAllocator<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TagAllocator").finish()
}
}

View file

@ -2,7 +2,7 @@ use thiserror::Error;
use crate::common::MAX_DEPTH;
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("Invalid root type {0}")]
InvalidRootType(u8),