Last active
December 30, 2022 00:37
-
-
Save yashbonde/a730be5d10c21204d6508feeed5abf4e to your computer and use it in GitHub Desktop.
A simple example of a golang client for Jupyter Kernel.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Jupyter Development Team, NimbleBox.ai | |
# Distributed under the terms of the Modified BSD License. | |
# this script runs a notebook | |
import os | |
from functools import partial | |
# from saturn.converters import path_to_notebook | |
from pprint import pprint as pp | |
import asyncio | |
import aiohttp | |
from time import sleep | |
from uuid import uuid4 | |
from heapq import heappush | |
from json import dumps as json_encode, loads as json_decode | |
from urllib.parse import quote as url_escape | |
async def main(target_fp: str, lang: str = "python", kernel_id: str = ""): | |
sleep(1) | |
base_url = os.getenv('BASE_GATEWAY_HTTP_URL', 'http://localhost:8891') | |
base_ws_url = os.getenv('BASE_GATEWAY_WS_URL', 'ws://localhost:8891') | |
print(" base_url:", base_url) | |
print("base_ws_url:", base_ws_url) | |
print(" lang:", lang) | |
print(" target_fp:", target_fp) | |
print(" kernel_id:", kernel_id) | |
async with aiohttp.ClientSession() as http_session: | |
kernel_id = kernel_id | |
if not kernel_id: | |
async with http_session.post( | |
f'{base_url}/api/kernels', | |
auth = aiohttp.BasicAuth('fakeuser', 'fakepass'), | |
data = json_encode({'name' : lang}) | |
) as resp: | |
text = await resp.text() | |
kernel = json_decode(text) | |
kernel_id = kernel['id'] | |
print(f"Created kernel: {kernel_id}") | |
# now we have a kernel_id, next we need to establish a websocket connection to the kernel | |
async with aiohttp.ClientSession() as ws_session: | |
# vvvvvvvvvv use this example for testing | |
# | |
code_snippets = [ | |
'x = 1 + 2;print(x)', | |
'y = x * x', | |
'import time', | |
'time.sleep(5)', | |
'print(y)' | |
] | |
# in case of Jobs: pull from DB: target_fp and start executing cells | |
# f, c, o = path_to_notebook(target_fp) | |
# code_snippets = [] | |
# for i, cell in enumerate(c): | |
# if cell.content.cell_type == "code": | |
# code = "".join(cell.content.source) | |
# heappush(code_snippets, (i, code)) | |
# code_snippets = [c[1] for c in code_snippets] | |
# in case of serving: pull from another WS connection (talking to dione) | |
async with ws_session.ws_connect( | |
f'{base_ws_url}/api/kernels/{url_escape(kernel_id)}/channels', | |
) as ws: | |
for code in code_snippets: | |
# send a message to the kernel | |
msg_id = uuid4().hex | |
await ws.send_str(json_encode({ | |
'header': { | |
'username': '', | |
'version': '5.0', | |
'session': '', | |
'msg_id': msg_id, | |
'msg_type': 'execute_request' | |
}, | |
'parent_header': {}, | |
'channel': 'shell', | |
'content': { | |
'code': code, | |
'silent': False, | |
'store_history': False, | |
'user_expressions' : {}, | |
'allow_stdin' : False | |
}, | |
'metadata': {}, | |
'buffers': {} | |
})) | |
# print('Sent message:', msg_id) | |
# Look for stream output for the print in the execute | |
while 1: | |
msg = await ws.receive_str() | |
# print(msg) | |
msg = json_decode(msg) | |
msg_type = msg['msg_type'] | |
# if msg_type == "status": | |
# pp(msg) | |
# print('Received message type:', msg_type, msg['content']) | |
# if error then toh definitely break | |
if msg_type == 'error': | |
print('ERROR') | |
pp(msg) | |
break | |
# else check if parent message id is same as the one we sent | |
parent_msg_id = msg['parent_header']['msg_id'] | |
if parent_msg_id == msg_id: | |
if msg_type == 'stream': | |
print('STREAM') | |
pp(msg) | |
break | |
elif msg_type == 'execute_reply': | |
print('EXECUTE REPLY') | |
pp(msg) | |
break | |
def run_notebook(target_fp: str): | |
"""Run a notebook""" | |
# https://stackoverflow.com/questions/46203876/what-are-the-differences-between-the-purposes-of-generator-functions-and-asynchr | |
out = main(target_fp=target_fp) | |
print(out) | |
asyncio.run(out) | |
# task = asyncio.create_task(out) | |
# print(task) | |
# print(out.__dir__()) | |
# asyncio.run(out) | |
# asyncio.get_event_loop().run_until_complete(out) | |
if __name__ == '__main__': | |
run_notebook(target_fp='test/models.ipynb') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is a copy of saturn/runner.py translated to Go | |
package main | |
import ( | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"os" | |
"net/http" | |
"github.com/gorilla/websocket" | |
) | |
func run_notebook() { | |
println("Starting the Saturn Runner...") | |
base_url := os.Getenv("BASE_GATEWAY_HTTP_URL") | |
if base_url == "" { | |
base_url = "http://localhost:8888" | |
} | |
base_ws_url := os.Getenv("BASE_GATEWAY_WS_URL") | |
if base_ws_url == "" { | |
base_ws_url = "ws://localhost:8888" | |
} | |
http_session := &http.Client{} | |
req, err := http.NewRequest("POST", base_url+"/api/kernels", nil) | |
if err != nil { | |
println(err) | |
} | |
req.SetBasicAuth("fakeuser", "fakepass") | |
resp, err := http_session.Do(req) | |
if err != nil { | |
println(err) | |
} | |
body, err := ioutil.ReadAll(resp.Body) | |
if err != nil { | |
// Handle error | |
println(err) | |
} | |
var data map[string]interface{} | |
err = json.Unmarshal(body, &data) | |
if err != nil { | |
println(err) | |
} | |
kernel_id := data["id"].(string) | |
fmt.Println("Kernel ID:", kernel_id) | |
// now we define the websocket session | |
ws_url := base_ws_url + "/api/kernels/" + kernel_id + "/channels" | |
ws, _, err := websocket.DefaultDialer.Dial(ws_url, nil) | |
if err != nil { | |
println(err) | |
} | |
defer ws.Close() | |
// now we send the code snippets | |
code_snippets := []string{ | |
"x = 1 + 2;print(x)", | |
"y = x * x", | |
"import time", | |
"time.sleep(5)", | |
"print(y)", | |
} | |
for i, code := range code_snippets { | |
// send the code | |
msg_id := fmt.Sprintf("%d", i) | |
msg := map[string]interface{}{ | |
"header": map[string]interface{}{ | |
"username": "", | |
"version": "5.0", | |
"session": "", | |
"msg_id": msg_id, | |
"msg_type": "execute_request", | |
}, | |
"parent_header": map[string]interface{}{}, | |
"channel": "shell", | |
"content": map[string]interface{}{ | |
"code": code, | |
"silent": false, | |
"store_history": false, | |
"user_expressions": map[string]interface{}{}, | |
"allow_stdin": false, | |
}, | |
"metadata": map[string]interface{}{}, | |
"buffers": map[string]interface{}{}, | |
} | |
fmt.Println("Sending message:", msg_id) | |
ws.WriteJSON(msg) | |
for { | |
// while 1 | |
_, msg, err := ws.ReadMessage() | |
if err != nil { | |
println(err) | |
} | |
var data map[string]interface{} | |
err = json.Unmarshal(msg, &data) | |
if err != nil { | |
println(err) | |
} | |
// fmt.Println("Received message:", data) | |
msg_type := data["msg_type"].(string) | |
if msg_type == "error" { | |
println("Error:", data["content"]) | |
break | |
} | |
parent_msg_id := data["parent_header"].(map[string]interface{})["msg_id"].(string) | |
if parent_msg_id == msg_id { | |
if msg_type == "stream" { | |
println("Stream:", data["content"]) | |
} else if msg_type == "execute_result" { | |
println("Result:", data["content"]) | |
} | |
break | |
} | |
} | |
} | |
} | |
func main() { | |
// var wg sync.WaitGroup | |
// for i := 0; i < 5; i++ { | |
// wg.Add(1) | |
// go func(i int) { | |
// time.Sleep(time.Duration(i)) | |
// run_notebook() | |
// wg.Done() | |
// }(i) | |
// } | |
// wg.Wait() | |
run_notebook() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment