Last active
January 27, 2024 14:40
-
-
Save BuckyI/ad620eeecbba526bdc987ceb2ce22221 to your computer and use it in GitHub Desktop.
use multiple sensor(observation) to get the unknown parameters 使用多个环境光传感器计算光源位置、强度
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from scipy.optimize import curve_fit | |
X0 = 1 # 传感器摆放间距 | |
def illuminate(idx, x, y, intensity): | |
"""模拟光敏传感器阵列的输出""" | |
theta = np.arctan2(idx * X0 - x, y) # incoming angle | |
value = intensity * (1 - (2 * theta / np.pi) ** 2) | |
# value[value < 0] = 0 # might cause gradient vanish | |
return value | |
def generate_observation(x, y, intensity, noise=0.0, sensor_num=8): | |
""" | |
x : 光源的 x 坐标 | |
y : 光源的 y 坐标 | |
intensity : 光源的光强 | |
noise : 噪声强度 | |
""" | |
sensor_idx = np.array([i * X0 for i in range(sensor_num)]) | |
sensor_data = illuminate(sensor_idx, x, y, intensity) | |
noise_data = noise * np.random.randn(len(sensor_idx)) | |
return sensor_idx, sensor_data + noise_data | |
def fit(x, y, plot=True): | |
params, params_covariance = curve_fit( | |
illuminate, | |
x, | |
y, | |
p0=[10.0, 10.0, 1000.0], # 初始参数猜测值 | |
maxfev=5000, | |
) | |
if plot: | |
print("优化后的参数:", params) | |
plt.scatter(x, y, label="Data") | |
plt.plot(x, illuminate(x, *params), "r", label="Fit") | |
plt.legend() | |
plt.show() | |
return params | |
def batch_test(): | |
light_intensity = 6500 | |
noise = 0.05 | |
sensor_num = 7 # 传感器数目 | |
X0 = 2 / sensor_num # 传感器摆放在 2m 宽的车窗上 | |
# 光源位置 | |
x = np.arange(10, -10, -1) | |
y = 5 * x + 50 | |
plt.scatter(x, y, label="real", marker="x") # 测试位置 | |
x_hat, y_hat = [], [] | |
for i, j in zip(x, y): | |
try: | |
# 生成传感器数据 | |
X, Y = generate_observation(i, j, light_intensity, noise, sensor_num) | |
sensor_idx = np.array([i * X0 for i in range(sensor_num)]) # 传感器位置 | |
# 拟合 | |
params, params_covariance = curve_fit( | |
illuminate, | |
X, | |
Y, | |
p0=[10.0, 10.0, 1000.0], # 初始参数猜测值 | |
) | |
if params[0] > 10000 or params[1] > 10000: | |
raise ValueError | |
except: | |
print("error", i, j) | |
else: # 如果拟合成功,就添加到结果里 | |
x_hat.append(params[0]) | |
y_hat.append(params[1]) | |
plt.scatter(x_hat, y_hat, label="predicted", marker="x") | |
plt.scatter(sensor_idx, [0] * len(sensor_idx), label="sensor", marker="v") | |
plt.legend() | |
rmse_x = np.sqrt(np.mean((np.array(x_hat) - np.array(x)) ** 2)) | |
rmse_y = np.sqrt(np.mean((np.array(y_hat) - np.array(y)) ** 2)) | |
print("均方根误差:", rmse_x, rmse_y) | |
plt.show() | |
if __name__ == "__main__": | |
# light_intensity = 1000 # 远光灯光照强度为6500Lux | |
# X, Y = generate_observation(40, 50, 1000, 10, sensor_num=8) | |
# fit(X, Y) | |
batch_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment