Skip to content

Instantly share code, notes, and snippets.

@THEFASHIONGEEK
Created April 9, 2020 16:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save THEFASHIONGEEK/3320e8f4492a99fb90329496e28a7fb8 to your computer and use it in GitHub Desktop.
Save THEFASHIONGEEK/3320e8f4492a99fb90329496e28a7fb8 to your computer and use it in GitHub Desktop.
from train_detector import Detector
gtf = Detector()
gtf.Train_Dataset(root_dir, coco_dir, img_dir, batch_size=32,image_size=300, num_workers=3)
gtf.Model(model_name="mobilenet", use_gpu=True, ngpu=1)
gtf.Set_HyperParams(lr=0.0001, momentum=0.9, weight_decay=0.0005, gamma=0.1, jaccard_threshold=0.5)
gtf.Train(epochs=10, log_iters=True, output_weights_dir="weights", saved_epoch_interval=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment