diff --git a/README.md b/README.md index 036d79f1ae1670f5727032ee25823462ff7997c2..43b82096f376abc4648d96a3cf72c1e6c7e7fe4b 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Precomputed distances are optional and only required for evaluation purposes python train.py --mode lstm - python train.py --mode linear + python train.py --mode linear --epochs 4000 Model checkpoints and predictions on the validation and test sets are saved to `out/`. diff --git a/train.py b/train.py index 0eca2a6d3f87b54721f20ee1d225643e6053672f..341c3550eb28d934664e7ebe09ca8cc1d5b3b6ee 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,8 @@ class SpiralLayer(torch.nn.Module): super(SpiralLayer, self).__init__() if linear: self.layer = nn.Linear(in_features * seq_length, out_features) + torch.nn.init.xavier_uniform_(self.layer.weight, gain=1) + torch.nn.init.constant_(self.layer.bias, 0) else: self.layer = nn.LSTM(in_features, out_features, batch_first=True) self.linear = linear @@ -48,6 +50,14 @@ class SpiralNet(torch.nn.Module): self.fc2 = nn.Linear(outputs[2], 256) self.fc3 = nn.Linear(256, n_classes) + if linear: + torch.nn.init.xavier_uniform_(self.fc1.weight, gain=1) + torch.nn.init.xavier_uniform_(self.fc2.weight, gain=1) + torch.nn.init.xavier_uniform_(self.fc3.weight, gain=1) + torch.nn.init.constant_(self.fc1.bias, 0) + torch.nn.init.constant_(self.fc2.bias, 0) + torch.nn.init.constant_(self.fc3.bias, 0) + def forward(self, x, indices): x = F.dropout(x, p=0.3, training=self.training) x = F.relu(self.fc1(x))