Commit 7a435e5d authored by Alexander Dielen's avatar Alexander Dielen

fixed initialization for the linear model

parent 8cac1989
......@@ -42,7 +42,7 @@ Precomputed distances are optional and only required for evaluation purposes
python train.py --mode lstm
python train.py --mode linear
python train.py --mode linear --epochs 4000
Model checkpoints and predictions on the validation and test sets are saved
to `out/`.
......
......@@ -20,6 +20,8 @@ class SpiralLayer(torch.nn.Module):
super(SpiralLayer, self).__init__()
if linear:
self.layer = nn.Linear(in_features * seq_length, out_features)
torch.nn.init.xavier_uniform_(self.layer.weight, gain=1)
torch.nn.init.constant_(self.layer.bias, 0)
else:
self.layer = nn.LSTM(in_features, out_features, batch_first=True)
self.linear = linear
......@@ -48,6 +50,14 @@ class SpiralNet(torch.nn.Module):
self.fc2 = nn.Linear(outputs[2], 256)
self.fc3 = nn.Linear(256, n_classes)
if linear:
torch.nn.init.xavier_uniform_(self.fc1.weight, gain=1)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=1)
torch.nn.init.xavier_uniform_(self.fc3.weight, gain=1)
torch.nn.init.constant_(self.fc1.bias, 0)
torch.nn.init.constant_(self.fc2.bias, 0)
torch.nn.init.constant_(self.fc3.bias, 0)
def forward(self, x, indices):
x = F.dropout(x, p=0.3, training=self.training)
x = F.relu(self.fc1(x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment