Skip to content

Instantly share code, notes, and snippets.

@goeo-

goeo-/labeler.py Secret

Last active July 8, 2024 20:09
Show Gist options
  • Save goeo-/58b75fa8661e54278a7b6274ad021160 to your computer and use it in GitHub Desktop.
Save goeo-/58b75fa8661e54278a7b6274ad021160 to your computer and use it in GitHub Desktop.
bsky labeler impl
import binascii
import time
import datetime
from base64 import urlsafe_b64decode
from typing import Annotated, Union, Literal
import httpx
import uvicorn
import aiosqlite
from fastapi import FastAPI, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from contextlib import asynccontextmanager
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature, decode_dss_signature
from multiformats import multibase, multicodec
import cbrrr
from pydantic import BaseModel, Field, ValidationError
from pydantic.functional_validators import AfterValidator
@asynccontextmanager
async def lifespan(app: FastAPI):
app.con = await aiosqlite.connect("labels.db")
await app.con.execute('pragma journal_mode=wal')
app.cur = await app.con.cursor()
await app.cur.execute("""
CREATE TABLE IF NOT EXISTS label(
seq INTEGER PRIMARY KEY NOT NULL,
data BLOB NOT NULL
)
""")
yield
await app.con.close()
app = FastAPI(lifespan=lifespan)
security = HTTPBearer()
client = httpx.AsyncClient()
WEBHOOK = 'https://discord.com/api/webhooks/AAAAAAAA'
OWNER_DID = 'did:web:genco.me'
LABELER_DID = 'did:plc:gcbmhqcuvuoz7jgmlanabiuv'
# ec.generate_private_key(ec.SECP256K1()).private_numbers().private_value
LABELER_PRIVATE_KEY = 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
priv_key = ec.derive_private_key(LABELER_PRIVATE_KEY, ec.SECP256K1())
compressed_public_bytes = priv_key.public_key().public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.CompressedPoint
)
print("running as did:key:" + multibase.encode(
multicodec.wrap("secp256k1-pub", compressed_public_bytes),
"base58btc"
))
DID = Annotated[str, Field(pattern=r'^did\:(?:plc\:[a-z2-7]{24}|web\:(?:[A-Za-z0-9\-]+\.)+[A-Za-z0-9\-]+)$')]
ATURI = Annotated[str, Field(pattern=r'^at\:\/\/did\:(?:plc\:[a-z2-7]{24}|web\:(?:[A-Za-z0-9\-]+\.)+[A-Za-z0-9\-]+)\/[^\/ ]+\/[^\/ ]+$')]
CID = Annotated[str, Field(pattern=r'^baf[yk]rei[a-z2-7]{52}$')]
DATETIME = Annotated[str, Field(pattern=r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d(\.\d{1,15})?(Z|[+-]\d\d:\d\d)$')]
LABELS_HEADER = cbrrr.encode_dag_cbor({'t': '#labels', 'op': 1})
ERROR_MSG = cbrrr.encode_dag_cbor({'op': -1}) + cbrrr.encode_dag_cbor({'error': 'FutureCursor', 'message': 'Cursor in the future.'})
class User(BaseModel):
did: DID
username: str
avatar_url: Union[str, None]
class JWTHeader(BaseModel):
typ: Literal['JWT']
alg: Literal['ES256', 'ES256K']
def check_current(v: int) -> int:
assert v > time.time(), f'{v} is too old'
return v
class JWTPayload(BaseModel):
exp: Annotated[int, AfterValidator(check_current)]
aud: Annotated[str, Field(pattern='^%s$' % LABELER_DID)]
iss: DID
class Label(BaseModel):
ver: Literal[1]
src: DID
uri: ATURI | DID
cid: CID | None = None
val: Annotated[str, Field(max_length=128)]
neg: bool | None = None
cts: DATETIME
exp: DATETIME | None = None
sig: bytes | None = None
async def get_current_user(bearer: Annotated[HTTPAuthorizationCredentials, Depends(security)]):
auth = bearer.credentials.split('.')
if not len(auth) == 3:
raise HTTPException(status_code=400, detail='Bad Bearer token')
try:
header, payload, signature = (urlsafe_b64decode(x.ljust(len(x) + (len(x) % 4), '=')) for x in auth)
except binascii.Error:
raise HTTPException(status_code=400, detail='Bad JWT: could not base64 decode')
try:
header = JWTHeader.model_validate_json(header)
except ValidationError:
raise HTTPException(status_code=400, detail='Bad JWT header')
try:
payload = JWTPayload.model_validate_json(payload)
except ValidationError:
raise HTTPException(status_code=400, detail='Bad JWT payload')
try:
signing_key, handle, pds = await get_signing_key_handle_pds(payload.iss)
except InvalidDIDException:
raise HTTPException(status_code=400, detail='Bad iss did')
signing_key = multibase.decode(signing_key)
codec, signing_key = multicodec.unwrap(signing_key)
if header.alg == 'ES256':
if codec.name != 'p256-pub':
raise HTTPException(status_code=400, detail="Bad JWT: signing key doesn't match JWT typ")
signing_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), signing_key)
else: # header.alg == 'ES256K':
if codec.name != 'secp256k1-pub':
raise HTTPException(status_code=400, detail="Bad JWT: signing key doesn't match JWT typ")
signing_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256K1(), signing_key)
to_verify = '.'.join(auth[:2]).encode()
try:
signing_key.verify(
signature=encode_dss_signature(
int.from_bytes(signature[:32], 'big'),
int.from_bytes(signature[32:], 'big')
),
data=to_verify,
signature_algorithm=ec.ECDSA(hashes.SHA256())
)
except InvalidSignature:
raise HTTPException(status_code=400, detail='Bad JWT signature')
try:
resolved_handle = await resolve_handle(handle)
verified_handle = resolved_handle == payload.iss
except CannotResolveHandleException:
verified_handle = False
avatar_url = None
display_name = None
if pds and verified_handle:
profile = await client.get(
f'{pds}/xrpc/com.atproto.repo.getRecord?repo={payload.iss}&collection=app.bsky.actor.profile&rkey=self'
)
profile = profile.json()
avatar_blob = profile.get('value', {}).get('avatar', {}).get('ref', {}).get('$link')
if avatar_blob:
avatar_url = f'{pds}/xrpc/com.atproto.sync.getBlob?did={payload.iss}&cid={avatar_blob}'
display_name = profile.get('value', {}).get('displayName')
if display_name:
username = f'{display_name} ({handle})'
elif verified_handle:
username = handle
else:
username = f'invalid handle: {handle} ({payload.iss})'
return User(did=payload.iss, username=username, avatar_url=avatar_url)
class StrongRef(BaseModel):
type: Literal['com.atproto.repo.strongRef'] = Field(alias='$type')
uri: ATURI
cid: str
class RepoRef(BaseModel):
type: Literal['com.atproto.admin.defs#repoRef'] = Field(alias='$type')
did: DID
class Report(BaseModel):
reason_type: Annotated[str, Field(alias='reasonType', pattern=r'^com.atproto.moderation.defs#reason(Spam|Violation|Misleading|Sexual|Rude|Other|Appeal)$')]
subject: Union[StrongRef, RepoRef]
reason: Union[str, None]
@app.post('/xrpc/com.atproto.moderation.createReport')
async def create_report(current_user: Annotated[User, Depends(get_current_user)], report: Report):
if not report.reason or current_user.did != OWNER_DID:
return JSONResponse({"message":"unauthorized"}, 401)
neg = None
if report.reason.startswith('!'):
neg = True
report.reason = report.reason[1:]
now = timestamp()
if type(report.subject) is StrongRef:
await new_labels([Label(ver=1, src=LABELER_DID, uri=report.subject.uri, cid=report.subject.cid, val=report.reason, neg=neg, cts=now)])
else: # type(report.subject) is RepoRef:
await new_labels([Label(ver=1, src=LABELER_DID, uri=report.subject.did, val=report.reason, neg=neg, cts=now)])
return JSONResponse(content={
"id": app.cur.lastrowid,
"reasonType": report.reason_type,
"reason": report.reason,
"subject": report.subject.model_dump(by_alias=True),
"reportedBy": current_user.did,
"createdAt": now
})
class CannotResolveHandleException(Exception):
pass
async def resolve_handle(handle):
res = await client.get(f'https://1.1.1.1/dns-query?name=_atproto.{handle}&type=TXT',
headers={'Accept': 'application/dns-json'})
res = res.json()
if 'Answer' in res and len(res['Answer']) > 0:
data = res['Answer'][0]['data']
assert data.startswith('"did=did:') and data.endswith('"')
return data[5:-1]
try:
res = await client.get(f'https://{handle}/.well-known/atproto-did')
except Exception as e:
print(e)
raise CannotResolveHandleException
assert res.text.startswith('did:')
return res.text.strip()
class InvalidDIDException(Exception):
pass
async def get_signing_key_handle_pds(did):
if did.startswith('did:plc:'):
doc = (await client.get(f'https://plc.directory/{did}')).json()
elif did.startswith('did:web:'):
doc = (await client.get(f'https://{did[8:]}/.well-known/did.json')).json()
else:
raise InvalidDIDException
handle = None
for aka in doc.get('alsoKnownAs', []):
if aka.startswith('at://'):
handle = aka[5:]
break
pds = None
for service in doc.get('service', []):
if service.get('id') == '#atproto_pds' and service.get('type') == 'AtprotoPersonalDataServer':
pds = service.get('serviceEndpoint')
break
signing_key = None
for verification_method in doc.get('verificationMethod', []):
if (
verification_method.get('id') == f'{did}#atproto' and
verification_method.get('type') == 'Multikey' and
verification_method.get('controller') == did
):
signing_key = verification_method.get('publicKeyMultibase')
break
if not signing_key:
raise InvalidDIDException
return signing_key, handle, pds
connections = []
@app.websocket("/xrpc/com.atproto.label.subscribeLabels")
async def websocket_endpoint(websocket: WebSocket, cursor: int | None=None):
global connections
await websocket.accept()
connections.append(websocket)
if cursor is not None:
res = await app.cur.execute('select seq, data from label where seq >= ? order by seq asc', (cursor,))
empty = True
while fetch := await res.fetchmany():
for (seq, data) in fetch:
empty = False
if seq == cursor:
continue
await websocket.send_bytes(LABELS_HEADER + cbrrr.encode_dag_cbor({'seq': seq, 'labels':0})[:-1] + data)
if empty:
await websocket.send_bytes(ERROR_MSG)
await websocket.close()
connections.remove(websocket)
return
try:
while True:
_ = await websocket.receive_bytes()
except WebSocketDisconnect:
connections.remove(websocket)
async def new_labels(labels):
out = bytearray()
label_count = len(labels)
if label_count <= 0x17:
out.append(0x80 + label_count)
elif label_count <= 0xff:
out.append(0x98)
out += label_count.to_bytes(1, 'little')
elif label_count <= 0xffff:
out.append(0x99)
out += label_count.to_bytes(2, 'little')
else:
raise Exception('please up to 65535 labels at once')
for label in labels:
if not isinstance(label, Label):
label = Label(**label)
label = {a:b for a, b in dict(label).items() if b is not None}
cbor = cbrrr.encode_dag_cbor(label)
r,s = decode_dss_signature(priv_key.sign(cbor, ec.ECDSA(hashes.SHA256())))
SECP256K1_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
if s > SECP256K1_N // 2:
s = SECP256K1_N - s
signature = r.to_bytes(32) + s.to_bytes(32)
label = label | {'sig': signature}
cbor = cbrrr.encode_dag_cbor(label)
out += cbor
await app.cur.execute('insert into label (data) values (?)', (out,))
await app.con.commit()
for x in connections:
await x.send_bytes(LABELS_HEADER + cbrrr.encode_dag_cbor({'seq': app.cur.lastrowid, 'labels':0})[:-1] + out)
def timestamp():
return datetime.datetime.now(datetime.UTC).isoformat()[:-9]+'Z'
if __name__ == '__main__':
uvicorn.run(app='labeler:app', reload=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment