utils.py 3.33 KB
Newer Older
Alexander Dielen's avatar
Alexander Dielen committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
import openmesh as om
import numpy as np

import os
import pickle

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
import faust_data


def _next_ring(mesh, last_ring, other):
    res = []

    def is_new_vertex(idx):
        return (idx not in last_ring and
                idx not in other and
                idx not in res)

    for vh1 in last_ring:
        vh1 = om.VertexHandle(vh1)
        # first pass: all vertices after last_ring
        after_last_ring = False
        for vh2 in mesh.vv(vh1):
            if after_last_ring:
                if is_new_vertex(vh2.idx()):
                    res.append(vh2.idx())
            if vh2.idx() in last_ring:
                after_last_ring = True
        # second pass: all vertices before last_ring
        for vh2 in mesh.vv(vh1):
            if vh2.idx() in last_ring:
                break
            if is_new_vertex(vh2.idx()):
                res.append(vh2.idx())
    return res


def extract_spirals(filename, seq_length):
    mesh = om.read_trimesh(filename)
    spirals = []
    for vh0 in mesh.vertices():
        reference_one_ring = []
        for vh1 in mesh.vv(vh0):
            reference_one_ring.append(vh1.idx())
        rotated_spirals = []
        for shift in range(len(reference_one_ring)):
            spiral = [vh0.idx()]
            one_ring = list(np.roll(reference_one_ring, -shift))
            last_ring = one_ring
            next_ring = _next_ring(mesh, last_ring, spiral)
            spiral.extend(last_ring)
            while len(spiral) + len(next_ring) < seq_length:
                last_ring = next_ring
                next_ring = _next_ring(mesh, last_ring, spiral)
                spiral.extend(last_ring)
            spiral.extend(next_ring)
            rotated_spirals.append(spiral[:seq_length])
        spirals.append(rotated_spirals)
    return spirals


def extract_and_save(idx, seq_length=None, input_dir=None, output_dir=None):
    ply_filename = input_dir + 'tr_reg_{:03}.ply'.format(idx)
    pkl_filename = output_dir + 'tr_reg_{:03}.pkl'.format(idx)
    if not os.path.isfile(pkl_filename):
        print('Computing spirals for tr_reg_{:03}.ply'.format(idx))
        spirals = extract_spirals(ply_filename, seq_length)
        pickle.dump(spirals, open(pkl_filename, 'wb'))


def save_benchmark(filename, errors):
    x = np.arange(0.0, 0.2, 0.001)
    y = np.zeros(x.shape)

    errors = np.sort(errors)

    m = 0
    n = len(errors)

    for idx, val in enumerate(x):
        while errors[m] <= val and m < n:
            m += 1
        y[idx] = float(m) / n

    references = [
        ('monet_raw.npy', 'tab:green', 'MoNet (raw)'),
        ('gcnn_symmetric.npy', 'tab:orange', 'GCNN (symmetric)'),
        ('gcnn_asymmetric.npy', 'tab:blue', 'GCNN (asymmetric)'),
    ]

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(x, y, 'tab:red', linewidth=2.2, label='our method')
    for fname, color, label in references:
        arr = np.load(faust_data.data_dir + 'references/' + fname)
        ax.plot(arr[0], arr[1], color, linewidth=2.2, label=label)
    ax.legend(loc='lower right')
    ax.set(xlabel='geodesic radius', ylabel='% correct correspondences')
    ax.grid()
    plt.axis([0, 0.2, 0, 1])
    plt.xticks(np.arange(0.0, 0.25, 0.05))
    plt.yticks(np.arange(0.0, 1.20, 0.20))
    plt.tight_layout()

    fig.savefig(filename)
    plt.close(fig)