Skip to content

Instantly share code, notes, and snippets.

@rayepeng
Created May 19, 2021 02:25
Show Gist options
  • Save rayepeng/8f1d3a5936c4563cab58c2a8d4d4c889 to your computer and use it in GitHub Desktop.
Save rayepeng/8f1d3a5936c4563cab58c2a8d4d4c889 to your computer and use it in GitHub Desktop.
残差网络图像分类
def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir) # 列出当前目录下所有文件
# 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件
img_names = list(filter(lambda x: x.endswith(format), file_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
def img_transform(img_rgb, transform=None):
"""
将数据转换为模型读取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""
if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")
img_t = transform(img_rgb)
return img_t
import torchvision.models as models
# torchvision.models 中封装了很多模型
# 加载残差网络
def get_model(m_path, vis_model=False):
resnet18 = models.resnet18()
# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)
# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
if vis_model:
from torchsummary import summary
summary(resnet18, input_size=(3, 224, 224), device="cpu")
return resnet18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment