Skip to content

Instantly share code, notes, and snippets.

@tlhakhan
Created November 15, 2021 15:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tlhakhan/8939d617d8fe659c1095f6c3f4f41df4 to your computer and use it in GitHub Desktop.
Save tlhakhan/8939d617d8fe659c1095f6c3f4f41df4 to your computer and use it in GitHub Desktop.
task.go
package task
import (
"context"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"log"
"net/http"
"os/exec"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
)
var allowedCmds = []string{"ls", "iostat", "mpstat", "vmstat", "tree", "cat", "echo"}
type StatusMessage struct {
Message string `json:"message"`
}
type Task struct {
Command string `json:"command"`
Args []string `json:"args"`
State string `json:"state"`
Priority time.Time `json:"priority"`
Done time.Time `json:"done"`
Output string `json:"output"`
Id int `json:"id"`
Context context.Context `json:"-"`
Cancel context.CancelFunc `json:"-"`
MessageBus *chan string `json:"-"`
sync.Mutex `json:"-"`
}
func (t *Task) SetState(nextState string) {
t.Lock()
defer t.Unlock()
prevState := t.State
t.State = nextState
switch {
case prevState == "queued" && nextState == "running":
log.Println("queued to running")
*t.MessageBus <- "qtr"
case prevState == "running" && nextState == "canceled":
log.Println("running to canceled")
*t.MessageBus <- "rtc"
case prevState == "queued" && nextState == "canceled":
log.Println("queued to canceled")
*t.MessageBus <- "qtc"
case prevState == "running" && nextState == "done":
log.Println("running to done")
*t.MessageBus <- "rtf"
case prevState == "running" && nextState == "errored":
log.Println("running to error")
*t.MessageBus <- "rte"
case prevState == "" && nextState == "queued":
log.Println("start to queued")
*t.MessageBus <- "stq"
}
return
}
func (t *Task) SetId(id int) {
t.Lock()
defer t.Unlock()
t.Id = id
}
func (t *Task) Do() {
t.SetState("running")
out, err := exec.CommandContext(t.Context, t.Command, t.Args...).Output()
if err != nil {
if t.Context.Err() == context.Canceled {
t.SetState("canceled")
} else {
t.SetState("errored")
}
t.Done = time.Now()
return
} else {
t.SetState("done")
t.Output = string(out)
t.Done = time.Now()
return
}
}
type TaskScheduler struct {
Tasks []*Task `json:"tasks"`
Running int64 `json:"running"`
Queued int64 `json:"queued"`
Errored int64 `json:"errored"`
Done int64 `json:"done"`
Canceled int64 `json:"canceled"`
MessageBus chan string `json:"-"`
TaskBus chan *Task `json:"-"`
sync.Mutex `json:"-"`
}
func init() {
sort.Strings(allowedCmds)
}
func newTaskScheduler() *TaskScheduler {
th := TaskScheduler{Running: 0, Queued: 0, Done: 0, Canceled: 0, Errored: 0}
th.MessageBus = make(chan string, 0)
th.TaskBus = make(chan *Task, 0)
go th.messageListener()
go th.startScheduler(int64(2))
return &th
}
func Register(r *mux.Router) {
sr := r.PathPrefix("/task").Subrouter()
log.Printf("allowed command set: %s\n", allowedCmds)
th := newTaskScheduler()
sr.HandleFunc("/cmds", th.cmdListHandler)
sr.HandleFunc("/list", th.listHandler)
sr.HandleFunc("/add", th.addHandler)
sr.HandleFunc("/get", th.getHandler)
sr.HandleFunc("/cancel", th.cancelHandler)
}
func isValidCmd(cmd string) bool {
i := sort.SearchStrings(allowedCmds, cmd)
if i < len(allowedCmds) && allowedCmds[i] == cmd {
return true
}
return false
}
func (t *TaskScheduler) messageListener() {
for {
select {
case msg := <-t.MessageBus:
switch msg {
// qtr = queued -> running
// rtf = running -> finish
// stq = start -> queue
// qtc = queued -> canceled
// rtc = running -> canceled
case "qtr":
atomic.AddInt64(&t.Queued, -1)
atomic.AddInt64(&t.Running, 1)
case "rtf":
atomic.AddInt64(&t.Running, -1)
atomic.AddInt64(&t.Done, 1)
case "stq":
atomic.AddInt64(&t.Queued, 1)
case "qtc":
atomic.AddInt64(&t.Queued, -1)
atomic.AddInt64(&t.Canceled, 1)
case "rtc":
atomic.AddInt64(&t.Running, -1)
atomic.AddInt64(&t.Canceled, 1)
case "rte":
atomic.AddInt64(&t.Running, -1)
atomic.AddInt64(&t.Errored, 1)
}
}
}
}
func (t *TaskScheduler) startScheduler(limit int64) {
// run will ensure that total running will not exceed limit
// scheduler loop
interval := time.Millisecond * 1000
timer := time.NewTimer(interval)
// this loop has to be as non-blocking as possible
// most heavy function is getTask() in this loop
for {
select {
case task := <-t.TaskBus:
go task.SetState("queued")
go task.SetId(len(t.Tasks))
t.Tasks = append(t.Tasks, task)
case <-timer.C:
r := atomic.LoadInt64(&t.Running)
q := atomic.LoadInt64(&t.Queued)
c := atomic.LoadInt64(&t.Canceled)
// if a runnable task exists and running Tasks less than limit
if id, task := t.getTask(); task != nil && r < limit {
log.Printf("running: %d, queued: %d, canceled: %d\n", r, q, c)
log.Printf("running task id: %d\n", id)
go task.Do()
} else if r == limit {
log.Printf("running: %d, queued: %d, canceled: %d\n", r, q, c)
}
timer.Reset(interval)
}
}
}
func (t *TaskScheduler) getTask() (int, *Task) {
for id, t := range t.Tasks {
if t.State == "queued" {
//log.Println(t)
return id, t
}
}
log.Println("no tasks found")
return -1, nil
}
func (t *TaskScheduler) cmdListHandler(w http.ResponseWriter, r *http.Request) {
cmdList := struct {
Commands []string `json:"commands"`
}{
Commands: allowedCmds,
}
j, _ := json.Marshal(cmdList)
fmt.Fprint(w, string(j))
}
func (t *TaskScheduler) listHandler(w http.ResponseWriter, r *http.Request) {
// needs lock
t.Lock()
defer t.Unlock()
j, err := json.Marshal(t)
if err != nil {
fmt.Fprintf(w, "error: %s", err)
return
}
fmt.Fprintf(w, string(j))
}
func (t *TaskScheduler) addHandler(w http.ResponseWriter, r *http.Request) {
var q = r.URL.Query()
if q.Get("cmd") == "" {
log.Println("command wasn't provided")
j, _ := json.Marshal(StatusMessage{Message: "command wasn't provided"})
http.Error(w, string(j), http.StatusBadRequest)
return
}
if !isValidCmd(q.Get("cmd")) {
log.Printf("command %q not in allowed commands\n", q.Get("cmd"))
j, _ := json.Marshal(StatusMessage{Message: "command wasn't in allowed list"})
http.Error(w, string(j), http.StatusBadRequest)
return
}
// log.Printf("command %q is allowed\n", q.Get("cmd"))
// get the task details
var args []string
for k, v := range q {
if k == "cmd" {
continue
}
if v[0] == "" {
args = append(args, k)
} else {
args = append(args, k, v[0])
}
}
task := Task{Command: q.Get("cmd"), Args: args, Priority: time.Now(), MessageBus: &t.MessageBus}
task.Context, task.Cancel = context.WithCancel(context.Background())
// care, has potential to block
t.TaskBus <- &task
j, _ := json.Marshal(StatusMessage{Message: "added command to task list"})
fmt.Fprintf(w, string(j))
return
}
func (t *TaskScheduler) getHandler(w http.ResponseWriter, r *http.Request) {
t.Lock()
defer t.Unlock()
var q = r.URL.Query()
id, err := strconv.Atoi(q.Get("id"))
if err == nil && id < len(t.Tasks) {
task := t.Tasks[id]
j, _ := json.Marshal(task)
fmt.Fprintf(w, string(j))
return
} else if err == nil && id >= len(t.Tasks) {
j, _ := json.Marshal(StatusMessage{Message: "id not in range"})
http.Error(w, string(j), http.StatusBadRequest)
log.Println("param:id not in range")
return
} else {
j, _ := json.Marshal(StatusMessage{Message: "id not valid"})
http.Error(w, string(j), http.StatusBadRequest)
log.Println("param:id not valid")
return
}
}
func (t *TaskScheduler) cancelHandler(w http.ResponseWriter, r *http.Request) {
t.Lock()
defer t.Unlock()
var q = r.URL.Query()
id, err := strconv.Atoi(q.Get("id"))
if err == nil && id < len(t.Tasks) {
task := t.Tasks[id]
if task.State == "running" || task.State == "queued" {
switch s := task.State; s {
case "running":
t.MessageBus <- "rtc"
case "queued":
t.MessageBus <- "qtc"
}
task.State = "canceled"
task.Cancel()
j, _ := json.Marshal(StatusMessage{Message: "task is canceled."})
fmt.Fprintf(w, string(j))
return
} else {
j, _ := json.Marshal(StatusMessage{Message: "task is already completed."})
fmt.Fprintf(w, string(j))
return
}
} else if err == nil && id >= len(t.Tasks) {
log.Println("id not found")
j, _ := json.Marshal(StatusMessage{Message: "task id not found in tasks."})
fmt.Fprint(w, string(j))
return
} else {
log.Println("valid id not given")
j, _ := json.Marshal(StatusMessage{Message: "task id not a nubmer."})
fmt.Fprint(w, string(j))
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment