| import torch.nn as nn | |
| from torchvision import models | |
| class MobileNetV2Classifier(nn.Module): | |
| def __init__(self, train_base=False): | |
| super().__init__() | |
| self.base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT) | |
| for param in self.base_model.features.parameters(): | |
| param.requires_grad = train_base | |
| in_features = self.base_model.classifier[1].in_features | |
| self.base_model.classifier = nn.Sequential( | |
| nn.BatchNorm1d(in_features), | |
| nn.Dropout(0.5), | |
| nn.Linear(in_features, 128), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(128), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.base_model(x) | |