Skip to content

Instantly share code, notes, and snippets.

@THEFASHIONGEEK
Created February 20, 2020 07:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save THEFASHIONGEEK/b9f2a85443d2911b4f03ba3e9cf09212 to your computer and use it in GitHub Desktop.
Save THEFASHIONGEEK/b9f2a85443d2911b4f03ba3e9cf09212 to your computer and use it in GitHub Desktop.
gtf = Detector()
gtf.Train_Dataset(root_dir, coco_dir, img_dir, set_dir, batch_size=16, use_gpu=True)
gtf.Model(model_name="resnet50", gpu_devices=[0, 1, 2, 3])
gtf.Set_Hyperparams(lr=0.0001, val_interval=1, print_interval=20)
gtf.Train(num_epochs=10, output_model_name="final_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment