Last active
May 30, 2024 07:58
-
-
Save AutoRecursive/b846e67c39e8b54dcd27485d393d4ca6 to your computer and use it in GitHub Desktop.
Interaction and Estimation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class Agent: | |
longitudinal_buffer = 2. | |
a_pass = 0.1 | |
def __init__(self, agent_id, state, yield_intention=0.5, cooperative=True): | |
self.state = state.astype(float) | |
self.yield_intention = yield_intention | |
self.other_yield_intention = 0.5 | |
self.other_last_action = None | |
self.last_action = 0. | |
self.id = agent_id | |
self.cooperative = cooperative | |
def propagate(self, action, time_duration=0.1): | |
jerk = (action - self.last_action) / time_duration | |
if abs(jerk) > 2: | |
jerk = jerk / abs(jerk) * 2 | |
action = self.last_action + jerk * time_duration | |
action += np.random.randn() * 0.05 | |
ds = np.array([max(self.state[1] * time_duration + 0.5 * action * time_duration**2, 0), | |
action * time_duration]) | |
self.state += ds | |
self.state[1] = min(max(0, self.state[1]), 11.1) | |
self.last_action = action | |
@staticmethod | |
def contingency_policy(ego, other): | |
s0, v0 = ego.state[0],ego.state[1] | |
s1, v1 = other.state[0], other.state[1] | |
s0 += Agent.longitudinal_buffer | |
s1 -= Agent.longitudinal_buffer | |
if s0 > 0: | |
return 0. | |
elif s0 < 0 and s1 > 0 or s0 > 0 and s1 > s0: | |
if v0 < v1: | |
return 0. | |
else: | |
b = max(1e-5, (v1 - v0)**2 / (2 * (other.state[0] - ego.state[0]))) | |
return b | |
# time agent 1 arrives the conflict point | |
t1 = max(-s1 / max(v1, 1e-3), 1e-3) | |
# s0 + v0 t - 0.5 b t^2 <= 0 | |
# b >= (s0 + v0t) / (0.5 t^2) | |
# v0 - bt <= v1 | |
# b >= (v0 - v1) / t | |
b1 = max((s0 + v0 * t1) / (0.5 * t1**2), 0.) | |
b2 = max((v0 - v1) / t1, 0.) | |
b3 = max(0., v0**2 / (2 * max(s1 - s0, 1e-3))) | |
b = min(max(b1, b2), b3) | |
tc = max((v0 - v1) / max(b, 1e-5), 0.) | |
return np.clip(-b, a_min=-10, a_max=0) | |
@staticmethod | |
def aggressive_policy(ego, other): | |
s0, v0 = ego.state[0], ego.state[1] | |
s1, v1 = other.state[0], other.state[1] | |
s0 -= Agent.longitudinal_buffer | |
s1 += Agent.longitudinal_buffer | |
if s1 > 0 and s1 > s0: | |
return 1e3 | |
elif s1 < 0 and s0 > 0 or s1 > 0 and s0 > s1: | |
return 0 | |
# time agent 1 arrives the conflict point | |
t1 = max(-s1 / max(v1, 1e-3), 1e-3) | |
# s0 + v0 t + 0.5 a t^2 >= 0 | |
# a >= (-s0 -v0t) / (0.5 t^2) | |
# v0 + at >= v1 | |
# a >= (v1- v0) / t | |
a1 = (-s0 - v0 * t1) / (0.5 * t1**2) | |
a2 = (v1 - v0) / t1 | |
# print(f"aggr:{s0},{s1},{v0},{v1}\t{s1}\tt1: {t1}\t{a1}, {a2}") | |
a = max(a1, a2) | |
ta = max((v1 - v0) / max(a, 1e-5), 0) | |
return np.clip(a, a_min=0, a_max=10) | |
def observe(self, other_action): | |
self.other_last_action = other_action | |
def update_intention(self, other_agent): | |
aggr_action_0 = Agent.aggressive_policy( | |
self, other_agent) | |
cont_action_0 = Agent.contingency_policy( | |
self, other_agent) | |
aggr_action_1 = Agent.aggressive_policy( | |
other_agent, self) | |
cont_action_1 = Agent.contingency_policy( | |
other_agent, self) | |
# Bayesian filter for intention estimatino | |
# x \in {0, 1} meaning yielding | |
# Prediction: p(x_{k-1} | y_{k-1}) = | |
# \sum {p(x_{k-1}) p(x_k | x_{k-1}) | |
# p(x_{k-1} | y_{1:k-1}) | |
# } | |
# Update: p(x_k | y_{1:k-1}) = \frac{1}{Z_k} | |
# p(y_k|x_k) p(x_k|y_{1:k-1}) | |
# print(f"CONT: {aggr_action_0}, {cont_action_0}, {self.last_action}") | |
# print("CONT:", aggr_action_1, cont_action_1, self.other_last_action) | |
obs_prob = np.clip((2 + 1e-3 - self.other_last_action) / | |
(2 + 1e-3 + 5), 0, 1) | |
# print(obs_prob, self.other_yield_intention) | |
z = obs_prob * self.other_yield_intention + \ | |
(1 - obs_prob) * (1 - self.other_yield_intention) | |
# print(z, obs_prob * self.other_yield_intention) | |
self.other_yield_intention = obs_prob * self.other_yield_intention / z | |
# print(self.other_yield_intention) | |
# print(self.yield_intention, self.other_yield_intention) | |
self.yield_intention = 1 - self.other_yield_intention | |
def policy(self, other_agent): | |
if self.state[0] > 0: | |
a = self.idm_policy(other_agent) | |
return np.clip(a, a_min=-5, a_max=2) | |
c0 = Agent.contingency_policy( | |
self, other_agent) | |
a0 = Agent.aggressive_policy( | |
self, other_agent) | |
c1 = Agent.contingency_policy( | |
other_agent, self) | |
a1 = Agent.aggressive_policy( | |
other_agent, self) | |
p0 = c0 if abs(c0) < abs(a0) else a0 | |
coop_p0 = a0 if abs(c0) < abs(a0) else c0 | |
if not self.cooperative: | |
return p0 | |
p1 = c0 if abs(c0) < abs(a0) else a0 | |
cost_ego_yield = abs(c0) + abs(a1) | |
cost_ego_pass = abs(a0) + abs(c1) | |
# print(f"{self.id}: policies: {c0}, {a1}, {c1}, {a0}, costs: {cost_ego_yield}, {cost_ego_pass}") | |
if min(cost_ego_pass, cost_ego_yield) < 4: | |
a = self.idm_policy(other_agent) | |
elif cost_ego_pass > cost_ego_yield + 1: | |
a = c0 * 1.1 # yield | |
elif cost_ego_yield > cost_ego_pass + 1: | |
a = a0 * 1.1 | |
else: | |
a = self.idm_policy(other_agent) | |
# elif abs(p0) + 2 < abs(p1): | |
# a = p0 | |
# elif abs(p0) > abs(p1) + 2: | |
# a = coop_p0 | |
return np.clip(a, a_min=-5, a_max=2) | |
def idm_policy(self, other_agent): | |
s0 = 1.0 | |
v0 = 12. | |
T = 1.5 | |
a_max = 0.73 | |
b_max = 1.67 | |
delta = 4 | |
v = self.state[1] | |
if other_agent is None or self.state[0] > other_agent.state[0]: | |
dv = 0 | |
s = 1000 | |
else: | |
dv = self.state[1] - other_agent.state[1] | |
s = max(other_agent.state[0] - self.state[0] - Agent.longitudinal_buffer * 2, 1e-3) | |
s_star = s0 + v * T + v * dv / (2*(np.sqrt(a_max * b_max))) | |
a = a_max*(1-(v/v0)**delta - (s_star / s)**2) | |
# print(v/v0, s_star/s, (s_star / s)) | |
# print(f"idm: {self.id}, {a}") | |
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
from copy import copy | |
from agent import Agent | |
import tqdm | |
from multiprocessing import Pool | |
import itertools | |
def experiment(states, plot=False, T=10.): | |
# print(s0, s1) | |
# print(states) | |
s0, s1, v0, v1 = states | |
a0 = Agent(0, np.array([s0, v0])) | |
a1 = Agent(1, np.array([s1, v1]), cooperative=False) | |
state_list = [[], []] | |
time_stamps = np.arange(0, T, 0.1) | |
res = True | |
for t in time_stamps: | |
actions = [a0.policy(a1), a1.policy(a0)] | |
if plot: | |
print(f"t={t}, state:{a0.state}, {a1.state}\tactions:{actions}") | |
a0.propagate(actions[0]) | |
a1.propagate(actions[1]) | |
a0.observe(actions[1]) | |
a1.observe(actions[0]) | |
# a0.update_intention(a1) | |
# a1.update_intention(a0) | |
# print(actions) | |
# print(a0.state, actions[0]) | |
state_list[0].append(copy(a0.state)) | |
# print(a1.state) | |
state_list[1].append(copy(a1.state)) | |
# print(f"T={t}") | |
# print("Ego:", a0.yield_intention, a0.other_yield_intention) | |
# print("Agent:", a1.yield_intention, a1.other_yield_intention) | |
# print(f"{a0.state}\t{a1.state}\t{abs(a0.state[0]-a1.state[0])}") | |
if a0.state[0] > 0 and a1.state[0] > 0 and abs(a0.state[0] - a1.state[0]) < 2.0: | |
# print(a0.state, a1.state) | |
print(s0, s1, v0, v1) | |
res = False | |
break | |
# print(state_list) | |
state_list_arr = np.array(state_list) | |
# print(state_list_arr) | |
# print(state_list_arr[0, :]) | |
if res and plot: | |
plt.plot(time_stamps[state_list_arr[0, :, 0] < 0], state_list_arr[0, | |
state_list_arr[0, :, 0] < 0, 0], linewidth=5, linestyle='--', c='skyblue') | |
plt.plot(time_stamps[state_list_arr[0, :, 0] > 0], state_list_arr[0, | |
state_list_arr[0, :, 0] > 0, 0], linewidth=5, c='skyblue') | |
plt.plot(time_stamps[state_list_arr[1, :, 0] < 0], state_list_arr[1, | |
state_list_arr[1, :, 0] < 0, 0], linestyle='--', c='orange') | |
plt.plot(time_stamps[state_list_arr[1, :, 0] > 0], | |
state_list_arr[1, state_list_arr[1, :, 0] > 0, 0], c='orange') | |
plt.plot(time_stamps, np.zeros(time_stamps.shape), c='green') | |
plt.show() | |
plt.plot(time_stamps, state_list_arr[0, :, 1], c='skyblue') | |
plt.plot(time_stamps, state_list_arr[1, :, 1], c='orange') | |
plt.show() | |
elif plot: | |
tlen = state_list_arr.shape[1] | |
plt.plot(time_stamps[:tlen], state_list_arr[0, :, 0], c='skyblue') | |
plt.plot(time_stamps[:tlen], state_list_arr[1, :, 0], c='orange') | |
plt.plot(time_stamps, np.zeros(time_stamps.shape), c='green') | |
plt.show() | |
return res | |
def run_all(): | |
succ, total = 0, 0 | |
fail_list = [] | |
for s0 in tqdm(np.arange(-50, -10, 2.)): | |
for s1 in tqdm(np.arange(-50, -10, 2.)): | |
for v0 in tqdm(np.arange(0, 11.1, 1)): | |
for v1 in np.arange(0, 11.1, 1): | |
res = experiment(s0, s1, v0, v1) | |
total += 1 | |
if res: | |
succ += 1 | |
else: | |
print(s0, s1, v0, v1) | |
fail_list.append([s0, s1, v0, v1]) | |
print(f"Successful rate: {succ / total}") | |
print(f"Final Successful rate: {succ / total}") | |
print([(f[0].state, f[1].state) for f in fail_list]) | |
def run_parallel(): | |
pool = Pool(8) | |
succ, total = 0, 0 | |
for result in tqdm.tqdm(pool.imap( | |
experiment, itertools.product(np.arange(-50, -20, 1.), np.arange(-50, -20, 1.), np.arange(0, 11.1, 1), np.arange(0, 11.1, 1) | |
))): | |
total += 1 | |
if result: | |
succ += 1 | |
else: | |
print(f"Final Successful rate: {succ / total}") | |
print(f"Final Successful rate: {succ / total}") | |
if __name__ == '__main__': | |
# print(experiment([-30.0, -20.0 ,5, 3.0], True, T=10)) | |
# print(experiment([-45.0, -47.0, 10.0, 11.0], True)) | |
print(experiment([-21.0, -23.0, 10.0, 11.0], True)) | |
# run_parallel() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a very simple demo for longitudinal spatiotemporal interactive agents. In the scripts, agents can estimate joint cost and figure out the reasonable gap to merge in the traffic. The most important concept here is the JOINT COST, which makes everything work as expected.