1
0
Fork 0
mirror of https://github.com/azalea-rs/simdnbt.git synced 2025-08-02 07:26:04 +00:00

fix ub on error and add a test for that

This commit is contained in:
mat 2024-05-13 00:39:55 +00:00
parent a6f47171bf
commit 97b9d5c76a
6 changed files with 75 additions and 62 deletions

View file

@ -87,11 +87,12 @@ Here's a benchmark comparing Simdnbt against a few of the other fastest NBT crat
| [hematite_nbt](https://docs.rs/hematite-nbt/latest/nbt/) | 108.91 MiB/s | | [hematite_nbt](https://docs.rs/hematite-nbt/latest/nbt/) | 108.91 MiB/s |
And for writing `complex_player.dat`: And for writing `complex_player.dat`:
| Library | Throughput |
| ----------------| ------------ | | Library | Throughput |
| --------------- | ------------ |
| simdnbt::borrow | 2.5914 GiB/s | | simdnbt::borrow | 2.5914 GiB/s |
| azalea_nbt | 2.1096 GiB/s | | azalea_nbt | 2.1096 GiB/s |
| simdnbt::owned | 1.9508 GiB/s | | simdnbt::owned | 1.9508 GiB/s |
| graphite_binary | 1.7745 GiB/s | | graphite_binary | 1.7745 GiB/s |
The tables above were made from the [compare benchmark](https://github.com/azalea-rs/simdnbt/tree/master/simdnbt/benches) in this repo. The tables above were made from the [compare benchmark](https://github.com/azalea-rs/simdnbt/tree/master/simdnbt/benches) in this repo.

View file

@ -1,4 +1,4 @@
use std::{cell::UnsafeCell, io::Cursor}; use std::io::Cursor;
use byteorder::ReadBytesExt; use byteorder::ReadBytesExt;
@ -19,27 +19,28 @@ pub struct NbtCompound<'a> {
} }
impl<'a> NbtCompound<'a> { impl<'a> NbtCompound<'a> {
pub fn read( pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
) -> Result<Self, Error> {
Self::read_with_depth(data, alloc, 0) Self::read_with_depth(data, alloc, 0)
} }
pub fn read_with_depth( pub fn read_with_depth(
data: &mut Cursor<&'a [u8]>, data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>, alloc: &TagAllocator<'a>,
depth: usize, depth: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
if depth > MAX_DEPTH { if depth > MAX_DEPTH {
return Err(Error::MaxDepthExceeded); return Err(Error::MaxDepthExceeded);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() }; let mut tags = alloc.get().named.start(depth);
let mut tags = alloc_mut.named.start(depth);
loop { loop {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; let tag_type = match data.read_u8() {
Ok(tag_type) => tag_type,
Err(_) => {
alloc.get().named.finish(tags, depth);
return Err(Error::UnexpectedEof);
}
};
if tag_type == END_ID { if tag_type == END_ID {
break; break;
} }
@ -47,7 +48,7 @@ impl<'a> NbtCompound<'a> {
let tag_name = match read_string(data) { let tag_name = match read_string(data) {
Ok(name) => name, Ok(name) => name,
Err(_) => { Err(_) => {
alloc_mut.named.finish(tags, depth); alloc.get().named.finish(tags, depth);
// the only error read_string can return is UnexpectedEof, so this makes it // the only error read_string can return is UnexpectedEof, so this makes it
// slightly faster // slightly faster
return Err(Error::UnexpectedEof); return Err(Error::UnexpectedEof);
@ -56,14 +57,13 @@ impl<'a> NbtCompound<'a> {
let tag = match NbtTag::read_with_type(data, alloc, tag_type, depth) { let tag = match NbtTag::read_with_type(data, alloc, tag_type, depth) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.named.finish(tags, depth); alloc.get().named.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push((tag_name, tag)); tags.push((tag_name, tag));
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() }; let values = alloc.get().named.finish(tags, depth);
let values = alloc_mut.named.finish(tags, depth);
Ok(Self { values }) Ok(Self { values })
} }

View file

@ -1,4 +1,4 @@
use std::{cell::UnsafeCell, io::Cursor}; use std::io::Cursor;
use byteorder::ReadBytesExt; use byteorder::ReadBytesExt;
@ -37,7 +37,7 @@ pub enum NbtList<'a> {
impl<'a> NbtList<'a> { impl<'a> NbtList<'a> {
pub fn read( pub fn read(
data: &mut Cursor<&'a [u8]>, data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>, alloc: &TagAllocator<'a>,
depth: usize, depth: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
if depth > MAX_DEPTH { if depth > MAX_DEPTH {
@ -57,105 +57,93 @@ impl<'a> NbtList<'a> {
DOUBLE_ID => NbtList::Double(RawList::new(read_with_u32_length(data, 8)?)), DOUBLE_ID => NbtList::Double(RawList::new(read_with_u32_length(data, 8)?)),
BYTE_ARRAY_ID => NbtList::ByteArray({ BYTE_ARRAY_ID => NbtList::ByteArray({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_bytearray.start(depth);
let mut tags = alloc_mut.unnamed_bytearray.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match read_u8_array(data) { let tag = match read_u8_array(data) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_bytearray.finish(tags, depth); alloc.get().unnamed_bytearray.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag); tags.push(tag);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_bytearray.finish(tags, depth)
alloc_mut.unnamed_bytearray.finish(tags, depth)
}), }),
STRING_ID => NbtList::String({ STRING_ID => NbtList::String({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_string.start(depth);
let mut tags = alloc_mut.unnamed_string.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match read_string(data) { let tag = match read_string(data) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_string.finish(tags, depth); alloc.get().unnamed_string.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag); tags.push(tag);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_string.finish(tags, depth)
alloc_mut.unnamed_string.finish(tags, depth)
}), }),
LIST_ID => NbtList::List({ LIST_ID => NbtList::List({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_list.start(depth);
let mut tags = alloc_mut.unnamed_list.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match NbtList::read(data, alloc, depth + 1) { let tag = match NbtList::read(data, alloc, depth + 1) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_list.finish(tags, depth); alloc.get().unnamed_list.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag) tags.push(tag)
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_list.finish(tags, depth)
alloc_mut.unnamed_list.finish(tags, depth)
}), }),
COMPOUND_ID => NbtList::Compound({ COMPOUND_ID => NbtList::Compound({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_compound.start(depth);
let mut tags = alloc_mut.unnamed_compound.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match NbtCompound::read_with_depth(data, alloc, depth + 1) { let tag = match NbtCompound::read_with_depth(data, alloc, depth + 1) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_compound.finish(tags, depth); alloc.get().unnamed_compound.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag); tags.push(tag);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_compound.finish(tags, depth)
alloc_mut.unnamed_compound.finish(tags, depth)
}), }),
INT_ARRAY_ID => NbtList::IntArray({ INT_ARRAY_ID => NbtList::IntArray({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_intarray.start(depth);
let mut tags = alloc_mut.unnamed_intarray.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match read_int_array(data) { let tag = match read_int_array(data) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_intarray.finish(tags, depth); alloc.get().unnamed_intarray.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag); tags.push(tag);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_intarray.finish(tags, depth)
alloc_mut.unnamed_intarray.finish(tags, depth)
}), }),
LONG_ARRAY_ID => NbtList::LongArray({ LONG_ARRAY_ID => NbtList::LongArray({
let length = read_u32(data)?; let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; let mut tags = alloc.get().unnamed_longarray.start(depth);
let mut tags = alloc_mut.unnamed_longarray.start(depth);
for _ in 0..length { for _ in 0..length {
let tag = match read_long_array(data) { let tag = match read_long_array(data) {
Ok(tag) => tag, Ok(tag) => tag,
Err(e) => { Err(e) => {
alloc_mut.unnamed_longarray.finish(tags, depth); alloc.get().unnamed_longarray.finish(tags, depth);
return Err(e); return Err(e);
} }
}; };
tags.push(tag); tags.push(tag);
} }
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() }; alloc.get().unnamed_longarray.finish(tags, depth)
alloc_mut.unnamed_longarray.finish(tags, depth)
}), }),
_ => return Err(Error::UnknownTagId(tag_type)), _ => return Err(Error::UnknownTagId(tag_type)),
}) })

View file

@ -4,7 +4,7 @@ mod compound;
mod list; mod list;
mod tag_alloc; mod tag_alloc;
use std::{cell::UnsafeCell, io::Cursor, ops::Deref}; use std::{io::Cursor, ops::Deref};
use byteorder::{ReadBytesExt, BE}; use byteorder::{ReadBytesExt, BE};
@ -47,7 +47,7 @@ impl<'a> Nbt<'a> {
if root_type != COMPOUND_ID { if root_type != COMPOUND_ID {
return Err(Error::InvalidRootType(root_type)); return Err(Error::InvalidRootType(root_type));
} }
let tag_alloc = UnsafeCell::new(TagAllocator::new()); let tag_alloc = TagAllocator::new();
let name = read_string(data)?; let name = read_string(data)?;
let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?; let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?;
@ -55,7 +55,7 @@ impl<'a> Nbt<'a> {
Ok(Nbt::Some(BaseNbt { Ok(Nbt::Some(BaseNbt {
name, name,
tag, tag,
_tag_alloc: tag_alloc.into_inner(), _tag_alloc: tag_alloc,
})) }))
} }
@ -149,7 +149,7 @@ impl<'a> NbtTag<'a> {
fn read_with_type( fn read_with_type(
data: &mut Cursor<&'a [u8]>, data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>, alloc: &TagAllocator<'a>,
tag_type: u8, tag_type: u8,
depth: usize, depth: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
@ -186,17 +186,14 @@ impl<'a> NbtTag<'a> {
} }
} }
pub fn read( pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
) -> Result<Self, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
Self::read_with_type(data, alloc, tag_type, 0) Self::read_with_type(data, alloc, tag_type, 0)
} }
pub fn read_optional( pub fn read_optional(
data: &mut Cursor<&'a [u8]>, data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>, alloc: &TagAllocator<'a>,
) -> Result<Option<Self>, Error> { ) -> Result<Option<Self>, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?; let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
if tag_type == END_ID { if tag_type == END_ID {
@ -444,4 +441,17 @@ mod tests {
} }
assert_eq!(ints.len(), 1023); assert_eq!(ints.len(), 1023);
} }
#[test]
fn compound_eof() {
let mut data = Vec::new();
data.write_u8(COMPOUND_ID).unwrap(); // root type
data.write_u16::<BE>(0).unwrap(); // root name length
data.write_u8(COMPOUND_ID).unwrap(); // first element type
data.write_u16::<BE>(0).unwrap(); // first element name length
// eof
let res = Nbt::read(&mut Cursor::new(&data));
assert_eq!(res, Err(Error::UnexpectedEof));
}
} }

View file

@ -17,6 +17,7 @@
use std::{ use std::{
alloc::{self, Layout}, alloc::{self, Layout},
cell::UnsafeCell,
fmt, fmt,
ptr::NonNull, ptr::NonNull,
}; };
@ -29,7 +30,20 @@ use super::{NbtCompound, NbtList, NbtTag};
const MIN_ALLOC_SIZE: usize = 1024; const MIN_ALLOC_SIZE: usize = 1024;
#[derive(Default)] #[derive(Default)]
pub struct TagAllocator<'a> { pub struct TagAllocator<'a>(UnsafeCell<TagAllocatorImpl<'a>>);
impl<'a> TagAllocator<'a> {
pub fn new() -> Self {
Self(UnsafeCell::new(TagAllocatorImpl::new()))
}
// shhhhh
#[allow(clippy::mut_from_ref)]
pub fn get(&self) -> &mut TagAllocatorImpl<'a> {
unsafe { self.0.get().as_mut().unwrap_unchecked() }
}
}
#[derive(Default)]
pub struct TagAllocatorImpl<'a> {
pub named: IndividualTagAllocator<(&'a Mutf8Str, NbtTag<'a>)>, pub named: IndividualTagAllocator<(&'a Mutf8Str, NbtTag<'a>)>,
// so remember earlier when i said the depth thing is only necessary because compounds aren't // so remember earlier when i said the depth thing is only necessary because compounds aren't
@ -44,7 +58,7 @@ pub struct TagAllocator<'a> {
pub unnamed_longarray: IndividualTagAllocator<RawList<'a, i64>>, pub unnamed_longarray: IndividualTagAllocator<RawList<'a, i64>>,
} }
impl<'a> TagAllocator<'a> { impl<'a> TagAllocatorImpl<'a> {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }

View file

@ -2,7 +2,7 @@ use thiserror::Error;
use crate::common::MAX_DEPTH; use crate::common::MAX_DEPTH;
#[derive(Error, Debug)] #[derive(Error, Debug, PartialEq)]
pub enum Error { pub enum Error {
#[error("Invalid root type {0}")] #[error("Invalid root type {0}")]
InvalidRootType(u8), InvalidRootType(u8),