Created
March 25, 2024 17:58
-
-
Save mikelgg93/6fdf361795f5e40084faf4b346363ebb to your computer and use it in GitHub Desktop.
Get the AOI fixated as a column to your fixations and gaze csv files.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import logging | |
import os | |
import tempfile | |
import zipfile | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import requests | |
API_URL = "https://api.cloud.pupil-labs.com/v2" | |
log = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def decode_img(img_str: str) -> np.array: | |
img_bytes = base64.b64decode(img_str) | |
img_arr = np.frombuffer(img_bytes, dtype=np.uint8) | |
img = cv2.imdecode(img_arr, cv2.IMREAD_GRAYSCALE) | |
return img | |
def get_scale_factor(original_heigh, original_width, max_size=1024) -> np.array: | |
if max(original_heigh, original_width) > max_size: | |
return ( | |
max_size / original_heigh | |
if original_heigh > original_width | |
else max_size / original_width | |
) | |
def get_aois( | |
workspace_id: str, project_id: str, enrichment_id: str, api_key: str | |
) -> dict: | |
logging.info("Getting AOIs masks from Cloud.") | |
response = requests.get( | |
f"{API_URL}/workspaces/{workspace_id}/projects/{project_id}/enrichments/{enrichment_id}/aois", | |
headers={"api-key": api_key}, | |
) | |
if response.status_code == 200: | |
aois = response.json().get("result", []) | |
return aois | |
else: | |
logging.error(response.text) | |
return None | |
def get_data( | |
workspace_id: str, | |
project_id: str, | |
enrichment_id: str, | |
api_key: str, | |
saving_path: str, | |
chunk_size: int = 128, | |
) -> None: | |
logging.info("Getting fixations data from Cloud.") | |
response = requests.get( | |
f"{API_URL}/workspaces/{workspace_id}/projects/{project_id}/enrichments/{enrichment_id}/export", | |
headers={"api-key": api_key}, | |
) | |
if response.status_code == 200: | |
with tempfile.TemporaryDirectory( | |
dir=os.path.expanduser(saving_path) | |
) as tmp_dir: | |
zip_path = os.path.join(tmp_dir, "file.zip") | |
with open(zip_path, "wb") as tmp_zip_file: | |
for chunk in response.iter_content(chunk_size=chunk_size): | |
tmp_zip_file.write(chunk) | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extract("fixations.csv", path=tmp_dir) | |
zip_ref.extract("gaze.csv", path=tmp_dir) | |
zip_ref.extract("reference_image.jpeg", path=tmp_dir) | |
ref_img = cv2.imread(os.path.join(tmp_dir, "reference_image.jpeg")) | |
height, width = ref_img.shape[:2] | |
return ( | |
pd.read_csv(os.path.join(tmp_dir, "fixations.csv")), | |
pd.read_csv(os.path.join(tmp_dir, "gaze.csv")), | |
height, | |
width, | |
) | |
else: | |
logging.error(response.text) | |
return None | |
def find_aoi_save(df: pd.DataFrame, x_name: str, y_name: str, filename: str) -> None: | |
for index, row in df.iterrows(): | |
if pd.notna(row[x_name]) and pd.notna(row[y_name]): | |
x, y = ( | |
int(row[x_name] * scale_factor), | |
int(row[y_name] * scale_factor), | |
) | |
fixated_aois = [] | |
for mask, aoi_name in zip(aoi_masks, aoi_names): | |
if 0 <= x < mask.shape[0] and 0 <= y < mask.shape[1]: | |
if mask[y, x]: | |
fixated_aois.append(aoi_name) | |
logging.info(f"Fixation/gaze {index} at {x,y}: {fixated_aois}") | |
df.at[index, "fixated_aoi"] = ( | |
", ".join(fixated_aois) if fixated_aois else "None" | |
) | |
else: | |
df.at[index, "fixated_aoi"] = [] | |
df.to_csv(os.path.join(download_path, filename), index=False) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--enrichment_url", | |
type=str, | |
default=None, | |
help="The URL of your enrichment", | |
) | |
parser.add_argument( | |
"--download_path", | |
type=str, | |
default=Path("~/Desktop/"), | |
help="Where to download the results download_path", | |
) | |
parser.add_argument( | |
"--API_KEY", | |
type=str, | |
default=None, | |
help="The API key to use", | |
) | |
args = parser.parse_args() | |
API_KEY = args.API_KEY | |
if API_KEY is None: | |
raise Exception("--API_KEY is None. Please provide an API_KEY") | |
if args.enrichment_url is None: | |
raise Exception("--enrichment_url is None. Please provide a enrichment url") | |
download_path = Path(args.download_path) | |
if download_path is None: | |
raise Exception("--download_path is None. Please provide a download_path") | |
(_, workspace_id, project_id, enrichment_id) = ( | |
args.enrichment_url.split("/")[2], | |
args.enrichment_url.split("/")[4], | |
args.enrichment_url.split("/")[6], | |
args.enrichment_url.split("/")[8], | |
) | |
logging.info("") | |
fixation_df, gaze_df, height, width = get_data( | |
workspace_id=workspace_id, | |
project_id=project_id, | |
enrichment_id=enrichment_id, | |
saving_path=download_path, | |
api_key=API_KEY, | |
) | |
scale_factor = get_scale_factor(height, width) | |
aois = get_aois( | |
workspace_id=workspace_id, | |
project_id=project_id, | |
enrichment_id=enrichment_id, | |
api_key=API_KEY, | |
) | |
aoi_names = [aoi["name"] for aoi in aois] | |
aoi_masks = [] | |
for aoi in aois: | |
image_data = decode_img(aoi["mask_image_data_url"].split(",")[1]) | |
binary_mask = image_data > 0 | |
aoi_masks.append(binary_mask) | |
fixation_df["fixated_aoi"] = None | |
gaze_df["gazed_aoi"] = None | |
logging.info("Checking if fixation point is on any AOI...") | |
find_aoi_save( | |
df=fixation_df, | |
x_name="fixation x [px]", | |
y_name="fixation y [px]", | |
filename="fixations_on_aois.csv", | |
) | |
logging.info("Checking if gaze point is on any AOI...") | |
find_aoi_save( | |
df=gaze_df, | |
x_name="gaze position in reference image x [px]", | |
y_name="gaze position in reference image y [px]", | |
filename="gaze_on_aois.csv", | |
) | |
logging.info(f"Done!, you can find it at {download_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
certifi==2024.2.2 | |
charset-normalizer==3.3.2 | |
idna==3.6 | |
numpy==1.26.4 | |
opencv-python-headless==4.9.0.80 | |
pandas==2.2.1 | |
pip==24.0 | |
python-dateutil==2.9.0.post0 | |
pytz==2024.1 | |
requests==2.31.0 | |
setuptools==69.2.0 | |
six==1.16.0 | |
tzdata==2024.1 | |
urllib3==2.2.1 | |
wheel==0.43.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment