Created
March 14, 2025 11:50
-
-
Save slinton5/f1fb6cfed2ad57c9a673ff287b5ebc2e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
import os | |
import imageio | |
from mpl_toolkits.mplot3d import Axes3D | |
# 设备配置 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 物理参数配置(已优化) | |
class Config: | |
# 三维场景参数 | |
voxel_size = 64 # 降低分辨率加速计算 | |
scene_size = 1.0 # 缩小场景尺寸 | |
relay_z = 0.0 # 中介面z坐标 | |
# 激光参数 | |
pulse_fwhm = 100e-12 # 增大脉冲宽度 | |
wavelength = 532e-9 | |
beam_waist = 0.1 # 增大初始束腰 | |
divergence = 0.005 # 调整发散角 | |
# SPAD参数 | |
gate_width = 10e-9 # 增大门宽 | |
gate_steps = 50 # 减少时间门数量 | |
gate_step = 200e-12 # 增大时间步长 | |
# 材料参数(已优化) | |
materials = { | |
'mirror': {'F0': 0.98, 'roughness': 0.05, 'k': 0.02}, | |
'rough': {'F0': 0.4, 'roughness': 0.8, 'k': 0.1} | |
} | |
def gaussian_beam_intensity(X, Y, Z, laser_pos): | |
"""改进的激光扩束模型""" | |
# 计算相对位置 | |
dx = X - laser_pos[0] | |
dy = Y - laser_pos[1] | |
dz = Z - laser_pos[2] | |
# 计算传播距离(添加最小量防止除零) | |
r = torch.sqrt(dx ** 2 + dy ** 2 + dz ** 2 + 1e-8) | |
# 计算束腰半径 | |
w_z = Config.beam_waist * torch.sqrt(1 + (r * Config.divergence) ** 2) | |
# 计算横向强度 | |
radial = (dx ** 2 + dy ** 2) / (w_z ** 2 + 1e-8) | |
intensity = (Config.beam_waist / w_z) ** 2 * torch.exp(-2 * radial) | |
return intensity.clamp(0, 1) | |
def cook_torrance_brdf(wi, wo, normal, material_params): | |
F0 = material_params['F0'] | |
alpha = material_params['roughness'] ** 2 + 1e-8 | |
k = material_params['k'] | |
# 安全计算半角向量 | |
h = (wi + wo) / (torch.norm(wi + wo, dim=-1, keepdim=True) + 1e-8) | |
# 安全计算点积 | |
ndoth = torch.sum(normal * h, dim=-1).clamp(0, 1) | |
ndotwi = torch.sum(normal * wi, dim=-1).clamp(0, 1) | |
ndotwo = torch.sum(normal * wo, dim=-1).clamp(0, 1) | |
# 菲涅尔项 | |
F = F0 + (1 - F0) * (1 - ndotwi) ** 5 | |
# 几何遮挡项(修正后的计算) | |
G1 = (2 * ndoth * ndotwo) / (ndotwo + 1e-8) | |
G2 = (2 * ndoth * ndotwi) / (ndotwi + 1e-8) | |
G = torch.minimum(G1, G2).clamp(max=1.0) # 先取两值较小者,再限制最大值 | |
# 微表面分布项 | |
ndoth_safe = ndoth.clamp(min=1e-6) | |
D_numerator = torch.exp(-(torch.acos(ndoth_safe) ** 2) / alpha) | |
D_denominator = np.pi * alpha * ndoth_safe ** 4 + 1e-8 | |
D = D_numerator / D_denominator | |
# 最终BRDF计算 | |
denominator = 4 * ndotwi * ndotwo + 1e-8 | |
brdf = (F * G * D) / denominator | |
# 处理异常值 | |
return torch.nan_to_num(brdf, nan=0.0, posinf=0.0, neginf=0.0).clamp(max=1.0) | |
def generate_voxel_scene(): | |
"""生成可见测试场景""" | |
scene = torch.zeros((Config.voxel_size,) * 3, device=device) | |
# 在中心创建立方体 | |
size = Config.voxel_size // 4 | |
start = Config.voxel_size // 2 - size // 2 | |
end = Config.voxel_size // 2 + size // 2 | |
scene[start:end, start:end, start:end] = 0.8 | |
# 添加调试信息 | |
print(f"Scene contains {torch.sum(scene > 0).item()} active voxels") | |
return scene | |
def simulate_nlos(scene, material): | |
"""带调试输出的仿真函数""" | |
# 坐标网格 | |
x = torch.linspace(-0.5, 0.5, Config.voxel_size, device=device) | |
X, Y, Z = torch.meshgrid(x, x, x, indexing='ij') | |
# 激光/SPAD位置(共焦) | |
laser_pos = torch.tensor([0, 0, -0.1], device=device) | |
# 计算光束强度 | |
beam = gaussian_beam_intensity(X, Y, Z, laser_pos) | |
print(f"Beam intensity: {beam.min().item():.2e} - {beam.max().item():.2e}") | |
# 三次反射路径 | |
d1 = torch.sqrt((X - laser_pos[0]) ** 2 + | |
(Y - laser_pos[1]) ** 2 + | |
(Z - laser_pos[2]) ** 2) | |
d2 = 2 * torch.sqrt(X ** 2 + Y ** 2 + Z ** 2) # 中介面反射 | |
d3 = d1 # 返回路径 | |
total_time = (d1 + d2 + d3) / 3e8 | |
# BRDF计算 | |
normal = torch.tensor([0, 0, 1], device=device, dtype=torch.float32) | |
wi = wo = torch.tensor([0, 0, -1], device=device, dtype=torch.float32) | |
brdf = cook_torrance_brdf(wi, wo, normal, Config.materials[material]) | |
print(f"BRDF values: {brdf.min().item():.2e} - {brdf.max().item():.2e}") | |
# 时间门积分 | |
measurement = torch.zeros((Config.voxel_size, Config.voxel_size), device=device) | |
time_points = torch.arange(0, Config.gate_steps * Config.gate_step, Config.gate_step) | |
for t in tqdm(time_points, desc=f"Simulating {material}"): | |
# 时间响应 | |
time_window = torch.exp(-(total_time - t) ** 2 / (2 * (Config.pulse_fwhm / 2.3548) ** 2)) | |
# 累积信号 | |
signal = scene * beam * brdf * time_window | |
measurement += signal.sum(dim=2) | |
# 调试输出 | |
print(f"Raw measurement range: {measurement.min().item():.2e} - {measurement.max().item():.2e}") | |
# 安全归一化 | |
if measurement.max() > 0: | |
measurement = (measurement / measurement.max() * 65535).clamp(0, 65535) | |
else: | |
print("Warning: Empty measurement!") | |
measurement = torch.zeros_like(measurement) | |
return measurement.cpu().numpy().astype(np.uint16) | |
class NLOSDataset(Dataset): | |
def __init__(self, num_samples=100): | |
self.num_samples = num_samples | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
scene = generate_voxel_scene() | |
# 模拟两种材质 | |
mirror = simulate_nlos(scene, 'mirror') | |
rough = simulate_nlos(scene, 'rough') | |
# 保存样本 | |
os.makedirs("dataset", exist_ok=True) | |
imageio.imwrite(f"dataset/{idx}_mirror.png", mirror) | |
imageio.imwrite(f"dataset/{idx}_rough.png", rough) | |
return rough, mirror | |
if __name__ == "__main__": | |
# 验证单个样本 | |
test_scene = generate_voxel_scene() | |
# 可视化场景 | |
plt.figure() | |
plt.imshow(test_scene.sum(dim=2).cpu().numpy(), cmap='gray') | |
plt.title("Test Scene Projection") | |
plt.savefig("scene_visualization.png") | |
# 生成测试数据 | |
print("\nGenerating test sample...") | |
mirror_img = simulate_nlos(test_scene, 'mirror') | |
rough_img = simulate_nlos(test_scene, 'rough') | |
# 显示结果 | |
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | |
ax[0].imshow(mirror_img, cmap='gray') | |
ax[0].set_title("Mirror Relay") | |
ax[1].imshow(rough_img, cmap='gray') | |
ax[1].set_title("Rough Relay") | |
plt.savefig("result_comparison.png") | |
print("\n验证输出文件:") | |
print("1. 场景投影图: scene_visualization.png") | |
print("2. 结果对比图: result_comparison.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment