Skip to content

Instantly share code, notes, and snippets.

@mikelgg93
Last active February 2, 2023 11:55
Show Gist options
  • Save mikelgg93/a250811d59885e791cbeeb99fd12ef55 to your computer and use it in GitHub Desktop.
Save mikelgg93/a250811d59885e791cbeeb99fd12ef55 to your computer and use it in GitHub Desktop.
import os
import platform
import cv2 # For selecting AOIs
import numpy as np # For numerical operations
import pandas as pd # For data manipulation
pd.set_option("mode.chained_assignment", None)
import matplotlib as mpl # For plotting
import matplotlib.pyplot as plt
from matplotlib import patches
import seaborn as sns
import tkinter as tk # For GUI
from tkinter import filedialog
import logging # For logging
import argparse # For parsing arguments
from enum import Enum # For enumerating the metrics
sns.set_context("paper") # Set the context of the plots in seaborn
# Preparing the logger
logging.getLogger("defineAOIs")
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
# What metrics are available?
class MetricClasses(Enum):
all = 0
hit_rate = 1
first_contact = 2
dwell_time = 3
class DataClasses(Enum):
fixations = 0
gaze = 1
def get_path():
root = tk.Tk()
root.withdraw()
msg = "Select the directory"
arguments = {"title": msg}
# if platform.system() == "Darwin":
# arguments["message"] = msg
path = filedialog.askdirectory(**arguments)
# check if the folder contains the required files
if (
not os.path.exists(os.path.join(path, "fixations.csv"))
or not os.path.exists(os.path.join(path, "gaze.csv"))
or not os.path.exists(os.path.join(path, "sections.csv"))
or not os.path.exists(os.path.join(path, "reference_image.jpeg"))
):
error = f"The selected folder does not contain a reference_image.jpeg, fixations.csv, gaze.csv or sections.csv files"
logging.error(error)
raise SystemExit(error)
return path
def main():
parser = argparse.ArgumentParser(description="Pupil Labs - AOI Annotation")
parser.add_argument("--metric", default=MetricClasses.all, type=str)
parser.add_argument("--input_path", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument("--aois", default=None, type=str)
parser.add_argument("--start", default="recording.begin", type=str)
parser.add_argument("--end", default="recording.end", type=str)
parser.add_argument("--type", default=DataClasses.fixations, type=str)
parser.add_argument("-s", "--scatter", action="store_true")
parser.set_defaults(scatter=False)
args = parser.parse_args()
# Report selected arguments
logging.info("args: %s", args)
if isinstance(args.metric, str):
args.metric = MetricClasses[args.metric]
logging.info("metric: %s", args.metric)
# check if is string
if isinstance(args.type, str):
args.type = DataClasses[args.type]
logging.info("type: %s", args.type)
# If the reference image folder path is not provided or does not exist, ask the user to select one
if args.input_path is None or not os.path.exists(args.input_path):
args.input_path = get_path()
# If the output path is not provided or does not exist, use the input path
if args.output_path is None or not os.path.exists(args.output_path):
args.output_path = args.input_path
logging.info("Input path: %s", args.input_path)
# Load the reference image
reference_image_bgr = cv2.imread(
os.path.join(args.input_path, "reference_image.jpeg")
)
# Convert the image to BGR for OpenCV
reference_image = cv2.cvtColor(reference_image_bgr, cv2.COLOR_BGR2RGB)
# If the AOIs are not provided, ask the user to select them
if args.aois != None:
# Load the AOIs from a csv file
aois = pd.read_csv(args.aois)
elif args.aois == None:
# if there is an aois.csv file in the input path, use it
if os.path.exists(os.path.join(args.input_path, "aoi_ids.csv")):
logging.info("AOIs already defined")
args.aois = os.path.join(args.input_path, "aoi_ids.csv")
aois = pd.read_csv(args.aois)
else:
# Resize the image before labelling AOIs makes the image stay in the screen boundaries
scaling_factor = 0.25
scaled_image = reference_image_bgr.copy()
scaled_image = cv2.resize(
scaled_image, dsize=None, fx=scaling_factor, fy=scaling_factor
)
# mark the AOIs
scaled_aois = cv2.selectROIs("AOI Annotation", scaled_image)
cv2.destroyAllWindows()
# scale back the position of AOIs
aois = scaled_aois / scaling_factor
# save the AOIs to a pandas DataFrame
aois = pd.DataFrame(aois, columns=["x", "y", "width", "height"])
# Save the AOIs to a csv file
aois.to_csv(args.output_path + "/aoi_ids.csv", index=False)
logging.info("Areas of interest:")
logging.info(aois)
# Plot the AOIs
plot_color_patches(reference_image, aois, pd.Series(aois.index), plt.gca(), args)
# Load the sections file and the fixations file onto pandas DataFrames
logging.info("Loading files ...")
sections_df = pd.read_csv(args.input_path + "/sections.csv")
logging.info(sections_df["start event name"].unique())
logging.info(sections_df["end event name"].unique())
fixations_df = pd.read_csv(args.input_path + "/fixations.csv")
logging.info("A total of %d fixations were found", len(fixations_df))
gaze_df = pd.read_csv(args.input_path + "/gaze.csv")
logging.info("A total of %d gaze points were found", len(gaze_df))
# Make data fixations or gaze, depending on the selected type
data_df = fixations_df if args.type == DataClasses.fixations else gaze_df
field_detected = (
"fixation detected in reference image"
if args.type == DataClasses.fixations
else "gaze detected in reference image"
)
# filter for fixations that are in the reference image and check which AOI they are in
data = data_df[data_df[field_detected]]
for row in aois.itertuples():
data_in_aoi = data.copy()
data_in_aoi = data.loc[
check_in_rect(data, [row.x, row.y, row.width, row.height], args)
]
data.loc[data_in_aoi.index, "AOI"] = row.Index
logging.info(f"A total of %d {args.type} points were detected in AOIs", len(data))
# AOIs that have never been gazed at do not show up in the fixations data
# so we need to set them to 0 manually
hits = data.groupby(["recording id", "AOI"]).size() > 0
hit_rate = hits.groupby("AOI").sum() / data["recording id"].nunique() * 100
for aoi_id in range(len(aois)):
if not aoi_id in hit_rate.index:
hit_rate.loc[aoi_id] = 0
hit_rate.sort_index(inplace=True)
logging.info("Hit rate per AOI:")
logging.info(hit_rate.head())
# Compute the time difference for the respective section
sections_df.set_index("section id", inplace=True)
for section_id, start_time in sections_df["section start time [ns]"].iteritems():
data_indices = data.loc[data["section id"] == section_id].index
logging.info(
"The section {} starts at {} and has {} points".format(
section_id,
start_time,
len(data_indices),
)
)
field_ts = (
"start timestamp [ns]"
if args.type == DataClasses.fixations
else "timestamp [ns]"
)
data.loc[data_indices, "aligned timestamp [s]"] = (
data.loc[data_indices, field_ts] - start_time
) / 1e9
first_contact = data.groupby(["section id", "AOI"])["aligned timestamp [s]"].min()
first_contact = first_contact.groupby("AOI").mean()
logging.info(first_contact)
# Compute the dwell time for the respective AOI
if args.type == DataClasses.fixations:
dwell_time = data.groupby(["recording id", "AOI"])["duration [ms]"].sum()
dwell_time = dwell_time.groupby("AOI").mean()
dwell_time /= 1000
logging.info(dwell_time.head())
# Plot the results
if args.metric == MetricClasses.hit_rate or args.metric == MetricClasses.all:
fig, ax = plt.subplots(1, 2, figsize=(18, 6))
sns.barplot(x=hit_rate.index.to_numpy(dtype=np.int64), y=hit_rate, ax=ax[0])
ax[0].set_xlabel("AOI ID")
ax[0].set_ylabel("Hit Rate [%]")
plot_color_patches(
reference_image,
aois,
hit_rate,
ax[1],
args,
colorbar=True,
unit_label="Hit Rate [%]",
data=data,
)
fig.suptitle(f"Hit Rate - {args.type}")
if args.metric == MetricClasses.first_contact or args.metric == MetricClasses.all:
fig, ax = plt.subplots(1, 2, figsize=(18, 6))
sns.barplot(
x=first_contact.index.to_numpy(dtype=np.int64), y=first_contact, ax=ax[0]
)
ax[0].set_xlabel("AOI ID")
ax[0].set_ylabel("Time to first contact [s]")
plot_color_patches(
reference_image,
aois,
first_contact,
ax[1],
args,
colorbar=True,
unit_label="Time to first contact [s]",
data=data,
)
fig.suptitle(f"First Contact - {args.type}")
if args.type == DataClasses.fixations:
if args.metric == MetricClasses.dwell_time or args.metric == MetricClasses.all:
fig, ax = plt.subplots(1, 2, figsize=(18, 6))
sns.barplot(
x=dwell_time.index.to_numpy(dtype=np.int64), y=dwell_time, ax=ax[0]
)
ax[0].set_xlabel("AOI ID")
ax[0].set_ylabel("Dwell Time [s]")
plot_color_patches(
reference_image,
aois,
dwell_time,
ax[1],
args,
colorbar=True,
unit_label="Dwell Time [s]",
data=data,
)
fig.suptitle(f"Dwell Time - {args.type}")
else:
logging.info("Dwell time is only available for fixations data")
pd.reset_option("mode.chained_assignment")
logging.info("Done")
def check_in_rect(data, rectangle_coordinates, args):
rect_x, rect_y, rect_width, rect_height = rectangle_coordinates
if args.type == DataClasses.fixations:
fieldx = "fixation x [px]"
fieldy = "fixation y [px]"
elif args.type == DataClasses.gaze:
fieldx = "gaze position in reference image x [px]"
fieldy = "gaze position in reference image y [px]"
x_hit = data[fieldx].between(rect_x, rect_x + rect_width)
y_hit = data[fieldy].between(rect_y, rect_y + rect_height)
return x_hit & y_hit
def plot_color_patches(
image,
aoi_positions,
values,
ax,
args,
alpha=0.3,
colorbar=False,
unit_label="",
data=None,
):
ax.imshow(image)
ax.axis("off")
# normalize patch values
values_normed = values.astype(np.float32)
values_normed -= values_normed.min()
values_normed /= values_normed.max()
colors = mpl.cm.get_cmap("viridis")
# for patch_idx, (aoi, value) in enumerate(zip(patch_position, patch_values_normed)):
for aoi_id, aoi_val in values_normed.iteritems():
aoi_id = int(aoi_id)
aoi = [
aoi_positions.x[aoi_id],
aoi_positions.y[aoi_id],
aoi_positions.width[aoi_id],
aoi_positions.height[aoi_id],
]
ax.add_patch(
patches.Rectangle(
aoi,
*aoi[2:],
alpha=alpha,
facecolor=colors(aoi_val),
edgecolor=colors(aoi_val),
linewidth=5,
)
)
ax.text(aoi[0] + 20, aoi[1] + 120, f"{aoi_id}", color="black")
if colorbar:
norm = mpl.colors.Normalize(vmin=values.min(), vmax=values.max())
cb = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=colors), ax=ax)
cb.set_label(unit_label)
if data is not None and args.scatter:
if args.type == DataClasses.fixations:
field0 = "fixation detected in reference image"
field1 = "fixation x [px]"
field2 = "fixation y [px]"
elif args.type == DataClasses.gaze:
field0 = "gaze detected in reference image"
field1 = "gaze position in reference image x [px]"
field2 = "gaze position in reference image y [px]"
data_in = data[data[field0] == True]
ax.scatter(data_in[field1], data_in[field2], s=20, color="red", alpha=0.8)
plt.show()
return ax
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment