Commit 76a3cfe1 authored by Isaak Lim's avatar Isaak Lim

init

parents
# A Convolutional Decoder for Point Clouds
This repository contains the code for the paper "A Convolutional Decoder for Point Clouds using Adaptive Instance Normalization"
@article{Lim:2019:ConvolutionalDecoder,
author = "Lim, Isaak and Ibing, Moritz and Kobbelt, Leif",
title = "A Convolutional Decoder for Point Clouds using Adaptive Instance Normalization",
journal = "Computer Graphics Forum",
volume = 38,
number = 5,
year = 2019
}
The paper as well as further data can be found at the [project page][page]
## Usage
We used our own sampling of the [Shapenet][shapenet] Dataset, which can be downloaded under this [link][sampling].
To evaluate a model run
python test_autoencoder.py --model-path pretrained/model_full_transformations.state --data-path "data/shapenet_v2_16000.hdf5"
--split-path "data/core_train_val_test.csv"
pretrained weights for our models trained with 3 or 9 layers of adaptive instance normalization can be found under models/pretrained/
When evaluating the version with only 3 used layers add the option --adain-layer 3
to train a new model run
python train_autoencoder.py --run-name full --data-path "data/shapenet_v2_16000.hdf5"
--split-path "data/core_train_val_test.csv"
the weights will be saved under models/, whereas the progress will be saved under runs/
There are further command line options to specify the used losses etc.
[shapenet]: https://www.shapenet.org/
[sampling]: https://www.graphics.rwth-aachen.de/blockstorage/free/shapenet_v2_16000.zip
[page]: https://graphics.rwth-aachen.de/publication/03303
import h5py
import numpy as np
import pandas as pd
import torch.utils.data
class Shapenet(torch.utils.data.Dataset):
def __init__(self, filename, split_file, mode='train', normalization=None, transform=None):
super().__init__()
assert mode in ['train', 'val', 'test', 'complete'], "mode '%s' is not supported" % mode
self.mode = mode
self.transform = transform
if mode in ['train', 'val', 'test']:
split = pd.read_csv(split_file)
ids = split.loc[split['split'] == mode]['modelId']
ids = set(ids)
with h5py.File(filename, "r") as f:
indices = []
for i, x in enumerate(f['model_id'][:].tolist()):
if x in ids:
indices.append(i)
self.data = torch.from_numpy(f['points'][indices])
self.label = f['synset_id'][indices]
else:
with h5py.File(filename, "r") as f:
self.data = torch.from_numpy(f['points'][:])
self.label = f['synset_id'][:]
label_dict = {key: value for (value, key) in enumerate(np.unique(self.label))}
self.label = torch.from_numpy(np.vectorize(label_dict.get)(self.label))
if normalization is not None:
self.data = torch.stack([normalization(self.data[i]) for i in range(self.data.shape[0])])
self.cls_count = torch.zeros(self.label.max().item() + 1, dtype=torch.int)
for i in range(self.label.max().item() + 1):
self.cls_count[i] = (self.label == i).sum().int()
self.model_ids = torch.tensor(range(self.data.shape[0]))
def __getitem__(self, index):
out = self.data[self.model_ids[index]]
cls_label = self.label[self.model_ids[index]]
if self.transform is not None:
out = self.transform(out)
return out, cls_label
def __len__(self):
return self.model_ids.shape[0]
def n_classes(self):
return self.label.max().item() + 1
def restrict_class(self, classes):
if not classes:
self.model_ids = torch.tensor(range(self.data.shape[0]))
else:
self.model_ids = torch.cat([(self.label == i).nonzero() for i in classes], dim=0).squeeze(1)
def normalize(points):
result = points - torch.mean(points, dim=1).unsqueeze(1).expand_as(points)
rad = torch.max(torch.norm(result, dim=0))
return result / rad
def normalize_unit_cube(points):
bb_max = points.max(-1)[0]
bb_min = points.min(-1)[0]
length = (bb_max - bb_min).max()
mean = (bb_max + bb_min) / 2.0
scale = 1.0 / length
res = (points - mean.unsqueeze(1)) * scale
return res.clamp(-0.5, 0.5)
def subsample(points, n=1024):
return points[:, :n]
def random_subsample(points, n=1024):
perm = torch.randperm(points.shape[1])[:n]
return torch.index_select(points, 1, perm)
def random_jitter(points, sigma=0.01, clip=0.05):
jittered_data = torch.clamp(sigma * torch.randn(points.shape), -1 * clip, clip)
jittered_data += points
return jittered_data
def random_scale(points, min_s=0.66, max_s=1.5, anisotrop=False):
if anisotrop:
s = torch.FloatTensor(3).uniform_(min_s, max_s).unsqueeze(1).expand_as(points)
else:
s = torch.FloatTensor(1).uniform_(min_s, max_s).unsqueeze(1).expand_as(points)
res = points * s
return res
def random_translation(points, max_t=0.2):
translation = torch.FloatTensor(3).uniform_(-max_t, max_t).unsqueeze(1).expand_as(points)
res = points + translation
return res
def get_rotation_matrix(angles):
s = torch.sin(angles)
c = torch.cos(angles)
rot_m = torch.zeros(3, 3).to(angles)
rot_m[0, 0] = c[1] * c[2]
rot_m[0, 1] = -c[1] * s[2]
rot_m[0, 2] = s[1]
rot_m[1, 0] = c[0] * s[2] + c[2] * s[0] * s[1]
rot_m[1, 1] = c[0] * c[2] - s[0] * s[1] * s[2]
rot_m[1, 2] = -c[1] * s[0]
rot_m[2, 0] = s[0] * s[2] - c[0] * c[2] * s[1]
rot_m[2, 1] = c[2] * s[0] + c[0] * s[1] * s[2]
rot_m[2, 2] = c[0] * c[1]
return rot_m
def random_rotation(points, normalize=True):
center = (points.max(1)[0] + points.min(1)[0]) * 0.5
angles = torch.rand(3) * 2 * np.pi
rot_mat = get_rotation_matrix(angles)
# points are 3 x n
r_points = rot_mat @ (points - center.unsqueeze(1))
if normalize:
r_points = normalize_unit_cube(r_points)
return r_points
def random_transform(points):
return random_jitter(random_scale(random_translation(points)))
def batch_augmentation(points, augmentation):
tList = [augmentation(m) for m in torch.unbind(points, dim=0)]
res = torch.stack(tList, dim=0)
return res
import torch
# input is expected in form (b x) c x n
def dist_mat_squared(x, y):
assert x.dim() == 3 or x.dim() == 2
xx = torch.sum(x ** 2, dim=-2).unsqueeze(-1)
yy = torch.sum(y ** 2, dim=-2)
if x.dim() == 3:
yy = yy.unsqueeze(-2)
dists = torch.bmm(x.transpose(2, 1), y)
else:
dists = torch.matmul(x.t(), y)
dists *= -2
dists += yy
dists += xx
return dists
def dist_norm_p(x, y, p=2):
d = dist_mat_squared(x, y)
if x.dim() == 2:
dists_1 = (x - y[:, d.min(-1)[1]]).norm(dim=0, p=p)
dists_2 = (x[:, d.min(-2)[1]] - y).norm(dim=0, p=p)
else: # dim is 3
b_d_ind_1 = d.min(-1)[1]
b_d_ind_2 = d.min(-2)[1]
b_dists_1 = []
b_dists_2 = []
for i in range(b_d_ind_1.shape[0]):
b_dists_1.append((x[i] - y[i, :, b_d_ind_1[i]]).norm(dim=0, p=p))
b_dists_2.append((x[i, :, b_d_ind_2[i]] - y[i]).norm(dim=0, p=p))
dists_1 = torch.stack(b_dists_1)
dists_2 = torch.stack(b_dists_2)
return dists_1, dists_2
def dist_norm(x, y, p=2, points_p=2):
dists_1, dists_2 = dist_norm_p(x, y, points_p)
return dists_1.norm(p=p, dim=-1) + dists_2.norm(p=p, dim=-1)
def chamfer(x, y, weights_x=None, weights_y=None):
d = dist_mat_squared(x, y)
dist1 = d.min(-1)[0]
dist2 = d.min(-2)[0]
if weights_x is not None:
dist1 = dist1 * weights_x * x.shape[-1]
if weights_y is not None:
dist2 = dist2 * weights_y * y.shape[-1]
return dist1.mean(-1) + dist2.mean(-1)
import math
import torch
import torch.nn as nn
import util
class GridEncoder(nn.Module):
def __init__(self, prep, grid_size):
super(self.__class__, self).__init__()
self.grid_size = grid_size
self.preprocessing = prep
def initialize_grid_ball(self, x):
if x.dim() == 2:
x = x.unsqueeze(0)
# input is expected to be in range -0.5 - 0.5
assert (x.min() >= -0.5)
assert (x.min() <= 0.5)
# bring vector into range -0.5 - grid_size-0.5
reshaped = (x + 0.5) * self.grid_size - 0.5
ind1 = reshaped.floor().clamp(0.0, self.grid_size - 1)
ind2 = reshaped.ceil().clamp(0.0, self.grid_size - 1)
ind = [torch.cat([ind1[:, 0, :].unsqueeze(1), ind1[:, 1, :].unsqueeze(1), ind1[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind1[:, 0, :].unsqueeze(1), ind1[:, 1, :].unsqueeze(1), ind2[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind1[:, 0, :].unsqueeze(1), ind2[:, 1, :].unsqueeze(1), ind1[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind1[:, 0, :].unsqueeze(1), ind2[:, 1, :].unsqueeze(1), ind2[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind2[:, 0, :].unsqueeze(1), ind1[:, 1, :].unsqueeze(1), ind1[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind2[:, 0, :].unsqueeze(1), ind1[:, 1, :].unsqueeze(1), ind2[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind2[:, 0, :].unsqueeze(1), ind2[:, 1, :].unsqueeze(1), ind1[:, 2, :].unsqueeze(1)], dim=1),
torch.cat([ind2[:, 0, :].unsqueeze(1), ind2[:, 1, :].unsqueeze(1), ind2[:, 2, :].unsqueeze(1)], dim=1)]
ind = torch.stack(ind, dim=-1)
# generate offset vectors
res = reshaped.unsqueeze(-1).repeat([1, 1, 1, 8]) - ind
# reshape indices
ind = ind[:, 0, :, :] * self.grid_size * self.grid_size + ind[:, 1, :, :] * self.grid_size + ind[:, 2, :, :]
ind = ind.long()
# binary weight to check wether point is in gridball
dist = res.norm(dim=1).detach()
weight = (dist < 0.87).float().detach() # half the diagonal of a grid cube
return res, weight, ind
def forward(self, x, per_point_features=False):
b, _, n = x.size()
# for each point find 8 nearest gridcells
res, weight, indices = self.initialize_grid_ball(x) # b x 3 x n x 8
res = self.preprocessing(res) # b x c x n x k
if per_point_features:
per_point_f = res.clone().view(res.shape[0], -1, res.shape[2])
cell_indices = indices.clone()
c = res.shape[1]
weight = weight.unsqueeze(1).expand_as(res)
res = res * weight # zero out weights of points outside of ball
# sum up features of points inside ball
x = torch.zeros(b, c, self.grid_size * self.grid_size * self.grid_size).to(res.device)
count = torch.zeros(b, c, self.grid_size * self.grid_size * self.grid_size).to(x)
res = res.contiguous().view(b, c, 8 * n)
weight = weight.contiguous().view(b, c, 8 * n)
indices = indices.view(b, -1)
indices.clamp_(0, self.grid_size ** 3)
for i in range(b):
x[i].index_add_(1, indices[i], res[i])
count[i].index_add_(1, indices[i], weight[i])
# number of points should have no effect
count = torch.max(count, torch.tensor([1.0]).to(weight.device))
x /= count
x = x.view(b, -1, self.grid_size, self.grid_size, self.grid_size) # b x c x grid_size x grid_size x grid_size
if per_point_features:
return x, per_point_f, cell_indices
else:
return x
# Generate points on a box grid given generator parameters for each cell and the number of points
class PointCloudGenerator(nn.Module):
def __init__(self, generator, rnd_dim=2, res=16):
super(self.__class__, self).__init__()
self.base_dim = rnd_dim
self.generator = generator
grid = util.meshgrid(res)
self.o = (((grid + 0.5) / res) - 0.5).view(3, -1)
self.s = res
def forward(self, x, dens, n_points):
b, c, g, _, _ = x.shape
self.o = self.o.to(x.device)
# Sample Density
n = util.densSample(dens, n_points)
# We call self.generator with the corresponding box descriptor and 2 random features for each point in the cell
# The output is then offset to the correct position in the grid
# this function is only efficient if the maximum number of points per grid cell is small
n = n.view(b, -1)
x = x.view(b, c, -1)
gen_inp = []
gen_off = []
for i in range(b):
indices = [] # of cells, inserted as many times as number of wanted points
for j in range(1, n[i].max() + 1):
ind = (n[i] >= j).nonzero().squeeze(-1)
indices.append(ind)
indices = torch.cat(indices)
x_ind = x[i, :, indices]
o_ind = self.o[:, indices]
b_rnd = torch.rand(self.base_dim, n_points).to(x_ind.device) * 2.0 - 1.0
b_inp = torch.cat([x_ind, b_rnd], dim=0)
gen_inp.append(b_inp)
gen_off.append(o_ind)
gen_inp = torch.stack(gen_inp)
gen_off = torch.stack(gen_off)
out = self.generator(gen_inp)
norm = out.norm(dim=1)
reg = (norm - (math.sqrt(3) / self.s)).clamp(0) # twice the size needed to cover a grid-cell
return out + gen_off, reg
def forward_fixed_pattern(self, x, dens, n):
b, c, g, _, _ = x.shape
self.o = self.o.to(x.device)
N = util.densSample(dens, n)
# We call self.generator with the corresponding box descriptor and 2 random features for each point in the cell
# The output is then offset to the correct position in the grid
# this function is only efficient if the maximum number of points per grid cell is small
N = N.view(b, -1)
x = x.view(b, c, -1)
gen_inp = []
gen_off = []
for i in range(b):
batch_inp = []
batch_off = []
for j in range(1, N.max() + 1):
ind = (N[i] == j).nonzero().squeeze(-1)
if ind.shape[0] is not 0:
x_ind = x[i, :, ind].repeat([1, j])
o_ind = self.o[:, ind].repeat([1, j])
b_rnd = util.fixed_sample(j, ind.shape[0]).to(x_ind) * 2.0 - 1.0
b_inp = torch.cat([x_ind, b_rnd], dim=0)
batch_inp.append(b_inp)
batch_off.append(o_ind)
gen_inp.append(torch.cat(batch_inp, dim=1))
gen_off.append(torch.cat(batch_off, dim=1))
gen_inp = torch.stack(gen_inp)
gen_off = torch.stack(gen_off)
out = self.generator(gen_inp)
norm = out.norm(dim=1)
reg = (norm - (math.sqrt(3) / (self.s))).clamp(0) # twice the size needed to cover a gridcell
return out + gen_off, reg
class AdaptiveDecoder(nn.Module):
def __init__(self, decoder, n_classes=None, max_layer=None):
super(self.__class__, self).__init__()
assert (isinstance(decoder, nn.ModuleList))
self.decoder = decoder
self.slices = []
self.norm_indices = []
self.conditional = n_classes is not None
first = True
for i, l in enumerate(self.decoder):
if isinstance(l, nn.InstanceNorm3d):
if first:
if self.conditional:
self.inp = nn.Linear(n_classes, l.num_features * 2 * 2 * 2)
else:
self.inp = nn.Parameter(torch.randn([1, l.num_features, 2, 2, 2]))
first = False
self.norm_indices.append(i)
self.slices.append(l.num_features * 2)
if max_layer is None:
self.max_layer = len(self.norm_indices)
else:
self.max_layer = max_layer
def forward(self, w, cls=None):
size = 0
j = 0
b = w.shape[0]
if self.conditional:
x = self.inp(cls).view(b, -1, 2, 2, 2) # in case of condition cls is expected to be a one-hot vector
else:
x = self.inp.repeat([b, 1, 1, 1, 1])
for i, l in enumerate(self.decoder):
x = l(x)
if j < self.max_layer and i == self.norm_indices[j]:
s = w[:, size:size + self.slices[j], None, None, None]
size += self.slices[j]
x = x * s[:, :self.slices[j] // 2]
x = x + s[:, self.slices[j] // 2:]
j += 1
return x
import torch
import torch.nn as nn
import torch.nn.functional as fn
import layer
class GridAutoEncoderAdaIN(nn.Module):
def __init__(self, rnd_dim=2, h_dim=62, enc_p=0, dec_p=0, adain_layer=None, filled_cls=True):
super().__init__()
self.grid_size = 32
self.filled_cls = filled_cls
self.grid_encoder = layer.GridEncoder(
nn.Sequential(
nn.Conv2d(3, 8, 1, bias=False),
nn.BatchNorm2d(8),
nn.ELU(),
nn.Conv2d(8, 16, 1, bias=False),
nn.BatchNorm2d(16),
nn.ELU(),
nn.Conv2d(16, 32, 1, bias=False),
nn.BatchNorm2d(32),
nn.ELU(),
nn.Conv2d(32, 32, 1, bias=False),
nn.BatchNorm2d(32)),
self.grid_size)
self.encoder = nn.Sequential(
nn.Conv3d(32, 64, 3, padding=1, bias=False),
nn.Dropout3d(enc_p),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(64, 64, 3, padding=1, bias=False),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(64, 64, 3, padding=1, bias=False),
nn.BatchNorm3d(64),
nn.ELU(),
nn.MaxPool3d(2), # 16
nn.Conv3d(64, 128, 3, padding=1, bias=False),
nn.BatchNorm3d(128),
nn.ELU(),
nn.Conv3d(128, 128, 3, padding=1, bias=False),
nn.BatchNorm3d(128),
nn.ELU(),
nn.MaxPool3d(2), # 8
nn.Conv3d(128, 256, 3, padding=1, bias=False),
nn.BatchNorm3d(256),
nn.ELU(),
nn.Conv3d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm3d(256),
nn.ELU(),
nn.MaxPool3d(2), # 4
nn.Conv3d(256, 512, 3, padding=1, bias=False),
nn.BatchNorm3d(512),
nn.ELU(),
nn.Conv3d(512, 512, 3, padding=1, bias=False),
nn.BatchNorm3d(512),
nn.ELU(),
nn.MaxPool3d(2), # 2
nn.Conv3d(512, 512, 3, padding=1, bias=False),
nn.BatchNorm3d(512),
nn.ELU(),
nn.Conv3d(512, 1024, 2, padding=0, bias=False),
nn.BatchNorm3d(1024),
nn.ELU(),
)
self.decoder = layer.AdaptiveDecoder(nn.ModuleList([
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(512),
nn.Conv3d(512, 512, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(512),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='trilinear'), # 4
nn.Conv3d(512, 512, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(512),
nn.ELU(),
nn.Conv3d(512, 256, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(256),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='trilinear'), # 8
nn.Conv3d(256, 256, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(256),
nn.ELU(),
nn.Conv3d(256, 128, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(128),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='trilinear'), # 16
nn.Conv3d(128, 128, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(128),
nn.ELU(),
nn.Conv3d(128, 64, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(64),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='trilinear'), # 32
nn.Conv3d(64, 64, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(64),
nn.ELU(),
nn.Conv3d(64, h_dim, 3, padding=1, bias=False),
nn.Dropout3d(dec_p),
nn.InstanceNorm3d(h_dim)
]), max_layer=adain_layer)
self.generator = layer.PointCloudGenerator(
nn.Sequential(nn.Conv1d(h_dim + rnd_dim, 64, 1),
nn.ELU(),
nn.Conv1d(64, 64, 1),
nn.ELU(),
nn.Conv1d(64, 32, 1),
nn.ELU(),
nn.Conv1d(32, 32, 1),
nn.ELU(),
nn.Conv1d(32, 16, 1),
nn.ELU(),
nn.Conv1d(16, 16, 1),
nn.ELU(),
nn.Conv1d(16, 8, 1),
nn.ELU(),
nn.Conv1d(8, 3, 1)),
rnd_dim=rnd_dim, res=self.grid_size)
self.density_estimator = nn.Sequential(
nn.Conv3d(h_dim, 16, 1, bias=False),
nn.BatchNorm3d(16),
nn.ELU(),
nn.Conv3d(16, 8, 1, bias=False),
nn.BatchNorm3d(8),
nn.ELU(),
nn.Conv3d(8, 4, 1, bias=False),
nn.BatchNorm3d(4),
nn.ELU(),
nn.Conv3d(4, 2, 1),
)
self.adaptive = nn.Sequential(
nn.Linear(1024, sum(self.decoder.slices))
)
def encode(self, x):
b = x.shape[0]
x = self.grid_encoder(x)
z = self.encoder(x).view(b, -1)
return z
def generate_points(self, w, n_points=5000, regular_sampling=True):
b = w.shape[0]
x_rec = self.decoder(w)
est = self.density_estimator(x_rec)
dens = fn.relu(est[:, 0])
dens_cls = est[:, 1].unsqueeze(1)
dens = dens.view(b, -1)
dens_s = dens.sum(-1).unsqueeze(1)
mask = dens_s < 1e-12
ones = torch.ones_like(dens_s)
dens_s[mask] = ones[mask]
dens = dens / dens_s
dens = dens.view(b, 1, self.grid_size, self.grid_size, self.grid_size)
if self.filled_cls:
filled = torch.sigmoid(dens_cls).round()
dens_ = filled * dens
for i in range(b):
if dens_[i].sum().item() < 1e-12:
dens_[i] = dens[i]
else:
dens_ = dens
if regular_sampling:
cloud, reg = self.generator.forward_fixed_pattern(x_rec, dens_, n_points)
else:
cloud, reg = self.generator(x_rec, dens_, n_points)
return cloud, dens, dens_cls.squeeze(), reg
def decode(self, z, n_points=5000, regular_sampling=True):
b = z.shape[0]
w = self.adaptive(z.view(b, -1))
return self.generate_points(w, n_points, regular_sampling)
def forward(self, x, n_points=5000, regular_sampling=True):
z = self.encode(x)
return self.decode(z, n_points, regular_sampling)
File added
import argparse
from collections import OrderedDict
import numpy as np
import torch
from tqdm import tqdm
import data
import models
from dists import chamfer
def undo_normalize(points, mean, scale):
res = points / scale.unsqueeze(1).unsqueeze(1)
res = res + mean.unsqueeze(2).expand_as(points)
return res
def normalize_unit_cube(points):
bb_max = points.max(-1)[0]
bb_min = points.min(-1)[0]
length = (bb_max - bb_min).max()
mean = (bb_max + bb_min) / 2.0
scale = 1.0 / length
res = (points - mean.unsqueeze(1)) * scale
return res.clamp(-0.5, 0.5), mean, scale