mhyatt000 commited on
Commit
0f67d39
1 Parent(s): ee8c70c
Files changed (2) hide show
  1. badnet.py +39 -0
  2. temp +7 -0
badnet.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+
4
+ class BadNet(nn.Module):
5
+
6
+ # def __init__(self, input_channels, output_num):
7
+ def __init__(self, 3072,10):
8
+ super().__init__()
9
+
10
+ self.conv1 = nn.Sequential(
11
+ nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
12
+ nn.ReLU(),
13
+ nn.AvgPool2d(kernel_size=2, stride=2)
14
+ )
15
+
16
+ self.conv2 = nn.Sequential(
17
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
18
+ nn.ReLU(),
19
+ nn.AvgPool2d(kernel_size=2, stride=2)
20
+ )
21
+ fc1_input_features = 800 if input_channels == 3 else 512
22
+ self.fc1 = nn.Sequential(
23
+ nn.Linear(in_features=fc1_input_features, out_features=512),
24
+ nn.ReLU()
25
+ )
26
+ self.fc2 = nn.Sequential(
27
+ nn.Linear(in_features=512, out_features=output_num),
28
+ nn.Softmax(dim=-1)
29
+ )
30
+ self.dropout = nn.Dropout(p=.5)
31
+
32
+ def forward(self, x):
33
+ x = self.conv1(x)
34
+ x = self.conv2(x)
35
+
36
+ x = x.view(x.size(0), -1)
37
+ x = self.fc1(x)
38
+ x = self.fc2(x)
39
+ return x
temp ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model = load_model(
2
+ basic_model_path,
3
+ model_type="badnet",
4
+ input_channels=train_data_loader.dataset.channels,
5
+ output_num=train_data_loader.dataset.class_num,
6
+ device=device,
7
+ )