Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

B导,我写改了一下ghostnet的网络模型,然后已经训练出来了。现在在循行get_map.py的时候发现有下面的问题 #45

Open
WDQGO opened this issue Nov 9, 2022 · 1 comment

Comments

@WDQGO
Copy link

WDQGO commented Nov 9, 2022

class GhostNet(nn.Module):
def init(self, pretrained=True):
super(GhostNet, self).init()
model = ghostnet()
if pretrained:
state_dict = torch.load("model_data/ghostnet_weights.pth")
model.load_state_dict(state_dict)
del model.global_pool
del model.conv_head
del model.act2
del model.classifier
del model.blocks[9]
self.model = model

def forward(self, x):
    x = self.model.conv_stem(x)
    x = self.model.bn1(x)
    x = self.model.act1(x)
    feature_maps = []

    for idx, block in enumerate(self.model.blocks):
        x = block(x)
        if idx in [2,4,6,8]:
            feature_maps.append(x)
    return feature_maps[1:]

需要加载 model_data/ghostnet_weights.pth,发现模型不匹配,我的ghostnet分类模型的特征提取网络是需要重新训练吗

@bubbliiiing
Copy link
Owner

为啥要加载他,我不理解

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants