Skip to content

Instantly share code, notes, and snippets.

@cfm
Created March 22, 2023 23:41
Show Gist options
  • Save cfm/8303223a496791c559595b1bdcdc373a to your computer and use it in GitHub Desktop.
Save cfm/8303223a496791c559595b1bdcdc373a to your computer and use it in GitHub Desktop.
diff --git a/alembic/versions/f394059c0898_add_reply_state.py b/alembic/versions/f394059c0898_add_reply_state.py
new file mode 100644
index 00000000..41fb7472
--- /dev/null
+++ b/alembic/versions/f394059c0898_add_reply_state.py
@@ -0,0 +1,28 @@
+"""add Reply.state
+
+Revision ID: f394059c0898
+Revises: 414627c04463
+Create Date: 2023-03-07 17:22:39.406291
+
+"""
+import sqlalchemy as sa
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "f394059c0898"
+down_revision = "414627c04463"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column("replies", sa.Column("state", sa.String(length=100), nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("replies", "state")
+ # ### end Alembic commands ###
diff --git a/requirements/requirements.in b/requirements/requirements.in
index d843971a..6250f389 100644
--- a/requirements/requirements.in
+++ b/requirements/requirements.in
@@ -14,3 +14,4 @@ six==1.11.0
sqlalchemy==1.3.3
urllib3>=1.26.5
jinja2==3.0.2 # per freedomofpress/securedrop#4829953
+python-statemachine # TODO: pin
diff --git a/securedrop_client/db.py b/securedrop_client/db.py
index 6e53de41..bcc68480 100644
--- a/securedrop_client/db.py
+++ b/securedrop_client/db.py
@@ -21,6 +21,9 @@ from sqlalchemy import (
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, scoped_session, sessionmaker
+from statemachine.mixins import MachineMixin
+
+from securedrop_client.statemachines import ReplyStateMachine # noqa: F401
convention = {
"ix": "ix_%(column_0_label)s",
@@ -424,13 +427,19 @@ class File(Base):
return False
-class Reply(Base):
+class Reply(MachineMixin, Base):
__tablename__ = "replies"
__table_args__ = (
UniqueConstraint("source_id", "file_counter", name="uq_messages_source_id_file_counter"),
)
+ state_machine_name = "ReplyStateMachine"
+ state_machine_attr = "sm"
+ state_field_name = "state"
+
+ state = Column(String(100), nullable=False)
+
id = Column(Integer, primary_key=True)
uuid = Column(String(36), unique=True, nullable=False)
source_id = Column(Integer, ForeignKey("sources.id"), nullable=False)
diff --git a/securedrop_client/logic.py b/securedrop_client/logic.py
index 5de96e3b..5ec336c1 100644
--- a/securedrop_client/logic.py
+++ b/securedrop_client/logic.py
@@ -61,6 +61,7 @@ from securedrop_client.api_jobs.uploads import (
SendReplyJobTimeoutError,
)
from securedrop_client.crypto import GpgHelper
+from securedrop_client.db import Reply
from securedrop_client.queue import ApiJobQueue
from securedrop_client.sync import ApiSync
from securedrop_client.utils import check_dir_permissions
@@ -378,6 +379,14 @@ class Controller(QObject):
self.session_maker = session_maker
self.session = session_maker()
+ r1 = Reply(filename="1-foo")
+ r2 = Reply(filename="2-foo")
+ print(r1.sm, r1.state)
+ print(r2.sm, r2.state)
+ r1.sm.send("send_pending")
+ print(r1.sm, r1.state)
+ print(r2.sm, r2.state)
+
# Queue that handles running API job
self.api_job_queue = ApiJobQueue(
self.api, self.session_maker, self.main_queue_thread, self.file_download_queue_thread
diff --git a/securedrop_client/statemachines/__init__.py b/securedrop_client/statemachines/__init__.py
new file mode 100644
index 00000000..3b106c4a
--- /dev/null
+++ b/securedrop_client/statemachines/__init__.py
@@ -0,0 +1 @@
+from .reply import ReplyStateMachine # noqa: F401
diff --git a/securedrop_client/statemachines/reply.py b/securedrop_client/statemachines/reply.py
new file mode 100644
index 00000000..5185abe7
--- /dev/null
+++ b/securedrop_client/statemachines/reply.py
@@ -0,0 +1,8 @@
+from statemachine import State, StateMachine
+
+
+class ReplyStateMachine(StateMachine):
+ Pending = State("Pending", initial=True)
+ SendPending = State("SendPending")
+
+ send_pending = Pending.to(SendPending)
diff --git a/alembic/versions/f394059c0898_add_reply_state.py b/alembic/versions/f394059c0898_add_reply_state.py
new file mode 100644
index 00000000..41fb7472
--- /dev/null
+++ b/alembic/versions/f394059c0898_add_reply_state.py
@@ -0,0 +1,28 @@
+"""add Reply.state
+
+Revision ID: f394059c0898
+Revises: 414627c04463
+Create Date: 2023-03-07 17:22:39.406291
+
+"""
+import sqlalchemy as sa
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "f394059c0898"
+down_revision = "414627c04463"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column("replies", sa.Column("state", sa.String(length=100), nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("replies", "state")
+ # ### end Alembic commands ###
diff --git a/requirements/requirements.in b/requirements/requirements.in
index d843971a..125c2339 100644
--- a/requirements/requirements.in
+++ b/requirements/requirements.in
@@ -14,3 +14,4 @@ six==1.11.0
sqlalchemy==1.3.3
urllib3>=1.26.5
jinja2==3.0.2 # per freedomofpress/securedrop#4829953
+transitions # TODO: pin
diff --git a/securedrop_client/db.py b/securedrop_client/db.py
index 6e53de41..c7ce389f 100644
--- a/securedrop_client/db.py
+++ b/securedrop_client/db.py
@@ -5,6 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Union # noqa: F401
from uuid import uuid4
+import sqlalchemy as sa
from sqlalchemy import (
Boolean,
CheckConstraint,
@@ -22,6 +23,8 @@ from sqlalchemy import (
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, scoped_session, sessionmaker
+from securedrop_client.statemachines import StateConfig, StateMixin
+
convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
@@ -424,13 +427,28 @@ class File(Base):
return False
-class Reply(Base):
+PENDING = "Pending"
+SEND_PENDING = "SendPending"
+
+
+class Reply(Base, StateMixin):
__tablename__ = "replies"
__table_args__ = (
UniqueConstraint("source_id", "file_counter", name="uq_messages_source_id_file_counter"),
)
+ state_config = StateConfig(
+ initial=PENDING,
+ states=[PENDING, SEND_PENDING],
+ transitions=[
+ ["send_pending", PENDING, SEND_PENDING],
+ ],
+ status_attribute="state",
+ )
+
+ state = Column(String(100), nullable=False)
+
id = Column(Integer, primary_key=True)
uuid = Column(String(36), unique=True, nullable=False)
source_id = Column(Integer, ForeignKey("sources.id"), nullable=False)
@@ -536,6 +554,9 @@ class Reply(Base):
usernames[seen_reply.journalist.username] = seen_reply.journalist
return usernames
+sa.event.listen(Reply, "init", Reply.init_state_machine)
+sa.event.listen(Reply, "load", Reply.init_state_machine)
+
class DownloadErrorCodes(Enum):
"""
diff --git a/securedrop_client/logic.py b/securedrop_client/logic.py
index 5de96e3b..c8024786 100644
--- a/securedrop_client/logic.py
+++ b/securedrop_client/logic.py
@@ -61,6 +61,7 @@ from securedrop_client.api_jobs.uploads import (
SendReplyJobTimeoutError,
)
from securedrop_client.crypto import GpgHelper
+from securedrop_client.db import Reply
from securedrop_client.queue import ApiJobQueue
from securedrop_client.sync import ApiSync
from securedrop_client.utils import check_dir_permissions
@@ -378,6 +379,14 @@ class Controller(QObject):
self.session_maker = session_maker
self.session = session_maker()
+ r1 = Reply(filename="1-foo")
+ r2 = Reply(filename="2-foo")
+ print(r1.machine, r1.state)
+ print(r2.machine, r2.state)
+ r1.send_pending()
+ print(r1.machine, r1.state)
+ print(r2.machine, r2.state)
+
# Queue that handles running API job
self.api_job_queue = ApiJobQueue(
self.api, self.session_maker, self.main_queue_thread, self.file_download_queue_thread
diff --git a/securedrop_client/statemachines.py b/securedrop_client/statemachines.py
new file mode 100644
index 00000000..e9c03c74
--- /dev/null
+++ b/securedrop_client/statemachines.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2021-present bigbag/sqlalchemy-state-machine authors and
+# contributors.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import typing as t
+from dataclasses import dataclass
+
+from transitions import Machine
+
+
+@dataclass
+class StateConfig:
+ initial: t.Any
+ states: t.List[t.Any]
+ transitions: t.Optional[t.List[t.List[t.Any]]]
+ state_attribute: t.Optional[str] = "state"
+ status_attribute: t.Optional[str] = "status"
+ machine_name: t.Optional[str] = "machine"
+ after_state_change: t.Optional[t.Any] = None
+
+
+@dataclass
+class StateMixin:
+ state_config: StateConfig
+
+ @property
+ def state(self):
+ return getattr(self, self.state_config.status_attribute)
+
+ @state.setter
+ def state(self, value):
+ setattr(self, self.state_config.status_attribute, value)
+
+ @classmethod
+ def init_state_machine(cls, obj, *args, **kwargs):
+ machine = Machine(
+ model=obj,
+ states=cls.state_config.states,
+ transitions=cls.state_config.transitions,
+ initial=getattr(obj, obj.state_config.status_attribute) or cls.state_config.initial,
+ model_attribute=cls.state_config.state_attribute,
+ after_state_change=cls.state_config.after_state_change,
+ )
+
+ setattr(obj, cls.state_config.machine_name, machine)
\ No newline at end of file
diff --git a/alembic/versions/f394059c0898_add_reply_state.py b/alembic/versions/f394059c0898_add_reply_state.py
new file mode 100644
index 00000000..41fb7472
--- /dev/null
+++ b/alembic/versions/f394059c0898_add_reply_state.py
@@ -0,0 +1,28 @@
+"""add Reply.state
+
+Revision ID: f394059c0898
+Revises: 414627c04463
+Create Date: 2023-03-07 17:22:39.406291
+
+"""
+import sqlalchemy as sa
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "f394059c0898"
+down_revision = "414627c04463"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column("replies", sa.Column("state", sa.String(length=100), nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("replies", "state")
+ # ### end Alembic commands ###
diff --git a/requirements/requirements.in b/requirements/requirements.in
index d843971a..125c2339 100644
--- a/requirements/requirements.in
+++ b/requirements/requirements.in
@@ -14,3 +14,4 @@ six==1.11.0
sqlalchemy==1.3.3
urllib3>=1.26.5
jinja2==3.0.2 # per freedomofpress/securedrop#4829953
+transitions # TODO: pin
diff --git a/securedrop_client/db.py b/securedrop_client/db.py
index 6e53de41..77dcd1ac 100644
--- a/securedrop_client/db.py
+++ b/securedrop_client/db.py
@@ -21,6 +21,9 @@ from sqlalchemy import (
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, scoped_session, sessionmaker
+from transitions import Machine
+
+from securedrop_client.statemachine import StateMachineMixin
convention = {
"ix": "ix_%(column_0_label)s",
@@ -424,13 +427,28 @@ class File(Base):
return False
-class Reply(Base):
+PENDING = "Pending"
+SEND_PENDING = "SendPending"
+
+
+class Reply(Base, StateMachineMixin):
__tablename__ = "replies"
__table_args__ = (
UniqueConstraint("source_id", "file_counter", name="uq_messages_source_id_file_counter"),
)
+ machine = Machine(
+ model=None,
+ initial=None,
+ states=[PENDING, SEND_PENDING],
+ transitions=[
+ ["send_pending", PENDING, SEND_PENDING],
+ ],
+ )
+
+ state = Column(String(100), nullable=False)
+
id = Column(Integer, primary_key=True)
uuid = Column(String(36), unique=True, nullable=False)
source_id = Column(Integer, ForeignKey("sources.id"), nullable=False)
@@ -481,6 +499,7 @@ class Reply(Base):
raise TypeError("Cannot manually set file_counter")
filename = kwargs["filename"]
kwargs["file_counter"] = int(filename.split("-")[0])
+ self.state = PENDING
super().__init__(**kwargs)
def __str__(self) -> str:
diff --git a/securedrop_client/logic.py b/securedrop_client/logic.py
index 5de96e3b..c8024786 100644
--- a/securedrop_client/logic.py
+++ b/securedrop_client/logic.py
@@ -61,6 +61,7 @@ from securedrop_client.api_jobs.uploads import (
SendReplyJobTimeoutError,
)
from securedrop_client.crypto import GpgHelper
+from securedrop_client.db import Reply
from securedrop_client.queue import ApiJobQueue
from securedrop_client.sync import ApiSync
from securedrop_client.utils import check_dir_permissions
@@ -378,6 +379,14 @@ class Controller(QObject):
self.session_maker = session_maker
self.session = session_maker()
+ r1 = Reply(filename="1-foo")
+ r2 = Reply(filename="2-foo")
+ print(r1.machine, r1.state)
+ print(r2.machine, r2.state)
+ r1.send_pending()
+ print(r1.machine, r1.state)
+ print(r2.machine, r2.state)
+
# Queue that handles running API job
self.api_job_queue = ApiJobQueue(
self.api, self.session_maker, self.main_queue_thread, self.file_download_queue_thread
diff --git a/securedrop_client/statemachine.py b/securedrop_client/statemachine.py
new file mode 100644
index 00000000..4f5ba023
--- /dev/null
+++ b/securedrop_client/statemachine.py
@@ -0,0 +1,10 @@
+from functools import partial
+
+class StateMachineMixin:
+ def __getattribute__(self, item):
+ try:
+ return super(StateMachineMixin, self).__getattribute__(item)
+ except AttributeError:
+ if item in self.machine.events:
+ return partial(self.machine.events[item].trigger, self)
+ raise
\ No newline at end of file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment