This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
eval_image_path = os.path.join(os.getcwd(), eval_set_local_dir) | |
list_eval_image = os.listdir(eval_image_path) | |
for img_path in list_eval_image: | |
eval_image_ = os.path.join(eval_image_path, img_path) | |
print(eval_image_) | |
eval_image = Image.open(eval_image_) | |
with torch.no_grad(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
picsellia_eval_ds = experiment.get_dataset(name=eval_version_name) | |
labels_picsellia = {k: picsellia_eval_ds.get_label(k) for k in label_names} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
image_processor = AutoImageProcessor.from_pretrained(finetuned_output_dir) | |
model = AutoModelForObjectDetection.from_pretrained(finetuned_output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# send the trained model to picsellia | |
model_latest_path = os.path.join(cwd, save_dir) | |
file_list = os.listdir(model_latest_path) | |
print(file_list) | |
for files in file_list: | |
model_latest_files = os.path.join(model_latest_path, files) | |
experiment.store(files, model_latest_files) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainer.save_model(finetuned_output_dir) | |
finetuned_model_path = os.path.join(os.getcwd(), finetuned_output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# initialize trainer callback | |
picsellia_callback = detr.CustomPicselliaCallback(experiment=experiment) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=collate_fn, | |
train_dataset=train_dataset, | |
tokenizer=image_processor, | |
callbacks=[picsellia_callback], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomPicselliaCallback(TrainerCallback): | |
def __init__(self, experiment: Experiment): | |
self.experiment=experiment | |
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control, **kwargs): | |
print("Starting training") | |
def on_train_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs): | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Log hyperparameters to Picesllia | |
training_hyp_params = training_args.to_dict() | |
experiment.log("hyper-parameters", training_hyp_params, type=LogType.TABLE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Training the DETR model | |
model = AutoModelForObjectDetection.from_pretrained( | |
checkpoint, | |
id2label=id2label, | |
label2id=label2id, | |
ignore_mismatched_sizes=True, | |
) | |
training_args = TrainingArguments( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load transformer image processor for DetrModel | |
checkpoint = "facebook/detr-resnet-50" | |
image_processor = AutoImageProcessor.from_pretrained(checkpoint) | |
# converting targets to DETR format & resizing + normalization both image & target | |
encoding=image_processor(images=images_trans, annotations=targets, return_tensors="pt") |