1
2
Fork 0
mirror of https://github.com/mat-1/azalea.git synced 2025-08-02 06:16:04 +00:00

make a_star function use an IndexMap like the pathfinding crate

This commit is contained in:
mat 2024-12-26 07:42:35 +00:00
parent 3c83e5b24a
commit adb56b7eb2
4 changed files with 114 additions and 70 deletions

1
Cargo.lock generated
View file

@ -232,6 +232,7 @@ dependencies = [
"derive_more",
"futures",
"futures-lite",
"indexmap",
"nohash-hasher",
"num-format",
"num-traits",

View file

@ -34,6 +34,7 @@ bevy_tasks = { workspace = true, features = ["multi_threaded"] }
derive_more = { workspace = true, features = ["deref", "deref_mut"] }
futures = { workspace = true }
futures-lite = { workspace = true }
indexmap = "2.7.0"
nohash-hasher = { workspace = true }
num-format = { workspace = true }
num-traits = { workspace = true }

View file

@ -2,12 +2,13 @@ use std::{
cmp::{self},
collections::BinaryHeap,
fmt::Debug,
hash::Hash,
hash::{BuildHasherDefault, Hash},
time::{Duration, Instant},
};
use indexmap::IndexMap;
use num_format::ToFormattedString;
use rustc_hash::FxHashMap;
use rustc_hash::FxHasher;
use tracing::{debug, trace, warn};
pub struct Path<P, M>
@ -37,6 +38,12 @@ pub enum PathfinderTimeout {
Nodes(usize),
}
type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
// Sources:
// - https://en.wikipedia.org/wiki/A*_search_algorithm
// - https://github.com/evenfurther/pathfinding/blob/main/src/directed/astar.rs
// - https://github.com/cabaletta/baritone/blob/1.19.4/src/main/java/baritone/pathing/calc/AbstractNodeCostSearch.java
pub fn a_star<P, M, HeuristicFn, SuccessorsFn, SuccessFn>(
start: P,
heuristic: HeuristicFn,
@ -52,77 +59,100 @@ where
{
let start_time = Instant::now();
let mut open_set = BinaryHeap::<WeightedNode<P>>::new();
open_set.push(WeightedNode(start, 0.));
let mut nodes: FxHashMap<P, Node<P, M>> = FxHashMap::default();
let mut open_set = BinaryHeap::<WeightedNode>::new();
open_set.push(WeightedNode {
g_score: 0.,
f_score: 0.,
index: 0,
});
let mut nodes: FxIndexMap<P, Node<M>> = IndexMap::default();
nodes.insert(
start,
Node {
position: start,
movement_data: None,
came_from: None,
g_score: f32::default(),
f_score: f32::INFINITY,
came_from: usize::MAX,
g_score: 0.,
},
);
let mut best_paths: [P; 7] = [start; 7];
let mut best_paths: [usize; 7] = [0; 7];
let mut best_path_scores: [f32; 7] = [heuristic(start); 7];
let mut num_nodes = 0;
while let Some(WeightedNode(current_node, _)) = open_set.pop() {
while let Some(WeightedNode { index, g_score, .. }) = open_set.pop() {
num_nodes += 1;
if success(current_node) {
let (&node, node_data) = nodes.get_index(index).unwrap();
if success(node) {
debug!("Nodes considered: {num_nodes}");
return Path {
movements: reconstruct_path(nodes, current_node),
movements: reconstruct_path(nodes, index),
partial: false,
};
}
let current_g_score = nodes
.get(&current_node)
.map(|n| n.g_score)
.unwrap_or(f32::INFINITY);
if g_score > node_data.g_score {
continue;
}
for neighbor in successors(current_node) {
let tentative_g_score = current_g_score + neighbor.cost;
let neighbor_g_score = nodes
.get(&neighbor.movement.target)
.map(|n| n.g_score)
.unwrap_or(f32::INFINITY);
if neighbor_g_score - tentative_g_score > MIN_IMPROVEMENT {
let heuristic = heuristic(neighbor.movement.target);
let f_score = tentative_g_score + heuristic;
nodes.insert(
neighbor.movement.target,
Node {
position: neighbor.movement.target,
movement_data: Some(neighbor.movement.data),
came_from: Some(current_node),
g_score: tentative_g_score,
f_score,
},
);
open_set.push(WeightedNode(neighbor.movement.target, f_score));
for neighbor in successors(node) {
let tentative_g_score = g_score + neighbor.cost;
// let neighbor_heuristic = heuristic(neighbor.movement.target);
let neighbor_heuristic;
let neighbor_index;
for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
let node_score = heuristic + tentative_g_score / coefficient;
if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
best_paths[coefficient_i] = neighbor.movement.target;
best_path_scores[coefficient_i] = node_score;
// skip neighbors that don't result in a big enough improvement
if tentative_g_score - g_score < MIN_IMPROVEMENT {
continue;
}
match nodes.entry(neighbor.movement.target) {
indexmap::map::Entry::Occupied(mut e) => {
if e.get().g_score > tentative_g_score {
neighbor_heuristic = heuristic(*e.key());
neighbor_index = e.index();
e.insert(Node {
movement_data: Some(neighbor.movement.data),
came_from: index,
g_score: tentative_g_score,
});
} else {
continue;
}
}
indexmap::map::Entry::Vacant(e) => {
neighbor_heuristic = heuristic(*e.key());
neighbor_index = e.index();
e.insert(Node {
movement_data: Some(neighbor.movement.data),
came_from: index,
g_score: tentative_g_score,
});
}
}
open_set.push(WeightedNode {
index: neighbor_index,
g_score: tentative_g_score,
f_score: tentative_g_score + neighbor_heuristic,
});
for (coefficient_i, &coefficient) in COEFFICIENTS.iter().enumerate() {
let node_score = neighbor_heuristic + tentative_g_score / coefficient;
if best_path_scores[coefficient_i] - node_score > MIN_IMPROVEMENT {
best_paths[coefficient_i] = neighbor_index;
best_path_scores[coefficient_i] = node_score;
}
}
}
// check for timeout every ~20ms
// check for timeout every ~10ms
if num_nodes % 10000 == 0 {
let timed_out = match timeout {
PathfinderTimeout::Time(max_duration) => start_time.elapsed() > max_duration,
PathfinderTimeout::Nodes(max_nodes) => num_nodes > max_nodes,
PathfinderTimeout::Time(max_duration) => start_time.elapsed() >= max_duration,
PathfinderTimeout::Nodes(max_nodes) => num_nodes >= max_nodes,
};
if timed_out {
// timeout, just return the best path we have so far
@ -132,7 +162,7 @@ where
}
}
let best_path = determine_best_path(&best_paths, &start);
let best_path = determine_best_path(best_paths, 0);
debug!(
"A* ran at {} nodes per second",
@ -146,48 +176,46 @@ where
}
}
fn determine_best_path<P>(best_paths: &[P; 7], start: &P) -> P
where
P: Eq + Hash + Copy + Debug,
{
fn determine_best_path(best_paths: [usize; 7], start: usize) -> usize {
// this basically makes sure we don't create a path that's really short
for node in best_paths.iter() {
for node in best_paths {
if node != start {
return *node;
return node;
}
}
warn!("No best node found, returning first node");
best_paths[0]
}
fn reconstruct_path<P, M>(mut nodes: FxHashMap<P, Node<P, M>>, current: P) -> Vec<Movement<P, M>>
fn reconstruct_path<P, M>(
mut nodes: FxIndexMap<P, Node<M>>,
mut current_index: usize,
) -> Vec<Movement<P, M>>
where
P: Eq + Hash + Copy + Debug,
{
let mut path = Vec::new();
let mut current = current;
while let Some(node) = nodes.remove(&current) {
if let Some(came_from) = node.came_from {
current = came_from;
} else {
while let Some((&node_position, node)) = nodes.get_index_mut(current_index) {
if node.came_from == usize::MAX {
break;
}
current_index = node.came_from;
path.push(Movement {
target: node.position,
data: node.movement_data.unwrap(),
target: node_position,
data: node.movement_data.take().unwrap(),
});
}
path.reverse();
path
}
pub struct Node<P, M> {
pub position: P,
pub struct Node<M> {
pub movement_data: Option<M>,
pub came_from: Option<P>,
pub came_from: usize,
pub g_score: f32,
pub f_score: f32,
}
pub struct Edge<P: Hash + Copy, M> {
@ -218,16 +246,30 @@ impl<P: Hash + Copy + Clone, M: Clone> Clone for Movement<P, M> {
}
#[derive(PartialEq)]
pub struct WeightedNode<P: PartialEq>(P, f32);
pub struct WeightedNode {
index: usize,
g_score: f32,
f_score: f32,
}
impl<P: PartialEq> Ord for WeightedNode<P> {
impl Ord for WeightedNode {
fn cmp(&self, other: &Self) -> cmp::Ordering {
// intentionally inverted to make the BinaryHeap a min-heap
other.1.partial_cmp(&self.1).unwrap_or(cmp::Ordering::Equal)
match other
.f_score
.partial_cmp(&self.f_score)
.unwrap_or(cmp::Ordering::Equal)
{
cmp::Ordering::Equal => self
.g_score
.partial_cmp(&other.g_score)
.unwrap_or(cmp::Ordering::Equal),
s => s,
}
}
}
impl<P: PartialEq> Eq for WeightedNode<P> {}
impl<P: PartialEq> PartialOrd for WeightedNode<P> {
impl Eq for WeightedNode {}
impl PartialOrd for WeightedNode {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}

View file

@ -778,7 +778,7 @@ pub fn check_for_path_obstruction(
new_path
.extend(executing_path.path.iter().skip(patch_end_index).cloned());
is_patch_complete = true;
debug!("the obstruction patch is not partial");
debug!("the obstruction patch is not partial :)");
} else {
debug!(
"the obstruction patch is partial, throwing away rest of path :("