Created
May 19, 2021 02:25
-
-
Save rayepeng/8f1d3a5936c4563cab58c2a8d4d4c889 to your computer and use it in GitHub Desktop.
残差网络图像分类
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_img_name(img_dir, format="jpg"): | |
""" | |
获取文件夹下format格式的文件名 | |
:param img_dir: str | |
:param format: str | |
:return: list | |
""" | |
file_names = os.listdir(img_dir) # 列出当前目录下所有文件 | |
# 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件 | |
img_names = list(filter(lambda x: x.endswith(format), file_names)) | |
if len(img_names) < 1: | |
raise ValueError("{}下找不到{}格式数据".format(img_dir, format)) | |
return img_names | |
inference_transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(norm_mean, norm_std), | |
]) | |
def img_transform(img_rgb, transform=None): | |
""" | |
将数据转换为模型读取的形式 | |
:param img_rgb: PIL Image | |
:param transform: torchvision.transform | |
:return: tensor | |
""" | |
if transform is None: | |
raise ValueError("找不到transform!必须有transform对img进行处理") | |
img_t = transform(img_rgb) | |
return img_t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision.models as models | |
# torchvision.models 中封装了很多模型 | |
# 加载残差网络 | |
def get_model(m_path, vis_model=False): | |
resnet18 = models.resnet18() | |
# 修改全连接层的输出 | |
num_ftrs = resnet18.fc.in_features | |
resnet18.fc = nn.Linear(num_ftrs, 2) | |
# 加载模型参数 | |
checkpoint = torch.load(m_path) | |
resnet18.load_state_dict(checkpoint['model_state_dict']) | |
if vis_model: | |
from torchsummary import summary | |
summary(resnet18, input_size=(3, 224, 224), device="cpu") | |
return resnet18 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment