mirror of
https://github.com/azalea-rs/simdnbt.git
synced 2025-08-02 07:26:04 +00:00
fix ub on error and add a test for that
This commit is contained in:
parent
a6f47171bf
commit
97b9d5c76a
6 changed files with 75 additions and 62 deletions
|
@ -87,11 +87,12 @@ Here's a benchmark comparing Simdnbt against a few of the other fastest NBT crat
|
|||
| [hematite_nbt](https://docs.rs/hematite-nbt/latest/nbt/) | 108.91 MiB/s |
|
||||
|
||||
And for writing `complex_player.dat`:
|
||||
| Library | Throughput |
|
||||
| ----------------| ------------ |
|
||||
|
||||
| Library | Throughput |
|
||||
| --------------- | ------------ |
|
||||
| simdnbt::borrow | 2.5914 GiB/s |
|
||||
| azalea_nbt | 2.1096 GiB/s |
|
||||
| simdnbt::owned | 1.9508 GiB/s |
|
||||
| azalea_nbt | 2.1096 GiB/s |
|
||||
| simdnbt::owned | 1.9508 GiB/s |
|
||||
| graphite_binary | 1.7745 GiB/s |
|
||||
|
||||
The tables above were made from the [compare benchmark](https://github.com/azalea-rs/simdnbt/tree/master/simdnbt/benches) in this repo.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{cell::UnsafeCell, io::Cursor};
|
||||
use std::io::Cursor;
|
||||
|
||||
use byteorder::ReadBytesExt;
|
||||
|
||||
|
@ -19,27 +19,28 @@ pub struct NbtCompound<'a> {
|
|||
}
|
||||
|
||||
impl<'a> NbtCompound<'a> {
|
||||
pub fn read(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
) -> Result<Self, Error> {
|
||||
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
|
||||
Self::read_with_depth(data, alloc, 0)
|
||||
}
|
||||
|
||||
pub fn read_with_depth(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
alloc: &TagAllocator<'a>,
|
||||
depth: usize,
|
||||
) -> Result<Self, Error> {
|
||||
if depth > MAX_DEPTH {
|
||||
return Err(Error::MaxDepthExceeded);
|
||||
}
|
||||
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() };
|
||||
|
||||
let mut tags = alloc_mut.named.start(depth);
|
||||
let mut tags = alloc.get().named.start(depth);
|
||||
loop {
|
||||
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
|
||||
let tag_type = match data.read_u8() {
|
||||
Ok(tag_type) => tag_type,
|
||||
Err(_) => {
|
||||
alloc.get().named.finish(tags, depth);
|
||||
return Err(Error::UnexpectedEof);
|
||||
}
|
||||
};
|
||||
if tag_type == END_ID {
|
||||
break;
|
||||
}
|
||||
|
@ -47,7 +48,7 @@ impl<'a> NbtCompound<'a> {
|
|||
let tag_name = match read_string(data) {
|
||||
Ok(name) => name,
|
||||
Err(_) => {
|
||||
alloc_mut.named.finish(tags, depth);
|
||||
alloc.get().named.finish(tags, depth);
|
||||
// the only error read_string can return is UnexpectedEof, so this makes it
|
||||
// slightly faster
|
||||
return Err(Error::UnexpectedEof);
|
||||
|
@ -56,14 +57,13 @@ impl<'a> NbtCompound<'a> {
|
|||
let tag = match NbtTag::read_with_type(data, alloc, tag_type, depth) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.named.finish(tags, depth);
|
||||
alloc.get().named.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push((tag_name, tag));
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() };
|
||||
let values = alloc_mut.named.finish(tags, depth);
|
||||
let values = alloc.get().named.finish(tags, depth);
|
||||
|
||||
Ok(Self { values })
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{cell::UnsafeCell, io::Cursor};
|
||||
use std::io::Cursor;
|
||||
|
||||
use byteorder::ReadBytesExt;
|
||||
|
||||
|
@ -37,7 +37,7 @@ pub enum NbtList<'a> {
|
|||
impl<'a> NbtList<'a> {
|
||||
pub fn read(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
alloc: &TagAllocator<'a>,
|
||||
depth: usize,
|
||||
) -> Result<Self, Error> {
|
||||
if depth > MAX_DEPTH {
|
||||
|
@ -57,105 +57,93 @@ impl<'a> NbtList<'a> {
|
|||
DOUBLE_ID => NbtList::Double(RawList::new(read_with_u32_length(data, 8)?)),
|
||||
BYTE_ARRAY_ID => NbtList::ByteArray({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_bytearray.start(depth);
|
||||
let mut tags = alloc.get().unnamed_bytearray.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match read_u8_array(data) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_bytearray.finish(tags, depth);
|
||||
alloc.get().unnamed_bytearray.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag);
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_bytearray.finish(tags, depth)
|
||||
alloc.get().unnamed_bytearray.finish(tags, depth)
|
||||
}),
|
||||
STRING_ID => NbtList::String({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_string.start(depth);
|
||||
let mut tags = alloc.get().unnamed_string.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match read_string(data) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_string.finish(tags, depth);
|
||||
alloc.get().unnamed_string.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag);
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_string.finish(tags, depth)
|
||||
alloc.get().unnamed_string.finish(tags, depth)
|
||||
}),
|
||||
LIST_ID => NbtList::List({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_list.start(depth);
|
||||
let mut tags = alloc.get().unnamed_list.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match NbtList::read(data, alloc, depth + 1) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_list.finish(tags, depth);
|
||||
alloc.get().unnamed_list.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag)
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_list.finish(tags, depth)
|
||||
alloc.get().unnamed_list.finish(tags, depth)
|
||||
}),
|
||||
COMPOUND_ID => NbtList::Compound({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_compound.start(depth);
|
||||
let mut tags = alloc.get().unnamed_compound.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match NbtCompound::read_with_depth(data, alloc, depth + 1) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_compound.finish(tags, depth);
|
||||
alloc.get().unnamed_compound.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag);
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_compound.finish(tags, depth)
|
||||
alloc.get().unnamed_compound.finish(tags, depth)
|
||||
}),
|
||||
INT_ARRAY_ID => NbtList::IntArray({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_intarray.start(depth);
|
||||
let mut tags = alloc.get().unnamed_intarray.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match read_int_array(data) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_intarray.finish(tags, depth);
|
||||
alloc.get().unnamed_intarray.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag);
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_intarray.finish(tags, depth)
|
||||
alloc.get().unnamed_intarray.finish(tags, depth)
|
||||
}),
|
||||
LONG_ARRAY_ID => NbtList::LongArray({
|
||||
let length = read_u32(data)?;
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
let mut tags = alloc_mut.unnamed_longarray.start(depth);
|
||||
let mut tags = alloc.get().unnamed_longarray.start(depth);
|
||||
for _ in 0..length {
|
||||
let tag = match read_long_array(data) {
|
||||
Ok(tag) => tag,
|
||||
Err(e) => {
|
||||
alloc_mut.unnamed_longarray.finish(tags, depth);
|
||||
alloc.get().unnamed_longarray.finish(tags, depth);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
tags.push(tag);
|
||||
}
|
||||
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
|
||||
alloc_mut.unnamed_longarray.finish(tags, depth)
|
||||
alloc.get().unnamed_longarray.finish(tags, depth)
|
||||
}),
|
||||
_ => return Err(Error::UnknownTagId(tag_type)),
|
||||
})
|
||||
|
|
|
@ -4,7 +4,7 @@ mod compound;
|
|||
mod list;
|
||||
mod tag_alloc;
|
||||
|
||||
use std::{cell::UnsafeCell, io::Cursor, ops::Deref};
|
||||
use std::{io::Cursor, ops::Deref};
|
||||
|
||||
use byteorder::{ReadBytesExt, BE};
|
||||
|
||||
|
@ -47,7 +47,7 @@ impl<'a> Nbt<'a> {
|
|||
if root_type != COMPOUND_ID {
|
||||
return Err(Error::InvalidRootType(root_type));
|
||||
}
|
||||
let tag_alloc = UnsafeCell::new(TagAllocator::new());
|
||||
let tag_alloc = TagAllocator::new();
|
||||
|
||||
let name = read_string(data)?;
|
||||
let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?;
|
||||
|
@ -55,7 +55,7 @@ impl<'a> Nbt<'a> {
|
|||
Ok(Nbt::Some(BaseNbt {
|
||||
name,
|
||||
tag,
|
||||
_tag_alloc: tag_alloc.into_inner(),
|
||||
_tag_alloc: tag_alloc,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ impl<'a> NbtTag<'a> {
|
|||
|
||||
fn read_with_type(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
alloc: &TagAllocator<'a>,
|
||||
tag_type: u8,
|
||||
depth: usize,
|
||||
) -> Result<Self, Error> {
|
||||
|
@ -186,17 +186,14 @@ impl<'a> NbtTag<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn read(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
) -> Result<Self, Error> {
|
||||
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
|
||||
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
|
||||
Self::read_with_type(data, alloc, tag_type, 0)
|
||||
}
|
||||
|
||||
pub fn read_optional(
|
||||
data: &mut Cursor<&'a [u8]>,
|
||||
alloc: &UnsafeCell<TagAllocator<'a>>,
|
||||
alloc: &TagAllocator<'a>,
|
||||
) -> Result<Option<Self>, Error> {
|
||||
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
|
||||
if tag_type == END_ID {
|
||||
|
@ -444,4 +441,17 @@ mod tests {
|
|||
}
|
||||
assert_eq!(ints.len(), 1023);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compound_eof() {
|
||||
let mut data = Vec::new();
|
||||
data.write_u8(COMPOUND_ID).unwrap(); // root type
|
||||
data.write_u16::<BE>(0).unwrap(); // root name length
|
||||
data.write_u8(COMPOUND_ID).unwrap(); // first element type
|
||||
data.write_u16::<BE>(0).unwrap(); // first element name length
|
||||
// eof
|
||||
|
||||
let res = Nbt::read(&mut Cursor::new(&data));
|
||||
assert_eq!(res, Err(Error::UnexpectedEof));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
use std::{
|
||||
alloc::{self, Layout},
|
||||
cell::UnsafeCell,
|
||||
fmt,
|
||||
ptr::NonNull,
|
||||
};
|
||||
|
@ -29,7 +30,20 @@ use super::{NbtCompound, NbtList, NbtTag};
|
|||
const MIN_ALLOC_SIZE: usize = 1024;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TagAllocator<'a> {
|
||||
pub struct TagAllocator<'a>(UnsafeCell<TagAllocatorImpl<'a>>);
|
||||
impl<'a> TagAllocator<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self(UnsafeCell::new(TagAllocatorImpl::new()))
|
||||
}
|
||||
// shhhhh
|
||||
#[allow(clippy::mut_from_ref)]
|
||||
pub fn get(&self) -> &mut TagAllocatorImpl<'a> {
|
||||
unsafe { self.0.get().as_mut().unwrap_unchecked() }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TagAllocatorImpl<'a> {
|
||||
pub named: IndividualTagAllocator<(&'a Mutf8Str, NbtTag<'a>)>,
|
||||
|
||||
// so remember earlier when i said the depth thing is only necessary because compounds aren't
|
||||
|
@ -44,7 +58,7 @@ pub struct TagAllocator<'a> {
|
|||
pub unnamed_longarray: IndividualTagAllocator<RawList<'a, i64>>,
|
||||
}
|
||||
|
||||
impl<'a> TagAllocator<'a> {
|
||||
impl<'a> TagAllocatorImpl<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use thiserror::Error;
|
|||
|
||||
use crate::common::MAX_DEPTH;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[derive(Error, Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
#[error("Invalid root type {0}")]
|
||||
InvalidRootType(u8),
|
||||
|
|
Loading…
Add table
Reference in a new issue