浏览代码

Add some websocket manager boilerplate

Claus Vium 6 年之前
父节点
当前提交
6bdb5debd2

+ 4 - 1
Emby.Server.Implementations/Middleware/WebSocketMiddleware.cs

@@ -25,7 +25,10 @@ namespace Emby.Server.Implementations.Middleware
             if (httpContext.WebSockets.IsWebSocketRequest)
             {
                 var webSocketContext = await httpContext.WebSockets.AcceptWebSocketAsync(null).ConfigureAwait(false);
-                _webSocketManager.AddSocket(webSocketContext);
+                if (webSocketContext != null)
+                {
+                    await _webSocketManager.OnWebSocketConnected(webSocketContext);
+                }
             }
             else
             {

+ 10 - 0
Emby.Server.Implementations/WebSockets/WebSocketHandler.cs

@@ -0,0 +1,10 @@
+using System.Threading.Tasks;
+using MediaBrowser.Model.Net;
+
+namespace Emby.Server.Implementations.WebSockets
+{
+    public interface IWebSocketHandler
+    {
+        Task ProcessMessage(WebSocketMessage<object> message);
+    }
+}

+ 84 - 6
Emby.Server.Implementations/WebSockets/WebSocketManager.cs

@@ -1,22 +1,100 @@
 using System;
 using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
 using System.Net.WebSockets;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using MediaBrowser.Controller.Net;
+using MediaBrowser.Model.Net;
+using MediaBrowser.Model.Serialization;
+using Microsoft.Extensions.Logging;
+using UtfUnknown;
 
 namespace Emby.Server.Implementations.WebSockets
 {
     public class WebSocketManager
     {
-        private readonly ConcurrentDictionary<Guid, WebSocket> _activeWebSockets;
+        private readonly IWebSocketHandler[] _webSocketHandlers;
+        private readonly IJsonSerializer _jsonSerializer;
+        private readonly ILogger<WebSocketManager> _logger;
+        private const int BufferSize = 4096;
 
-        public WebSocketManager()
+        public WebSocketManager(IWebSocketHandler[] webSocketHandlers, IJsonSerializer jsonSerializer, ILogger<WebSocketManager> logger)
         {
-            _activeWebSockets = new ConcurrentDictionary<Guid, WebSocket>();
+            _webSocketHandlers = webSocketHandlers;
+            _jsonSerializer = jsonSerializer;
+            _logger = logger;
         }
 
-        public void AddSocket(WebSocket webSocket)
+        public async Task OnWebSocketConnected(WebSocket webSocket)
         {
-            var guid = Guid.NewGuid();
-            _activeWebSockets.TryAdd(guid, webSocket);
+            var taskCompletionSource = new TaskCompletionSource<bool>();
+            var cancellationToken = new CancellationTokenSource().Token;
+            WebSocketReceiveResult result;
+            var message = new List<byte>();
+
+            do
+            {
+                var buffer = WebSocket.CreateServerBuffer(BufferSize);
+                result = await webSocket.ReceiveAsync(buffer, cancellationToken);
+                message.AddRange(buffer.Array.Take(result.Count));
+
+                if (result.EndOfMessage)
+                {
+                    await ProcessMessage(message.ToArray(), taskCompletionSource);
+                    message.Clear();
+                }
+            } while (!taskCompletionSource.Task.IsCompleted &&
+                     webSocket.State == WebSocketState.Open &&
+                     result.MessageType != WebSocketMessageType.Close);
+
+            if (webSocket.State == WebSocketState.Open)
+            {
+                await webSocket.CloseAsync(result.CloseStatus ?? WebSocketCloseStatus.NormalClosure,
+                    result.CloseStatusDescription, cancellationToken);
+            }
+        }
+
+        public async Task ProcessMessage(byte[] messageBytes, TaskCompletionSource<bool> taskCompletionSource)
+        {
+            var charset = CharsetDetector.DetectFromBytes(messageBytes).Detected?.EncodingName;
+            var message = string.Equals(charset, "utf-8", StringComparison.OrdinalIgnoreCase)
+                ? Encoding.UTF8.GetString(messageBytes, 0, messageBytes.Length)
+                : Encoding.ASCII.GetString(messageBytes, 0, messageBytes.Length);
+
+            // All messages are expected to be json
+            if (!message.StartsWith("{", StringComparison.OrdinalIgnoreCase))
+            {
+                _logger.LogDebug("Received web socket message that is not a json structure: {Message}", message);
+                return;
+            }
+
+            try
+            {
+                var info = _jsonSerializer.DeserializeFromString<WebSocketMessage<object>>(message);
+
+                _logger.LogDebug("Websocket message received: {0}", info.MessageType);
+
+                var tasks = _webSocketHandlers.Select(handler => Task.Run(() =>
+                {
+                    try
+                    {
+                        handler.ProcessMessage(info).ConfigureAwait(false);
+                    }
+                    catch (Exception ex)
+                    {
+                        _logger.LogError(ex, "{0} failed processing WebSocket message {1}", handler.GetType().Name, info.MessageType ?? string.Empty);
+                    }
+                }));
+
+                await Task.WhenAll(tasks);
+            }
+            catch (Exception ex)
+            {
+                _logger.LogError(ex, "Error processing web socket message");
+            }
         }
     }
 }