You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class GhostNet(nn.Module):
def init(self, pretrained=True):
super(GhostNet, self).init()
model = ghostnet()
if pretrained:
state_dict = torch.load("model_data/ghostnet_weights.pth")
model.load_state_dict(state_dict)
del model.global_pool
del model.conv_head
del model.act2
del model.classifier
del model.blocks[9]
self.model = model
def forward(self, x):
x = self.model.conv_stem(x)
x = self.model.bn1(x)
x = self.model.act1(x)
feature_maps = []
for idx, block in enumerate(self.model.blocks):
x = block(x)
if idx in [2,4,6,8]:
feature_maps.append(x)
return feature_maps[1:]
class GhostNet(nn.Module):
def init(self, pretrained=True):
super(GhostNet, self).init()
model = ghostnet()
if pretrained:
state_dict = torch.load("model_data/ghostnet_weights.pth")
model.load_state_dict(state_dict)
del model.global_pool
del model.conv_head
del model.act2
del model.classifier
del model.blocks[9]
self.model = model
需要加载 model_data/ghostnet_weights.pth,发现模型不匹配,我的ghostnet分类模型的特征提取网络是需要重新训练吗
The text was updated successfully, but these errors were encountered: