Last active
February 2, 2023 11:55
-
-
Save mikelgg93/a250811d59885e791cbeeb99fd12ef55 to your computer and use it in GitHub Desktop.
Code for https://docs.pupil-labs.com/alpha-lab/gaze-metrics-in-aois/ in Vanilla Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import platform | |
import cv2 # For selecting AOIs | |
import numpy as np # For numerical operations | |
import pandas as pd # For data manipulation | |
pd.set_option("mode.chained_assignment", None) | |
import matplotlib as mpl # For plotting | |
import matplotlib.pyplot as plt | |
from matplotlib import patches | |
import seaborn as sns | |
import tkinter as tk # For GUI | |
from tkinter import filedialog | |
import logging # For logging | |
import argparse # For parsing arguments | |
from enum import Enum # For enumerating the metrics | |
sns.set_context("paper") # Set the context of the plots in seaborn | |
# Preparing the logger | |
logging.getLogger("defineAOIs") | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
# What metrics are available? | |
class MetricClasses(Enum): | |
all = 0 | |
hit_rate = 1 | |
first_contact = 2 | |
dwell_time = 3 | |
class DataClasses(Enum): | |
fixations = 0 | |
gaze = 1 | |
def get_path(): | |
root = tk.Tk() | |
root.withdraw() | |
msg = "Select the directory" | |
arguments = {"title": msg} | |
# if platform.system() == "Darwin": | |
# arguments["message"] = msg | |
path = filedialog.askdirectory(**arguments) | |
# check if the folder contains the required files | |
if ( | |
not os.path.exists(os.path.join(path, "fixations.csv")) | |
or not os.path.exists(os.path.join(path, "gaze.csv")) | |
or not os.path.exists(os.path.join(path, "sections.csv")) | |
or not os.path.exists(os.path.join(path, "reference_image.jpeg")) | |
): | |
error = f"The selected folder does not contain a reference_image.jpeg, fixations.csv, gaze.csv or sections.csv files" | |
logging.error(error) | |
raise SystemExit(error) | |
return path | |
def main(): | |
parser = argparse.ArgumentParser(description="Pupil Labs - AOI Annotation") | |
parser.add_argument("--metric", default=MetricClasses.all, type=str) | |
parser.add_argument("--input_path", default=None, type=str) | |
parser.add_argument("--output_path", default=None, type=str) | |
parser.add_argument("--aois", default=None, type=str) | |
parser.add_argument("--start", default="recording.begin", type=str) | |
parser.add_argument("--end", default="recording.end", type=str) | |
parser.add_argument("--type", default=DataClasses.fixations, type=str) | |
parser.add_argument("-s", "--scatter", action="store_true") | |
parser.set_defaults(scatter=False) | |
args = parser.parse_args() | |
# Report selected arguments | |
logging.info("args: %s", args) | |
if isinstance(args.metric, str): | |
args.metric = MetricClasses[args.metric] | |
logging.info("metric: %s", args.metric) | |
# check if is string | |
if isinstance(args.type, str): | |
args.type = DataClasses[args.type] | |
logging.info("type: %s", args.type) | |
# If the reference image folder path is not provided or does not exist, ask the user to select one | |
if args.input_path is None or not os.path.exists(args.input_path): | |
args.input_path = get_path() | |
# If the output path is not provided or does not exist, use the input path | |
if args.output_path is None or not os.path.exists(args.output_path): | |
args.output_path = args.input_path | |
logging.info("Input path: %s", args.input_path) | |
# Load the reference image | |
reference_image_bgr = cv2.imread( | |
os.path.join(args.input_path, "reference_image.jpeg") | |
) | |
# Convert the image to BGR for OpenCV | |
reference_image = cv2.cvtColor(reference_image_bgr, cv2.COLOR_BGR2RGB) | |
# If the AOIs are not provided, ask the user to select them | |
if args.aois != None: | |
# Load the AOIs from a csv file | |
aois = pd.read_csv(args.aois) | |
elif args.aois == None: | |
# if there is an aois.csv file in the input path, use it | |
if os.path.exists(os.path.join(args.input_path, "aoi_ids.csv")): | |
logging.info("AOIs already defined") | |
args.aois = os.path.join(args.input_path, "aoi_ids.csv") | |
aois = pd.read_csv(args.aois) | |
else: | |
# Resize the image before labelling AOIs makes the image stay in the screen boundaries | |
scaling_factor = 0.25 | |
scaled_image = reference_image_bgr.copy() | |
scaled_image = cv2.resize( | |
scaled_image, dsize=None, fx=scaling_factor, fy=scaling_factor | |
) | |
# mark the AOIs | |
scaled_aois = cv2.selectROIs("AOI Annotation", scaled_image) | |
cv2.destroyAllWindows() | |
# scale back the position of AOIs | |
aois = scaled_aois / scaling_factor | |
# save the AOIs to a pandas DataFrame | |
aois = pd.DataFrame(aois, columns=["x", "y", "width", "height"]) | |
# Save the AOIs to a csv file | |
aois.to_csv(args.output_path + "/aoi_ids.csv", index=False) | |
logging.info("Areas of interest:") | |
logging.info(aois) | |
# Plot the AOIs | |
plot_color_patches(reference_image, aois, pd.Series(aois.index), plt.gca(), args) | |
# Load the sections file and the fixations file onto pandas DataFrames | |
logging.info("Loading files ...") | |
sections_df = pd.read_csv(args.input_path + "/sections.csv") | |
logging.info(sections_df["start event name"].unique()) | |
logging.info(sections_df["end event name"].unique()) | |
fixations_df = pd.read_csv(args.input_path + "/fixations.csv") | |
logging.info("A total of %d fixations were found", len(fixations_df)) | |
gaze_df = pd.read_csv(args.input_path + "/gaze.csv") | |
logging.info("A total of %d gaze points were found", len(gaze_df)) | |
# Make data fixations or gaze, depending on the selected type | |
data_df = fixations_df if args.type == DataClasses.fixations else gaze_df | |
field_detected = ( | |
"fixation detected in reference image" | |
if args.type == DataClasses.fixations | |
else "gaze detected in reference image" | |
) | |
# filter for fixations that are in the reference image and check which AOI they are in | |
data = data_df[data_df[field_detected]] | |
for row in aois.itertuples(): | |
data_in_aoi = data.copy() | |
data_in_aoi = data.loc[ | |
check_in_rect(data, [row.x, row.y, row.width, row.height], args) | |
] | |
data.loc[data_in_aoi.index, "AOI"] = row.Index | |
logging.info(f"A total of %d {args.type} points were detected in AOIs", len(data)) | |
# AOIs that have never been gazed at do not show up in the fixations data | |
# so we need to set them to 0 manually | |
hits = data.groupby(["recording id", "AOI"]).size() > 0 | |
hit_rate = hits.groupby("AOI").sum() / data["recording id"].nunique() * 100 | |
for aoi_id in range(len(aois)): | |
if not aoi_id in hit_rate.index: | |
hit_rate.loc[aoi_id] = 0 | |
hit_rate.sort_index(inplace=True) | |
logging.info("Hit rate per AOI:") | |
logging.info(hit_rate.head()) | |
# Compute the time difference for the respective section | |
sections_df.set_index("section id", inplace=True) | |
for section_id, start_time in sections_df["section start time [ns]"].iteritems(): | |
data_indices = data.loc[data["section id"] == section_id].index | |
logging.info( | |
"The section {} starts at {} and has {} points".format( | |
section_id, | |
start_time, | |
len(data_indices), | |
) | |
) | |
field_ts = ( | |
"start timestamp [ns]" | |
if args.type == DataClasses.fixations | |
else "timestamp [ns]" | |
) | |
data.loc[data_indices, "aligned timestamp [s]"] = ( | |
data.loc[data_indices, field_ts] - start_time | |
) / 1e9 | |
first_contact = data.groupby(["section id", "AOI"])["aligned timestamp [s]"].min() | |
first_contact = first_contact.groupby("AOI").mean() | |
logging.info(first_contact) | |
# Compute the dwell time for the respective AOI | |
if args.type == DataClasses.fixations: | |
dwell_time = data.groupby(["recording id", "AOI"])["duration [ms]"].sum() | |
dwell_time = dwell_time.groupby("AOI").mean() | |
dwell_time /= 1000 | |
logging.info(dwell_time.head()) | |
# Plot the results | |
if args.metric == MetricClasses.hit_rate or args.metric == MetricClasses.all: | |
fig, ax = plt.subplots(1, 2, figsize=(18, 6)) | |
sns.barplot(x=hit_rate.index.to_numpy(dtype=np.int64), y=hit_rate, ax=ax[0]) | |
ax[0].set_xlabel("AOI ID") | |
ax[0].set_ylabel("Hit Rate [%]") | |
plot_color_patches( | |
reference_image, | |
aois, | |
hit_rate, | |
ax[1], | |
args, | |
colorbar=True, | |
unit_label="Hit Rate [%]", | |
data=data, | |
) | |
fig.suptitle(f"Hit Rate - {args.type}") | |
if args.metric == MetricClasses.first_contact or args.metric == MetricClasses.all: | |
fig, ax = plt.subplots(1, 2, figsize=(18, 6)) | |
sns.barplot( | |
x=first_contact.index.to_numpy(dtype=np.int64), y=first_contact, ax=ax[0] | |
) | |
ax[0].set_xlabel("AOI ID") | |
ax[0].set_ylabel("Time to first contact [s]") | |
plot_color_patches( | |
reference_image, | |
aois, | |
first_contact, | |
ax[1], | |
args, | |
colorbar=True, | |
unit_label="Time to first contact [s]", | |
data=data, | |
) | |
fig.suptitle(f"First Contact - {args.type}") | |
if args.type == DataClasses.fixations: | |
if args.metric == MetricClasses.dwell_time or args.metric == MetricClasses.all: | |
fig, ax = plt.subplots(1, 2, figsize=(18, 6)) | |
sns.barplot( | |
x=dwell_time.index.to_numpy(dtype=np.int64), y=dwell_time, ax=ax[0] | |
) | |
ax[0].set_xlabel("AOI ID") | |
ax[0].set_ylabel("Dwell Time [s]") | |
plot_color_patches( | |
reference_image, | |
aois, | |
dwell_time, | |
ax[1], | |
args, | |
colorbar=True, | |
unit_label="Dwell Time [s]", | |
data=data, | |
) | |
fig.suptitle(f"Dwell Time - {args.type}") | |
else: | |
logging.info("Dwell time is only available for fixations data") | |
pd.reset_option("mode.chained_assignment") | |
logging.info("Done") | |
def check_in_rect(data, rectangle_coordinates, args): | |
rect_x, rect_y, rect_width, rect_height = rectangle_coordinates | |
if args.type == DataClasses.fixations: | |
fieldx = "fixation x [px]" | |
fieldy = "fixation y [px]" | |
elif args.type == DataClasses.gaze: | |
fieldx = "gaze position in reference image x [px]" | |
fieldy = "gaze position in reference image y [px]" | |
x_hit = data[fieldx].between(rect_x, rect_x + rect_width) | |
y_hit = data[fieldy].between(rect_y, rect_y + rect_height) | |
return x_hit & y_hit | |
def plot_color_patches( | |
image, | |
aoi_positions, | |
values, | |
ax, | |
args, | |
alpha=0.3, | |
colorbar=False, | |
unit_label="", | |
data=None, | |
): | |
ax.imshow(image) | |
ax.axis("off") | |
# normalize patch values | |
values_normed = values.astype(np.float32) | |
values_normed -= values_normed.min() | |
values_normed /= values_normed.max() | |
colors = mpl.cm.get_cmap("viridis") | |
# for patch_idx, (aoi, value) in enumerate(zip(patch_position, patch_values_normed)): | |
for aoi_id, aoi_val in values_normed.iteritems(): | |
aoi_id = int(aoi_id) | |
aoi = [ | |
aoi_positions.x[aoi_id], | |
aoi_positions.y[aoi_id], | |
aoi_positions.width[aoi_id], | |
aoi_positions.height[aoi_id], | |
] | |
ax.add_patch( | |
patches.Rectangle( | |
aoi, | |
*aoi[2:], | |
alpha=alpha, | |
facecolor=colors(aoi_val), | |
edgecolor=colors(aoi_val), | |
linewidth=5, | |
) | |
) | |
ax.text(aoi[0] + 20, aoi[1] + 120, f"{aoi_id}", color="black") | |
if colorbar: | |
norm = mpl.colors.Normalize(vmin=values.min(), vmax=values.max()) | |
cb = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=colors), ax=ax) | |
cb.set_label(unit_label) | |
if data is not None and args.scatter: | |
if args.type == DataClasses.fixations: | |
field0 = "fixation detected in reference image" | |
field1 = "fixation x [px]" | |
field2 = "fixation y [px]" | |
elif args.type == DataClasses.gaze: | |
field0 = "gaze detected in reference image" | |
field1 = "gaze position in reference image x [px]" | |
field2 = "gaze position in reference image y [px]" | |
data_in = data[data[field0] == True] | |
ax.scatter(data_in[field1], data_in[field2], s=20, color="red", alpha=0.8) | |
plt.show() | |
return ax | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment