# PCR

```python
class PointNet(nn.Module):
    def __init__(self):
        super(PointNet, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 4)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))  # 1D卷积层 3 -> 64、归一化、Relu
        x = torch.relu(self.bn2(self.conv2(x)))  # 1D卷积层 64->128、归一化、Relu
        x = self.bn3(self.conv3(x))  # 1D卷积层 128->1024、归一化
        x = torch.max(x, 2, keepdim=True)[0]  # 最大池化 （batch_size, 1024）
        x = x.view(-1, 1024)
        x = torch.relu(self.bn2(self.fc1(x)))
        x = torch.relu(self.bn1(self.fc2(x)))
        x = self.fc3(x).float()
        x = torch.nn.functional.normalize(x, p=2, dim=1)
        return x
```

```python
class QuatLoss(nn.Module):
    def __init__(self):
        super(QuatLoss, self).__init__()

    def forward(self, y_hat, label):
        w, x, y, z = label.split(1, dim=-1)
        label_inv = torch.cat([w, -x, -y, -z], dim=-1)  # label共轭
        sa, xa, ya, za = y_hat.split(1, dim=-1)
        sb, xb, yb, zb = label_inv.split(1, dim=-1)
        loss_quat = torch.cat([
            sa * sb - xa * xb - ya * yb - za * zb,
            sa * xb + xa * sb + ya * zb - za * yb,
            sa * yb - xa * zb + ya * sb + za * xb,
            sa * zb + xa * yb - ya * xb + za * sb
        ], dim=-1)
        loss = torch.pow(torch.acos(loss_quat[:, 0]) * 2, 2)
        return torch.mean(loss)
```

```python
from torch.utils.data import dataset
import open3d as o3d
import os
import numpy as np
import Quaternion as quat
n

def random_unit_quaternion():
    # 随机生成旋转角（0 到 180 度），并将其转换为弧度
    angle = np.random.rand() * np.pi * 2
    # 计算实部
    w = np.cos(angle / 2)

    # 随机生成虚部
    x, y, z = np.random.rand(3)
    v_norm = np.sqrt(x ** 2 + y ** 2 + z ** 2)

    # 单位化虚部向量
    x, y, z = x / v_norm, y / v_norm, z / v_norm

    # 计算旋转轴的缩放系数
    s = np.sin(angle / 2)

    # 根据实部和虚部组合成单位四元数
    q = np.array([w, x * s, y * s, z * s])

    # 将四元数转换为旋转矩阵
    quat_obj = quat.Quaternion.from_numpy(q)
    R = quat_obj.getRotationMatrix()

    return q, R


class LoadData(dataset.Dataset):
    def __init__(self, points_count):
        super(LoadData, self).__init__()
        folder_path = 'StandardModel'
        self.file_paths = []
        self.points_count = points_count
        self.count = 0
        self.pcds = []

        for root, dirs, files in os.walk(folder_path):
            for file in files:
                if file == '.DS_Store':
                    continue
                self.count = self.count + 1
                self.pcds.append(o3d.io.read_point_cloud(os.path.join(root, file)))
                print(os.path.join(root, file))

    def __getitem__(self, index):
        pcd = self.pcds[index % self.count]
        points_before = np.asarray(pcd.points)
        q, R = random_unit_quaternion()  # 获取随机旋转四元数
        sampled_indices = np.random.choice(points_before.shape[0], size=self.points_count, replace=False)
        points_before = points_before[sampled_indices]  # 从模型中随机采样n个点
        points_after = np.matmul(R, points_before.T).T  # 将采样到的点运用 随机旋转四元数
        label = q  # 将四元数作为标签值
        return points_after, label  # 返回随机旋转后的点云 以及对应的四元数

    def __len__(self):
        return 500  # self.count
```

```python
import argparse
import time
import torch
import DataLoader
from Net import PointNet
import torch
from torch.utils.data import dataloader
import QuatLoss
import TransformLoss

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=500, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--nepoch', type=int, default=200000, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='seg', help='outputfolder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--dataset', type=str, default='StandardModel', help="dataset path")
parser.add_argument('--npoints', type=int, default=500, help="sample point counts")

opt = parser.parse_args()

net = PointNet()  # 实例化网络
net.to(device)

name = str(int(time.time()))

optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)  # 指定优化器
# optimizer = torch.optim.SGD(params=net.parameters(), lr=opt.lr, momentum=0.3)
# criterion = torch.nn.MSELoss()  # TransformLoss.TransformLoss()
# scheduler = StepLR(optimizer, step_size=1000, gamma=0.95)  # 学习率衰减
# criterion = TransformLoss.TransformLoss()
criterion = QuatLoss.QuatLoss()

train_dataset = DataLoader.LoadData(points_count=opt.npoints)
train_loader = dataloader.DataLoader(dataset=train_dataset, batch_size=opt.batchSize, shuffle=True)

loss_sum = 0
loss_count = 0

# 开始训练
for epoch in range(opt.nepoch):
    count = 0
    for points, label in train_loader:
        count += 1
        # print(str(count) + "/" + str(train_loader.batch_size))
        # points.requires_grad = True
        # label.requires_grad = True
        points = points.float().to(device)
        label = label.float().to(device)
        x = points.float().transpose(1, 2)
        # print(x.shape)
        optimizer.zero_grad()
        y_hat = net(x).float()
        loss = criterion(y_hat, label)
        loss_sum += loss.item()
        loss_count += 1
        loss.backward()
        optimizer.step()
        # scheduler.step()  # 更新学习率

    if epoch % 50 == 0:
        torch.save(net, './Save/' + name + '_epoch_' + str(epoch) + '_loss_' + str(loss_sum / loss_count) +'.module')

    if epoch % 1 == 0:
        print("epoch ", epoch, " loss ", loss_sum / loss_count, " lr ", get_lr(optimizer))
        loss_sum = 0
        loss_count = 0


torch.save(net, './Save/' + name + '.module')

```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://akitumn.gitbook.io/pointcloudwork/pcr.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
