-
-
Save simonw/281eac9c73b062c3469607ad86470eb2 to your computer and use it in GitHub Desktop.
Temporary table plugin, refs https://github.com/simonw/datasette/issues/878
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasette.database import QueryInterrupted | |
from datasette.utils import ( | |
append_querystring, | |
escape_sqlite, | |
CustomJSONEncoder, | |
to_css_class, | |
path_from_row_pks, | |
await_me_maybe, | |
is_url, | |
path_with_replaced_args, | |
path_with_removed_args, | |
) | |
from datasette.utils.asgi import Response, NotFound, Forbidden | |
from datasette.views.base import DatasetteError | |
from datasette import hookimpl | |
from asyncinject import AsyncInject, inject | |
from pprint import pformat | |
import json | |
from datasette.utils import sqlite3 | |
from datasette.plugins import pm | |
import base64 | |
import markupsafe | |
import urllib | |
import pint | |
ureg = pint.UnitRegistry() | |
LINK_WITH_LABEL = ( | |
'<a href="{base_url}{database}/{table}/{link_id}">{label}</a> <em>{id}</em>' | |
) | |
LINK_WITH_VALUE = '<a href="{base_url}{database}/{table}/{link_id}">{id}</a>' | |
class CustomJSONEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, sqlite3.Row): | |
return tuple(obj) | |
if isinstance(obj, sqlite3.Cursor): | |
return list(obj) | |
if isinstance(obj, bytes): | |
# Does it encode to utf8? | |
try: | |
return obj.decode("utf8") | |
except UnicodeDecodeError: | |
return { | |
"$base64": True, | |
"encoded": base64.b64encode(obj).decode("latin1"), | |
} | |
return repr(obj) | |
class Table(AsyncInject): | |
view_name = "table" | |
@inject | |
async def database(self, request, datasette): | |
# TODO: all that nasty hash resolving stuff can go here | |
db_name = request.url_vars["db_name"] | |
try: | |
db = datasette.databases[db_name] | |
except KeyError: | |
raise NotFound(f"Database '{db_name}' does not exist") | |
return db | |
@inject | |
async def table_and_format(self, request, database, datasette): | |
table_and_format = request.url_vars["table_and_format"] | |
# TODO: be a lot smarter here | |
if "." in table_and_format: | |
return table_and_format.split(".", 2) | |
else: | |
return table_and_format, "html" | |
@inject | |
async def main(self, request, database, table_and_format, datasette): | |
# TODO: if this is actually a canned query, dispatch to it | |
table, format = table_and_format | |
is_view = bool(await database.get_view_definition(table)) | |
table_exists = bool(await database.table_exists(table)) | |
if not is_view and not table_exists: | |
raise NotFound(f"Table not found: {table}") | |
await check_permissions( | |
datasette, | |
request, | |
[ | |
("view-table", (database.name, table)), | |
("view-database", database.name), | |
"view-instance", | |
], | |
) | |
private = not await datasette.permission_allowed( | |
None, "view-table", (database.name, table), default=True | |
) | |
pks = await database.primary_keys(table) | |
table_columns = await database.table_columns(table) | |
specified_columns = await columns_to_select(datasette, database, table, request) | |
select_specified_columns = ", ".join( | |
escape_sqlite(t) for t in specified_columns | |
) | |
select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns) | |
use_rowid = not pks and not is_view | |
if use_rowid: | |
select_specified_columns = f"rowid, {select_specified_columns}" | |
select_all_columns = f"rowid, {select_all_columns}" | |
order_by = "rowid" | |
order_by_pks = "rowid" | |
else: | |
order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) | |
order_by = order_by_pks | |
if is_view: | |
order_by = "" | |
nocount = request.args.get("_nocount") | |
nofacet = request.args.get("_nofacet") | |
if request.args.get("_shape") in ("array", "object"): | |
nocount = True | |
nofacet = True | |
# Next, a TON of SQL to build where_params and filters and suchlike | |
# skipping that and jumping straight to... | |
where_clauses = [] | |
where_clause = "" | |
if where_clauses: | |
where_clause = f"where {' and '.join(where_clauses)} " | |
from_sql = "from {table_name} {where}".format( | |
table_name=escape_sqlite(table), | |
where=("where {} ".format(" and ".join(where_clauses))) | |
if where_clauses | |
else "", | |
) | |
from_sql_params = {} | |
params = {} | |
count_sql = f"select count(*) {from_sql}" | |
sql_no_order_no_limit = ( | |
"select {select_all_columns} from {table_name} {where}".format( | |
select_all_columns=select_all_columns, | |
table_name=escape_sqlite(table), | |
where=where_clause, | |
) | |
) | |
page_size = 100 | |
offset = " offset 0" | |
sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format( | |
select_specified_columns=select_specified_columns, | |
table_name=escape_sqlite(table), | |
where=where_clause, | |
order_by=order_by, | |
page_size=page_size + 1, | |
offset=offset, | |
) | |
# Fetch rows | |
results = await database.execute(sql, params, truncate=True) | |
columns = [r[0] for r in results.description] | |
rows = list(results.rows) | |
# Fetch count | |
filtered_table_rows_count = None | |
if count_sql: | |
try: | |
count_rows = list(await database.execute(count_sql, from_sql_params)) | |
filtered_table_rows_count = count_rows[0][0] | |
except QueryInterrupted: | |
pass | |
class Filters: | |
def lookups(self): | |
return [] | |
def selections(self): | |
return [] | |
display_columns, display_rows = await self.display_columns_and_rows( | |
datasette, | |
database.name, | |
table, | |
results.description, | |
rows, | |
link_column=not is_view, | |
truncate_cells=datasette.setting("truncate_cells_html"), | |
) | |
vars = { | |
"json": { | |
# THIS STUFF is from the regular JSON | |
"database": database.name, | |
"table": table, | |
"is_view": is_view, | |
# "human_description_en": human_description_en, | |
"rows": rows[:page_size], | |
"truncated": results.truncated, | |
"filtered_table_rows_count": filtered_table_rows_count, | |
# "expanded_columns": expanded_columns, | |
# "expandable_columns": expandable_columns, | |
"columns": columns, | |
"primary_keys": pks, | |
# "units": units, | |
"query": {"sql": sql, "params": params}, | |
# "facet_results": facet_results, | |
# "suggested_facets": suggested_facets, | |
# "next": next_value and str(next_value) or None, | |
# "next_url": next_url, | |
"private": private, | |
"allow_execute_sql": await datasette.permission_allowed( | |
request.actor, "execute-sql", database, default=True | |
), | |
}, | |
"html": { | |
# ... this is the HTML special stuff | |
"table_actions": lambda: [], # table_actions, | |
# "supports_search": bool(fts_table), | |
# "search": search or "", | |
"use_rowid": use_rowid, | |
"filters": Filters(), | |
"display_columns": display_columns, | |
# "filter_columns": filter_columns, | |
"display_rows": display_rows, | |
# "facets_timed_out": facets_timed_out, | |
# "sorted_facet_results": sorted( | |
# facet_results.values(), | |
# key=lambda f: (len(f["results"]), f["name"]), | |
# reverse=True, | |
# ), | |
# "show_facet_counts": special_args.get("_facet_size") == "max", | |
# "extra_wheres_for_ui": extra_wheres_for_ui, | |
# "form_hidden_args": form_hidden_args, | |
# "is_sortable": any(c["sortable"] for c in display_columns), | |
"path_with_replaced_args": path_with_replaced_args, | |
"path_with_removed_args": path_with_removed_args, | |
"append_querystring": append_querystring, | |
"request": request, | |
# "sort": sort, | |
# "sort_desc": sort_desc, | |
"disable_sort": is_view, | |
"custom_table_templates": [ | |
f"_table-{to_css_class(database.name)}-{to_css_class(table)}.html", | |
f"_table-table-{to_css_class(database.name)}-{to_css_class(table)}.html", | |
"_table.html", | |
], | |
"metadata": {}, # metadata, | |
# "view_definition": await db.get_view_definition(table), | |
# "table_definition": await db.get_table_definition(table), | |
# And extra stuff from BaseView | |
"renderers": {}, | |
}, | |
} | |
# I'm just trying to get HTML to work for the moment | |
if format == "json": | |
return Response.json( | |
dict(vars, locals=locals()), default=CustomJSONEncoder().default | |
) | |
else: | |
context = vars["json"] | |
context.update(vars["html"]) | |
return await self.render_html(datasette, ["table.html"], request, context) | |
async def display_columns_and_rows( | |
self, | |
datasette, | |
database, | |
table, | |
description, | |
rows, | |
link_column=False, | |
truncate_cells=0, | |
): | |
"""Returns columns, rows for specified table - including fancy foreign key treatment""" | |
db = datasette.databases[database] | |
table_metadata = datasette.table_metadata(database, table) | |
column_descriptions = table_metadata.get("columns") or {} | |
column_details = {col.name: col for col in await db.table_column_details(table)} | |
sortable_columns = await sortable_columns_for_table( | |
datasette, database, table, True | |
) | |
pks = await db.primary_keys(table) | |
pks_for_display = pks | |
if not pks_for_display: | |
pks_for_display = ["rowid"] | |
columns = [] | |
for r in description: | |
if r[0] == "rowid" and "rowid" not in column_details: | |
type_ = "integer" | |
notnull = 0 | |
else: | |
type_ = column_details[r[0]].type | |
notnull = column_details[r[0]].notnull | |
columns.append( | |
{ | |
"name": r[0], | |
"sortable": r[0] in sortable_columns, | |
"is_pk": r[0] in pks_for_display, | |
"type": type_, | |
"notnull": notnull, | |
"description": column_descriptions.get(r[0]), | |
} | |
) | |
column_to_foreign_key_table = { | |
fk["column"]: fk["other_table"] | |
for fk in await db.foreign_keys_for_table(table) | |
} | |
cell_rows = [] | |
base_url = datasette.setting("base_url") | |
for row in rows: | |
cells = [] | |
# Unless we are a view, the first column is a link - either to the rowid | |
# or to the simple or compound primary key | |
if link_column: | |
is_special_link_column = len(pks) != 1 | |
pk_path = path_from_row_pks(row, pks, not pks, False) | |
cells.append( | |
{ | |
"column": pks[0] if len(pks) == 1 else "Link", | |
"value_type": "pk", | |
"is_special_link_column": is_special_link_column, | |
"raw": pk_path, | |
"value": markupsafe.Markup( | |
'<a href="{base_url}{database}/{table}/{flat_pks_quoted}">{flat_pks}</a>'.format( | |
base_url=base_url, | |
database=database, | |
table=urllib.parse.quote_plus(table), | |
flat_pks=str(markupsafe.escape(pk_path)), | |
flat_pks_quoted=path_from_row_pks(row, pks, not pks), | |
) | |
), | |
} | |
) | |
for value, column_dict in zip(row, columns): | |
column = column_dict["name"] | |
if link_column and len(pks) == 1 and column == pks[0]: | |
# If there's a simple primary key, don't repeat the value as it's | |
# already shown in the link column. | |
continue | |
# First let the plugins have a go | |
# pylint: disable=no-member | |
plugin_display_value = None | |
for candidate in pm.hook.render_cell( | |
value=value, | |
column=column, | |
table=table, | |
database=database, | |
datasette=datasette, | |
): | |
candidate = await await_me_maybe(candidate) | |
if candidate is not None: | |
plugin_display_value = candidate | |
break | |
if plugin_display_value: | |
display_value = plugin_display_value | |
elif isinstance(value, bytes): | |
display_value = markupsafe.Markup( | |
'<a class="blob-download" href="{}"><Binary: {} byte{}></a>'.format( | |
datasette.urls.row_blob( | |
database, | |
table, | |
path_from_row_pks(row, pks, not pks), | |
column, | |
), | |
len(value), | |
"" if len(value) == 1 else "s", | |
) | |
) | |
elif isinstance(value, dict): | |
# It's an expanded foreign key - display link to other row | |
label = value["label"] | |
value = value["value"] | |
# The table we link to depends on the column | |
other_table = column_to_foreign_key_table[column] | |
link_template = ( | |
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE | |
) | |
display_value = markupsafe.Markup( | |
link_template.format( | |
database=database, | |
base_url=base_url, | |
table=urllib.parse.quote_plus(other_table), | |
link_id=urllib.parse.quote_plus(str(value)), | |
id=str(markupsafe.escape(value)), | |
label=str(markupsafe.escape(label)) or "-", | |
) | |
) | |
elif value in ("", None): | |
display_value = markupsafe.Markup(" ") | |
elif is_url(str(value).strip()): | |
display_value = markupsafe.Markup( | |
'<a href="{url}">{url}</a>'.format( | |
url=markupsafe.escape(value.strip()) | |
) | |
) | |
elif column in table_metadata.get("units", {}) and value != "": | |
# Interpret units using pint | |
value = value * ureg(table_metadata["units"][column]) | |
# Pint uses floating point which sometimes introduces errors in the compact | |
# representation, which we have to round off to avoid ugliness. In the vast | |
# majority of cases this rounding will be inconsequential. I hope. | |
value = round(value.to_compact(), 6) | |
display_value = markupsafe.Markup( | |
f"{value:~P}".replace(" ", " ") | |
) | |
else: | |
display_value = str(value) | |
if truncate_cells and len(display_value) > truncate_cells: | |
display_value = display_value[:truncate_cells] + "\u2026" | |
cells.append( | |
{ | |
"column": column, | |
"value": display_value, | |
"raw": value, | |
"value_type": "none" | |
if value is None | |
else str(type(value).__name__), | |
} | |
) | |
cell_rows.append(Row(cells)) | |
if link_column: | |
# Add the link column header. | |
# If it's a simple primary key, we have to remove and re-add that column name at | |
# the beginning of the header row. | |
first_column = None | |
if len(pks) == 1: | |
columns = [col for col in columns if col["name"] != pks[0]] | |
first_column = { | |
"name": pks[0], | |
"sortable": len(pks) == 1, | |
"is_pk": True, | |
"type": column_details[pks[0]].type, | |
"notnull": column_details[pks[0]].notnull, | |
} | |
else: | |
first_column = { | |
"name": "Link", | |
"sortable": False, | |
"is_pk": False, | |
"type": "", | |
"notnull": 0, | |
} | |
columns = [first_column] + columns | |
return columns, cell_rows | |
async def view(self, request, datasette): | |
return await self.main(request=request, datasette=datasette) | |
async def render_html(self, datasette, templates, request, context=None): | |
context = context or {} | |
template = datasette.jinja_env.select_template(templates) | |
template_context = { | |
**context, | |
**{ | |
"database_color": lambda database: "ff0000", | |
"select_templates": [ | |
f"{'*' if template_name == template.name else ''}{template_name}" | |
for template_name in templates | |
], | |
}, | |
} | |
return Response.html( | |
await datasette.render_template( | |
template, template_context, request=request, view_name=self.view_name | |
) | |
) | |
@hookimpl | |
def register_routes(): | |
return [ | |
(r"/t/(?P<db_name>[^/]+)/(?P<table_and_format>[^/]+?$)", Table().view), | |
] | |
async def check_permissions(datasette, request, permissions): | |
"""permissions is a list of (action, resource) tuples or 'action' strings""" | |
for permission in permissions: | |
if isinstance(permission, str): | |
action = permission | |
resource = None | |
elif isinstance(permission, (tuple, list)) and len(permission) == 2: | |
action, resource = permission | |
else: | |
assert ( | |
False | |
), "permission should be string or tuple of two items: {}".format( | |
repr(permission) | |
) | |
ok = await datasette.permission_allowed( | |
request.actor, | |
action, | |
resource=resource, | |
default=None, | |
) | |
if ok is not None: | |
if ok: | |
return | |
else: | |
raise Forbidden(action) | |
async def columns_to_select(datasette, database, table, request): | |
table_columns = await database.table_columns(table) | |
pks = await database.primary_keys(table) | |
columns = list(table_columns) | |
if "_col" in request.args: | |
columns = list(pks) | |
_cols = request.args.getlist("_col") | |
bad_columns = [column for column in _cols if column not in table_columns] | |
if bad_columns: | |
raise DatasetteError( | |
"_col={} - invalid columns".format(", ".join(bad_columns)), | |
status=400, | |
) | |
# De-duplicate maintaining order: | |
columns.extend(dict.fromkeys(_cols)) | |
if "_nocol" in request.args: | |
# Return all columns EXCEPT these | |
bad_columns = [ | |
column | |
for column in request.args.getlist("_nocol") | |
if (column not in table_columns) or (column in pks) | |
] | |
if bad_columns: | |
raise DatasetteError( | |
"_nocol={} - invalid columns".format(", ".join(bad_columns)), | |
status=400, | |
) | |
tmp_columns = [ | |
column for column in columns if column not in request.args.getlist("_nocol") | |
] | |
columns = tmp_columns | |
return columns | |
async def sortable_columns_for_table(datasette, database, table, use_rowid): | |
db = datasette.databases[database] | |
table_metadata = datasette.table_metadata(database, table) | |
if "sortable_columns" in table_metadata: | |
sortable_columns = set(table_metadata["sortable_columns"]) | |
else: | |
sortable_columns = set(await db.table_columns(table)) | |
if use_rowid: | |
sortable_columns.add("rowid") | |
return sortable_columns | |
class Row: | |
def __init__(self, cells): | |
self.cells = cells | |
def __iter__(self): | |
return iter(self.cells) | |
def __getitem__(self, key): | |
for cell in self.cells: | |
if cell["column"] == key: | |
return cell["raw"] | |
raise KeyError | |
def display(self, key): | |
for cell in self.cells: | |
if cell["column"] == key: | |
return cell["value"] | |
return None | |
def __str__(self): | |
d = { | |
key: self[key] | |
for key in [ | |
c["column"] for c in self.cells if not c.get("is_special_link_column") | |
] | |
} | |
return json.dumps(d, default=repr, indent=2) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment