Skip to content

Instantly share code, notes, and snippets.

@BuckyI
Last active January 27, 2024 14:40
Show Gist options
  • Save BuckyI/ad620eeecbba526bdc987ceb2ce22221 to your computer and use it in GitHub Desktop.
Save BuckyI/ad620eeecbba526bdc987ceb2ce22221 to your computer and use it in GitHub Desktop.
use multiple sensor(observation) to get the unknown parameters 使用多个环境光传感器计算光源位置、强度
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
X0 = 1 # 传感器摆放间距
def illuminate(idx, x, y, intensity):
"""模拟光敏传感器阵列的输出"""
theta = np.arctan2(idx * X0 - x, y) # incoming angle
value = intensity * (1 - (2 * theta / np.pi) ** 2)
# value[value < 0] = 0 # might cause gradient vanish
return value
def generate_observation(x, y, intensity, noise=0.0, sensor_num=8):
"""
x : 光源的 x 坐标
y : 光源的 y 坐标
intensity : 光源的光强
noise : 噪声强度
"""
sensor_idx = np.array([i * X0 for i in range(sensor_num)])
sensor_data = illuminate(sensor_idx, x, y, intensity)
noise_data = noise * np.random.randn(len(sensor_idx))
return sensor_idx, sensor_data + noise_data
def fit(x, y, plot=True):
params, params_covariance = curve_fit(
illuminate,
x,
y,
p0=[10.0, 10.0, 1000.0], # 初始参数猜测值
maxfev=5000,
)
if plot:
print("优化后的参数:", params)
plt.scatter(x, y, label="Data")
plt.plot(x, illuminate(x, *params), "r", label="Fit")
plt.legend()
plt.show()
return params
def batch_test():
light_intensity = 6500
noise = 0.05
sensor_num = 7 # 传感器数目
X0 = 2 / sensor_num # 传感器摆放在 2m 宽的车窗上
# 光源位置
x = np.arange(10, -10, -1)
y = 5 * x + 50
plt.scatter(x, y, label="real", marker="x") # 测试位置
x_hat, y_hat = [], []
for i, j in zip(x, y):
try:
# 生成传感器数据
X, Y = generate_observation(i, j, light_intensity, noise, sensor_num)
sensor_idx = np.array([i * X0 for i in range(sensor_num)]) # 传感器位置
# 拟合
params, params_covariance = curve_fit(
illuminate,
X,
Y,
p0=[10.0, 10.0, 1000.0], # 初始参数猜测值
)
if params[0] > 10000 or params[1] > 10000:
raise ValueError
except:
print("error", i, j)
else: # 如果拟合成功,就添加到结果里
x_hat.append(params[0])
y_hat.append(params[1])
plt.scatter(x_hat, y_hat, label="predicted", marker="x")
plt.scatter(sensor_idx, [0] * len(sensor_idx), label="sensor", marker="v")
plt.legend()
rmse_x = np.sqrt(np.mean((np.array(x_hat) - np.array(x)) ** 2))
rmse_y = np.sqrt(np.mean((np.array(y_hat) - np.array(y)) ** 2))
print("均方根误差:", rmse_x, rmse_y)
plt.show()
if __name__ == "__main__":
# light_intensity = 1000 # 远光灯光照强度为6500Lux
# X, Y = generate_observation(40, 50, 1000, 10, sensor_num=8)
# fit(X, Y)
batch_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment