Skip to content

Instantly share code, notes, and snippets.

@tbillington
Last active January 30, 2025 18:56
Show Gist options
  • Save tbillington/1f89f519dbf060b03824ac801c811058 to your computer and use it in GitHub Desktop.
Save tbillington/1f89f519dbf060b03824ac801c811058 to your computer and use it in GitHub Desktop.
Galaxy Brain - Utility AI inspired by https://github.com/zkat/big-brain

Galaxy Brain

Utility AI inspired by 💖 @zkat https://github.com/zkat/big-brain/

Written for Bevy 0.14.

This library was written for my specific needs, it doesn't support all the features of big brain, and I would probably still change things if I needed to use it more in future.

The upside is it's only 300 lines of code! So understanding what's happening & modifying it should be trivial.

Minimal Example

#[derive(Component)]
struct PrintHelloAction {
    msg: String,
}

impl Action for PrintHelloAction {
    type Scorer = PrintHelloActionScore;
}

fn init_print_hello_action(In(i): In<Entity>, mut cmd: Commands) {
    cmd.entity(i).insert(PrintHelloAction {
        msg: "Hello!".to_string(),
    });
}

fn hello_action_exec(query: Query<&PrintHelloAction>) {
    for hello in query.iter() {
        println!("{}", hello.msg);
    }
}

#[derive(Component)]
struct PrintHelloActionScore {
    score: f32,
}

impl Scorer for PrintHelloActionScore {
    fn score(&self) -> f32 {
        self.score
    }
}

fn update_hello_action_score(mut query: Query<&mut PrintHelloActionScore>) {
    for mut print_hello_scorer in query.iter_mut() {
        print_hello_scorer.score = 1.0;
    }
}

fn setup(mut cmd: Commands) {
    cmd.spawn((
        Name::new("Agent"),
        PrintHelloActionScore { score: 0.0 },
        AiBundle::<ActionPickerHighestScore>::default(),
    ));
}

fn main() {
    App::new()
        .add_plugins((MinimalPlugins, GalaxyBrainPlugins::new(Update)))
        .register_action::<PrintHelloAction, _>(init_print_hello_action)
        .add_systems(Startup, setup)
        .add_systems(
            Update,
            (
                update_hello_action_score.in_set(GalaxyBrainSet::UserScorers),
                hello_action_exec.in_set(GalaxyBrainSet::UserActions),
            ),
        )
        .run();
}

Differences to Big Brain / Limitations

Actions & Scorers live on the "agent" entity

Actions & Scorers are kept on the base entity, so you're limited to 1 of each type. I liked the much reduced entity count and easier debugging experience, but it's definitely a limit of this approach (without generic/runtime component/reflect shenanigans). It also simplifies most action systems as often the components you need like Transform are on the same entity, so you avoid the double query with Query::get indirection in every big brain scorer/action system.

TL;DR: If the action component exists, the action is taking place.

This also combines nicely with actions/scorers living on the agent to remove 2 levels of nesting within action/scorer systems.

There is no ActionState. All actions are initialised via a one shot system that receives the Entity it's being inserted on as an In parameter. If you need to do something on init or finish you can use either hooks or observers, or code it as part of your action impl. I went for custom events + observer to signal an action has finished, see cleaning_bots.rs example. I found matching on the ActionState constantly was pretty tedious, and nothing I couldn't handle in hooks/observers when needed.

This also means that within a single update schedule run, an action can be initialised, run, finish, and be removed.

User initialises the Action

I disliked how Big Brain actions required you to have Actions in a "default"/"template" state, and the default build would clone that. What if there wasn't a valid default state? I didn't want to force user level code to have unnecessary Options etc. It also means there are no required traits for actions other than Component.

No library level Composite Actions

Composite actions aka Steps are done at the user level, not the library level. This simplifies the library implementation, and gives the user greater flexibility in how they are implemented. I usually implemented them as just a enum, see CountAndPrintActionManager in basic.rs. It makes it easy to debug which step the action is on. I also had the base "action" struct known to the library act as a sort of manager, which was responsible for setting up the observers that handled sub-states finishing to transition them to the next state. It was extra boilerplate at the user level, but I found the full control to be worth it.

You can even skip using an enum if you like, see init_chop_tree_action_manager in chop_tree.rs for an example.


What I wanted to change after using Big Brain

  • Multi step actions have no context. Eg in Walking and want to change state, can't do any "clean up".
  • No easy way to see current "action"
  • No easy way to use hysteresis
  • Logic spread over many entities made debugging more involved
use std::{any::TypeId, cmp::Ordering, collections::HashMap};
use bevy::{
app::PluginGroupBuilder,
ecs::{
component::ComponentId,
intern::Interned,
schedule::ScheduleLabel,
system::{IntoObserverSystem, SystemId},
},
prelude::*,
};
pub mod prelude {
pub use super::{
Action, ActionFinished, GalaxyBrainPlugins, GalaxyBrainSet,
ObserveWithComponentLifetime as _, RegisterAction as _, Scorer,
};
}
pub struct GalaxyBrainPlugins {
schedule: Interned<dyn ScheduleLabel>,
}
impl GalaxyBrainPlugins {
pub fn new(schedule: impl ScheduleLabel) -> Self {
Self {
schedule: schedule.intern(),
}
}
}
impl PluginGroup for GalaxyBrainPlugins {
fn build(self) -> PluginGroupBuilder {
PluginGroupBuilder::start::<Self>()
.add(GalaxyBrainCorePlugin::new(self.schedule))
.add(HighestScorePickerPlugin)
}
}
pub struct GalaxyBrainCorePlugin {
schedule: Interned<dyn ScheduleLabel>,
}
impl GalaxyBrainCorePlugin {
pub fn new(schedule: impl ScheduleLabel) -> Self {
Self {
schedule: schedule.intern(),
}
}
}
#[derive(Resource, Debug)]
pub struct GalaxyBrainConfig {
schedule: Interned<dyn ScheduleLabel>,
action_creators: HashMap<ComponentId, SystemId<Entity, ()>>,
pub registered_actions_info: Vec<RegisteredActionInfo>,
}
#[derive(Debug)]
pub struct RegisteredActionInfo {
pub action_id: ComponentId,
pub action_name: String,
pub scorer_id: ComponentId,
pub scorer_type_id: TypeId,
}
impl Plugin for GalaxyBrainCorePlugin {
fn build(&self, app: &mut App) {
app.register_type::<Scores>();
app.register_type::<CurrentAction>();
app.configure_sets(
self.schedule.intern(),
(
(
GalaxyBrainSet::UserScorers,
GalaxyBrainSet::ClearPreviousScores,
),
GalaxyBrainSet::ScoreCollection,
GalaxyBrainSet::ActionPicking,
GalaxyBrainSet::TransitionActions,
GalaxyBrainSet::UserActions,
)
.chain(),
);
app.insert_resource(GalaxyBrainConfig {
schedule: self.schedule.intern(),
action_creators: HashMap::new(),
registered_actions_info: Vec::new(),
});
app.add_systems(
self.schedule.intern(),
(
clear_scores.in_set(GalaxyBrainSet::ClearPreviousScores),
transition_actions.in_set(GalaxyBrainSet::TransitionActions),
),
);
app.world_mut().spawn((
Name::new("GalaxyBrain::HandleActionFinished Observer"),
Observer::new(handle_action_finished),
));
}
}
pub struct HighestScorePickerPlugin;
impl Plugin for HighestScorePickerPlugin {
fn build(&self, app: &mut App) {
let schedule = app.world().resource::<GalaxyBrainConfig>().schedule;
app.add_systems(
schedule,
action_picker_highest_score.in_set(GalaxyBrainSet::ActionPicking),
);
}
}
#[derive(SystemSet, Debug, Hash, PartialEq, Eq, Clone)]
pub enum GalaxyBrainSet {
UserScorers,
ClearPreviousScores,
ScoreCollection,
ActionPicking,
TransitionActions,
UserActions,
}
pub trait Action: Sized {
type Scorer: Component + Scorer;
}
pub trait Scorer {
fn score(&self) -> f32;
}
pub trait RegisterAction {
fn register_action<T: Component + Action, M>(
&mut self,
system: impl IntoSystem<Entity, (), M> + 'static,
) -> &mut Self;
}
impl RegisterAction for App {
fn register_action<T: Component + Action, M>(
&mut self,
init_action: impl IntoSystem<Entity, (), M> + 'static,
) -> &mut Self {
let world = self.world_mut();
let sys_id = world.register_system(init_action);
let action_id = world.init_component::<T>();
let action_name = world.components().get_name(action_id).unwrap().to_owned();
let scorer_id = world.init_component::<T::Scorer>();
let mut config = world.resource_mut::<GalaxyBrainConfig>();
config.action_creators.insert(action_id, sys_id);
config.registered_actions_info.push(RegisteredActionInfo {
action_id,
action_name,
scorer_id,
scorer_type_id: TypeId::of::<T::Scorer>(),
});
let schedule = config.schedule;
self.add_systems(
schedule,
make_collect_score::<T::Scorer>(action_id).in_set(GalaxyBrainSet::ScoreCollection),
);
self
}
}
#[derive(Bundle, Default)]
pub struct AiBundle<Picker: Component + Default> {
pub scores: Scores,
pub current_action: CurrentAction,
pub picker: Picker,
}
/// Holds all the scores from [`Scorer`]s.
#[derive(Component, Reflect, Default, Debug)]
pub struct Scores {
pub scores: Vec<(ComponentId, f32)>,
}
#[derive(Component, Reflect, Default, Debug)]
pub struct CurrentAction {
pub component: Option<ComponentId>,
pub last_action: Option<ComponentId>,
}
#[derive(Component, Default, Debug)]
pub struct ActionPickerHighestScore;
fn action_picker_highest_score(mut query: Query<&mut Scores, With<ActionPickerHighestScore>>) {
query.iter_mut().for_each(|mut scores| {
scores.scores.sort_unstable_by(|(_, a), (_, b)| {
if *a > *b {
Ordering::Less
} else {
Ordering::Greater
}
});
});
}
fn make_collect_score<C: Component + Scorer>(
component_id: ComponentId,
) -> impl Fn(Query<'_, '_, (&mut Scores, &C)>) {
move |mut query: Query<(&mut Scores, &C)>| {
query.iter_mut().for_each(|(mut scores, scorer)| {
scores.scores.push((component_id, scorer.score()));
});
}
}
fn clear_scores(mut query: Query<&mut Scores>) {
query.iter_mut().for_each(|mut scores| {
scores.scores.clear();
});
}
fn transition_actions(
mut query: Query<(&mut CurrentAction, &Scores, Entity)>,
mut cmd: Commands,
config: Res<GalaxyBrainConfig>,
) {
query
.iter_mut()
.for_each(|(mut current_action, scores, entity)| {
let last_action = current_action.component;
macro_rules! remove_current_action {
($id: ident) => {
cmd.entity(entity).remove_by_id($id);
};
}
macro_rules! init_and_set_new_action {
($id: ident) => {
let action_creator_sys = config
.action_creators
.get($id)
.expect("expected action creator to have been registered");
cmd.run_system_with_input(*action_creator_sys, entity);
current_action.component = Some(*$id);
if last_action.is_none() || last_action.is_some_and(|l| l != *$id) {
current_action.last_action = last_action;
}
};
}
match (scores.scores.first(), current_action.component) {
(Some((new_action_id, _)), Some(current_action_id)) => {
if *new_action_id != current_action_id {
remove_current_action!(current_action_id);
init_and_set_new_action!(new_action_id);
}
}
(Some((new_action_id, _)), None) => {
init_and_set_new_action!(new_action_id);
}
(None, Some(current_action_id)) => {
remove_current_action!(current_action_id);
}
(None, None) => {}
};
});
}
#[derive(Event)]
pub struct ActionFinished;
fn handle_action_finished(
trigger: Trigger<ActionFinished>,
mut query: Query<&mut CurrentAction>,
mut cmd: Commands,
) {
if let Ok(mut current_action) = query.get_mut(trigger.entity()) {
if let Some(current_action_id) = current_action.component {
cmd.entity(trigger.entity()).remove_by_id(current_action_id);
current_action.component = None;
}
}
}
pub trait ObserveWithComponentLifetime {
fn observe_with_component_lifetime<C: Component, E: Event, B: Bundle, M>(
&mut self,
entity: Entity,
system: impl IntoObserverSystem<E, B, M>,
) -> &mut Self;
}
// TODO: make a version of this so multiple observers can be created that share a single cleanup observer
impl ObserveWithComponentLifetime for Commands<'_, '_> {
fn observe_with_component_lifetime<C: Component, E: Event, B: Bundle, M>(
&mut self,
entity: Entity,
system: impl IntoObserverSystem<E, B, M>,
) -> &mut Self {
// User's observer observing supplied entity
let mut user_observer = Observer::new(system);
user_observer.watch_entity(entity);
let user_observer = self
.spawn((Name::new("GalaxyBrain::User Observer"), user_observer))
.id();
// Observer that removes user's observer and itself upon component removal
let mut cleanup_observer_cmd = self.spawn_empty();
let cleanup_observer_id = cleanup_observer_cmd.id();
let mut cleanup_obsever =
Observer::new(move |_: Trigger<OnRemove, C>, mut cmd: Commands| {
cmd.entity(user_observer).despawn();
cmd.entity(cleanup_observer_id).despawn();
});
cleanup_obsever.watch_entity(entity);
cleanup_observer_cmd.insert((Name::new("GalaxyBrain::Cleanup Observer"), cleanup_obsever));
self
}
}
use bevy::{
ecs::component::{ComponentHooks, StorageType},
prelude::*,
};
use galaxy_brain::{
Action, ActionFinished, ActionPickerHighestScore, AiBundle, GalaxyBrainPlugins, GalaxyBrainSet,
RegisterAction as _, Scorer, Scores,
};
fn main() {
App::new()
.add_plugins((MinimalPlugins, GalaxyBrainPlugins::new(Update)))
.register_action::<SleepAction, _>(init_sleep_action)
.register_action::<EatAction, _>(init_eat_action)
.register_action::<CountAndPrintActionManager, _>(init_count_and_print_action)
.add_systems(Startup, setup)
.add_systems(
Update,
(
(update_sleep_action_score, update_eat_action_score)
.in_set(GalaxyBrainSet::UserScorers),
(print_scores, print_winning_action)
.chain()
.after(GalaxyBrainSet::TransitionActions)
.before(GalaxyBrainSet::UserActions),
(
sleep_action_exec,
eat_action_exec,
count_to_three_action_impl,
print_success_action_impl,
)
.in_set(GalaxyBrainSet::UserActions),
),
)
.observe(handle_count_and_print_stage_event)
.observe(on_remove_eat_action)
.set_runner(|mut app| {
app.update();
println!("--------------------------------------------------");
app.update();
println!("--------------------------------------------------");
app.update();
println!("--------------------------------------------------");
app.update();
// println!("--------------------------------------------------");
// app.update();
AppExit::Success
})
.run();
}
fn setup(mut cmd: Commands) {
cmd.spawn((
SleepActionScore { score: 0.0 },
EatActionScore { score: 0.11 },
// CountAndPrintScorer {},
AiBundle::<ActionPickerHighestScore>::default(),
));
}
#[derive(Debug)]
enum ActionStageResult {
Success,
#[allow(unused)]
Failure,
}
#[derive(Event, Debug)]
struct ActionStageEvent<T> {
action_manager: T,
#[allow(unused)]
result: ActionStageResult,
}
#[derive(Component, Debug)]
enum CountAndPrintActionManager {
CountToThreeAction,
PrintSuccessAction,
}
fn init_count_and_print_action(In(i): In<Entity>, mut cmd: Commands) {
println!("Init CountAndPrintActionManager");
cmd.entity(i).insert((
CountAndPrintActionManager::CountToThreeAction,
CountToThreeAction { count: 0 },
));
}
impl Action for CountAndPrintActionManager {
type Scorer = CountAndPrintScorer;
}
#[derive(Component)]
struct CountToThreeAction {
count: u8,
}
fn count_to_three_action_impl(
mut query: Query<(&mut CountToThreeAction, Entity)>,
mut cmd: Commands,
) {
query.iter_mut().for_each(|(mut action, entity)| {
action.count += 1;
println!("count: {}", action.count);
if action.count == 3 {
cmd.trigger_targets(
ActionStageEvent {
action_manager: CountAndPrintActionManager::CountToThreeAction,
result: ActionStageResult::Success,
},
entity,
);
}
});
}
fn handle_count_and_print_stage_event(
trigger: Trigger<ActionStageEvent<CountAndPrintActionManager>>,
mut cmd: Commands,
) {
println!("stage_event {:?}", trigger.event());
match trigger.event().action_manager {
CountAndPrintActionManager::CountToThreeAction => {
cmd.entity(trigger.entity())
.remove::<CountToThreeAction>()
.insert(PrintSuccessAction { done: false });
}
CountAndPrintActionManager::PrintSuccessAction => {
cmd.entity(trigger.entity()).remove::<PrintSuccessAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
}
}
}
#[derive(Component)]
struct PrintSuccessAction {
done: bool,
}
fn print_success_action_impl(
mut query: Query<(&mut PrintSuccessAction, Entity)>,
mut cmd: Commands,
) {
query.iter_mut().for_each(|(mut action, entity)| {
if !action.done {
println!("Success!");
action.done = true;
cmd.trigger_targets(
ActionStageEvent {
action_manager: CountAndPrintActionManager::PrintSuccessAction,
result: ActionStageResult::Success,
},
entity,
);
}
});
}
#[derive(Component)]
struct CountAndPrintScorer {}
impl Scorer for CountAndPrintScorer {
fn score(&self) -> f32 {
0.8
}
}
fn init_sleep_action(In(i): In<Entity>, mut cmd: Commands) {
println!("Init sleep action!");
cmd.entity(i).insert(SleepAction {});
}
fn init_eat_action(In(i): In<Entity>, mut cmd: Commands) {
println!("Init eat action!");
cmd.entity(i).insert(EatAction {});
}
fn sleep_action_exec(query: Query<&SleepAction>) {
for _ in query.iter() {
println!("Sleep!");
}
}
fn eat_action_exec(query: Query<&EatAction>) {
for _ in query.iter() {
println!("Eat!");
}
}
fn print_winning_action(world: &World, query: Query<&Scores>) {
// world.run_system_once_with(input, system);
for scores in query.iter() {
for (action_id, score) in scores.scores.iter().take(1) {
let info = world.components().get_name(*action_id).unwrap();
println!("{info:?}: {score}");
}
}
}
fn print_scores(scores: Query<&Scores>) {
scores.iter().for_each(|scores| {
println!("{:?}", scores.scores);
});
}
#[derive(Component)]
struct SleepAction {}
impl Action for SleepAction {
type Scorer = SleepActionScore;
}
#[derive(Component)]
struct SleepActionScore {
score: f32,
}
impl Scorer for SleepActionScore {
fn score(&self) -> f32 {
self.score
}
}
fn update_sleep_action_score(mut query: Query<&mut SleepActionScore>) {
query.iter_mut().for_each(|mut score| {
score.score += 0.05;
});
}
struct EatAction {}
impl Component for EatAction {
const STORAGE_TYPE: StorageType = StorageType::Table;
fn register_component_hooks(_hooks: &mut ComponentHooks) {
_hooks.on_remove(|_, _, _| {
println!("Eat Action Removed");
});
}
}
fn on_remove_eat_action(_: Trigger<OnRemove, EatAction>) {
println!("on_remove_eat_action");
}
impl Action for EatAction {
type Scorer = EatActionScore;
}
#[derive(Component)]
struct EatActionScore {
score: f32,
}
impl Scorer for EatActionScore {
fn score(&self) -> f32 {
self.score
}
}
fn update_eat_action_score(mut query: Query<&mut EatActionScore>) {
query.iter_mut().for_each(|_| {
// score.score = time.delta_seconds() * 2.0;
});
}
use std::f32::consts::PI;
use bevy::{color::palettes::tailwind::*, math::bounding::*, prelude::*};
use galaxy_brain::{
Action, ActionFinished, ActionPickerHighestScore, AiBundle, GalaxyBrainPlugins, GalaxyBrainSet,
ObserveWithComponentLifetime, RegisterAction, Scorer,
};
use rand::prelude::*;
fn main() {
let mut app = App::new();
app.add_plugins((
DefaultPlugins.set(WindowPlugin {
primary_window: Some(Window {
name: Some("Cleaning Bots".to_string()),
present_mode: bevy::window::PresentMode::AutoVsync,
// resizable: false,
// resize_constraints: WindowResizeConstraints {
// max_width: 1280.0,
// max_height: 720.0,
// ..default()
// },
..default()
}),
..default()
}),
GalaxyBrainPlugins::new(FixedUpdate),
));
app.register_action::<CleanActionManager, _>(init_clean_action);
app.register_action::<SearchActionManager, _>(init_search_action);
app.register_action::<ReturnAndChargeActionManager, _>(init_return_and_charge_action);
app.add_systems(Startup, (setup_scene, setup_ui));
app.add_systems(Update, (handle_ui, rotate_sweepers, render_scene).chain());
app.add_systems(FixedUpdate, populate_dirt);
app.add_systems(
FixedUpdate,
(
(
score_clean_action,
score_return_and_charge_action,
score_search_action,
)
.in_set(GalaxyBrainSet::UserScorers),
(
suck_dirt_action_exec,
movement_action_exec,
charge_action_exec,
)
.in_set(GalaxyBrainSet::UserActions),
// print_winning_action
// .after(GalaxyBrainSet::TransitionActions)
// .before(GalaxyBrainSet::UserActions),
),
);
app.run();
}
fn _print_winning_action(world: &World, query: Query<&galaxy_brain::Scores>) {
// world.run_system_once_with(input, system);
for scores in query.iter() {
for (action_id, score) in scores.scores.iter() {
let name = world.components().get_name(*action_id).unwrap();
let name = name.strip_prefix("cleaning_bots::").unwrap();
print!("{name} {score} | ");
}
println!();
}
}
// BEGIN AI ------------------------------------------------
// AI Actions
// Clean Action Manager
#[derive(Component)]
struct CleanActionManager {
dirt: Entity,
}
impl Action for CleanActionManager {
type Scorer = CleanScorer;
}
#[derive(Component)]
struct CleanScorer {
dirt: Option<Entity>,
}
impl Scorer for CleanScorer {
fn score(&self) -> f32 {
if self.dirt.is_some() {
0.5
} else {
0.0
}
}
}
#[derive(Component)]
struct SuckDirtAction {
dirt: Entity,
}
#[derive(Event)]
struct SuckDirtActionComplete;
fn score_clean_action(
mut query: Query<(&mut CleanScorer, &Transform, Option<&CleanActionManager>)>,
dirt: Query<(&Transform, &Dirt, Entity)>,
) {
for (mut scorer, bot_tns, cleaning_action) in query.iter_mut() {
if let Some(cleaning_action) = cleaning_action {
scorer.dirt = Some(cleaning_action.dirt);
}
if scorer.dirt.is_some() {
scorer.dirt = None;
}
let bounding_circle = {
let (search_cone, translation, rotation) = bot_vision(bot_tns);
search_cone.bounding_circle(translation, rotation)
};
for (dirt_tns, dirt, dirt_entity) in dirt.iter() {
if bounding_circle
.contains(&BoundingCircle::new(dirt_tns.translation.xy(), dirt.amount))
{
scorer.dirt = Some(dirt_entity);
break;
}
}
}
}
fn init_clean_action(
In(entity): In<Entity>,
bot: Query<&CleanScorer>,
dirt: Query<&Dirt>,
mut cmd: Commands,
) {
let dirt_entity = bot.get(entity).unwrap().dirt.unwrap();
let dirt = dirt.get(dirt_entity).unwrap();
// MovementAction finished
cmd.observe_with_component_lifetime::<CleanActionManager, _, _, _>(
entity,
move |trigger: Trigger<MovementActionOutcome>, mut cmd: Commands| match trigger.event() {
MovementActionOutcome::Success => {
cmd.entity(trigger.entity())
.remove::<MovementAction>()
.insert(SuckDirtAction { dirt: dirt_entity });
}
MovementActionOutcome::Fail => {
cmd.entity(trigger.entity()).remove::<MovementAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
}
},
);
// SuckDirtAction finished
cmd.observe_with_component_lifetime::<CleanActionManager, _, _, _>(
entity,
|trigger: Trigger<SuckDirtActionComplete>, mut cmd: Commands| {
cmd.entity(trigger.entity()).remove::<SuckDirtAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
},
);
// Handle the dirt being despawned out from under us
cmd.entity(dirt_entity)
.observe(move |_: Trigger<OnRemove, Dirt>, mut cmd: Commands| {
cmd.entity(entity)
.remove::<(MovementAction, SuckDirtAction)>();
cmd.trigger_targets(ActionFinished, entity);
});
cmd.entity(entity).insert((
CleanActionManager { dirt: dirt_entity },
MovementAction::Entity(dirt_entity, dirt.size),
));
}
fn suck_dirt_action_exec(
query: Query<(&SuckDirtAction, Entity)>,
mut dirt: Query<(&mut Dirt, Entity)>,
time: Res<Time>,
mut cmd: Commands,
) {
for (action, entity) in query.iter() {
let mut done = false;
if let Ok((mut dirt, dirt_entity)) = dirt.get_mut(action.dirt) {
dirt.amount -= time.delta_seconds();
if dirt.amount <= 0.0 {
cmd.entity(dirt_entity).despawn();
done = true;
}
} else {
done = true;
}
if done {
cmd.trigger_targets(SuckDirtActionComplete, entity);
}
}
}
// Search Action Manager
#[derive(Component)]
struct SearchActionManager;
impl Action for SearchActionManager {
type Scorer = SearchScorer;
}
const SEARCH_POSITION_RADIUS: f32 = 150.0;
#[derive(Component)]
struct SearchScorer {
score: f32,
}
impl Scorer for SearchScorer {
fn score(&self) -> f32 {
self.score
}
}
fn score_search_action(mut query: Query<(&mut SearchScorer, &CleaningBot)>) {
for (mut scorer, bot) in query.iter_mut() {
scorer.score = bot.charge * 0.5;
}
}
fn init_search_action(
In(entity): In<Entity>,
bot: Query<&Transform>,
arena: Query<&Arena>,
mut cmd: Commands,
) {
let bot_pos = bot.get(entity).unwrap().translation;
let arena = arena.single();
cmd.observe_with_component_lifetime::<SearchActionManager, _, _, _>(
entity,
|trigger: Trigger<MovementActionOutcome>, mut cmd: Commands| match trigger.event() {
MovementActionOutcome::Success | MovementActionOutcome::Fail => {
cmd.entity(trigger.entity()).remove::<MovementAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
}
},
);
let mut rng = rand::thread_rng();
let search_target = loop {
let search_target =
Circle::new(SEARCH_POSITION_RADIUS).sample_interior(&mut rng) + bot_pos.xy();
if arena.size.contains(search_target) {
break search_target;
}
};
cmd.entity(entity)
.insert((SearchActionManager, MovementAction::Position(search_target)));
}
// Return And Charge Action Manager
#[derive(Component)]
struct ReturnAndChargeActionManager;
impl Action for ReturnAndChargeActionManager {
type Scorer = ReturnAndChargeScorer;
}
#[derive(Component)]
struct ReturnAndChargeScorer {
score: f32,
}
impl Scorer for ReturnAndChargeScorer {
fn score(&self) -> f32 {
self.score
}
}
fn score_return_and_charge_action(
mut query: Query<(
&mut ReturnAndChargeScorer,
Has<ReturnAndChargeActionManager>,
&CleaningBot,
)>,
) {
for (mut scorer, is_charging, bot) in query.iter_mut() {
if is_charging {
scorer.score = 1.0;
} else {
scorer.score = if bot.charge < 0.2 { 1.0 } else { 0.0 }
}
}
}
fn init_return_and_charge_action(
In(entity): In<Entity>,
bot: Query<&CleaningBot>,
mut cmd: Commands,
) {
let bot = bot.get(entity).unwrap();
cmd.observe_with_component_lifetime::<ReturnAndChargeActionManager, _, _, _>(
entity,
|trigger: Trigger<MovementActionOutcome>, mut cmd: Commands| match trigger.event() {
MovementActionOutcome::Success | MovementActionOutcome::Fail => {
cmd.entity(trigger.entity())
.remove::<MovementAction>()
.insert(ChargeAction);
}
},
);
cmd.observe_with_component_lifetime::<ReturnAndChargeActionManager, _, _, _>(
entity,
|trigger: Trigger<ChargeActionComplete>, mut cmd: Commands| {
cmd.entity(trigger.entity()).remove::<ChargeAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
},
);
cmd.entity(entity).insert((
ReturnAndChargeActionManager,
MovementAction::Position(bot.charge_location),
));
}
// Movement Action
#[derive(Component)]
enum MovementAction {
Position(Vec2),
Entity(Entity, f32),
}
#[derive(Event)]
enum MovementActionOutcome {
Success,
Fail,
}
const BOT_SPEED: f32 = 30.0;
const BOT_ROTATE_SPEED: f32 = 1.0;
const BOT_MOVE_CHARGE_DRAIN: f32 = 1.0 / 30.0;
fn movement_action_exec(
mut query: Query<(&MovementAction, &mut CleaningBot, Entity)>,
mut tns: Query<&mut Transform>,
time: Res<Time>,
mut cmd: Commands,
) {
for (movement, mut bot, entity) in query.iter_mut() {
bot.charge = (bot.charge - BOT_MOVE_CHARGE_DRAIN * time.delta_seconds()).max(0.0);
let (target_pos, buffer) = match movement {
MovementAction::Position(pos) => (pos.extend(0.0), 0.0),
MovementAction::Entity(target, buffer) => {
let Ok(tns) = tns.get(*target) else {
cmd.trigger_targets(MovementActionOutcome::Fail, entity);
continue;
};
(tns.translation, *buffer + bot.radius)
}
};
let Ok(mut tns) = tns.get_mut(entity) else {
cmd.trigger_targets(MovementActionOutcome::Fail, entity);
continue;
};
let current_rot = Rot2::radians(tns.rotation.to_euler(EulerRot::YXZ).0);
let target_rot = {
let dir = target_pos - tns.translation;
Rot2::radians(dir.y.atan2(dir.x) - PI / 2.0)
};
let angle_between = current_rot.angle_between(target_rot);
if angle_between != 0.0 {
let max_rotate = BOT_ROTATE_SPEED * time.delta_seconds();
tns.rotate_y(angle_between.clamp(-max_rotate, max_rotate));
}
if angle_between.to_degrees().abs() < 5.0 {
tns.translation = tns.translation
+ (target_pos - tns.translation).clamp_length_max(BOT_SPEED * time.delta_seconds());
}
if tns.translation.distance(target_pos) <= buffer {
cmd.trigger_targets(MovementActionOutcome::Success, entity);
}
}
}
// Charge Action
#[derive(Component)]
struct ChargeAction;
#[derive(Event)]
struct ChargeActionComplete;
const CHARGE_TIME: f32 = 4.0;
fn charge_action_exec(
mut query: Query<(&mut CleaningBot, Entity), With<ChargeAction>>,
time: Res<Time>,
mut cmd: Commands,
) {
for (mut bot, entity) in query.iter_mut() {
bot.charge = (bot.charge + (1.0 / CHARGE_TIME) * time.delta_seconds()).clamp(0.0, 1.0);
if bot.charge >= 1.0 {
cmd.trigger_targets(ChargeActionComplete, entity);
}
}
}
// END AI --------------------------------------------------
// Components
#[derive(Component)]
struct CleaningBot {
charge: f32,
charge_location: Vec2,
radius: f32,
color: Color,
}
#[derive(Component)]
struct SweeperRotation {
rot: f32,
}
#[derive(Component)]
struct ChargeStation {
radius: f32,
color: Color,
}
#[derive(Component)]
struct Dirt {
size: f32,
amount: f32,
color: Color,
}
#[derive(Component)]
struct Arena {
size: Rect,
color: Color,
}
#[derive(Component)]
enum TimeUiThing {
Slower,
Faster,
PlayPause,
Speed,
}
// Setup
fn setup_scene(window: Query<&Window>, mut cmd: Commands) {
let window = window.single();
let window_size = window.size();
let scale = 0.7;
let mut cam_bundle = Camera2dBundle::default();
cam_bundle.projection.scaling_mode =
bevy::render::camera::ScalingMode::FixedHorizontal(window_size.x);
cam_bundle.projection.scale = scale;
cmd.spawn(cam_bundle);
cmd.spawn(Arena {
size: Rect::from_center_size(Vec2::ZERO, window_size * scale),
color: YELLOW_700.into(),
});
for (pos, charge) in [
(Vec2::new(-150.0, 0.0), 1.0),
(Vec2::new(0.0, 0.0), 0.5),
(Vec2::new(150.0, 0.0), 0.0),
] {
cmd.spawn((
ChargeStation {
radius: 20.0,
color: GREEN_400.into(),
},
SpatialBundle::from_transform(Transform::from_translation(pos.extend(0.0))),
));
cmd.spawn((
CleaningBot {
charge,
charge_location: pos,
radius: 15.0,
color: RED_500.into(),
},
SweeperRotation { rot: 0.0 },
CleanScorer { dirt: None },
SearchScorer { score: 0.0 },
ReturnAndChargeScorer { score: 0.0 },
AiBundle::<ActionPickerHighestScore>::default(),
SpatialBundle::from_transform(Transform::from_translation(pos.extend(0.0))),
));
}
}
fn setup_ui(mut cmd: Commands) {
cmd.spawn(NodeBundle {
style: Style {
position_type: PositionType::Absolute,
top: Val::Px(0.),
width: Val::Percent(100.),
justify_content: JustifyContent::Center,
..default()
},
..default()
})
.with_children(|p| {
macro_rules! ui_button {
($ui_thing: expr, $text: expr) => {
p.spawn((
$ui_thing,
ButtonBundle {
style: Style {
width: Val::Px(75.),
height: Val::Px(65.),
justify_content: JustifyContent::Center,
align_items: AlignItems::Center,
..default()
},
..default()
},
))
.with_children(|p| {
p.spawn(TextBundle::from_section(
$text,
TextStyle {
font_size: 20.0,
color: Color::srgb(0.9, 0.9, 0.9),
..default()
},
));
});
};
}
ui_button!(TimeUiThing::PlayPause, "Pause");
ui_button!(TimeUiThing::Speed, "1.0");
ui_button!(TimeUiThing::Slower, "Slower");
ui_button!(TimeUiThing::Faster, "Faster");
});
}
// Rendering
fn render_scene(
arena: Query<&Arena>,
charge_stations: Query<(&Transform, &ChargeStation)>,
dirt: Query<(&Transform, &Dirt)>,
cleaning_bots: Query<(
&Transform,
&CleaningBot,
&SweeperRotation,
Option<&MovementAction>,
)>,
transform_q: Query<&Transform>,
mut gizmos: Gizmos,
) {
let arena = arena.single();
gizmos.rect_2d(
arena.size.center(),
Rot2::IDENTITY,
arena.size.size() - Vec2::splat(2.0),
arena.color,
);
for (tns, station) in charge_stations.iter() {
gizmos.rect_2d(
tns.translation.xy(),
0.0,
Vec2::splat(station.radius * 2.0),
station.color,
);
}
for (tns, dirt) in dirt.iter() {
gizmos.circle_2d(tns.translation.xy(), dirt.amount, dirt.color);
}
for (tns, bot, sweeper_rotation, movement) in cleaning_bots.iter() {
let bot_pos = tns.translation.xy();
let bot_rot = tns.rotation.to_euler(EulerRot::YXZ).0;
// Body
gizmos.circle_2d(bot_pos, bot.radius, bot.color);
// Eyes
for eye_angle_offset in [-0.4, 0.4] {
gizmos.circle_2d(
bot_pos + Rot2::radians(bot_rot + eye_angle_offset) * Vec2::new(0.0, 8.0),
2.0,
Srgba::WHITE,
);
}
// Sweepers
for angle in
// [0., PI / 3., 2. * PI / 3., PI, 4. * PI / 3., 5. * PI / 3.]
[0.0, PI / 2.0, PI, 3.0 * PI / 2.0]
{
const SWEEPER_ROTATE_SPEED: f32 = 2.0;
let sweeper_len = 4.0;
let sweeper_left = bot_pos + Rot2::radians(bot_rot + PI / 10.0) * Vec2::new(0.0, 15.0);
let sweeper_right = bot_pos + Rot2::radians(bot_rot - PI / 10.0) * Vec2::new(0.0, 15.0);
let sweeper_rotation = sweeper_rotation.rot * SWEEPER_ROTATE_SPEED;
let left_angle = Rot2::radians(bot_rot + angle - sweeper_rotation);
let right_angle = Rot2::radians(bot_rot + angle + sweeper_rotation);
let wtf_offset_1 = 0.45;
let wtf_offset_2 = 0.2;
let lar = Rot2::radians(bot_rot).angle_between(left_angle);
if (lar - wtf_offset_1) < (PI / 2.0) && (lar - wtf_offset_2) > (-PI / 2.0) {
gizmos.line_2d(
sweeper_left,
sweeper_left + left_angle * Vec2::new(0.0, sweeper_len),
NEUTRAL_300,
);
}
let rawr = Rot2::radians(bot_rot).angle_between(right_angle);
if (rawr + wtf_offset_2) < PI / 2.0 && (rawr + wtf_offset_1) > -PI / 2.0 {
gizmos.line_2d(
sweeper_right,
sweeper_right + right_angle * Vec2::new(0.0, sweeper_len),
NEUTRAL_300,
);
}
}
let battery_y = 35.0;
let battery_size = Vec2::new(30.0, 12.0);
// Battery indicator
gizmos.rect_2d(
bot_pos + Vec2::new(0.0, battery_y),
0.0,
battery_size,
NEUTRAL_500,
);
let battery_right = bot_pos + Vec2::new(battery_size.x / 2.0, battery_y);
gizmos.line_2d(
battery_right + Vec2::new(2.0, 4.0),
battery_right + Vec2::new(2.0, -4.0),
NEUTRAL_500,
);
let left = -13.0;
let right = 15.0;
let batter_line_color = Hsla::from(RED_500).mix(&Hsla::from(GREEN_500), bot.charge);
for i in 0..10 {
let i = i as f32 / 10.0;
let x = left.lerp(right, i);
if i > bot.charge {
continue;
}
gizmos.line_2d(
bot_pos + Vec2::new(x, battery_y - 5.0),
bot_pos + Vec2::new(x, battery_y + 5.0),
batter_line_color,
);
}
if let Some(movement) = movement {
// Vision cone
let (primitive, position, angle) = bot_vision(tns);
gizmos.primitive_2d(&primitive, position, angle, EMERALD_300.with_alpha(0.05));
// Target indicator
if let MovementAction::Entity(target, buffer) = movement {
let target_pos = transform_q.get(*target).unwrap().translation.xy();
let distance = bot_pos.distance(target_pos);
let target_pos = bot_pos.lerp(target_pos, (distance - *buffer) / distance);
gizmos.arrow_2d(bot_pos, target_pos, BLUE_400.with_alpha(0.1));
}
}
}
}
fn rotate_sweepers(mut query: Query<&mut SweeperRotation, With<SuckDirtAction>>, time: Res<Time>) {
for mut sweeper in query.iter_mut() {
sweeper.rot += time.delta_seconds();
}
}
// UI
const SIM_SPEEDS: &[f32] = &[0.0, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0];
fn handle_ui(
buttons: Query<(Option<&Interaction>, &TimeUiThing, &Children), Changed<Interaction>>,
time_holder: Query<(&TimeUiThing, &Children)>,
mut text: Query<&mut Text>,
mut time: ResMut<Time<Virtual>>,
) {
macro_rules! update_speed_text {
() => {
let speed_entity = time_holder
.iter()
.find_map(|(t, c)| {
if matches!(t, TimeUiThing::Speed) {
Some(c[0])
} else {
None
}
})
.unwrap();
let mut text = text.get_mut(speed_entity).unwrap();
text.sections[0].value = format!("{:.1}", time.relative_speed());
};
}
for (interaction, ui_thing, children) in buttons.iter() {
if !matches!(interaction, Some(Interaction::Pressed)) {
continue;
}
match ui_thing {
TimeUiThing::Slower => {
if let Some((idx, _)) = SIM_SPEEDS
.iter()
.enumerate()
.find(|spd| *spd.1 == time.relative_speed())
{
time.set_relative_speed(SIM_SPEEDS[idx.saturating_sub(1)]);
} else {
time.set_relative_speed(1.0);
}
update_speed_text!();
}
TimeUiThing::Faster => {
if let Some((idx, _)) = SIM_SPEEDS
.iter()
.enumerate()
.find(|spd| *spd.1 == time.relative_speed())
{
time.set_relative_speed(SIM_SPEEDS[(idx + 1).min(SIM_SPEEDS.len() - 1)]);
} else {
time.set_relative_speed(1.0);
}
update_speed_text!();
}
TimeUiThing::PlayPause => {
let mut text = text.get_mut(children[0]).unwrap();
if time.is_paused() {
time.unpause();
text.sections[0].value = "Pause".to_string();
} else {
time.pause();
text.sections[0].value = "Play".to_string();
}
}
_ => {}
}
}
}
// Dirt
fn populate_dirt(
dirt: Query<(), With<Dirt>>,
arena: Query<&Arena>,
mut cmd: Commands,
time: Res<Time>,
) {
let arena = arena.single();
const MAX_DIRT: usize = 30;
let chance_to_spawn = (1.0
- (dirt.iter().len().clamp(0, MAX_DIRT) as f32 * (1.0 / MAX_DIRT as f32)))
* 10.0
* time.delta_seconds();
let mut rng = rand::thread_rng();
if chance_to_spawn > rng.gen::<f32>() {
let amount = (rng.gen::<f32>() + 1.0) * 2.5;
cmd.spawn((
Dirt {
size: amount,
amount,
color: AMBER_300.into(),
},
SpatialBundle::from_transform(Transform::from_xyz(
(rng.gen::<f32>() - 0.5) * (arena.size.width() - 20.0),
(rng.gen::<f32>() - 0.5) * (arena.size.height() - 20.0),
0.0,
)),
));
}
}
fn bot_vision(tns: &Transform) -> (CircularSector, Vec2, f32) {
(
CircularSector::from_radians(100.0, PI / 2.0),
tns.translation.xy(),
tns.rotation.to_euler(EulerRot::YXZ).0,
)
}
use galaxy_brain::prelude::*;
use plot_gizmos::GizmoLineGraph;
use crate::{ai, foliage, pathing::PathingGrid, prelude::*, quest};
use super::{
ai_debug::{ReflectScorerDebugIntoGraph, ScorerDebugIntoGraph},
movement,
};
pub(super) fn chop_tree_plugin(app: &mut App) {
app.register_type::<ChopTreeActionManager>();
app.register_type::<ChopTreeActionScorer>();
app.register_action::<ChopTreeActionManager, _>(init_chop_tree_action_manager);
app.add_systems(
FixedLogicUpdate,
(
chop_tree_action_scorer.in_set(GalaxyBrainSet::UserScorers),
chop_tree_in_progress_action.in_set(GalaxyBrainSet::UserActions),
),
);
}
#[derive(Reflect, Debug)]
pub(super) struct ChopTreeActionManager {
tree: Entity,
}
impl Component for ChopTreeActionManager {
const STORAGE_TYPE: StorageType = StorageType::Table;
fn register_component_hooks(hooks: &mut ComponentHooks) {
hooks.on_remove(|mut w, e, _| {
w.commands()
.entity(e)
.remove::<(movement::MoveAlongPathAction, ChopTreeInProgressAction)>();
});
}
}
impl Action for ChopTreeActionManager {
type Scorer = ChopTreeActionScorer;
}
#[derive(Component, Reflect, Default, Debug)]
#[reflect(ScorerDebugIntoGraph)]
pub(super) struct ChopTreeActionScorer {
pub(super) quest_requirements: Option<(quest::QuestId, /* cost? */ u32)>,
pub(super) tree: Option<(Entity, /* distance */ f32)>,
}
impl ChopTreeActionScorer {
fn quest_requirement(&self) -> f32 {
self.quest_requirements.is_some().ifel(1., 0.)
}
fn tree_dist(&self) -> f32 {
let tree_max_range = 20.;
let Some((_, distance)) = self.tree else {
return 0.;
};
ai::logistic(distance / tree_max_range, 50., -0.8, 1., 0.5)
}
}
impl Scorer for ChopTreeActionScorer {
fn score(&self) -> f32 {
self.quest_requirement() * self.tree_dist()
}
}
impl ScorerDebugIntoGraph for ChopTreeActionScorer {
fn write_scores(&self, x: f32, graph: &mut GizmoLineGraph) {
if graph.lines.is_empty() {
graph.lines.push(plot_gizmos::GizmoLine {
name: "tree_dist".to_string(),
points: Vec::new(),
color: Srgba::RED.into(),
});
}
graph.lines[0].points.push(Vec2::new(x, self.tree_dist()));
}
}
fn chop_tree_action_scorer(
mut query: Query<(
&mut ChopTreeActionScorer,
Entity,
&Transform,
Option<&super::PawnQuestState2>,
)>,
quest: Query<&quest::Quest>,
trees: Query<(&Transform, Entity), With<foliage::Tree>>, // building_manager: Res<building::BuildingManager>,
) {
// TODO: do we need to score trees even after starting the chop_tree action ?
query
.iter_mut()
.for_each(|(mut scorer, _, pawn, quest_id)| {
let Some(quest_id) = quest_id else {
scorer.quest_requirements = None;
scorer.tree = None;
return;
};
let Some(quest) = quest.iter().find(|q| q.id == quest_id.quest) else {
scorer.quest_requirements = None;
scorer.tree = None;
return;
};
let need_wood = quest
.goal
.iter()
.any(|item| item.kind == crate::storage::ItemKind::Wood);
if !need_wood {
scorer.quest_requirements = None;
scorer.tree = None;
return;
}
scorer.quest_requirements = Some((quest.id, 1));
scorer.tree = trees
.iter()
.reduce(|acc, tree| {
if tree.0.translation.distance_squared(pawn.translation)
< acc.0.translation.distance_squared(pawn.translation)
{
tree
} else {
acc
}
})
.map(|(t, e)| (e, t.translation.distance(pawn.translation) / GRID_SIZE));
});
}
fn init_chop_tree_action_manager(
In(pawn): In<Entity>,
scorer: Query<&ChopTreeActionScorer>,
tns: Query<&Transform>,
pathing: Res<PathingGrid>,
mut cmd: Commands,
) {
let start_pos = get!(tns, pawn).translation.to_grid_space();
let score = get_logging!(scorer, pawn);
let Some((tree, _)) = score.tree else {
error!("chop tree action picked with no tree selected");
return;
};
let tree_pos = get!(tns, tree).translation.to_grid_space();
cmd.entity(pawn).insert((
ChopTreeActionManager { tree },
movement::MovementPathAsyncBundle::new(start_pos, tree_pos, pathing.clone_for_async_task()),
));
// Movement finished
cmd.observe_with_component_lifetime::<movement::MoveAlongPathAction, _, _, _>(
pawn,
move |trigger: Trigger<movement::MovementActionOutcome>, mut cmd: Commands| {
cmd.entity(trigger.entity())
.remove::<movement::MoveAlongPathAction>()
.insert(ChopTreeInProgressAction { tree });
cmd.observe_with_component_lifetime::<ChopTreeInProgressAction, _, _, _>(
trigger.entity(),
|trigger: Trigger<ChopTreeInProgressActionOutcome>, mut cmd: Commands| {
cmd.entity(trigger.entity())
.remove::<ChopTreeInProgressAction>();
cmd.trigger_targets(ActionFinished, trigger.entity());
},
);
},
);
// TODO: remove this observer on action finished
// Handle the tree being despawned/finished by someone else
cmd.entity(tree).observe_named(
"Pawn::HandleChopTreeInProgressRemoved Observer",
move |_: Trigger<OnRemove, foliage::Tree>, mut cmd: Commands| {
cmd.trigger_targets(ActionFinished, pawn);
},
);
}
#[derive(Component, Reflect, Debug)]
pub(super) struct ChopTreeInProgressAction {
tree: Entity,
}
#[derive(Event)]
pub(super) enum ChopTreeInProgressActionOutcome {
Success,
}
fn chop_tree_in_progress_action(
mut query: Query<(&ChopTreeInProgressAction, Entity)>,
mut trees: Query<&mut foliage::Tree>,
mut cmd: Commands,
time: Res<Time>,
) {
let chop_speed = 0.5; // use something from the character here ?
query.iter_mut().for_each(|(action, entity)| {
let Ok(mut tree) = trees.get_mut(action.tree) else {
return;
};
tree.health -= chop_speed * time.delta_seconds();
if tree.health <= 0. {
cmd.trigger_targets(
foliage::TreeDestroyedEvent {
chopper: Some(entity),
},
action.tree,
);
cmd.trigger_targets(ChopTreeInProgressActionOutcome::Success, entity);
}
})
}
use bevy::ecs::component::ComponentId;
use galaxy_brain::{GalaxyBrainConfig, GalaxyBrainSet, Scores};
use plot_gizmos::GizmoLineGraph;
use crate::prelude::*;
pub(super) fn ai_debug_plugin(app: &mut App) {
app.add_systems(
FixedLogicUpdate,
(ra, render_utility_scores_debug_graph)
.chain()
.after(GalaxyBrainSet::ActionPicking),
);
}
#[derive(Component)]
pub(crate) struct UtilityScoresDebugGraph {
pub(crate) visible: bool,
graph: GizmoLineGraph,
action_graphs: HashMap<ComponentId, GizmoLineGraph>,
offset: Vec2,
}
impl UtilityScoresDebugGraph {
pub(crate) fn new(size: Vec2, offset: Vec2, visible: bool) -> Self {
Self {
visible,
graph: GizmoLineGraph {
lines: Vec::new(),
bounds: Rect::from_center_size(offset, size),
},
action_graphs: HashMap::default(),
offset,
}
}
}
fn ra(world: &mut World, mut tick: Local<u32>) {
world.resource_scope(|world, bc: Mut<GalaxyBrainConfig>| {
// let tids = bc.scorers.iter().map(|x| x.1).collect::<Vec<_>>();
let mut dbg_entities_q = world.query_filtered::<Entity, With<UtilityScoresDebugGraph>>();
let dbg_entities = dbg_entities_q.iter(world).collect::<Vec<_>>();
*tick += 1;
let tick = *tick as f32;
for e in dbg_entities {
for action_info in bc.registered_actions_info.iter() {
let world2 = unsafe { &mut *(world as *mut World) };
let Ok(cmp) = get_reflect(world, e, action_info.scorer_type_id) else {
continue;
};
let tr = world.resource::<AppTypeRegistry>().read();
let Some(reflect_scorer_debug) =
tr.get_type_data::<ReflectScorerDebugIntoGraph>(cmp.type_id())
else {
continue;
};
let scorer_debug: &dyn ScorerDebugIntoGraph =
reflect_scorer_debug.get(cmp).unwrap();
if let Some(mut graph) = world2.get_mut::<UtilityScoresDebugGraph>(e) {
let max_data_points = 100;
let bounds = graph.graph.bounds;
let graph = graph
.action_graphs
.entry(action_info.scorer_id)
.or_insert_with(|| GizmoLineGraph {
bounds,
lines: Vec::new(),
});
// Prepare graph
// let graph = &mut graph.graph;
for line in graph.lines.iter_mut() {
while line.points.len() >= max_data_points {
line.points.remove(0);
}
}
scorer_debug.write_scores(tick, graph);
}
}
}
});
#[inline]
fn get_reflect(
world: &World,
entity: Entity,
type_id: std::any::TypeId,
) -> Result<&dyn Reflect, GetComponentReflectError> {
let Some(component_id) = world.components().get_id(type_id) else {
return Err(GetComponentReflectError::NoCorrespondingComponentId(
type_id,
));
};
let Some(comp_ptr) = world.get_by_id(entity, component_id) else {
let component_name = world
.components()
.get_name(component_id)
.map(ToString::to_string);
return Err(GetComponentReflectError::EntityDoesNotHaveComponent {
entity,
type_id,
component_id,
component_name,
});
};
let Some(type_registry) = world
.get_resource::<AppTypeRegistry>()
.map(|atr| atr.read())
else {
return Err(GetComponentReflectError::MissingAppTypeRegistry);
};
let Some(reflect_from_ptr) =
type_registry.get_type_data::<bevy::reflect::ReflectFromPtr>(type_id)
else {
return Err(GetComponentReflectError::MissingReflectFromPtrTypeData(
type_id,
));
};
// SAFETY:
// - `comp_ptr` is guaranteed to point to an object of type `type_id`
// - `reflect_from_ptr` was constructed for type `type_id`
// - Assertion that checks this equality is present
unsafe {
assert_eq!(
reflect_from_ptr.type_id(),
type_id,
"Mismatch between Ptr's type_id and ReflectFromPtr's type_id",
);
Ok(reflect_from_ptr.as_reflect(comp_ptr))
}
}
#[allow(unused)]
#[derive(Debug)]
pub enum GetComponentReflectError {
NoCorrespondingComponentId(std::any::TypeId),
EntityDoesNotHaveComponent {
/// The given [`Entity`].
entity: Entity,
/// The given [`TypeId`].
type_id: std::any::TypeId,
/// The [`ComponentId`] corresponding to the given [`TypeId`].
component_id: ComponentId,
/// The name corresponding to the [`Component`] with the given [`TypeId`], or `None`
/// if not available.
component_name: Option<String>,
},
MissingAppTypeRegistry,
MissingReflectFromPtrTypeData(std::any::TypeId),
}
}
#[derive(Component)]
struct AiDbgImmediateShittyText;
fn render_utility_scores_debug_graph(
mut query: Query<(
&mut UtilityScoresDebugGraph,
&Scores,
// &quest_ai::AcquireQuestScorer
)>,
shitty_ai_text: Query<Entity, With<AiDbgImmediateShittyText>>,
mut gizmos: Gizmos,
mut tick: Local<u32>,
camera: Query<(&Transform, &OrthographicProjection), With<MainCamera>>,
ai_config: Res<GalaxyBrainConfig>,
// type_registry: Res<AppTypeRegistry>,
mut cmd: Commands,
) {
let (camera, projection) = get_single!(camera);
let line_colors = [
"#e60049", "#0bb4ff", "#50e991", "#e6d800", "#9b19f5", "#ffa300", "#dc0ab4", "#b3d4ff",
"#00bfa0",
]
.map(Srgba::hex)
.map(Result::unwrap)
.map(Color::from);
// ["#e60049", "#0bb4ff", "#50e991", "#e6d800", "#9b19f5", "#ffa300", "#dc0ab4", "#b3d4ff", "#00bfa0"]
shitty_ai_text.iter().for_each(|entity| {
cmd.entity(entity).despawn();
});
*tick += 1;
let tick = *tick;
let max_data_points = 100;
let projection = Vec2::splat(projection.scale);
let cam_xy = camera.translation.xy() * Vec2::splat(1. / projection.x);
for (
mut graph,
scores,
// acquire_quest
) in query.iter_mut()
{
let visible = graph.visible;
let offset = graph.offset;
// Prepare graph
let main_graph = &mut graph.graph;
for line in main_graph.lines.iter_mut() {
while line.points.len() >= max_data_points {
line.points.remove(0);
}
}
let mut score_texts = Vec::new();
for (cid, score) in scores.scores.iter().copied() {
let action_name = ai_config
.registered_actions_info
.iter()
.find(|s| s.action_id == cid)
.unwrap()
.action_name
.as_str();
let action_name = action_name
.rsplit_once("::")
.map(|split| split.1)
.unwrap_or(action_name);
let line_idx = main_graph
.lines
.iter()
.enumerate()
.find_map(|(i, l)| (l.name == action_name).then_some(i));
let line_idx = match line_idx {
Some(idx) => idx,
None => {
let line_color = line_colors
.iter()
.find(|c| main_graph.lines.iter().all(|l| l.color != **c))
.cloned()
.unwrap_or_else(|| {
let mut h = bevy::utils::AHasher::default();
// let mut h = std::hash::DefaultHasher::new();
std::hash::Hash::hash(&action_name, &mut h);
let hue = (std::hash::Hasher::finish(&h) % 360) as f32;
let line_color = Hsla::hsl(hue, 1.0, 0.5);
line_color.into()
});
main_graph.lines.push(plot_gizmos::GizmoLine {
name: action_name.to_owned(),
points: Vec::new(),
color: line_color,
});
main_graph.lines.len() - 1
}
};
main_graph.lines[line_idx]
.points
.push(Vec2::new(tick as f32, score));
score_texts.push((
format!("{}: {:.2}", action_name, score),
score,
main_graph.lines[line_idx].color,
));
}
score_texts.sort_unstable_by(|(_, a, _), (_, b, _)| {
if *a > *b {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
});
// acquire_quest.write_scores(tick as f32, graph);
if !visible {
continue;
}
// Draw graph
main_graph.bounds = Rect::from_center_size(cam_xy + offset, main_graph.bounds.size());
gizmos.rect_2d(
main_graph.bounds.center() * projection,
0.,
(main_graph.bounds.size() + Vec2::splat(2.0)) * projection,
Color::BLACK.with_alpha(0.2),
);
let (mut min, mut max) = main_graph
.min_max()
.unwrap_or((Vec2::splat(0.), Vec2::splat(1.)));
min.y = 0.;
max.x = min.x + max_data_points as f32;
max.y = 1.;
let projected_bounds = Rect {
min: main_graph.bounds.min * projection,
max: main_graph.bounds.max * projection,
};
main_graph.render_unfiltered(projected_bounds, min, max, &mut gizmos);
let mut y_height = main_graph.bounds.min.y;
// Draw action & scores
{
for (i, (text, _, color)) in score_texts.into_iter().enumerate() {
let font_size = 12.;
let tl = ((cam_xy
+ offset
+ Vec2::new(
0.,
(-main_graph.bounds.size().y * 0.5) - font_size - (i as f32 * font_size),
))
* projection)
.extend(2.);
y_height = tl.y / projection.y - font_size;
cmd.spawn((
Text2dBundle {
text: Text::from_section(
text,
TextStyle {
font_size,
color,
..default()
},
),
transform: Transform::from_translation(tl)
.with_scale(projection.extend(1.)),
..default()
},
AiDbgImmediateShittyText,
));
}
}
let main_graph_bounds = main_graph.bounds;
let action_graphs = &mut graph.action_graphs;
let font_size = 12.;
for (cid, g) in action_graphs.iter_mut() {
y_height -= main_graph_bounds.size().y / 2.0 + font_size;
g.bounds = Rect::from_center_size(
Vec2::new(cam_xy.x + offset.x, y_height),
// camera.translation.xy()
// + offset
// + Vec2::new(0., -((i + 1) as f32) * main_graph_bounds.size().y * 1.5),
main_graph_bounds.size(),
);
y_height -= main_graph_bounds.size().y / 2.0;
gizmos.rect_2d(
g.bounds.center() * projection,
0.,
(g.bounds.size() + Vec2::splat(2.0)) * projection,
Color::BLACK.with_alpha(0.2),
);
let (mut min, mut max) = g.min_max().unwrap_or((Vec2::splat(0.), Vec2::splat(1.)));
min.y = 0.;
max.x = min.x + max_data_points as f32;
max.y = 1.;
let projected_bounds = Rect {
min: g.bounds.min * projection,
max: g.bounds.max * projection,
};
g.render_unfiltered(projected_bounds, min, max, &mut gizmos);
let action_name: &str = ai_config
.registered_actions_info
.iter()
.find(|s| s.scorer_id == *cid)
.unwrap()
.action_name
.as_ref();
let action_name = action_name
.rsplit_once("::")
.map(|split| split.1)
.unwrap_or(action_name);
// println!("{}", action_name);
cmd.spawn((
Text2dBundle {
text: Text::from_section(
action_name,
TextStyle {
font_size,
..default()
},
),
transform: Transform::from_translation(
(g.bounds.center() * projection).extend(2.),
)
.with_scale(projection.extend(1.)),
..default()
},
AiDbgImmediateShittyText,
));
let last_idx = g
.lines
.iter()
.fold(0, |acc, l| l.points.len().saturating_sub(1).max(acc));
let mut parts = g
.lines
.iter()
.filter_map(|l| l.points.get(last_idx).map(|p| (p.y, l.name.as_str())))
.collect::<Vec<_>>();
parts.sort_unstable_by(|(a, _), (b, _)| {
if *a > *b {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
});
let x = g.bounds.center().x;
for (score, part) in parts {
y_height -= font_size;
cmd.spawn((
Text2dBundle {
text: Text::from_section(
format!("{part}: {score:.2}"),
TextStyle {
font_size,
..default()
},
),
transform: Transform::from_translation(
(Vec2::new(x, y_height) * projection).extend(2.),
)
.with_scale(projection.extend(1.)),
..default()
},
AiDbgImmediateShittyText,
));
}
}
}
}
#[reflect_trait]
pub(super) trait ScorerDebugIntoGraph {
fn write_scores(&self, x: f32, graph: &mut GizmoLineGraph);
}
@tbillington
Copy link
Author

Video of cleaning_bots.rs in action:

2024-07-25_17-52-16.mp4

See ai_debug.rs for an example of using reflection to get a debug ui for scorers.

Screenshot 2024-08-05 at 2 09 37 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment