Skip to content

Instantly share code, notes, and snippets.

@woctezuma
Forked from m-klasen/detr_finetune.md
Created August 19, 2020 15:00
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woctezuma/e9f8f9fe1737987351582e9441c46b5d to your computer and use it in GitHub Desktop.
Save woctezuma/e9f8f9fe1737987351582e9441c46b5d to your computer and use it in GitHub Desktop.

Get pretrained weights:

wget https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth

Remove class weights

checkpoint = torch.load("detr-r50-e632da11.pth", map_location='cpu')
del checkpoint["model"]["class_embed.weight"]
del checkpoint["model"]["class_embed.bias"]
torch.save(checkpoint,"detr-r50_no-class-head.pth")

and make sure to set non-strict weight loading in main.py

model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

Your dataset should ideally be in the COCO-format. Make your own data-builder (alternatively rename your train/valid/annotation file to match the COCO Dataset) In datasets.coco.py add:

def build_your_dataset(image_set, args):
    root = Path(args.coco_path)
    assert root.exists(), f'provided COCO path {root} does not exist'
    mode = 'instances'
    PATHS = {
        "train": (root / "train", root / "annotations" / f'train.json'),
        "val": (root / "valid", root / "annotations" / f'valid.json'),
    }

    img_folder, ann_file = PATHS[image_set]
    dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
    return dataset

In datasets.__init__.py add your builder as an option:

def build_dataset(image_set, args):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args)
    if args.dataset_file == 'your_dataset':
        return build_your_dataset(image_set, args)
    [...]

And lastly define how many classes you have in models.detr.py

def build(args):
    [...]
    if args.dataset_file == 'your_dataset': num_classes = 4
    [...]

Run your model (example): python main.py --dataset_file your_dataset --coco_path data --epochs 50 --lr=1e-4 --batch_size=2 --num_workers=4 --output_dir="outputs" --resume="detr-r50_no-class-head.pth"

@parthkvv
Copy link

parthkvv commented Oct 17, 2022

Hi, thanks a lot for this. I followed these steps and was able to begin training. However, I am facing a problem with loading the trained model for inference.

Following this, I wrote this code for loading my saved model and preparing it for inference on an image (most of my code is based on that notebook)

detr = DETRdemo(num_classes=13)
state_dict = torch.load("detr-main\\output\\checkpoint.pth", map_location='cpu')
detr.load_state_dict(state_dict['model'], strict=False)
detr.eval();

Using the pre-trained model (given method), I successfully obtained the bounding boxes as well as class predictions on a sample image

detr = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(
    url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
    map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
detr.eval();

But could not do so with trained checkpoints.
Moreover, setting strict=True in the former code yields the following error:

RuntimeError: Error(s) in loading state_dict for DETRdemo:
Missing key(s) in state_dict:

I have tried other methods from here and here, but could not find success.

Any advice would be helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment