Skip to content

Instantly share code, notes, and snippets.

View piercelamb's full-sized avatar

Pierce Lamb piercelamb

View GitHub Profile
@piercelamb
piercelamb / inference_folder_structure
Created December 20, 2022 18:04
inference_folder_structure
EXP-3333-longformer/
data/
reconciled_artifacts/
<all raw training data>
prepared_data/
<all encoded training data>
run_1/
tuning/ (or training/)
<all files emitted by training>
inference/
@piercelamb
piercelamb / process_statistics.py
Created December 20, 2022 18:03
process_statistics
def process_statistics(model_containers, config):
for model_name, model_data in model_containers.items():
run_statistics = {}
for label, counts in model_data['stats'].items():
recall_denom = counts['false_negatives'] + counts['true_positives']
precision_denom = counts['false_positives'] + counts['true_positives']
if (recall_denom > 0) and (precision_denom > 0):
precision = counts['true_positives'] / precision_denom
recall = counts['true_positives'] / recall_denom
if precision + recall > 0:
@piercelamb
piercelamb / get_multi_class_statistics.py
Created December 20, 2022 18:02
get_multi_class_statistics
def get_multi_class_stats(statistics_counts, predicted_label_id, actual_label_id, id2label):
actual_label = id2label[actual_label_id]
statistics_counts[actual_label]['total'] += 1
predicted_label = id2label[predicted_label_id]
if predicted_label == actual_label:
statistics_counts[actual_label]['true_positives'] += 1
else:
# wrong prediction, the prediction is thus a false positive
# and the actual label is a false negative
statistics_counts[predicted_label]['false_positives'] += 1
@piercelamb
piercelamb / get_statistics.py
Created December 20, 2022 18:01
get_statistics
predicted_label_id = str(result[0])
if config.is_comparison:
predictions[model_name] = id2label[predicted_label_id]
print(f"{model_name} Predicted Label: {str(id2label[predicted_label_id])}")
f1_metric = model_data['metrics']['f1']
acc_metric = model_data['metrics']['acc']
f1_metric.add_batch(
predictions=[int(predicted_label_id)],
references=[int(actual_label_id)]
)
@piercelamb
piercelamb / collect_inference.py
Created December 20, 2022 17:59
collect_inference
if config.is_comparison:
comparison_stats = init_comparison_stats(id2label, config)
for i, raw_instance in enumerate(test_data):
print("Testing Artifact: "+str(i+1))
actual_label_id = str(raw_instance['labels'].item())
ground_truth_label = id2label[actual_label_id]
print("\n--------------------------------------------")
print("Ground Truth Label: " + ground_truth_label)
if config.is_comparison:
comparison_stats[ground_truth_label]['count'] += 1
@piercelamb
piercelamb / test_data.py
Created December 20, 2022 17:58
test_data
test_data = load_from_disk(SAGEMAKER_LOCAL_INFERENCE_DATA_DIR)
print(f"Running inference over {test_data.num_rows} samples")
test_data.set_format(type="torch")
@piercelamb
piercelamb / init_metrics.py
Created December 20, 2022 17:57
init_metrics
def init_metrics():
return evaluate.load("f1"), evaluate.load("accuracy")
@piercelamb
piercelamb / init_model_stats.py
Created December 20, 2022 17:56
init_model_stats
def init_model_stats(id2label):
return {
label: {
'total': 0,
'true_positives': 0,
'false_positives': 0,
'false_negatives': 0,
'accuracy': 0.0,
'f1': 0.0
}
@piercelamb
piercelamb / get_loaded_model.py
Created December 20, 2022 17:55
get_loaded_model
def get_loaded_model(config, model_name):
model_parent_path = f"{config.s3_parent_dir}/run_{config.run_num}/"
if folder_exists(bucket, model_parent_path + "tuning"):
model_parent_path += "tuning"
elif folder_exists(bucket, model_parent_path + "training"):
model_parent_path += "training"
else:
# it's a comparison job
model_parent_path += f"{model_name}_training"
@piercelamb
piercelamb / model_containers.py
Created December 20, 2022 17:53
model_containers
model_containers = {}
for model_name in config.model_names:
loaded_model = get_loaded_model(config, model_name)
init_stats = init_model_stats(id2label)
f1_metric, acc_metric = init_metrics()
model_containers[model_name] = {
'loaded_model': loaded_model,
'stats': init_stats,
'metrics':{
'f1': f1_metric,