浏览代码

Use IAuthorizationContext for websocket

Cody Robibero 3 年之前
父节点
当前提交
0765fd568f

+ 1 - 10
Emby.Server.Implementations/HttpServer/WebSocketConnection.cs

@@ -42,17 +42,14 @@ namespace Emby.Server.Implementations.HttpServer
         /// <param name="logger">The logger.</param>
         /// <param name="logger">The logger.</param>
         /// <param name="socket">The socket.</param>
         /// <param name="socket">The socket.</param>
         /// <param name="remoteEndPoint">The remote end point.</param>
         /// <param name="remoteEndPoint">The remote end point.</param>
-        /// <param name="query">The query.</param>
         public WebSocketConnection(
         public WebSocketConnection(
             ILogger<WebSocketConnection> logger,
             ILogger<WebSocketConnection> logger,
             WebSocket socket,
             WebSocket socket,
-            IPAddress? remoteEndPoint,
-            IQueryCollection query)
+            IPAddress? remoteEndPoint)
         {
         {
             _logger = logger;
             _logger = logger;
             _socket = socket;
             _socket = socket;
             RemoteEndPoint = remoteEndPoint;
             RemoteEndPoint = remoteEndPoint;
-            QueryString = query;
 
 
             _jsonOptions = JsonDefaults.Options;
             _jsonOptions = JsonDefaults.Options;
             LastActivityDate = DateTime.Now;
             LastActivityDate = DateTime.Now;
@@ -81,12 +78,6 @@ namespace Emby.Server.Implementations.HttpServer
         /// <inheritdoc />
         /// <inheritdoc />
         public DateTime LastKeepAliveDate { get; set; }
         public DateTime LastKeepAliveDate { get; set; }
 
 
-        /// <summary>
-        /// Gets the query string.
-        /// </summary>
-        /// <value>The query string.</value>
-        public IQueryCollection QueryString { get; }
-
         /// <summary>
         /// <summary>
         /// Gets the state.
         /// Gets the state.
         /// </summary>
         /// </summary>

+ 3 - 3
Emby.Server.Implementations/HttpServer/WebSocketManager.cs

@@ -7,6 +7,7 @@ using System.Collections.Generic;
 using System.Linq;
 using System.Linq;
 using System.Net.WebSockets;
 using System.Net.WebSockets;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
+using MediaBrowser.Common.Extensions;
 using MediaBrowser.Controller.Net;
 using MediaBrowser.Controller.Net;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Logging;
@@ -50,8 +51,7 @@ namespace Emby.Server.Implementations.HttpServer
                 using var connection = new WebSocketConnection(
                 using var connection = new WebSocketConnection(
                     _loggerFactory.CreateLogger<WebSocketConnection>(),
                     _loggerFactory.CreateLogger<WebSocketConnection>(),
                     webSocket,
                     webSocket,
-                    context.Connection.RemoteIpAddress,
-                    context.Request.Query)
+                    context.GetNormalizedRemoteIp())
                 {
                 {
                     OnReceive = ProcessWebSocketMessageReceived
                     OnReceive = ProcessWebSocketMessageReceived
                 };
                 };
@@ -59,7 +59,7 @@ namespace Emby.Server.Implementations.HttpServer
                 var tasks = new Task[_webSocketListeners.Length];
                 var tasks = new Task[_webSocketListeners.Length];
                 for (var i = 0; i < _webSocketListeners.Length; ++i)
                 for (var i = 0; i < _webSocketListeners.Length; ++i)
                 {
                 {
-                    tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection);
+                    tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context);
                 }
                 }
 
 
                 await Task.WhenAll(tasks).ConfigureAwait(false);
                 await Task.WhenAll(tasks).ConfigureAwait(false);

+ 19 - 18
Emby.Server.Implementations/Session/SessionWebSocketListener.cs

@@ -6,6 +6,7 @@ using System.Linq;
 using System.Net.WebSockets;
 using System.Net.WebSockets;
 using System.Threading;
 using System.Threading;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
+using MediaBrowser.Common.Extensions;
 using MediaBrowser.Controller.Net;
 using MediaBrowser.Controller.Net;
 using MediaBrowser.Controller.Session;
 using MediaBrowser.Controller.Session;
 using MediaBrowser.Model.Net;
 using MediaBrowser.Model.Net;
@@ -50,16 +51,10 @@ namespace Emby.Server.Implementations.Session
         /// </summary>
         /// </summary>
         private readonly object _webSocketsLock = new object();
         private readonly object _webSocketsLock = new object();
 
 
-        /// <summary>
-        /// The _session manager.
-        /// </summary>
         private readonly ISessionManager _sessionManager;
         private readonly ISessionManager _sessionManager;
-
-        /// <summary>
-        /// The _logger.
-        /// </summary>
         private readonly ILogger<SessionWebSocketListener> _logger;
         private readonly ILogger<SessionWebSocketListener> _logger;
         private readonly ILoggerFactory _loggerFactory;
         private readonly ILoggerFactory _loggerFactory;
+        private readonly IAuthorizationContext _authorizationContext;
 
 
         /// <summary>
         /// <summary>
         /// The KeepAlive cancellation token.
         /// The KeepAlive cancellation token.
@@ -72,14 +67,17 @@ namespace Emby.Server.Implementations.Session
         /// <param name="logger">The logger.</param>
         /// <param name="logger">The logger.</param>
         /// <param name="sessionManager">The session manager.</param>
         /// <param name="sessionManager">The session manager.</param>
         /// <param name="loggerFactory">The logger factory.</param>
         /// <param name="loggerFactory">The logger factory.</param>
+        /// <param name="authorizationContext">The authorization context.</param>
         public SessionWebSocketListener(
         public SessionWebSocketListener(
             ILogger<SessionWebSocketListener> logger,
             ILogger<SessionWebSocketListener> logger,
             ISessionManager sessionManager,
             ISessionManager sessionManager,
-            ILoggerFactory loggerFactory)
+            ILoggerFactory loggerFactory,
+            IAuthorizationContext authorizationContext)
         {
         {
             _logger = logger;
             _logger = logger;
             _sessionManager = sessionManager;
             _sessionManager = sessionManager;
             _loggerFactory = loggerFactory;
             _loggerFactory = loggerFactory;
+            _authorizationContext = authorizationContext;
         }
         }
 
 
         /// <inheritdoc />
         /// <inheritdoc />
@@ -97,9 +95,9 @@ namespace Emby.Server.Implementations.Session
             => Task.CompletedTask;
             => Task.CompletedTask;
 
 
         /// <inheritdoc />
         /// <inheritdoc />
-        public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection)
+        public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext)
         {
         {
-            var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false);
+            var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false);
             if (session != null)
             if (session != null)
             {
             {
                 EnsureController(session, connection);
                 EnsureController(session, connection);
@@ -107,25 +105,28 @@ namespace Emby.Server.Implementations.Session
             }
             }
             else
             else
             {
             {
-                _logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString);
+                _logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString);
             }
             }
         }
         }
 
 
-        private Task<SessionInfo> GetSession(IQueryCollection queryString, string remoteEndpoint)
+        private async Task<SessionInfo> GetSession(HttpContext httpContext, string remoteEndpoint)
         {
         {
-            if (queryString == null)
+            var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext)
+                .ConfigureAwait(false);
+
+            if (!authorizationInfo.IsAuthenticated)
             {
             {
                 return null;
                 return null;
             }
             }
 
 
-            var token = queryString["api_key"];
-            if (string.IsNullOrWhiteSpace(token))
+            var deviceId = authorizationInfo.DeviceId;
+            if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId))
             {
             {
-                return null;
+                deviceId = queryDeviceId;
             }
             }
 
 
-            var deviceId = queryString["deviceId"];
-            return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint);
+            return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint)
+                .ConfigureAwait(false);
         }
         }
 
 
         private void EnsureController(SessionInfo session, IWebSocketConnection connection)
         private void EnsureController(SessionInfo session, IWebSocketConnection connection)

+ 2 - 1
MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs

@@ -11,6 +11,7 @@ using System.Threading;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
 using MediaBrowser.Model.Net;
 using MediaBrowser.Model.Net;
 using MediaBrowser.Model.Session;
 using MediaBrowser.Model.Session;
+using Microsoft.AspNetCore.Http;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Logging;
 
 
 namespace MediaBrowser.Controller.Net
 namespace MediaBrowser.Controller.Net
@@ -95,7 +96,7 @@ namespace MediaBrowser.Controller.Net
         }
         }
 
 
         /// <inheritdoc />
         /// <inheritdoc />
-        public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask;
+        public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask;
 
 
         /// <summary>
         /// <summary>
         /// Starts sending messages over a web socket.
         /// Starts sending messages over a web socket.

+ 0 - 6
MediaBrowser.Controller/Net/IWebSocketConnection.cs

@@ -29,12 +29,6 @@ namespace MediaBrowser.Controller.Net
         /// <value>The date of last Keeplive received.</value>
         /// <value>The date of last Keeplive received.</value>
         DateTime LastKeepAliveDate { get; set; }
         DateTime LastKeepAliveDate { get; set; }
 
 
-        /// <summary>
-        /// Gets the query string.
-        /// </summary>
-        /// <value>The query string.</value>
-        IQueryCollection QueryString { get; }
-
         /// <summary>
         /// <summary>
         /// Gets or sets the receive action.
         /// Gets or sets the receive action.
         /// </summary>
         /// </summary>

+ 3 - 1
MediaBrowser.Controller/Net/IWebSocketListener.cs

@@ -1,4 +1,5 @@
 using System.Threading.Tasks;
 using System.Threading.Tasks;
+using Microsoft.AspNetCore.Http;
 
 
 namespace MediaBrowser.Controller.Net
 namespace MediaBrowser.Controller.Net
 {
 {
@@ -18,7 +19,8 @@ namespace MediaBrowser.Controller.Net
         /// Processes a new web socket connection.
         /// Processes a new web socket connection.
         /// </summary>
         /// </summary>
         /// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param>
         /// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param>
+        /// <param name="httpContext">The current http context.</param>
         /// <returns>Task.</returns>
         /// <returns>Task.</returns>
-        Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection);
+        Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext);
     }
     }
 }
 }

+ 4 - 4
tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs

@@ -13,7 +13,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
         [Fact]
         [Fact]
         public void DeserializeWebSocketMessage_SingleSegment_Success()
         public void DeserializeWebSocketMessage_SingleSegment_Success()
         {
         {
-            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
             con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
             con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
             Assert.Equal(109, bytesConsumed);
             Assert.Equal(109, bytesConsumed);
@@ -23,7 +23,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
         public void DeserializeWebSocketMessage_MultipleSegments_Success()
         public void DeserializeWebSocketMessage_MultipleSegments_Success()
         {
         {
             const int SplitPos = 64;
             const int SplitPos = 64;
-            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
             var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
             var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
             var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
             var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
@@ -34,7 +34,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
         [Fact]
         [Fact]
         public void DeserializeWebSocketMessage_ValidPartial_Success()
         public void DeserializeWebSocketMessage_ValidPartial_Success()
         {
         {
-            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
             var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
             con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
             con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
             Assert.Equal(109, bytesConsumed);
             Assert.Equal(109, bytesConsumed);
@@ -43,7 +43,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
         [Fact]
         [Fact]
         public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
         public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
         {
         {
-            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
             var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
             var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
             Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
             Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
         }
         }