基于PointNet网络

编辑中

表达旋转:四元数

原理简析

详细可以看SLAM14讲关于四元数的部分,如果看不懂,可以看相关资料中的虚数视频。 四元数在点云中是一个很好的表示旋转的方式。 一个三维空间点用四元数表示如下

P=[0,x,y,z]=[0,v]P = [0, x, y, z] = [0, v]

三维空间的点是个纯虚四元数(实部为0),三个虚部与空间中的三个轴对应。

怎么用四元数表示一个旋转呢?如下所示

q=[cosθ2nsinθ2]q = \left [ \cos \frac{\theta }{2}, n\sin \frac{\theta }{2} \right ]

左边的cos是实部,也就是说,实部存储的是旋转角度 右边的sin是虚部,n表示的是旋转轴,旋转轴有三个维度,所以虚部有三个数,每个数都要分别乘以一个θ/2。也就是说,虚部存放的是旋转轴,表示物体绕着哪个轴进行旋转

那么,一个点绕着一个轴进行旋转,旋转后的点怎么计算出来呢?

p=qpq1p' = qpq^{-1}

其中,p'是结果,通过原点p左乘旋转四元数,右乘旋转四元数的转置得到,结果就是新的点p的四元数数值

一个单位四元数(即模长为1的四元数)可以表示三维空间中的任意旋转,

用Numpy实现四元数运算

最终输出的结果为

第一行,这是一个四元数,表示绕x轴旋转90度。这个四元数的实部为cos(旋转角/2),虚部的第一个元素为1*sin(旋转角/2),其余元素为0,符合预期,虚部的三个元素表示旋转轴(1, 0, 0)的方向乘以sin(旋转角度/2)。 第二行,表示结果的四元数表示,第一维度为0,也就是实部为0,表示是一个点。三个虚部表示点的坐标,其中第二个虚部非常小,可以看作是0,所以得到的坐标为(0, 0, 2),符合预期。

解决思路

学习旋转角度,也就是分别绕着x轴、y轴、z轴旋转的角度,跟欧拉角差不多,但是用四元数进行计算,能够避免欧拉角的奇异性问题。

损失函数:两个四元数a和b,a乘以b的逆,能得到一个新的四元数,新的四元数表示了一个绕着新的轴进行的一次a到b的相对旋转,也就是说损失是拿这个新的四元数的实部的旋转角,旋转角越小表示两个旋转四元数越接近

网络结构

源码

四元数计算类 Quaternion

数据预处理 Dataset

  1. 随机生成单位四元数(random_unit_quaternion): 这个函数随机生成一个单位四元数表示的旋转,并将其转换为旋转矩阵。同时,生成一个随机平移向量。这个函数的主要作用是为点云数据应用随机的旋转和平移。

  2. 点云数据集类(LoadData): 这个类继承自torch.utils.data.Dataset,用于加载点云数据并进行预处理。主要功能如下:

    • __init__方法中,读取指定文件夹下的所有点云文件,并将其转换为Tensor格式存储。

    • __getitem__方法中,对点云数据进行以下处理: a. 随机选择两个点云(标准点云和被对齐点云)。 b. 随机生成旋转四元数和平移向量。 c. 对点云数据进行随机采样,选取指定数量的点。 d. 将被对齐点云应用随机旋转和平移,得到旋转后的点云。 e. 返回旋转后的点云、标准点云和对应的四元数或平移向量(根据label_is_rotation参数决定)。

    • __len__方法中,返回数据集的长度,这里设置为固定值500。

定义网络 Net

  1. PointNet(点云特征提取网络):

    • 包含三个1D卷积层和相应的批归一化层。

    • 在forward方法中,输入的点云数据依次经过三个1D卷积层和批归一化层,然后通过ReLU激活函数。

    • 最后,对结果进行最大池化操作,得到大小为(batch_size, 1024)的特征表示。

  2. RotationNet(旋转预测网络):

    • 包含一个PointNet实例用于特征提取,以及三个全连接层和相应的批归一化层。

    • 在forward方法中,首先计算输入点云的质心,然后将点云归一化到质心位置。

    • 使用PointNet提取点云特征,然后将两个点云的特征沿维度1拼接。

    • 拼接后的特征依次通过全连接层、批归一化层和ReLU激活函数。

    • 最后,经过一层全连接层并进行L2范数归一化,得到大小为(batch_size, 4)的四元数表示。

    • 保证四元数的第一个分量为正值。

  3. TranslationNet(平移预测网络):

    • 包含一个PointNet实例用于特征提取,以及三个全连接层和相应的批归一化层。

    • 在forward方法中,使用PointNet提取点云特征,然后将两个点云的特征沿维度1拼接。

    • 拼接后的特征依次通过全连接层、批归一化层和ReLU激活函数。

    • 最后,经过一层全连接层,得到大小为(batch_size, 3)的平移向量表示。

四元数损失定义 QuatLoss

这段代码定义了一个四元数损失函数(QuatLoss),它继承自torch.nn.Module。这个损失函数的目的是计算两个四元数之间的角度差异,并使用这个差异作为预测旋转的损失。

forward方法中:

  1. 将标签(label)四元数分解为实部(w)和虚部(x, y, z)。

  2. 计算标签四元数的共轭(label_inv)。

  3. 将预测四元数(y_hat)分解为实部(sa)和虚部(xa, ya, za)。

  4. 将标签共轭四元数(label_inv)分解为实部(sb)和虚部(xb, yb, zb)。

  5. 计算两个四元数的乘积(loss_quat)。

  6. 提取乘积四元数中的实部(cos_angle),它表示输入四元数之间的夹角的余弦值。

  7. 计算损失值:1 减去余弦值的平方(loss)。

  8. 返回损失值的均值。

这个损失函数的作用是衡量预测四元数(y_hat)与标签四元数(label)之间的角度差异。当预测四元数与标签四元数之间的角度相差较大时,损失值较大;当它们之间的角度相差较小时,损失值较小。这有助于神经网络在训练过程中学习到正确的旋转信息。

训练旋转四元数 TrainQuaternion

  1. 设置命令行参数以配置训练过程。

  2. 初始化设备为GPU或CPU,取决于可用硬件。

  3. 创建并配置平移网络模型、优化器、损失函数和学习率衰减策略。

  4. 加载预训练的RotationNet模型,用于将点云旋转到正确的位姿。

  5. 设置数据集和数据加载器。

  6. 创建模型保存路径。

  7. 进行训练循环,每个epoch中:

    • 遍历数据集的批次。

    • 将数据移动到指定设备(GPU或CPU)上。

    • 使用RotationNet网络调整输入点云的位姿。

    • 执行正向传播和反向传播。

    • 更新优化器和学习率。

    • 每隔一定步长保存一次模型

训练平移向量 TrainTranslateVector

  1. 导入所需库和模块,包括PyTorch、自定义的DataLoader、Net等。

  2. 定义get_lr()函数,用于获取优化器的当前学习率。

  3. 设定训练使用的设备(GPU或CPU)。

  4. 定义命令行参数,包括批次大小、工作进程数、训练周期数、输出文件夹等。

  5. 创建TranslationNet实例并将其加载到相应设备上。

  6. 加载预训练的RotationNet模型。

  7. 配置优化器、损失函数和学习率衰减策略。

  8. 设置DataLoader来加载训练数据集。

  9. 创建模型保存路径。

  10. 初始化损失值记录变量。

  11. 进行训练循环,每个循环执行以下操作:

    • 遍历数据集的批次。

    • 将数据移动到指定设备(GPU或CPU)。

    • 使用RotationNet网络调整点云的位姿。

    • 执行正向传播、计算损失、执行反向传播。

    • 更新优化器和学习率。

    • 每隔一定周期保存模型和输出损失值。

测试代码 Test

这段代码实现了使用训练好的RotationNet和TranslationNet进行点云对齐。主要步骤如下:

  1. 导入所需库,包括Open3D、NumPy、PyTorch等。

  2. 加载训练好的RotationNet和TranslationNet模型。

  3. 设置测试数据集和DataLoader。

  4. 定义四元数损失函数。

  5. 设置是否使用ICP算法(迭代最近点算法)进行细对齐。

  6. 遍历测试数据集的批次,执行以下操作:

    • 将输入点云和标准点云传入RotationNet,得到预测的旋转四元数。

    • 使用预测的旋转四元数对输入点云进行旋转。

    • 将旋转后的点云和标准点云传入TranslationNet,得到预测的平移向量。

    • 使用预测的平移向量对输入点云进行平移。

    • 如果使用ICP算法,则对旋转和平移后的点云进行细对齐。

    • 将输入点云、对齐后的点云和标准点云可视化并显示。

在这个脚本中,使用Open3D库进行点云的可视化和处理。通过预测的旋转四元数和平移向量,对输入点云进行旋转和平移,以实现与标准点云的对齐。如果启用了ICP算法,还会对旋转和平移后的点云进行细对齐。最后,将输入点云、对齐后的点云和标准点云可视化并显示。

Last updated