Skip to content

Instantly share code, notes, and snippets.

@elandau
Created October 30, 2014 23:22
Show Gist options
  • Save elandau/38a28ffab5ad6566f166 to your computer and use it in GitHub Desktop.
Save elandau/38a28ffab5ad6566f166 to your computer and use it in GitHub Desktop.
Rx based state machine
package com.netflix.experiments.rx;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Observable.OnSubscribe;
import rx.Subscriber;
import rx.functions.Action1;
import rx.functions.Action2;
import rx.subjects.PublishSubject;
public class StateMachine<T, E> implements Action1<E> {
private static final Logger LOG = LoggerFactory.getLogger(StateMachine.class);
public static class State<T, E> {
private String name;
private Action2<T, State<T, E>> enter;
private Action2<T, State<T, E>> exit;
private Map<E, State<T, E>> transitions = new HashMap<E, State<T, E>>();
public State(String name) {
this.name = name;
}
public State<T, E> onEnter(Action2<T, State<T, E>> func) {
this.enter = func;
return this;
}
public State<T, E> onExit(Action2<T, State<T, E>> func) {
this.exit = func;
return this;
}
public void enter(T context) {
enter.call(context, this);
}
public void exit(T context) {
exit.call(context, this);
}
public State<T, E> transition(E event, State<T, E> state) {
transitions.put(event, state);
return this;
}
public State<T, E> next(E event) {
return transitions.get(event);
}
public String toString() {
return name;
}
}
private volatile State<T, E> state;
private final T context;
private final PublishSubject<E> events = PublishSubject.create();
protected StateMachine(T context, State<T, E> initial) {
this.state = initial;
this.context = context;
}
public Observable<Void> connect() {
return Observable.create(new OnSubscribe<Void>() {
@Override
public void call(Subscriber<? super Void> sub) {
state.enter(context);
sub.add(events.collect(context, new Action2<T, E>() {
@Override
public void call(T context, E event) {
final State<T, E> next = state.next(event);
if (next != null) {
state.exit(context);
state = next;
next.enter(context);
}
else {
LOG.info("Invalid event : " + event);
}
}
})
.subscribe());
}
});
}
@Override
public void call(E event) {
events.onNext(event);
}
public State<T, E> getState() {
return state;
}
}
package com.netflix.experiments.rx;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.functions.Action2;
import com.netflix.experiments.rx.StateMachine.State;
public class StateMachineTest {
private static final Logger LOG = LoggerFactory.getLogger(StateMachineTest.class);
public static enum Event {
IDLE,
CONNECT,
CONNECTED,
FAILED,
UNQUARANTINE,
REMOVE
}
public static Action2<SomeContext, State<SomeContext, Event>> log(final String text) {
return new Action2<SomeContext, State<SomeContext, Event>>() {
@Override
public void call(SomeContext t1, State<SomeContext, Event> state) {
LOG.info("" + t1 + ":" + state + ":" + text);
}
};
}
public static class SomeContext {
@Override
public String toString() {
return "Foo []";
}
}
public static State<SomeContext, Event> IDLE = new State<SomeContext, Event>("IDLE");
public static State<SomeContext, Event> CONNECTING = new State<SomeContext, Event>("CONNECTING");
public static State<SomeContext, Event> CONNECTED = new State<SomeContext, Event>("CONNECTED");
public static State<SomeContext, Event> QUARANTINED = new State<SomeContext, Event>("QUARANTINED");
public static State<SomeContext, Event> REMOVED = new State<SomeContext, Event>("REMOVED");
@BeforeClass
public static void beforeClass() {
IDLE
.onEnter(log("enter"))
.onExit(log("exit"))
.transition(Event.CONNECT, CONNECTING)
.transition(Event.REMOVE, REMOVED);
CONNECTING
.onEnter(log("enter"))
.onExit(log("exit"))
.transition(Event.CONNECTED, CONNECTED)
.transition(Event.FAILED, QUARANTINED)
.transition(Event.REMOVE, REMOVED);
CONNECTED
.onEnter(log("enter"))
.onExit(log("exit"))
.transition(Event.IDLE, IDLE)
.transition(Event.FAILED, QUARANTINED)
.transition(Event.REMOVE, REMOVED);
QUARANTINED
.onEnter(log("enter"))
.onExit(log("exit"))
.transition(Event.IDLE, IDLE)
.transition(Event.REMOVE, REMOVED);
REMOVED
.onEnter(log("enter"))
.onExit(log("exit"))
.transition(Event.CONNECT, CONNECTING);
}
@Test
public void test() {
StateMachine<SomeContext, Event> sm = new StateMachine<SomeContext, Event>(new SomeContext(), IDLE);
sm.connect().subscribe();
sm.call(Event.CONNECT);
sm.call(Event.CONNECTED);
sm.call(Event.FAILED);
sm.call(Event.REMOVE);
}
}
@eskim
Copy link

eskim commented Feb 17, 2016

awesome, this is what I was looking for.

@EliyahuStern
Copy link

Converted to RXJava2 and Java8 (with log4j):

import io.reactivex.Observable;
import io.reactivex.functions.BiConsumer;
import io.reactivex.functions.Consumer;
import io.reactivex.subjects.PublishSubject;
import org.apache.log4j.Logger;

import java.util.HashMap;
import java.util.Map;

public class RXStateMachine<T, E> implements Consumer<E> {
    private static final Logger LOG = Logger.getLogger(RXStateMachine.class);

    public static class State<T, E> {
        private String name;
        private BiConsumer<T, State<T, E>> enter;
        private BiConsumer<T, State<T, E>> exit;
        private Map<E, State<T, E>> transitions = new HashMap<>();

        public State(String name) {
            this.name = name;
        }

        public State<T, E> onEnter(BiConsumer<T, State<T, E>> func) {
            this.enter = func;
            return this;
        }

        public State<T, E> onExit(BiConsumer<T, State<T, E>> func) {
            this.exit = func;
            return this;
        }

        public void enter(T context) {
            try {
                enter.accept(context, this);
            } catch (Exception e) {
                LOG.warn(e);
            }
        }

        public void exit(T context) {
            try {
                exit.accept(context, this);
            } catch (Exception e) {
                LOG.warn(e);
            }
        }

        public State<T, E> transition(E event, State<T, E> state) {
            transitions.put(event, state);
            return this;
        }

        public State<T, E> next(E event) {
            return transitions.get(event);
        }

        public String toString() {
            return name;
        }
    }

    private volatile State<T, E> state;
    private final T context;
    private final PublishSubject<E> events = PublishSubject.create();

    protected RXStateMachine(T context, State<T, E> initial) {
        this.state = initial;
        this.context = context;
    }

    public Observable<Void> connect() {
        return Observable.create(sub -> {
            state.enter(context);

            sub.setDisposable(events.collect(() -> context, (context, event) -> {
                final State<T, E> next = state.next(event);
                if (next != null) {
                    state.exit(context);
                    state = next;
                    next.enter(context);
                }
                else {
                    LOG.info("Invalid event : " + event);
                }
            }).subscribe());
        });
    }

    @Override
    public void accept(E event) {
        events.onNext(event);
    }

    public State<T, E> getState() {
        return state;
    }

}

@antonpus
Copy link

thanks a lot, very impressive.

@aliab
Copy link

aliab commented Mar 7, 2020

Converted to RXJava2 and Kotlin:

public class State<T, E>(val name: String) {
    private var enter: BiConsumer<T, State<T, E>>? = null
    private var exit: BiConsumer<T, State<T, E>>? = null
    private var transitions = mutableMapOf<E, State<T, E>>()

    public fun onEnter(func: BiConsumer<T, State<T, E>>): State<T, E> {
        this.enter = func
        return this
    }

    public fun onExit(func: BiConsumer<T, State<T, E>>): State<T, E> {
        this.exit = func
        return this
    }

    public fun enter(context: T) {
        enter?.accept(context, this)
    }

    public fun exit(context: T) {
        exit?.accept(context, this)
    }

    public fun transition(event: E, state: State<T, E>): State<T, E> {
        transitions[event] = state
        return this
    }

    public fun next(event: E): State<T, E>? {
        return transitions[event]
    }

    override fun toString(): String {
        return name
    }
}

public class RxStateMachine<T, E>(val context: T, private val initialState: State<T, E>) :
    Consumer<E> {

    private val events = PublishSubject.create<E>()
    var state: State<T, E> = initialState

    fun connect(): Observable<Unit> {
        return Observable.create {
            state.enter(context)
            it.setDisposable(events.collect({
                context
            }, { context: T, event: E ->
                val next = state.next(event)
                next?.let {
                    state.exit(context)
                    state = next
                    next.enter(context)
                } ?: run {
                    Log.e("STATE", "Invalid Event: $event")
                }
            }).subscribe())
        }
    }

    override fun accept(t: E) {
        events.onNext(t)
    }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment