1
0
Fork 0
mirror of https://github.com/azalea-rs/simdnbt.git synced 2025-08-02 07:26:04 +00:00

fix ub on error and add a test for that

This commit is contained in:
mat 2024-05-13 00:39:55 +00:00
parent a6f47171bf
commit 97b9d5c76a
6 changed files with 75 additions and 62 deletions

View file

@ -87,8 +87,9 @@ Here's a benchmark comparing Simdnbt against a few of the other fastest NBT crat
| [hematite_nbt](https://docs.rs/hematite-nbt/latest/nbt/) | 108.91 MiB/s |
And for writing `complex_player.dat`:
| Library | Throughput |
| ----------------| ------------ |
| --------------- | ------------ |
| simdnbt::borrow | 2.5914 GiB/s |
| azalea_nbt | 2.1096 GiB/s |
| simdnbt::owned | 1.9508 GiB/s |

View file

@ -1,4 +1,4 @@
use std::{cell::UnsafeCell, io::Cursor};
use std::io::Cursor;
use byteorder::ReadBytesExt;
@ -19,27 +19,28 @@ pub struct NbtCompound<'a> {
}
impl<'a> NbtCompound<'a> {
pub fn read(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
) -> Result<Self, Error> {
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
Self::read_with_depth(data, alloc, 0)
}
pub fn read_with_depth(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
alloc: &TagAllocator<'a>,
depth: usize,
) -> Result<Self, Error> {
if depth > MAX_DEPTH {
return Err(Error::MaxDepthExceeded);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() };
let mut tags = alloc_mut.named.start(depth);
let mut tags = alloc.get().named.start(depth);
loop {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
let tag_type = match data.read_u8() {
Ok(tag_type) => tag_type,
Err(_) => {
alloc.get().named.finish(tags, depth);
return Err(Error::UnexpectedEof);
}
};
if tag_type == END_ID {
break;
}
@ -47,7 +48,7 @@ impl<'a> NbtCompound<'a> {
let tag_name = match read_string(data) {
Ok(name) => name,
Err(_) => {
alloc_mut.named.finish(tags, depth);
alloc.get().named.finish(tags, depth);
// the only error read_string can return is UnexpectedEof, so this makes it
// slightly faster
return Err(Error::UnexpectedEof);
@ -56,14 +57,13 @@ impl<'a> NbtCompound<'a> {
let tag = match NbtTag::read_with_type(data, alloc, tag_type, depth) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.named.finish(tags, depth);
alloc.get().named.finish(tags, depth);
return Err(e);
}
};
tags.push((tag_name, tag));
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap_unchecked() };
let values = alloc_mut.named.finish(tags, depth);
let values = alloc.get().named.finish(tags, depth);
Ok(Self { values })
}

View file

@ -1,4 +1,4 @@
use std::{cell::UnsafeCell, io::Cursor};
use std::io::Cursor;
use byteorder::ReadBytesExt;
@ -37,7 +37,7 @@ pub enum NbtList<'a> {
impl<'a> NbtList<'a> {
pub fn read(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
alloc: &TagAllocator<'a>,
depth: usize,
) -> Result<Self, Error> {
if depth > MAX_DEPTH {
@ -57,105 +57,93 @@ impl<'a> NbtList<'a> {
DOUBLE_ID => NbtList::Double(RawList::new(read_with_u32_length(data, 8)?)),
BYTE_ARRAY_ID => NbtList::ByteArray({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_bytearray.start(depth);
let mut tags = alloc.get().unnamed_bytearray.start(depth);
for _ in 0..length {
let tag = match read_u8_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_bytearray.finish(tags, depth);
alloc.get().unnamed_bytearray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_bytearray.finish(tags, depth)
alloc.get().unnamed_bytearray.finish(tags, depth)
}),
STRING_ID => NbtList::String({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_string.start(depth);
let mut tags = alloc.get().unnamed_string.start(depth);
for _ in 0..length {
let tag = match read_string(data) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_string.finish(tags, depth);
alloc.get().unnamed_string.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_string.finish(tags, depth)
alloc.get().unnamed_string.finish(tags, depth)
}),
LIST_ID => NbtList::List({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_list.start(depth);
let mut tags = alloc.get().unnamed_list.start(depth);
for _ in 0..length {
let tag = match NbtList::read(data, alloc, depth + 1) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_list.finish(tags, depth);
alloc.get().unnamed_list.finish(tags, depth);
return Err(e);
}
};
tags.push(tag)
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_list.finish(tags, depth)
alloc.get().unnamed_list.finish(tags, depth)
}),
COMPOUND_ID => NbtList::Compound({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_compound.start(depth);
let mut tags = alloc.get().unnamed_compound.start(depth);
for _ in 0..length {
let tag = match NbtCompound::read_with_depth(data, alloc, depth + 1) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_compound.finish(tags, depth);
alloc.get().unnamed_compound.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_compound.finish(tags, depth)
alloc.get().unnamed_compound.finish(tags, depth)
}),
INT_ARRAY_ID => NbtList::IntArray({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_intarray.start(depth);
let mut tags = alloc.get().unnamed_intarray.start(depth);
for _ in 0..length {
let tag = match read_int_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_intarray.finish(tags, depth);
alloc.get().unnamed_intarray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_intarray.finish(tags, depth)
alloc.get().unnamed_intarray.finish(tags, depth)
}),
LONG_ARRAY_ID => NbtList::LongArray({
let length = read_u32(data)?;
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
let mut tags = alloc_mut.unnamed_longarray.start(depth);
let mut tags = alloc.get().unnamed_longarray.start(depth);
for _ in 0..length {
let tag = match read_long_array(data) {
Ok(tag) => tag,
Err(e) => {
alloc_mut.unnamed_longarray.finish(tags, depth);
alloc.get().unnamed_longarray.finish(tags, depth);
return Err(e);
}
};
tags.push(tag);
}
let alloc_mut = unsafe { alloc.get().as_mut().unwrap() };
alloc_mut.unnamed_longarray.finish(tags, depth)
alloc.get().unnamed_longarray.finish(tags, depth)
}),
_ => return Err(Error::UnknownTagId(tag_type)),
})

View file

@ -4,7 +4,7 @@ mod compound;
mod list;
mod tag_alloc;
use std::{cell::UnsafeCell, io::Cursor, ops::Deref};
use std::{io::Cursor, ops::Deref};
use byteorder::{ReadBytesExt, BE};
@ -47,7 +47,7 @@ impl<'a> Nbt<'a> {
if root_type != COMPOUND_ID {
return Err(Error::InvalidRootType(root_type));
}
let tag_alloc = UnsafeCell::new(TagAllocator::new());
let tag_alloc = TagAllocator::new();
let name = read_string(data)?;
let tag = NbtCompound::read_with_depth(data, &tag_alloc, 0)?;
@ -55,7 +55,7 @@ impl<'a> Nbt<'a> {
Ok(Nbt::Some(BaseNbt {
name,
tag,
_tag_alloc: tag_alloc.into_inner(),
_tag_alloc: tag_alloc,
}))
}
@ -149,7 +149,7 @@ impl<'a> NbtTag<'a> {
fn read_with_type(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
alloc: &TagAllocator<'a>,
tag_type: u8,
depth: usize,
) -> Result<Self, Error> {
@ -186,17 +186,14 @@ impl<'a> NbtTag<'a> {
}
}
pub fn read(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
) -> Result<Self, Error> {
pub fn read(data: &mut Cursor<&'a [u8]>, alloc: &TagAllocator<'a>) -> Result<Self, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
Self::read_with_type(data, alloc, tag_type, 0)
}
pub fn read_optional(
data: &mut Cursor<&'a [u8]>,
alloc: &UnsafeCell<TagAllocator<'a>>,
alloc: &TagAllocator<'a>,
) -> Result<Option<Self>, Error> {
let tag_type = data.read_u8().map_err(|_| Error::UnexpectedEof)?;
if tag_type == END_ID {
@ -444,4 +441,17 @@ mod tests {
}
assert_eq!(ints.len(), 1023);
}
#[test]
fn compound_eof() {
let mut data = Vec::new();
data.write_u8(COMPOUND_ID).unwrap(); // root type
data.write_u16::<BE>(0).unwrap(); // root name length
data.write_u8(COMPOUND_ID).unwrap(); // first element type
data.write_u16::<BE>(0).unwrap(); // first element name length
// eof
let res = Nbt::read(&mut Cursor::new(&data));
assert_eq!(res, Err(Error::UnexpectedEof));
}
}

View file

@ -17,6 +17,7 @@
use std::{
alloc::{self, Layout},
cell::UnsafeCell,
fmt,
ptr::NonNull,
};
@ -29,7 +30,20 @@ use super::{NbtCompound, NbtList, NbtTag};
const MIN_ALLOC_SIZE: usize = 1024;
#[derive(Default)]
pub struct TagAllocator<'a> {
pub struct TagAllocator<'a>(UnsafeCell<TagAllocatorImpl<'a>>);
impl<'a> TagAllocator<'a> {
pub fn new() -> Self {
Self(UnsafeCell::new(TagAllocatorImpl::new()))
}
// shhhhh
#[allow(clippy::mut_from_ref)]
pub fn get(&self) -> &mut TagAllocatorImpl<'a> {
unsafe { self.0.get().as_mut().unwrap_unchecked() }
}
}
#[derive(Default)]
pub struct TagAllocatorImpl<'a> {
pub named: IndividualTagAllocator<(&'a Mutf8Str, NbtTag<'a>)>,
// so remember earlier when i said the depth thing is only necessary because compounds aren't
@ -44,7 +58,7 @@ pub struct TagAllocator<'a> {
pub unnamed_longarray: IndividualTagAllocator<RawList<'a, i64>>,
}
impl<'a> TagAllocator<'a> {
impl<'a> TagAllocatorImpl<'a> {
pub fn new() -> Self {
Self::default()
}

View file

@ -2,7 +2,7 @@ use thiserror::Error;
use crate::common::MAX_DEPTH;
#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("Invalid root type {0}")]
InvalidRootType(u8),