-
-
Save goeo-/58b75fa8661e54278a7b6274ad021160 to your computer and use it in GitHub Desktop.
bsky labeler impl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import binascii | |
import time | |
import datetime | |
from base64 import urlsafe_b64decode | |
from typing import Annotated, Union, Literal | |
import httpx | |
import uvicorn | |
import aiosqlite | |
from fastapi import FastAPI, Depends, WebSocket, WebSocketDisconnect | |
from fastapi.responses import JSONResponse | |
from fastapi.exceptions import HTTPException | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from contextlib import asynccontextmanager | |
from cryptography.exceptions import InvalidSignature | |
from cryptography.hazmat.primitives import hashes, serialization | |
from cryptography.hazmat.primitives.asymmetric import ec | |
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature, decode_dss_signature | |
from multiformats import multibase, multicodec | |
import cbrrr | |
from pydantic import BaseModel, Field, ValidationError | |
from pydantic.functional_validators import AfterValidator | |
@asynccontextmanager | |
async def lifespan(app: FastAPI): | |
app.con = await aiosqlite.connect("labels.db") | |
await app.con.execute('pragma journal_mode=wal') | |
app.cur = await app.con.cursor() | |
await app.cur.execute(""" | |
CREATE TABLE IF NOT EXISTS label( | |
seq INTEGER PRIMARY KEY NOT NULL, | |
data BLOB NOT NULL | |
) | |
""") | |
yield | |
await app.con.close() | |
app = FastAPI(lifespan=lifespan) | |
security = HTTPBearer() | |
client = httpx.AsyncClient() | |
WEBHOOK = 'https://discord.com/api/webhooks/AAAAAAAA' | |
OWNER_DID = 'did:web:genco.me' | |
LABELER_DID = 'did:plc:gcbmhqcuvuoz7jgmlanabiuv' | |
# ec.generate_private_key(ec.SECP256K1()).private_numbers().private_value | |
LABELER_PRIVATE_KEY = 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa | |
priv_key = ec.derive_private_key(LABELER_PRIVATE_KEY, ec.SECP256K1()) | |
compressed_public_bytes = priv_key.public_key().public_bytes( | |
serialization.Encoding.X962, | |
serialization.PublicFormat.CompressedPoint | |
) | |
print("running as did:key:" + multibase.encode( | |
multicodec.wrap("secp256k1-pub", compressed_public_bytes), | |
"base58btc" | |
)) | |
DID = Annotated[str, Field(pattern=r'^did\:(?:plc\:[a-z2-7]{24}|web\:(?:[A-Za-z0-9\-]+\.)+[A-Za-z0-9\-]+)$')] | |
ATURI = Annotated[str, Field(pattern=r'^at\:\/\/did\:(?:plc\:[a-z2-7]{24}|web\:(?:[A-Za-z0-9\-]+\.)+[A-Za-z0-9\-]+)\/[^\/ ]+\/[^\/ ]+$')] | |
CID = Annotated[str, Field(pattern=r'^baf[yk]rei[a-z2-7]{52}$')] | |
DATETIME = Annotated[str, Field(pattern=r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d(\.\d{1,15})?(Z|[+-]\d\d:\d\d)$')] | |
LABELS_HEADER = cbrrr.encode_dag_cbor({'t': '#labels', 'op': 1}) | |
ERROR_MSG = cbrrr.encode_dag_cbor({'op': -1}) + cbrrr.encode_dag_cbor({'error': 'FutureCursor', 'message': 'Cursor in the future.'}) | |
class User(BaseModel): | |
did: DID | |
username: str | |
avatar_url: Union[str, None] | |
class JWTHeader(BaseModel): | |
typ: Literal['JWT'] | |
alg: Literal['ES256', 'ES256K'] | |
def check_current(v: int) -> int: | |
assert v > time.time(), f'{v} is too old' | |
return v | |
class JWTPayload(BaseModel): | |
exp: Annotated[int, AfterValidator(check_current)] | |
aud: Annotated[str, Field(pattern='^%s$' % LABELER_DID)] | |
iss: DID | |
class Label(BaseModel): | |
ver: Literal[1] | |
src: DID | |
uri: ATURI | DID | |
cid: CID | None = None | |
val: Annotated[str, Field(max_length=128)] | |
neg: bool | None = None | |
cts: DATETIME | |
exp: DATETIME | None = None | |
sig: bytes | None = None | |
async def get_current_user(bearer: Annotated[HTTPAuthorizationCredentials, Depends(security)]): | |
auth = bearer.credentials.split('.') | |
if not len(auth) == 3: | |
raise HTTPException(status_code=400, detail='Bad Bearer token') | |
try: | |
header, payload, signature = (urlsafe_b64decode(x.ljust(len(x) + (len(x) % 4), '=')) for x in auth) | |
except binascii.Error: | |
raise HTTPException(status_code=400, detail='Bad JWT: could not base64 decode') | |
try: | |
header = JWTHeader.model_validate_json(header) | |
except ValidationError: | |
raise HTTPException(status_code=400, detail='Bad JWT header') | |
try: | |
payload = JWTPayload.model_validate_json(payload) | |
except ValidationError: | |
raise HTTPException(status_code=400, detail='Bad JWT payload') | |
try: | |
signing_key, handle, pds = await get_signing_key_handle_pds(payload.iss) | |
except InvalidDIDException: | |
raise HTTPException(status_code=400, detail='Bad iss did') | |
signing_key = multibase.decode(signing_key) | |
codec, signing_key = multicodec.unwrap(signing_key) | |
if header.alg == 'ES256': | |
if codec.name != 'p256-pub': | |
raise HTTPException(status_code=400, detail="Bad JWT: signing key doesn't match JWT typ") | |
signing_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), signing_key) | |
else: # header.alg == 'ES256K': | |
if codec.name != 'secp256k1-pub': | |
raise HTTPException(status_code=400, detail="Bad JWT: signing key doesn't match JWT typ") | |
signing_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256K1(), signing_key) | |
to_verify = '.'.join(auth[:2]).encode() | |
try: | |
signing_key.verify( | |
signature=encode_dss_signature( | |
int.from_bytes(signature[:32], 'big'), | |
int.from_bytes(signature[32:], 'big') | |
), | |
data=to_verify, | |
signature_algorithm=ec.ECDSA(hashes.SHA256()) | |
) | |
except InvalidSignature: | |
raise HTTPException(status_code=400, detail='Bad JWT signature') | |
try: | |
resolved_handle = await resolve_handle(handle) | |
verified_handle = resolved_handle == payload.iss | |
except CannotResolveHandleException: | |
verified_handle = False | |
avatar_url = None | |
display_name = None | |
if pds and verified_handle: | |
profile = await client.get( | |
f'{pds}/xrpc/com.atproto.repo.getRecord?repo={payload.iss}&collection=app.bsky.actor.profile&rkey=self' | |
) | |
profile = profile.json() | |
avatar_blob = profile.get('value', {}).get('avatar', {}).get('ref', {}).get('$link') | |
if avatar_blob: | |
avatar_url = f'{pds}/xrpc/com.atproto.sync.getBlob?did={payload.iss}&cid={avatar_blob}' | |
display_name = profile.get('value', {}).get('displayName') | |
if display_name: | |
username = f'{display_name} ({handle})' | |
elif verified_handle: | |
username = handle | |
else: | |
username = f'invalid handle: {handle} ({payload.iss})' | |
return User(did=payload.iss, username=username, avatar_url=avatar_url) | |
class StrongRef(BaseModel): | |
type: Literal['com.atproto.repo.strongRef'] = Field(alias='$type') | |
uri: ATURI | |
cid: str | |
class RepoRef(BaseModel): | |
type: Literal['com.atproto.admin.defs#repoRef'] = Field(alias='$type') | |
did: DID | |
class Report(BaseModel): | |
reason_type: Annotated[str, Field(alias='reasonType', pattern=r'^com.atproto.moderation.defs#reason(Spam|Violation|Misleading|Sexual|Rude|Other|Appeal)$')] | |
subject: Union[StrongRef, RepoRef] | |
reason: Union[str, None] | |
@app.post('/xrpc/com.atproto.moderation.createReport') | |
async def create_report(current_user: Annotated[User, Depends(get_current_user)], report: Report): | |
if not report.reason or current_user.did != OWNER_DID: | |
return JSONResponse({"message":"unauthorized"}, 401) | |
neg = None | |
if report.reason.startswith('!'): | |
neg = True | |
report.reason = report.reason[1:] | |
now = timestamp() | |
if type(report.subject) is StrongRef: | |
await new_labels([Label(ver=1, src=LABELER_DID, uri=report.subject.uri, cid=report.subject.cid, val=report.reason, neg=neg, cts=now)]) | |
else: # type(report.subject) is RepoRef: | |
await new_labels([Label(ver=1, src=LABELER_DID, uri=report.subject.did, val=report.reason, neg=neg, cts=now)]) | |
return JSONResponse(content={ | |
"id": app.cur.lastrowid, | |
"reasonType": report.reason_type, | |
"reason": report.reason, | |
"subject": report.subject.model_dump(by_alias=True), | |
"reportedBy": current_user.did, | |
"createdAt": now | |
}) | |
class CannotResolveHandleException(Exception): | |
pass | |
async def resolve_handle(handle): | |
res = await client.get(f'https://1.1.1.1/dns-query?name=_atproto.{handle}&type=TXT', | |
headers={'Accept': 'application/dns-json'}) | |
res = res.json() | |
if 'Answer' in res and len(res['Answer']) > 0: | |
data = res['Answer'][0]['data'] | |
assert data.startswith('"did=did:') and data.endswith('"') | |
return data[5:-1] | |
try: | |
res = await client.get(f'https://{handle}/.well-known/atproto-did') | |
except Exception as e: | |
print(e) | |
raise CannotResolveHandleException | |
assert res.text.startswith('did:') | |
return res.text.strip() | |
class InvalidDIDException(Exception): | |
pass | |
async def get_signing_key_handle_pds(did): | |
if did.startswith('did:plc:'): | |
doc = (await client.get(f'https://plc.directory/{did}')).json() | |
elif did.startswith('did:web:'): | |
doc = (await client.get(f'https://{did[8:]}/.well-known/did.json')).json() | |
else: | |
raise InvalidDIDException | |
handle = None | |
for aka in doc.get('alsoKnownAs', []): | |
if aka.startswith('at://'): | |
handle = aka[5:] | |
break | |
pds = None | |
for service in doc.get('service', []): | |
if service.get('id') == '#atproto_pds' and service.get('type') == 'AtprotoPersonalDataServer': | |
pds = service.get('serviceEndpoint') | |
break | |
signing_key = None | |
for verification_method in doc.get('verificationMethod', []): | |
if ( | |
verification_method.get('id') == f'{did}#atproto' and | |
verification_method.get('type') == 'Multikey' and | |
verification_method.get('controller') == did | |
): | |
signing_key = verification_method.get('publicKeyMultibase') | |
break | |
if not signing_key: | |
raise InvalidDIDException | |
return signing_key, handle, pds | |
connections = [] | |
@app.websocket("/xrpc/com.atproto.label.subscribeLabels") | |
async def websocket_endpoint(websocket: WebSocket, cursor: int | None=None): | |
global connections | |
await websocket.accept() | |
connections.append(websocket) | |
if cursor is not None: | |
res = await app.cur.execute('select seq, data from label where seq >= ? order by seq asc', (cursor,)) | |
empty = True | |
while fetch := await res.fetchmany(): | |
for (seq, data) in fetch: | |
empty = False | |
if seq == cursor: | |
continue | |
await websocket.send_bytes(LABELS_HEADER + cbrrr.encode_dag_cbor({'seq': seq, 'labels':0})[:-1] + data) | |
if empty: | |
await websocket.send_bytes(ERROR_MSG) | |
await websocket.close() | |
connections.remove(websocket) | |
return | |
try: | |
while True: | |
_ = await websocket.receive_bytes() | |
except WebSocketDisconnect: | |
connections.remove(websocket) | |
async def new_labels(labels): | |
out = bytearray() | |
label_count = len(labels) | |
if label_count <= 0x17: | |
out.append(0x80 + label_count) | |
elif label_count <= 0xff: | |
out.append(0x98) | |
out += label_count.to_bytes(1, 'little') | |
elif label_count <= 0xffff: | |
out.append(0x99) | |
out += label_count.to_bytes(2, 'little') | |
else: | |
raise Exception('please up to 65535 labels at once') | |
for label in labels: | |
if not isinstance(label, Label): | |
label = Label(**label) | |
label = {a:b for a, b in dict(label).items() if b is not None} | |
cbor = cbrrr.encode_dag_cbor(label) | |
r,s = decode_dss_signature(priv_key.sign(cbor, ec.ECDSA(hashes.SHA256()))) | |
SECP256K1_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | |
if s > SECP256K1_N // 2: | |
s = SECP256K1_N - s | |
signature = r.to_bytes(32) + s.to_bytes(32) | |
label = label | {'sig': signature} | |
cbor = cbrrr.encode_dag_cbor(label) | |
out += cbor | |
await app.cur.execute('insert into label (data) values (?)', (out,)) | |
await app.con.commit() | |
for x in connections: | |
await x.send_bytes(LABELS_HEADER + cbrrr.encode_dag_cbor({'seq': app.cur.lastrowid, 'labels':0})[:-1] + out) | |
def timestamp(): | |
return datetime.datetime.now(datetime.UTC).isoformat()[:-9]+'Z' | |
if __name__ == '__main__': | |
uvicorn.run(app='labeler:app', reload=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment