Skip to content

Instantly share code, notes, and snippets.

@paulrobello
Created January 29, 2024 22:02
Show Gist options
  • Save paulrobello/0e5b73bd9c2438436ff7d16eaafa7903 to your computer and use it in GitHub Desktop.
Save paulrobello/0e5b73bd9c2438436ff7d16eaafa7903 to your computer and use it in GitHub Desktop.
Websocket connection manager with message subscription handling
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using UnityEngine;
using NativeWebSocket;
using PAR;
using Sirenix.OdinInspector;
using Sirenix.Utilities;
using Websocket.AWS;
namespace Websocket
{
public class WsConnection : Singleton<WsConnection>
{
private static readonly Regex MessageTypeRegex = new Regex("^{\"msgType\":\"([^\"]+)\"", RegexOptions.Compiled);
[SerializeField] private bool connectOnStart = true;
[SerializeField] private bool autoReconnect = true;
[Min(0)] [SerializeField] private float reconnectDelay = 1f;
[Min(0)] [SerializeField] private int maxReconnectAttempts = 50;
[SerializeField] private string url = "ws://localhost:3000";
[SerializeField] private string apiKey;
public event WebSocketOpenEventHandler OnOpen = delegate { };
public event WebSocketErrorEventHandler OnError = delegate { };
public event WebSocketCloseEventHandler OnClose = delegate { };
private readonly Queue<IWsMessage> _pendingMessage = new();
private int _reconnectAttempts;
private readonly Dictionary<string, Type> _messageTypes = new();
private static readonly List<WsMessageFilter> Subscribers = new();
private WebSocket _websocket;
protected override void Awake()
{
base.Awake();
RegisterMessageTypes();
var headers = new Dictionary<string, string>();
if (apiKey != null)
{
headers.Add("x-api-key", apiKey);
}
else
{
Debug.LogWarning("No API key set for websocket connection");
}
_websocket = new WebSocket(url, headers);
_websocket.OnOpen += HandleOpen;
_websocket.OnError += HandleError;
_websocket.OnClose += HandleClose;
_websocket.OnMessage += HandleMessage;
Subscribe(new WsMessageFilter().SetCallback<WsMessage>(msg => { Debug.Log(msg.Message); }));
SendWebSocketMessage(new WsMessage { MsgType = "Echo", Topic = "echo", Message = "hello world" });
}
private void RegisterMessageTypes()
{
RegisterMessageType(typeof(DockerGetLogsResult));
RegisterMessageType(typeof(CreateBucketResult));
RegisterMessageType(typeof(ListBucketsResult));
RegisterMessageType(typeof(DeleteBucketResult));
RegisterMessageType(typeof(WsMessage));
}
private Type IdentifyMessageType(string jsonMessage)
{
var match = MessageTypeRegex.Match(jsonMessage);
if (!match.Success) return null;
var type = match.Groups[1].Value;
if (_messageTypes.TryGetValue(type, out var messageType)) return messageType;
Debug.LogWarning("Unknown message type " + type + " falling back to WsMessage");
return typeof(WsMessage);
}
public void RegisterMessageType(Type type)
{
string typeName = type.Name;
if (_messageTypes.ContainsKey(typeName))
{
Debug.LogWarning("Message type " + typeName + " already registered");
return;
}
Debug.Log("Registering message type " + typeName);
_messageTypes.Add(typeName, type);
}
public void UnregisterMessageType(string typeName)
{
_messageTypes.Remove(typeName);
}
public void UnregisterMessageType(Type type)
{
UnregisterMessageType(nameof(type));
}
private void HandleMessage(byte[] bytes)
{
try
{
// Reading a plain text message
var message = System.Text.Encoding.UTF8.GetString(bytes);
Debug.Log("Received OnMessage! (" + bytes.Length + " bytes) " + message);
var messageType = IdentifyMessageType(message);
if (messageType == null) return;
var messageObject = JsonUtility.FromJson(message, messageType);
Debug.Log("Got a ws messageType of: " + messageObject);
Subscribers.Where(
subscriber => subscriber.InvokeIfMatch(messageObject as WsMessage) && subscriber.Once
)
.ForEach(Unsubscribe);
}
catch (Exception e)
{
Debug.LogError(e);
}
}
private void HandleClose(WebSocketCloseCode closeCode)
{
Debug.Log("Connection closed!");
if (!autoReconnect) return;
_reconnectAttempts++;
if (_reconnectAttempts >= maxReconnectAttempts)
{
Debug.LogError("Max reconnect attempts reached");
return;
}
Invoke(nameof(Connect), reconnectDelay);
OnClose.Invoke(closeCode);
}
private void HandleOpen()
{
Debug.Log("Connection open!");
_reconnectAttempts = 0;
if (_pendingMessage.Count > 0)
{
Debug.Log("Sending pending messages");
}
while (_pendingMessage.Count > 0)
{
SendWebSocketMessage(_pendingMessage.Dequeue());
}
OnOpen.Invoke();
}
[Button]
public Task Connect()
{
Debug.Log("Connect called");
return _websocket.State == WebSocketState.Open ? null : _websocket.Connect();
}
[Button]
public Task Disconnect()
{
Debug.Log("Disconnect called");
return _websocket.State == WebSocketState.Open ? null : _websocket.Close();
}
private void Start()
{
if (connectOnStart) Connect();
}
private void Update()
{
#if !UNITY_WEBGL || UNITY_EDITOR
_websocket.DispatchMessageQueue();
#endif
}
private void OnDestroy()
{
Debug.Log("Cleaning up websocket");
autoReconnect = false;
if (_websocket == null) return;
_websocket.OnMessage -= HandleMessage;
_websocket.OnClose -= HandleClose;
_websocket.OnOpen -= HandleOpen;
_websocket.OnError -= HandleError;
Disconnect();
}
private void HandleError(string err)
{
Debug.LogError("Error! " + err);
OnError.Invoke(err);
}
public async void SendWebSocketMessage(IWsMessage msg)
{
try
{
if (_websocket?.State == WebSocketState.Open)
{
Debug.Log("Sending message to server");
msg.MsgType ??= msg.GetType().Name;
var jsonMsg = JsonUtility.ToJson(msg);
Debug.Log(jsonMsg);
await _websocket.SendText(jsonMsg);
}
else
{
Debug.Log("Queueing message to server");
_pendingMessage.Enqueue(msg);
}
}
catch (Exception e)
{
Debug.LogError(e.Message);
}
}
public void Subscribe(WsMessageFilter filter)
{
Subscribers.Add(filter);
}
public void Unsubscribe(WsMessageFilter filter)
{
Subscribers.Remove(filter);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment