class QuatLoss(nn.Module):
def __init__(self):
super(QuatLoss, self).__init__()
def forward(self, y_hat, label):
w, x, y, z = label.split(1, dim=-1)
label_inv = torch.cat([w, -x, -y, -z], dim=-1) # label共轭
sa, xa, ya, za = y_hat.split(1, dim=-1)
sb, xb, yb, zb = label_inv.split(1, dim=-1)
loss_quat = torch.cat([
sa * sb - xa * xb - ya * yb - za * zb,
sa * xb + xa * sb + ya * zb - za * yb,
sa * yb - xa * zb + ya * sb + za * xb,
sa * zb + xa * yb - ya * xb + za * sb
], dim=-1)
loss = torch.pow(torch.acos(loss_quat[:, 0]) * 2, 2)
return torch.mean(loss)
from torch.utils.data import dataset
import open3d as o3d
import os
import numpy as np
import Quaternion as quat
n
def random_unit_quaternion():
# 随机生成旋转角(0 到 180 度),并将其转换为弧度
angle = np.random.rand() * np.pi * 2
# 计算实部
w = np.cos(angle / 2)
# 随机生成虚部
x, y, z = np.random.rand(3)
v_norm = np.sqrt(x ** 2 + y ** 2 + z ** 2)
# 单位化虚部向量
x, y, z = x / v_norm, y / v_norm, z / v_norm
# 计算旋转轴的缩放系数
s = np.sin(angle / 2)
# 根据实部和虚部组合成单位四元数
q = np.array([w, x * s, y * s, z * s])
# 将四元数转换为旋转矩阵
quat_obj = quat.Quaternion.from_numpy(q)
R = quat_obj.getRotationMatrix()
return q, R
class LoadData(dataset.Dataset):
def __init__(self, points_count):
super(LoadData, self).__init__()
folder_path = 'StandardModel'
self.file_paths = []
self.points_count = points_count
self.count = 0
self.pcds = []
for root, dirs, files in os.walk(folder_path):
for file in files:
if file == '.DS_Store':
continue
self.count = self.count + 1
self.pcds.append(o3d.io.read_point_cloud(os.path.join(root, file)))
print(os.path.join(root, file))
def __getitem__(self, index):
pcd = self.pcds[index % self.count]
points_before = np.asarray(pcd.points)
q, R = random_unit_quaternion() # 获取随机旋转四元数
sampled_indices = np.random.choice(points_before.shape[0], size=self.points_count, replace=False)
points_before = points_before[sampled_indices] # 从模型中随机采样n个点
points_after = np.matmul(R, points_before.T).T # 将采样到的点运用 随机旋转四元数
label = q # 将四元数作为标签值
return points_after, label # 返回随机旋转后的点云 以及对应的四元数
def __len__(self):
return 500 # self.count