-
-
Save mohammedkhalilia/2b682b67f33c922ffbeaeb9aa1d7f001 to your computer and use it in GitHub Desktop.
retrieve_toponyms_w_osm.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"name": "retrieve_toponyms_w_osm.ipynb", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/mohammedkhalilia/2b682b67f33c922ffbeaeb9aa1d7f001/retrieve_toponyms_w_osm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install transformers\n", | |
"!pip install datasets" | |
], | |
"metadata": { | |
"id": "R-E-k-8J07TE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import json\n", | |
"from geopy.geocoders import Nominatim\n", | |
"import pandas as pd\n", | |
"from datasets import load_dataset\n", | |
"import copy\n", | |
"import math" | |
], | |
"metadata": { | |
"id": "m82Rg85sBvqL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Download the data from Hugging Face\n", | |
"dataset = load_dataset(\"rsuwaileh/IDRISI-DA\")" | |
], | |
"metadata": { | |
"id": "8_vJk1pM_JOo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def retrieve_toponym(location_mention, k=3):\n", | |
" \"\"\"\n", | |
" This function takes the location mention text and passes it through\n", | |
" Open Street Map (OSM) to retrieve locations. We will return the top k toponyms\n", | |
"\n", | |
" Parameters:\n", | |
" location_mention (str): text with the location mention\n", | |
" k (int): number of toponyms to retrieve from OSM\n", | |
"\n", | |
" Returns:\n", | |
" toponyms (list): list of dictionaries of toponyms for the location mention\n", | |
" [{\"toponym_id\": <ID>, \"rank\": <RANK>}, ...]\n", | |
" \"\"\"\n", | |
" # Instantiate Nominatim client\n", | |
" app = Nominatim(timeout=10, user_agent=\"tutorial\")\n", | |
"\n", | |
" toponyms = list()\n", | |
"\n", | |
" # Retrieve list of toponyms from OSM\n", | |
" candidates = app.geocode(location_mention, exactly_one=False)\n", | |
"\n", | |
" if candidates is None or len(candidates) == 0:\n", | |
" return\n", | |
"\n", | |
" # Get the top three candidates\n", | |
" candidates = candidates[:k]\n", | |
"\n", | |
" for i, candidate in enumerate(candidates):\n", | |
" toponym = candidate.raw\n", | |
" if 'osm_type' not in toponym:\n", | |
" continue\n", | |
"\n", | |
" # The toponym_id is the first character of the osm_type concatenated with osm_id\n", | |
" top_id = toponym['osm_type'][0] + str(toponym['osm_id'])\n", | |
" toponym = {\"toponym_id\": top_id, \"rank\": i+1}\n", | |
" toponyms.append(toponym)\n", | |
"\n", | |
" return toponyms\n", | |
"\n", | |
"def flatten_data(data):\n", | |
" \"\"\"\n", | |
" We simplify the data by flattening the location mentions into a list, track\n", | |
" each location mention by its loc_id, which is a concatenation of tweet_id and\n", | |
" location_mention_id.\n", | |
"\n", | |
" Parameters:\n", | |
" data (list): original data from huggingface, which is a list of dictionaries\n", | |
"\n", | |
" Returns:\n", | |
" output (list): flattened data, which is a list of location mention dictionaries\n", | |
" [{\"loc_id\": <ID>, \"toponym_id\": <ID>, \"location_mention\": <TEXT>}, ...]\n", | |
" \"\"\"\n", | |
" output = []\n", | |
" for tweet in data:\n", | |
" for lm in tweet[\"location_mentions\"]:\n", | |
" output.append(\n", | |
" {\n", | |
" \"loc_id\": tweet[\"tweet_id\"]+\"_\"+str(lm[\"location_mention_id\"]),\n", | |
" \"toponym_id\": lm[\"target_toponym\"],\n", | |
" \"location_mention\": lm[\"location_mention\"]\n", | |
" }\n", | |
" )\n", | |
" return output\n", | |
"\n", | |
"def evaluate(gold, pred, rank):\n", | |
" \"\"\"\n", | |
" Mean Reciprocal Rank (MRR)\n", | |
"\n", | |
" Parameters:\n", | |
" gold (list): list of location mentions with ground truth toponym\n", | |
" pred (list): list of location mentions with predicted toponym\n", | |
" rank (int): k at which MRR is computed\n", | |
"\n", | |
" Returns:\n", | |
" mrr (float): MRR metric\n", | |
" \"\"\"\n", | |
" ground_truth = pd.DataFrame.from_dict(gold)\n", | |
" predictions = pd.DataFrame.from_dict(pred)\n", | |
"\n", | |
" # Keep toponyms which rank is less than the target rank\n", | |
" predictions = predictions[predictions['rank'] <= rank]\n", | |
"\n", | |
" hits = pd.merge(ground_truth, predictions,\n", | |
" on=[\"loc_id\", \"toponym_id\"], how=\"left\").fillna(math.inf)\n", | |
"\n", | |
" mrr = (1 / hits.groupby('loc_id')['rank'].min()).mean()\n", | |
" return round(mrr, 4)" | |
], | |
"metadata": { | |
"id": "E4KqrAOw_K5c" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"k = 10\n", | |
"labeled_lms = flatten_data(dataset[\"validation\"])\n", | |
"unlabeled_lms = copy.deepcopy(labeled_lms)\n", | |
"\n", | |
"for lm in unlabeled_lms:\n", | |
" del lm[\"toponym_id\"]" | |
], | |
"metadata": { | |
"id": "eLrGIy_a14C_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"results = list()\n", | |
"\n", | |
"# For each location mention, retrieve the top k toponyms\n", | |
"# Store the results in a list a long with the loc_id, toponym_id,\n", | |
"# rank and location_mention\n", | |
"for lm in unlabeled_lms:\n", | |
" toponyms = retrieve_toponym(lm[\"location_mention\"], k=k)\n", | |
"\n", | |
" if toponyms:\n", | |
" # Merge the toponym and location mention information so later\n", | |
" # we can compute the MRR metric\n", | |
" toponyms = [{**t, **lm} for t in toponyms]\n", | |
" results += toponyms" | |
], | |
"metadata": { | |
"id": "6N_oP_0zbFZx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Store retrieved toponyms to JSON file\n", | |
"with open(\"predictions.jsonl\", \"w\", encoding=\"utf-8\", newline=\"\") as fh:\n", | |
" for p in results:\n", | |
" fh.write(json.dumps(p, ensure_ascii=False) + \"\\n\")" | |
], | |
"metadata": { | |
"id": "frfabx9N1siV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Compute MRR@k\n", | |
"# When target toponym returned by OSM is in the top k results\n", | |
"evaluate(labeled_lms, results, 1)" | |
], | |
"metadata": { | |
"id": "tmDiwCqY34Yz" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment