adm18/IMPAX/nni/models.py
2025-09-16 13:20:19 +08:00

27 lines
697 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, hidden_layer, hidden_size):
super(Net, self).__init__()
self.name = 'model-%d-%d'%(hidden_layer, hidden_size)
self.hidden_layer = hidden_layer
self.conv_a = nn.Conv2d(1 , hidden_size, 3, padding=1)
self.conv_x = nn.Conv2d(hidden_size, hidden_size, 3, padding=1)
self.conv_z = nn.Conv2d(hidden_size, 2 , 3, padding=1)
def forward(self, x):
x = F.relu(self.conv_a(x))
for i in range(self.hidden_layer):
x = F.relu(self.conv_x(x))
x = self.conv_z(x)
return x