-
-
Save timwis/50fe4e90e052475e01d357e9b6142374 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Column, ForeignKey, Integer, String | |
from sqlalchemy.orm import declarative_base, relationship | |
from sqlalchemy.types import TIMESTAMP | |
Base = declarative_base() | |
class User(Base): | |
__tablename__ = "auth.users" | |
id = Column(Integer, primary_key=True) | |
connections = relationship("Connection", back_populates="user") | |
job_definitions = relationship("JobDefinition", back_populates="user") | |
def __repr__(self): | |
return f"User(id={self.id!r})" | |
class Connection(Base): | |
__tablename__ = "decrypted_connections" | |
id = Column(Integer, primary_key=True) | |
created_at = Column(TIMESTAMP, nullable=False) | |
user_id = Column(Integer, ForeignKey("auth.users.id"), nullable=False) | |
decrypted_access_token = Column(String, nullable=False) | |
expires_at = Column(TIMESTAMP) | |
decrypted_refresh_token = Column(String, nullable=False) | |
user = relationship("User", back_populates="connections") | |
accounts = relationship("Account", back_populates="connection") | |
def __repr__(self): | |
return f"Connection(id={self.id})" | |
class Account(Base): | |
__tablename__ = "accounts" | |
id = Column(Integer, primary_key=True) | |
created_at = Column(TIMESTAMP, nullable=False) | |
account_type = Column(String, nullable=False) | |
truelayer_account_id = Column(String, nullable=False) | |
truelayer_display_name = Column(String, nullable=False) | |
connection_id = Column(Integer, ForeignKey("decrypted_connections.id"), nullable=False) | |
connection = relationship("Connection", back_populates="accounts") | |
job_definitions = relationship("JobDefinition", back_populates="card_account") | |
def __repr__(self): | |
return f"Account(id={self.id!r}, account_type={self.account_type!r})" | |
class JobDefinition(Base): | |
__tablename__ = "job_definitions" | |
id = Column(Integer, primary_key=True) | |
created_at = Column(TIMESTAMP, nullable=False) | |
user_id = Column(Integer, ForeignKey("auth.users.id"), nullable=False) | |
card_account_id = Column(Integer, ForeignKey("accounts.id"), nullable=False) | |
last_synced_at: Column(TIMESTAMP) | |
user = relationship("User", back_populates="job_definitions") | |
card_account = relationship("Account", back_populates="job_definitions") | |
def __repr__(self): | |
return f"JobDefinition(id={self.id!r})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from operator import itemgetter | |
import httpx | |
from prefect import flow, task | |
from prefect_sqlalchemy import SqlAlchemyConnector | |
from sqlalchemy.orm import Session, joinedload | |
from true_layer import TrueLayer | |
from models import JobDefinition, Account, Connection | |
database_block = SqlAlchemyConnector.load("db") | |
truelayer_block = TrueLayer.load("truelayer") | |
@task() | |
def get_job_definitions(): | |
with Session(database_block.get_engine()) as session: | |
query = session.query(JobDefinition).options( | |
joinedload(JobDefinition.card_account) \ | |
.joinedload(Account.connection) | |
) | |
print(query) | |
return query.all() | |
@task | |
def renew_token(connection: Connection): | |
refresh_token = connection.decrypted_refresh_token | |
response_data = truelayer_block.get_renewed_token(refresh_token) | |
print(response_data) | |
access_token, expires_in, refresh_token = itemgetter('access_token', 'expires_in', 'refresh_token')(response_data) | |
with Session(database_block.get_engine()) as session: | |
connection.access_token = access_token | |
connection.refresh_token = refresh_token | |
connection.expires_at = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(0, expires_in) | |
session.commit() # <---- This doesn't seem to be updating the database | |
@flow() | |
def sync(): | |
job_definitions = get_job_definitions() | |
results = [] | |
for jd in job_definitions: | |
if jd.card_account.connection.expires_at <= datetime.datetime.now(datetime.timezone.utc): | |
renew_token(jd.card_account.connection) | |
print(jd.card_account.connection) | |
return results | |
if __name__ == "__main__": | |
print(sync()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment