1
2
Fork 0
mirror of https://github.com/mat-1/azalea.git synced 2025-08-02 23:44:38 +00:00

make a_star function use an IndexMap like the pathfinding crate

This commit is contained in:
mat 2024-12-26 07:42:35 +00:00
commit adb56b7eb2
4 changed files with 114 additions and 70 deletions

1
Cargo.lock generated
View file

@ -232,6 +232,7 @@ dependencies = [
"derive_more", "derive_more",
"futures", "futures",
"futures-lite", "futures-lite",
"indexmap",
"nohash-hasher", "nohash-hasher",
"num-format", "num-format",
"num-traits", "num-traits",

View file

@ -34,6 +34,7 @@ bevy_tasks = { workspace = true, features = ["multi_threaded"] }
derive_more = { workspace = true, features = ["deref", "deref_mut"] } derive_more = { workspace = true, features = ["deref", "deref_mut"] }
futures = { workspace = true } futures = { workspace = true }
futures-lite = { workspace = true } futures-lite = { workspace = true }
indexmap = "2.7.0"
nohash-hasher = { workspace = true } nohash-hasher = { workspace = true }
num-format = { workspace = true } num-format = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }

View file

@ -2,12 +2,13 @@ use std::{
cmp::{self}, cmp::{self},
collections::BinaryHeap, collections::BinaryHeap,
fmt::Debug, fmt::Debug,
hash::Hash, hash::{BuildHasherDefault, Hash},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use indexmap::IndexMap;
use num_format::ToFormattedString; use num_format::ToFormattedString;
use rustc_hash::FxHashMap; use rustc_hash::FxHasher;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
pub struct Path<P, M> pub struct Path<P, M>
@ -37,6 +38,12 @@ pub enum PathfinderTimeout {
Nodes(usize), Nodes(usize),
} }
type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
// Sources:
// - https://en.wikipedia.org/wiki/A*_search_algorithm
// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>( pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
start: P, start: P,
heuristic: HeuristicFn, heuristic: HeuristicFn,
@ -52,77 +59,100 @@ where
{ {
let start_time = Instant::now(); let start_time = Instant::now();
let mut open_set = BinaryHeap::<WeightedNode<P>>::new(); let mut open_set = BinaryHeap::<WeightedNode>::new();
open_set.push(WeightedNode(start, 0.)); open_set.push(WeightedNode {
let mut nodes: FxHashMap<P, Node<P, M>> = FxHashMap::default(); g_score: 0.,
f_score: 0.,
index: 0,
});
let mut nodes: FxIndexMap<P, Node<M>> = IndexMap::default();
nodes.insert( nodes.insert(
start, start,
Node { Node {
position: start,
movement_data: None, movement_data: None,
came_from: None, came_from: usize::MAX,
g_score: f32::default(), g_score: 0.,
f_score: f32::INFINITY,
}, },
); );
let mut best_paths: [P; 7] = [start; 7]; let mut best_paths: [usize; 7] = [0; 7];
let mut best_path_scores: [f32; 7] = [heuristic(start); 7]; let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
let mut num_nodes = 0; let mut num_nodes = 0;
while let Some(WeightedNode(current_node, _)) = open_set.pop() { while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
num_nodes += 1; num_nodes += 1;
if success(current_node) {
let (&node, node_data) = nodes.get_index(index).unwrap();
if success(node) {
debug!("Nodes considered: {num_nodes}"); debug!("Nodes considered: {num_nodes}");
return Path { return Path {
movements: reconstruct_path(nodes, current_node), movements: reconstruct_path(nodes, index),
partial: false, partial: false,
}; };
} }
let current_g_score = nodes if g_score > node_data.g_score {
.get(&current_node) continue;
.map(|n| n.g_score) }
.unwrap_or(f32::INFINITY);
for neighbor in successors(current_node) { for neighbor in successors(node) {
let tentative_g_score = current_g_score + neighbor.cost; let tentative_g_score = g_score + neighbor.cost;
let neighbor_g_score = nodes // let neighbor_heuristic = heuristic(neighbor.movement.target);
.get(&neighbor.movement.target) let neighbor_heuristic;
.map(|n| n.g_score) let neighbor_index;
.unwrap_or(f32::INFINITY);
if neighbor_g_score - tentative_g_score > MIN_IMPROVEMENT { // skip neighbors that don't result in a big enough improvement
let heuristic = heuristic(neighbor.movement.target); if tentative_g_score - g_score < MIN_IMPROVEMENT {
let f_score = tentative_g_score + heuristic; continue;
nodes.insert( }
neighbor.movement.target,
Node { match nodes.entry(neighbor.movement.target) {
position: neighbor.movement.target, indexmap::map::Entry::Occupied(mut e) => {
if e.get().g_score > tentative_g_score {
neighbor_heuristic = heuristic(*e.key());
neighbor_index = e.index();
e.insert(Node {
movement_data: Some(neighbor.movement.data), movement_data: Some(neighbor.movement.data),
came_from: Some(current_node), came_from: index,
g_score: tentative_g_score, g_score: tentative_g_score,
f_score, });
}, } else {
); continue;
open_set.push(WeightedNode(neighbor.movement.target, f_score)); }
}
indexmap::map::Entry::Vacant(e) => {
neighbor_heuristic = heuristic(*e.key());
neighbor_index = e.index();
e.insert(Node {
movement_data: Some(neighbor.movement.data),
came_from: index,
g_score: tentative_g_score,
});
}
}
open_set.push(WeightedNode {
index: neighbor_index,
g_score: tentative_g_score,
f_score: tentative_g_score + neighbor_heuristic,
});
for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() { for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
let node_score = heuristic + tentative_g_score / coefficient; let node_score = neighbor_heuristic + tentative_g_score / coefficient;
if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT { if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
best_paths[coefficient_i] = neighbor.movement.target; best_paths[coefficient_i] = neighbor_index;
best_path_scores[coefficient_i] = node_score; best_path_scores[coefficient_i] = node_score;
} }
} }
} }
}
// check for timeout every ~20ms // check for timeout every ~10ms
if num_nodes % 10000 == 0 { if num_nodes % 10000 == 0 {
let timed_out = match timeout { let timed_out = match timeout {
PathfinderTimeout::Time(max_duration) => start_time.elapsed() > max_duration, PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
PathfinderTimeout::Nodes(max_nodes) => num_nodes > max_nodes, PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
}; };
if timed_out { if timed_out {
// timeout, just return the best path we have so far // timeout, just return the best path we have so far
@ -132,7 +162,7 @@ where
} }
} }
let best_path = determine_best_path(&best_paths, &start); let best_path = determine_best_path(best_paths, 0);
debug!( debug!(
"A* ran at {} nodes per second", "A* ran at {} nodes per second",
@ -146,48 +176,46 @@ where
} }
} }
fn determine_best_path<P>(best_paths: &[P; 7], start: &P) -> P fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
where
P: Eq + Hash + Copy + Debug,
{
// this basically makes sure we don't create a path that's really short // this basically makes sure we don't create a path that's really short
for node in best_paths.iter() { for node in best_paths {
if node != start { if node != start {
return *node; return node;
} }
} }
warn!("No best node found, returning first node"); warn!("No best node found, returning first node");
best_paths[0] best_paths[0]
} }
fn reconstruct_path<P, M>(mut nodes: FxHashMap<P, Node<P, M>>, current: P) -> Vec<Movement<P, M>> fn reconstruct_path<P, M>(
mut nodes: FxIndexMap<P, Node<M>>,
mut current_index: usize,
) -> Vec<Movement<P, M>>
where where
P: Eq + Hash + Copy + Debug, P: Eq + Hash + Copy + Debug,
{ {
let mut path = Vec::new(); let mut path = Vec::new();
let mut current = current; while let Some((&node_position, node)) = nodes.get_index_mut(current_index) {
while let Some(node) = nodes.remove(&current) { if node.came_from == usize::MAX {
if let Some(came_from) = node.came_from {
current = came_from;
} else {
break; break;
} }
current_index = node.came_from;
path.push(Movement { path.push(Movement {
target: node.position, target: node_position,
data: node.movement_data.unwrap(), data: node.movement_data.take().unwrap(),
}); });
} }
path.reverse(); path.reverse();
path path
} }
pub struct Node<P, M> { pub struct Node<M> {
pub position: P,
pub movement_data: Option<M>, pub movement_data: Option<M>,
pub came_from: Option<P>, pub came_from: usize,
pub g_score: f32, pub g_score: f32,
pub f_score: f32,
} }
pub struct Edge<P: Hash + Copy, M> { pub struct Edge<P: Hash + Copy, M> {
@ -218,16 +246,30 @@ impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
} }
#[derive(PartialEq)] #[derive(PartialEq)]
pub struct WeightedNode<P: PartialEq>(P, f32); pub struct WeightedNode {
index: usize,
g_score: f32,
f_score: f32,
}
impl<P: PartialEq> Ord for WeightedNode<P> { impl Ord for WeightedNode {
fn cmp(&self, other: &Self) -> cmp::Ordering { fn cmp(&self, other: &Self) -> cmp::Ordering {
// intentionally inverted to make the BinaryHeap a min-heap // intentionally inverted to make the BinaryHeap a min-heap
other.1.partial_cmp(&self.1).unwrap_or(cmp::Ordering::Equal) match other
.f_score
.partial_cmp(&self.f_score)
.unwrap_or(cmp::Ordering::Equal)
{
cmp::Ordering::Equal => self
.g_score
.partial_cmp(&other.g_score)
.unwrap_or(cmp::Ordering::Equal),
s => s,
}
} }
} }
impl<P: PartialEq> Eq for WeightedNode<P> {} impl Eq for WeightedNode {}
impl<P: PartialEq> PartialOrd for WeightedNode<P> { impl PartialOrd for WeightedNode {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> { fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other)) Some(self.cmp(other))
} }

View file

@ -778,7 +778,7 @@ pub fn check_for_path_obstruction(
new_path new_path
.extend(executing_path.path.iter().skip(patch_end_index).cloned()); .extend(executing_path.path.iter().skip(patch_end_index).cloned());
is_patch_complete = true; is_patch_complete = true;
debug!("the obstruction patch is not partial"); debug!("the obstruction patch is not partial :)");
} else { } else {
debug!( debug!(
"the obstruction patch is partial, throwing away rest of path :(" "the obstruction patch is partial, throwing away rest of path :("