PCR

class PointNet(nn.Module):
    def __init__(self):
        super(PointNet, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 4)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))  # 1D卷积层 3 -> 64、归一化、Relu
        x = torch.relu(self.bn2(self.conv2(x)))  # 1D卷积层 64->128、归一化、Relu
        x = self.bn3(self.conv3(x))  # 1D卷积层 128->1024、归一化
        x = torch.max(x, 2, keepdim=True)[0]  # 最大池化 (batch_size, 1024)
        x = x.view(-1, 1024)
        x = torch.relu(self.bn2(self.fc1(x)))
        x = torch.relu(self.bn1(self.fc2(x)))
        x = self.fc3(x).float()
        x = torch.nn.functional.normalize(x, p=2, dim=1)
        return x
class QuatLoss(nn.Module):
    def __init__(self):
        super(QuatLoss, self).__init__()

    def forward(self, y_hat, label):
        w, x, y, z = label.split(1, dim=-1)
        label_inv = torch.cat([w, -x, -y, -z], dim=-1)  # label共轭
        sa, xa, ya, za = y_hat.split(1, dim=-1)
        sb, xb, yb, zb = label_inv.split(1, dim=-1)
        loss_quat = torch.cat([
            sa * sb - xa * xb - ya * yb - za * zb,
            sa * xb + xa * sb + ya * zb - za * yb,
            sa * yb - xa * zb + ya * sb + za * xb,
            sa * zb + xa * yb - ya * xb + za * sb
        ], dim=-1)
        loss = torch.pow(torch.acos(loss_quat[:, 0]) * 2, 2)
        return torch.mean(loss)
from torch.utils.data import dataset
import open3d as o3d
import os
import numpy as np
import Quaternion as quat
n

def random_unit_quaternion():
    # 随机生成旋转角(0 到 180 度),并将其转换为弧度
    angle = np.random.rand() * np.pi * 2
    # 计算实部
    w = np.cos(angle / 2)

    # 随机生成虚部
    x, y, z = np.random.rand(3)
    v_norm = np.sqrt(x ** 2 + y ** 2 + z ** 2)

    # 单位化虚部向量
    x, y, z = x / v_norm, y / v_norm, z / v_norm

    # 计算旋转轴的缩放系数
    s = np.sin(angle / 2)

    # 根据实部和虚部组合成单位四元数
    q = np.array([w, x * s, y * s, z * s])

    # 将四元数转换为旋转矩阵
    quat_obj = quat.Quaternion.from_numpy(q)
    R = quat_obj.getRotationMatrix()

    return q, R


class LoadData(dataset.Dataset):
    def __init__(self, points_count):
        super(LoadData, self).__init__()
        folder_path = 'StandardModel'
        self.file_paths = []
        self.points_count = points_count
        self.count = 0
        self.pcds = []

        for root, dirs, files in os.walk(folder_path):
            for file in files:
                if file == '.DS_Store':
                    continue
                self.count = self.count + 1
                self.pcds.append(o3d.io.read_point_cloud(os.path.join(root, file)))
                print(os.path.join(root, file))

    def __getitem__(self, index):
        pcd = self.pcds[index % self.count]
        points_before = np.asarray(pcd.points)
        q, R = random_unit_quaternion()  # 获取随机旋转四元数
        sampled_indices = np.random.choice(points_before.shape[0], size=self.points_count, replace=False)
        points_before = points_before[sampled_indices]  # 从模型中随机采样n个点
        points_after = np.matmul(R, points_before.T).T  # 将采样到的点运用 随机旋转四元数
        label = q  # 将四元数作为标签值
        return points_after, label  # 返回随机旋转后的点云 以及对应的四元数

    def __len__(self):
        return 500  # self.count
import argparse
import time
import torch
import DataLoader
from Net import PointNet
import torch
from torch.utils.data import dataloader
import QuatLoss
import TransformLoss

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=500, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--nepoch', type=int, default=200000, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='seg', help='outputfolder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--dataset', type=str, default='StandardModel', help="dataset path")
parser.add_argument('--npoints', type=int, default=500, help="sample point counts")

opt = parser.parse_args()

net = PointNet()  # 实例化网络
net.to(device)

name = str(int(time.time()))

optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)  # 指定优化器
# optimizer = torch.optim.SGD(params=net.parameters(), lr=opt.lr, momentum=0.3)
# criterion = torch.nn.MSELoss()  # TransformLoss.TransformLoss()
# scheduler = StepLR(optimizer, step_size=1000, gamma=0.95)  # 学习率衰减
# criterion = TransformLoss.TransformLoss()
criterion = QuatLoss.QuatLoss()

train_dataset = DataLoader.LoadData(points_count=opt.npoints)
train_loader = dataloader.DataLoader(dataset=train_dataset, batch_size=opt.batchSize, shuffle=True)

loss_sum = 0
loss_count = 0

# 开始训练
for epoch in range(opt.nepoch):
    count = 0
    for points, label in train_loader:
        count += 1
        # print(str(count) + "/" + str(train_loader.batch_size))
        # points.requires_grad = True
        # label.requires_grad = True
        points = points.float().to(device)
        label = label.float().to(device)
        x = points.float().transpose(1, 2)
        # print(x.shape)
        optimizer.zero_grad()
        y_hat = net(x).float()
        loss = criterion(y_hat, label)
        loss_sum += loss.item()
        loss_count += 1
        loss.backward()
        optimizer.step()
        # scheduler.step()  # 更新学习率

    if epoch % 50 == 0:
        torch.save(net, './Save/' + name + '_epoch_' + str(epoch) + '_loss_' + str(loss_sum / loss_count) +'.module')

    if epoch % 1 == 0:
        print("epoch ", epoch, " loss ", loss_sum / loss_count, " lr ", get_lr(optimizer))
        loss_sum = 0
        loss_count = 0


torch.save(net, './Save/' + name + '.module')

Last updated