Skip to content

Instantly share code, notes, and snippets.

@simonw

simonw/table.py Secret

Created November 19, 2021 02:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonw/281eac9c73b062c3469607ad86470eb2 to your computer and use it in GitHub Desktop.
Save simonw/281eac9c73b062c3469607ad86470eb2 to your computer and use it in GitHub Desktop.
Temporary table plugin, refs https://github.com/simonw/datasette/issues/878
from datasette.database import QueryInterrupted
from datasette.utils import (
append_querystring,
escape_sqlite,
CustomJSONEncoder,
to_css_class,
path_from_row_pks,
await_me_maybe,
is_url,
path_with_replaced_args,
path_with_removed_args,
)
from datasette.utils.asgi import Response, NotFound, Forbidden
from datasette.views.base import DatasetteError
from datasette import hookimpl
from asyncinject import AsyncInject, inject
from pprint import pformat
import json
from datasette.utils import sqlite3
from datasette.plugins import pm
import base64
import markupsafe
import urllib
import pint
ureg = pint.UnitRegistry()
LINK_WITH_LABEL = (
'<a href="{base_url}{database}/{table}/{link_id}">{label}</a>&nbsp;<em>{id}</em>'
)
LINK_WITH_VALUE = '<a href="{base_url}{database}/{table}/{link_id}">{id}</a>'
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, sqlite3.Row):
return tuple(obj)
if isinstance(obj, sqlite3.Cursor):
return list(obj)
if isinstance(obj, bytes):
# Does it encode to utf8?
try:
return obj.decode("utf8")
except UnicodeDecodeError:
return {
"$base64": True,
"encoded": base64.b64encode(obj).decode("latin1"),
}
return repr(obj)
class Table(AsyncInject):
view_name = "table"
@inject
async def database(self, request, datasette):
# TODO: all that nasty hash resolving stuff can go here
db_name = request.url_vars["db_name"]
try:
db = datasette.databases[db_name]
except KeyError:
raise NotFound(f"Database '{db_name}' does not exist")
return db
@inject
async def table_and_format(self, request, database, datasette):
table_and_format = request.url_vars["table_and_format"]
# TODO: be a lot smarter here
if "." in table_and_format:
return table_and_format.split(".", 2)
else:
return table_and_format, "html"
@inject
async def main(self, request, database, table_and_format, datasette):
# TODO: if this is actually a canned query, dispatch to it
table, format = table_and_format
is_view = bool(await database.get_view_definition(table))
table_exists = bool(await database.table_exists(table))
if not is_view and not table_exists:
raise NotFound(f"Table not found: {table}")
await check_permissions(
datasette,
request,
[
("view-table", (database.name, table)),
("view-database", database.name),
"view-instance",
],
)
private = not await datasette.permission_allowed(
None, "view-table", (database.name, table), default=True
)
pks = await database.primary_keys(table)
table_columns = await database.table_columns(table)
specified_columns = await columns_to_select(datasette, database, table, request)
select_specified_columns = ", ".join(
escape_sqlite(t) for t in specified_columns
)
select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns)
use_rowid = not pks and not is_view
if use_rowid:
select_specified_columns = f"rowid, {select_specified_columns}"
select_all_columns = f"rowid, {select_all_columns}"
order_by = "rowid"
order_by_pks = "rowid"
else:
order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks])
order_by = order_by_pks
if is_view:
order_by = ""
nocount = request.args.get("_nocount")
nofacet = request.args.get("_nofacet")
if request.args.get("_shape") in ("array", "object"):
nocount = True
nofacet = True
# Next, a TON of SQL to build where_params and filters and suchlike
# skipping that and jumping straight to...
where_clauses = []
where_clause = ""
if where_clauses:
where_clause = f"where {' and '.join(where_clauses)} "
from_sql = "from {table_name} {where}".format(
table_name=escape_sqlite(table),
where=("where {} ".format(" and ".join(where_clauses)))
if where_clauses
else "",
)
from_sql_params = {}
params = {}
count_sql = f"select count(*) {from_sql}"
sql_no_order_no_limit = (
"select {select_all_columns} from {table_name} {where}".format(
select_all_columns=select_all_columns,
table_name=escape_sqlite(table),
where=where_clause,
)
)
page_size = 100
offset = " offset 0"
sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format(
select_specified_columns=select_specified_columns,
table_name=escape_sqlite(table),
where=where_clause,
order_by=order_by,
page_size=page_size + 1,
offset=offset,
)
# Fetch rows
results = await database.execute(sql, params, truncate=True)
columns = [r[0] for r in results.description]
rows = list(results.rows)
# Fetch count
filtered_table_rows_count = None
if count_sql:
try:
count_rows = list(await database.execute(count_sql, from_sql_params))
filtered_table_rows_count = count_rows[0][0]
except QueryInterrupted:
pass
class Filters:
def lookups(self):
return []
def selections(self):
return []
display_columns, display_rows = await self.display_columns_and_rows(
datasette,
database.name,
table,
results.description,
rows,
link_column=not is_view,
truncate_cells=datasette.setting("truncate_cells_html"),
)
vars = {
"json": {
# THIS STUFF is from the regular JSON
"database": database.name,
"table": table,
"is_view": is_view,
# "human_description_en": human_description_en,
"rows": rows[:page_size],
"truncated": results.truncated,
"filtered_table_rows_count": filtered_table_rows_count,
# "expanded_columns": expanded_columns,
# "expandable_columns": expandable_columns,
"columns": columns,
"primary_keys": pks,
# "units": units,
"query": {"sql": sql, "params": params},
# "facet_results": facet_results,
# "suggested_facets": suggested_facets,
# "next": next_value and str(next_value) or None,
# "next_url": next_url,
"private": private,
"allow_execute_sql": await datasette.permission_allowed(
request.actor, "execute-sql", database, default=True
),
},
"html": {
# ... this is the HTML special stuff
"table_actions": lambda: [], # table_actions,
# "supports_search": bool(fts_table),
# "search": search or "",
"use_rowid": use_rowid,
"filters": Filters(),
"display_columns": display_columns,
# "filter_columns": filter_columns,
"display_rows": display_rows,
# "facets_timed_out": facets_timed_out,
# "sorted_facet_results": sorted(
# facet_results.values(),
# key=lambda f: (len(f["results"]), f["name"]),
# reverse=True,
# ),
# "show_facet_counts": special_args.get("_facet_size") == "max",
# "extra_wheres_for_ui": extra_wheres_for_ui,
# "form_hidden_args": form_hidden_args,
# "is_sortable": any(c["sortable"] for c in display_columns),
"path_with_replaced_args": path_with_replaced_args,
"path_with_removed_args": path_with_removed_args,
"append_querystring": append_querystring,
"request": request,
# "sort": sort,
# "sort_desc": sort_desc,
"disable_sort": is_view,
"custom_table_templates": [
f"_table-{to_css_class(database.name)}-{to_css_class(table)}.html",
f"_table-table-{to_css_class(database.name)}-{to_css_class(table)}.html",
"_table.html",
],
"metadata": {}, # metadata,
# "view_definition": await db.get_view_definition(table),
# "table_definition": await db.get_table_definition(table),
# And extra stuff from BaseView
"renderers": {},
},
}
# I'm just trying to get HTML to work for the moment
if format == "json":
return Response.json(
dict(vars, locals=locals()), default=CustomJSONEncoder().default
)
else:
context = vars["json"]
context.update(vars["html"])
return await self.render_html(datasette, ["table.html"], request, context)
async def display_columns_and_rows(
self,
datasette,
database,
table,
description,
rows,
link_column=False,
truncate_cells=0,
):
"""Returns columns, rows for specified table - including fancy foreign key treatment"""
db = datasette.databases[database]
table_metadata = datasette.table_metadata(database, table)
column_descriptions = table_metadata.get("columns") or {}
column_details = {col.name: col for col in await db.table_column_details(table)}
sortable_columns = await sortable_columns_for_table(
datasette, database, table, True
)
pks = await db.primary_keys(table)
pks_for_display = pks
if not pks_for_display:
pks_for_display = ["rowid"]
columns = []
for r in description:
if r[0] == "rowid" and "rowid" not in column_details:
type_ = "integer"
notnull = 0
else:
type_ = column_details[r[0]].type
notnull = column_details[r[0]].notnull
columns.append(
{
"name": r[0],
"sortable": r[0] in sortable_columns,
"is_pk": r[0] in pks_for_display,
"type": type_,
"notnull": notnull,
"description": column_descriptions.get(r[0]),
}
)
column_to_foreign_key_table = {
fk["column"]: fk["other_table"]
for fk in await db.foreign_keys_for_table(table)
}
cell_rows = []
base_url = datasette.setting("base_url")
for row in rows:
cells = []
# Unless we are a view, the first column is a link - either to the rowid
# or to the simple or compound primary key
if link_column:
is_special_link_column = len(pks) != 1
pk_path = path_from_row_pks(row, pks, not pks, False)
cells.append(
{
"column": pks[0] if len(pks) == 1 else "Link",
"value_type": "pk",
"is_special_link_column": is_special_link_column,
"raw": pk_path,
"value": markupsafe.Markup(
'<a href="{base_url}{database}/{table}/{flat_pks_quoted}">{flat_pks}</a>'.format(
base_url=base_url,
database=database,
table=urllib.parse.quote_plus(table),
flat_pks=str(markupsafe.escape(pk_path)),
flat_pks_quoted=path_from_row_pks(row, pks, not pks),
)
),
}
)
for value, column_dict in zip(row, columns):
column = column_dict["name"]
if link_column and len(pks) == 1 and column == pks[0]:
# If there's a simple primary key, don't repeat the value as it's
# already shown in the link column.
continue
# First let the plugins have a go
# pylint: disable=no-member
plugin_display_value = None
for candidate in pm.hook.render_cell(
value=value,
column=column,
table=table,
database=database,
datasette=datasette,
):
candidate = await await_me_maybe(candidate)
if candidate is not None:
plugin_display_value = candidate
break
if plugin_display_value:
display_value = plugin_display_value
elif isinstance(value, bytes):
display_value = markupsafe.Markup(
'<a class="blob-download" href="{}">&lt;Binary:&nbsp;{}&nbsp;byte{}&gt;</a>'.format(
datasette.urls.row_blob(
database,
table,
path_from_row_pks(row, pks, not pks),
column,
),
len(value),
"" if len(value) == 1 else "s",
)
)
elif isinstance(value, dict):
# It's an expanded foreign key - display link to other row
label = value["label"]
value = value["value"]
# The table we link to depends on the column
other_table = column_to_foreign_key_table[column]
link_template = (
LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE
)
display_value = markupsafe.Markup(
link_template.format(
database=database,
base_url=base_url,
table=urllib.parse.quote_plus(other_table),
link_id=urllib.parse.quote_plus(str(value)),
id=str(markupsafe.escape(value)),
label=str(markupsafe.escape(label)) or "-",
)
)
elif value in ("", None):
display_value = markupsafe.Markup("&nbsp;")
elif is_url(str(value).strip()):
display_value = markupsafe.Markup(
'<a href="{url}">{url}</a>'.format(
url=markupsafe.escape(value.strip())
)
)
elif column in table_metadata.get("units", {}) and value != "":
# Interpret units using pint
value = value * ureg(table_metadata["units"][column])
# Pint uses floating point which sometimes introduces errors in the compact
# representation, which we have to round off to avoid ugliness. In the vast
# majority of cases this rounding will be inconsequential. I hope.
value = round(value.to_compact(), 6)
display_value = markupsafe.Markup(
f"{value:~P}".replace(" ", "&nbsp;")
)
else:
display_value = str(value)
if truncate_cells and len(display_value) > truncate_cells:
display_value = display_value[:truncate_cells] + "\u2026"
cells.append(
{
"column": column,
"value": display_value,
"raw": value,
"value_type": "none"
if value is None
else str(type(value).__name__),
}
)
cell_rows.append(Row(cells))
if link_column:
# Add the link column header.
# If it's a simple primary key, we have to remove and re-add that column name at
# the beginning of the header row.
first_column = None
if len(pks) == 1:
columns = [col for col in columns if col["name"] != pks[0]]
first_column = {
"name": pks[0],
"sortable": len(pks) == 1,
"is_pk": True,
"type": column_details[pks[0]].type,
"notnull": column_details[pks[0]].notnull,
}
else:
first_column = {
"name": "Link",
"sortable": False,
"is_pk": False,
"type": "",
"notnull": 0,
}
columns = [first_column] + columns
return columns, cell_rows
async def view(self, request, datasette):
return await self.main(request=request, datasette=datasette)
async def render_html(self, datasette, templates, request, context=None):
context = context or {}
template = datasette.jinja_env.select_template(templates)
template_context = {
**context,
**{
"database_color": lambda database: "ff0000",
"select_templates": [
f"{'*' if template_name == template.name else ''}{template_name}"
for template_name in templates
],
},
}
return Response.html(
await datasette.render_template(
template, template_context, request=request, view_name=self.view_name
)
)
@hookimpl
def register_routes():
return [
(r"/t/(?P<db_name>[^/]+)/(?P<table_and_format>[^/]+?$)", Table().view),
]
async def check_permissions(datasette, request, permissions):
"""permissions is a list of (action, resource) tuples or 'action' strings"""
for permission in permissions:
if isinstance(permission, str):
action = permission
resource = None
elif isinstance(permission, (tuple, list)) and len(permission) == 2:
action, resource = permission
else:
assert (
False
), "permission should be string or tuple of two items: {}".format(
repr(permission)
)
ok = await datasette.permission_allowed(
request.actor,
action,
resource=resource,
default=None,
)
if ok is not None:
if ok:
return
else:
raise Forbidden(action)
async def columns_to_select(datasette, database, table, request):
table_columns = await database.table_columns(table)
pks = await database.primary_keys(table)
columns = list(table_columns)
if "_col" in request.args:
columns = list(pks)
_cols = request.args.getlist("_col")
bad_columns = [column for column in _cols if column not in table_columns]
if bad_columns:
raise DatasetteError(
"_col={} - invalid columns".format(", ".join(bad_columns)),
status=400,
)
# De-duplicate maintaining order:
columns.extend(dict.fromkeys(_cols))
if "_nocol" in request.args:
# Return all columns EXCEPT these
bad_columns = [
column
for column in request.args.getlist("_nocol")
if (column not in table_columns) or (column in pks)
]
if bad_columns:
raise DatasetteError(
"_nocol={} - invalid columns".format(", ".join(bad_columns)),
status=400,
)
tmp_columns = [
column for column in columns if column not in request.args.getlist("_nocol")
]
columns = tmp_columns
return columns
async def sortable_columns_for_table(datasette, database, table, use_rowid):
db = datasette.databases[database]
table_metadata = datasette.table_metadata(database, table)
if "sortable_columns" in table_metadata:
sortable_columns = set(table_metadata["sortable_columns"])
else:
sortable_columns = set(await db.table_columns(table))
if use_rowid:
sortable_columns.add("rowid")
return sortable_columns
class Row:
def __init__(self, cells):
self.cells = cells
def __iter__(self):
return iter(self.cells)
def __getitem__(self, key):
for cell in self.cells:
if cell["column"] == key:
return cell["raw"]
raise KeyError
def display(self, key):
for cell in self.cells:
if cell["column"] == key:
return cell["value"]
return None
def __str__(self):
d = {
key: self[key]
for key in [
c["column"] for c in self.cells if not c.get("is_special_link_column")
]
}
return json.dumps(d, default=repr, indent=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment