Skip to content

Instantly share code, notes, and snippets.

@piercelamb
Created December 20, 2022 17:59
Show Gist options
  • Save piercelamb/8923bde24ae2513d930d30187985baaa to your computer and use it in GitHub Desktop.
Save piercelamb/8923bde24ae2513d930d30187985baaa to your computer and use it in GitHub Desktop.
collect_inference
if config.is_comparison:
comparison_stats = init_comparison_stats(id2label, config)
for i, raw_instance in enumerate(test_data):
print("Testing Artifact: "+str(i+1))
actual_label_id = str(raw_instance['labels'].item())
ground_truth_label = id2label[actual_label_id]
print("\n--------------------------------------------")
print("Ground Truth Label: " + ground_truth_label)
if config.is_comparison:
comparison_stats[ground_truth_label]['count'] += 1
predictions = {}
for model_name, model_data in model_containers.items():
if config.is_comparison:
raw_instance = get_model_specific_batch(raw_instance, model_name)
instance = {}
for input_key, input_value in raw_instance.items():
if input_key != 'labels':
instance[input_key] = input_value.unsqueeze(0)
with torch.no_grad():
loaded_model = model_data['loaded_model']
prediction = loaded_model(**instance)
result = prediction['logits'].detach().cpu().numpy().argmax(-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment