Skip to content

Instantly share code, notes, and snippets.

View piercelamb's full-sized avatar

Pierce Lamb piercelamb

View GitHub Profile
@piercelamb
piercelamb / copy_to_s3.py
Created December 19, 2022 20:28
copy_to_s3
def copy_to_s3(
bucket: Bucket,
files_to_copy: List[str],
row: NamedTuple,
s3_artifact_path: str,
process_func: Optional[Callable[[str, List, str, bool], None]]=None,
reload: bool=False
):
guid = row.file_location
@piercelamb
piercelamb / copy_s3_data_in_parallel.py
Created December 19, 2022 20:25
copy_s3_data_in_parallel
def copy_s3_data_in_parallel(
df: pd.DataFrame,
bucket: str,
raw_training_data_paths: Dict[str, List[str]],
s3_artifact_path: str,
num_processes: int,
process_func: Optional[Callable[[str, List, str, bool], None]]=None,
reload: bool=False):
existing_artifacts_state = get_raw_data_paths(bucket, path_to_filter_for=s3_artifact_path)
split_df = np.array_split(df, num_processes)
@piercelamb
piercelamb / inference_folder_structure
Created December 20, 2022 18:04
inference_folder_structure
EXP-3333-longformer/
data/
reconciled_artifacts/
<all raw training data>
prepared_data/
<all encoded training data>
run_1/
tuning/ (or training/)
<all files emitted by training>
inference/
@piercelamb
piercelamb / process_statistics.py
Created December 20, 2022 18:03
process_statistics
def process_statistics(model_containers, config):
for model_name, model_data in model_containers.items():
run_statistics = {}
for label, counts in model_data['stats'].items():
recall_denom = counts['false_negatives'] + counts['true_positives']
precision_denom = counts['false_positives'] + counts['true_positives']
if (recall_denom > 0) and (precision_denom > 0):
precision = counts['true_positives'] / precision_denom
recall = counts['true_positives'] / recall_denom
if precision + recall > 0:
@piercelamb
piercelamb / get_multi_class_statistics.py
Created December 20, 2022 18:02
get_multi_class_statistics
def get_multi_class_stats(statistics_counts, predicted_label_id, actual_label_id, id2label):
actual_label = id2label[actual_label_id]
statistics_counts[actual_label]['total'] += 1
predicted_label = id2label[predicted_label_id]
if predicted_label == actual_label:
statistics_counts[actual_label]['true_positives'] += 1
else:
# wrong prediction, the prediction is thus a false positive
# and the actual label is a false negative
statistics_counts[predicted_label]['false_positives'] += 1
@piercelamb
piercelamb / get_statistics.py
Created December 20, 2022 18:01
get_statistics
predicted_label_id = str(result[0])
if config.is_comparison:
predictions[model_name] = id2label[predicted_label_id]
print(f"{model_name} Predicted Label: {str(id2label[predicted_label_id])}")
f1_metric = model_data['metrics']['f1']
acc_metric = model_data['metrics']['acc']
f1_metric.add_batch(
predictions=[int(predicted_label_id)],
references=[int(actual_label_id)]
)
@piercelamb
piercelamb / collect_inference.py
Created December 20, 2022 17:59
collect_inference
if config.is_comparison:
comparison_stats = init_comparison_stats(id2label, config)
for i, raw_instance in enumerate(test_data):
print("Testing Artifact: "+str(i+1))
actual_label_id = str(raw_instance['labels'].item())
ground_truth_label = id2label[actual_label_id]
print("\n--------------------------------------------")
print("Ground Truth Label: " + ground_truth_label)
if config.is_comparison:
comparison_stats[ground_truth_label]['count'] += 1
@piercelamb
piercelamb / test_data.py
Created December 20, 2022 17:58
test_data
test_data = load_from_disk(SAGEMAKER_LOCAL_INFERENCE_DATA_DIR)
print(f"Running inference over {test_data.num_rows} samples")
test_data.set_format(type="torch")
@piercelamb
piercelamb / init_metrics.py
Created December 20, 2022 17:57
init_metrics
def init_metrics():
return evaluate.load("f1"), evaluate.load("accuracy")
@piercelamb
piercelamb / init_model_stats.py
Created December 20, 2022 17:56
init_model_stats
def init_model_stats(id2label):
return {
label: {
'total': 0,
'true_positives': 0,
'false_positives': 0,
'false_negatives': 0,
'accuracy': 0.0,
'f1': 0.0
}