27 lines
697 B
Python
27 lines
697 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self, hidden_layer, hidden_size):
|
|
super(Net, self).__init__()
|
|
|
|
self.name = 'model-%d-%d'%(hidden_layer, hidden_size)
|
|
|
|
self.hidden_layer = hidden_layer
|
|
|
|
self.conv_a = nn.Conv2d(1 , hidden_size, 3, padding=1)
|
|
self.conv_x = nn.Conv2d(hidden_size, hidden_size, 3, padding=1)
|
|
self.conv_z = nn.Conv2d(hidden_size, 2 , 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv_a(x))
|
|
|
|
for i in range(self.hidden_layer):
|
|
x = F.relu(self.conv_x(x))
|
|
|
|
x = self.conv_z(x)
|
|
return x
|
|
|
|
|