Преглед изворни кода

connect to socket with access token

Luke Pulverenti пре 10 година
родитељ
комит
ccb2dda358

+ 5 - 0
MediaBrowser.Controller/Net/IHttpServer.cs

@@ -44,6 +44,11 @@ namespace MediaBrowser.Controller.Net
         /// </summary>
         event EventHandler<WebSocketConnectEventArgs> WebSocketConnected;
 
+        /// <summary>
+        /// Occurs when [web socket connecting].
+        /// </summary>
+        event EventHandler<WebSocketConnectingEventArgs> WebSocketConnecting;
+
         /// <summary>
         /// Inits this instance.
         /// </summary>

+ 6 - 1
MediaBrowser.Controller/Net/IServerManager.cs

@@ -1,4 +1,4 @@
-using MediaBrowser.Common.Net;
+using MediaBrowser.Model.Events;
 using System;
 using System.Collections.Generic;
 using System.Threading;
@@ -58,5 +58,10 @@ namespace MediaBrowser.Controller.Net
         /// </summary>
         /// <value>The web socket connections.</value>
         IEnumerable<IWebSocketConnection> WebSocketConnections { get; }
+
+        /// <summary>
+        /// Occurs when [web socket connected].
+        /// </summary>
+        event EventHandler<GenericEventArgs<IWebSocketConnection>> WebSocketConnected;
     }
 }

+ 13 - 1
MediaBrowser.Controller/Net/IWebSocketConnection.cs

@@ -1,5 +1,6 @@
 using MediaBrowser.Model.Net;
 using System;
+using System.Collections.Specialized;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -11,7 +12,7 @@ namespace MediaBrowser.Controller.Net
         /// Occurs when [closed].
         /// </summary>
         event EventHandler<EventArgs> Closed;
-        
+
         /// <summary>
         /// Gets the id.
         /// </summary>
@@ -24,6 +25,17 @@ namespace MediaBrowser.Controller.Net
         /// <value>The last activity date.</value>
         DateTime LastActivityDate { get; }
 
+        /// <summary>
+        /// Gets or sets the URL.
+        /// </summary>
+        /// <value>The URL.</value>
+        string Url { get; set; }
+        /// <summary>
+        /// Gets or sets the query string.
+        /// </summary>
+        /// <value>The query string.</value>
+        NameValueCollection QueryString { get; set; }
+
         /// <summary>
         /// Gets or sets the receive action.
         /// </summary>

+ 42 - 0
MediaBrowser.Controller/Net/WebSocketConnectEventArgs.cs

@@ -1,4 +1,5 @@
 using System;
+using System.Collections.Specialized;
 
 namespace MediaBrowser.Controller.Net
 {
@@ -7,6 +8,16 @@ namespace MediaBrowser.Controller.Net
     /// </summary>
     public class WebSocketConnectEventArgs : EventArgs
     {
+        /// <summary>
+        /// Gets or sets the URL.
+        /// </summary>
+        /// <value>The URL.</value>
+        public string Url { get; set; }
+        /// <summary>
+        /// Gets or sets the query string.
+        /// </summary>
+        /// <value>The query string.</value>
+        public NameValueCollection QueryString { get; set; }
         /// <summary>
         /// Gets or sets the web socket.
         /// </summary>
@@ -18,4 +29,35 @@ namespace MediaBrowser.Controller.Net
         /// <value>The endpoint.</value>
         public string Endpoint { get; set; }
     }
+
+    public class WebSocketConnectingEventArgs : EventArgs
+    {
+        /// <summary>
+        /// Gets or sets the URL.
+        /// </summary>
+        /// <value>The URL.</value>
+        public string Url { get; set; }
+        /// <summary>
+        /// Gets or sets the endpoint.
+        /// </summary>
+        /// <value>The endpoint.</value>
+        public string Endpoint { get; set; }
+        /// <summary>
+        /// Gets or sets the query string.
+        /// </summary>
+        /// <value>The query string.</value>
+        public NameValueCollection QueryString { get; set; }
+        /// <summary>
+        /// Gets or sets a value indicating whether [allow connection].
+        /// </summary>
+        /// <value><c>true</c> if [allow connection]; otherwise, <c>false</c>.</value>
+        public bool AllowConnection { get; set; }
+
+        public WebSocketConnectingEventArgs()
+        {
+            QueryString = new NameValueCollection();
+            AllowConnection = true;
+        }
+    }
+
 }

+ 7 - 0
MediaBrowser.Controller/Session/ISessionManager.cs

@@ -278,6 +278,13 @@ namespace MediaBrowser.Controller.Session
         /// <returns>SessionInfo.</returns>
         SessionInfo GetSession(string deviceId, string client, string version);
 
+        /// <summary>
+        /// Gets the session by authentication token.
+        /// </summary>
+        /// <param name="token">The token.</param>
+        /// <returns>SessionInfo.</returns>
+        SessionInfo GetSessionByAuthenticationToken(string token);
+
         /// <summary>
         /// Logouts the specified access token.
         /// </summary>

+ 12 - 2
MediaBrowser.Server.Implementations/HttpServer/HttpListenerHost.cs

@@ -35,6 +35,7 @@ namespace MediaBrowser.Server.Implementations.HttpServer
         private readonly ContainerAdapter _containerAdapter;
 
         public event EventHandler<WebSocketConnectEventArgs> WebSocketConnected;
+        public event EventHandler<WebSocketConnectingEventArgs> WebSocketConnecting;
 
         private readonly List<string> _localEndpoints = new List<string>();
 
@@ -196,7 +197,8 @@ namespace MediaBrowser.Server.Implementations.HttpServer
 
             _listener = GetListener();
 
-            _listener.WebSocketHandler = WebSocketHandler;
+            _listener.WebSocketConnected = OnWebSocketConnected;
+            _listener.WebSocketConnecting = OnWebSocketConnecting;
             _listener.ErrorHandler = ErrorHandler;
             _listener.RequestHandler = RequestHandler;
 
@@ -208,7 +210,15 @@ namespace MediaBrowser.Server.Implementations.HttpServer
             return new WebSocketSharpListener(_logger, OnRequestReceived, CertificatePath);
         }
 
-        private void WebSocketHandler(WebSocketConnectEventArgs args)
+        private void OnWebSocketConnecting(WebSocketConnectingEventArgs args)
+        {
+            if (WebSocketConnecting != null)
+            {
+                WebSocketConnecting(this, args);
+            }
+        }
+
+        private void OnWebSocketConnected(WebSocketConnectEventArgs args)
         {
             if (WebSocketConnected != null)
             {

+ 7 - 1
MediaBrowser.Server.Implementations/HttpServer/IHttpListener.cs

@@ -24,8 +24,14 @@ namespace MediaBrowser.Server.Implementations.HttpServer
         /// Gets or sets the web socket handler.
         /// </summary>
         /// <value>The web socket handler.</value>
-        Action<WebSocketConnectEventArgs> WebSocketHandler { get; set; }
+        Action<WebSocketConnectEventArgs> WebSocketConnected { get; set; }
 
+        /// <summary>
+        /// Gets or sets the web socket connecting.
+        /// </summary>
+        /// <value>The web socket connecting.</value>
+        Action<WebSocketConnectingEventArgs> WebSocketConnecting { get; set; }
+        
         /// <summary>
         /// Starts this instance.
         /// </summary>

+ 4 - 0
MediaBrowser.Server.Implementations/HttpServer/Security/SessionContext.cs

@@ -23,6 +23,10 @@ namespace MediaBrowser.Server.Implementations.HttpServer.Security
         {
             var authorization = _authContext.GetAuthorizationInfo(requestContext);
 
+            if (!string.IsNullOrWhiteSpace(authorization.Token))
+            {
+                return _sessionManager.GetSessionByAuthenticationToken(authorization.Token);
+            }
             return _sessionManager.GetSession(authorization.DeviceId, authorization.Client, authorization.Version);
         }
 

+ 41 - 10
MediaBrowser.Server.Implementations/HttpServer/SocketSharp/WebSocketSharpListener.cs

@@ -1,4 +1,5 @@
-using MediaBrowser.Controller.Net;
+using System.Collections.Specialized;
+using MediaBrowser.Controller.Net;
 using MediaBrowser.Model.Logging;
 using MediaBrowser.Server.Implementations.Logging;
 using ServiceStack;
@@ -18,9 +19,9 @@ namespace MediaBrowser.Server.Implementations.HttpServer.SocketSharp
 
         private readonly ILogger _logger;
         private readonly Action<string> _endpointListener;
-        private readonly string  _certificatePath ;
+        private readonly string _certificatePath;
 
-        public WebSocketSharpListener(ILogger logger, Action<string> endpointListener, 
+        public WebSocketSharpListener(ILogger logger, Action<string> endpointListener,
             string certificatePath)
         {
             _logger = logger;
@@ -32,7 +33,9 @@ namespace MediaBrowser.Server.Implementations.HttpServer.SocketSharp
 
         public Func<IHttpRequest, Uri, Task> RequestHandler { get; set; }
 
-        public Action<WebSocketConnectEventArgs> WebSocketHandler { get; set; }
+        public Action<WebSocketConnectingEventArgs> WebSocketConnecting { get; set; }
+
+        public Action<WebSocketConnectEventArgs> WebSocketConnected { get; set; }
 
         public void Start(IEnumerable<string> urlPrefixes)
         {
@@ -115,15 +118,43 @@ namespace MediaBrowser.Server.Implementations.HttpServer.SocketSharp
         {
             try
             {
-                var webSocketContext = ctx.AcceptWebSocket(null);
+                var endpoint = ctx.Request.RemoteEndPoint.ToString();
+                var url = ctx.Request.RawUrl;
+                var queryString = new NameValueCollection(ctx.Request.QueryString);
+
+                var connectingArgs = new WebSocketConnectingEventArgs
+                {
+                    Url = url,
+                    QueryString = queryString,
+                    Endpoint = endpoint
+                };
+
+                if (WebSocketConnecting != null)
+                {
+                    WebSocketConnecting(connectingArgs);
+                }
 
-                if (WebSocketHandler != null)
+                if (connectingArgs.AllowConnection)
                 {
-                    WebSocketHandler(new WebSocketConnectEventArgs
+                    _logger.Debug("Web socket connection allowed");
+
+                    var webSocketContext = ctx.AcceptWebSocket(null);
+
+                    if (WebSocketConnected != null)
                     {
-                        WebSocket = new SharpWebSocket(webSocketContext.WebSocket, _logger),
-                        Endpoint = ctx.Request.RemoteEndPoint.ToString()
-                    });
+                        WebSocketConnected(new WebSocketConnectEventArgs
+                        {
+                            Url = url,
+                            QueryString = queryString,
+                            WebSocket = new SharpWebSocket(webSocketContext.WebSocket, _logger),
+                            Endpoint = endpoint
+                        });
+                    }
+                }
+                else
+                {
+                    _logger.Warn("Web socket connection not allowed");
+                    ctx.Response.Close();
                 }
             }
             catch (Exception ex)

+ 14 - 2
MediaBrowser.Server.Implementations/ServerManager/ServerManager.cs

@@ -1,11 +1,14 @@
-using MediaBrowser.Controller;
+using MediaBrowser.Common.Events;
+using MediaBrowser.Controller;
 using MediaBrowser.Controller.Configuration;
 using MediaBrowser.Controller.Net;
+using MediaBrowser.Model.Events;
 using MediaBrowser.Model.Logging;
 using MediaBrowser.Model.Net;
 using MediaBrowser.Model.Serialization;
 using System;
 using System.Collections.Generic;
+using System.Collections.Specialized;
 using System.Linq;
 using System.Net.Sockets;
 using System.Threading;
@@ -44,6 +47,8 @@ namespace MediaBrowser.Server.Implementations.ServerManager
             get { return _webSocketConnections; }
         }
 
+        public event EventHandler<GenericEventArgs<IWebSocketConnection>> WebSocketConnected;
+
         /// <summary>
         /// The _logger
         /// </summary>
@@ -140,10 +145,17 @@ namespace MediaBrowser.Server.Implementations.ServerManager
         {
             var connection = new WebSocketConnection(e.WebSocket, e.Endpoint, _jsonSerializer, _logger)
             {
-                OnReceive = ProcessWebSocketMessageReceived
+                OnReceive = ProcessWebSocketMessageReceived,
+                Url = e.Url,
+                QueryString = new NameValueCollection(e.QueryString)
             };
 
             _webSocketConnections.Add(connection);
+
+            if (WebSocketConnected != null)
+            {
+                EventHelper.FireEventIfNotNull(WebSocketConnected, this, new GenericEventArgs<IWebSocketConnection> (connection), _logger);
+            }
         }
 
         /// <summary>

+ 13 - 2
MediaBrowser.Server.Implementations/ServerManager/WebSocketConnection.cs

@@ -1,10 +1,10 @@
-using System.Text;
-using MediaBrowser.Common.Events;
+using MediaBrowser.Common.Events;
 using MediaBrowser.Controller.Net;
 using MediaBrowser.Model.Logging;
 using MediaBrowser.Model.Net;
 using MediaBrowser.Model.Serialization;
 using System;
+using System.Collections.Specialized;
 using System.IO;
 using System.Threading;
 using System.Threading.Tasks;
@@ -66,6 +66,17 @@ namespace MediaBrowser.Server.Implementations.ServerManager
         /// <value>The id.</value>
         public Guid Id { get; private set; }
 
+        /// <summary>
+        /// Gets or sets the URL.
+        /// </summary>
+        /// <value>The URL.</value>
+        public string Url { get; set; }
+        /// <summary>
+        /// Gets or sets the query string.
+        /// </summary>
+        /// <value>The query string.</value>
+        public NameValueCollection QueryString { get; set; }
+        
         /// <summary>
         /// Initializes a new instance of the <see cref="WebSocketConnection" /> class.
         /// </summary>

+ 20 - 0
MediaBrowser.Server.Implementations/Session/SessionManager.cs

@@ -1640,6 +1640,26 @@ namespace MediaBrowser.Server.Implementations.Session
                 string.Equals(i.Client, client));
         }
 
+        public SessionInfo GetSessionByAuthenticationToken(string token)
+        {
+            var result = _authRepo.Get(new AuthenticationInfoQuery
+            {
+                AccessToken = token
+            });
+
+            if (result.Items.Length == 0)
+            {
+                return null;
+            }
+
+            var info = result.Items[0];
+
+            // TODO: Make Token part of SessionInfo and get result that way
+            // This can't be done until all apps are updated to new authentication.
+            return Sessions.FirstOrDefault(i => string.Equals(i.DeviceId, info.DeviceId) &&
+                string.Equals(i.Client, info.AppName));
+        }
+
         public Task SendMessageToUserSessions<T>(string userId, string name, T data,
             CancellationToken cancellationToken)
         {

+ 63 - 2
MediaBrowser.Server.Implementations/Session/SessionWebSocketListener.cs

@@ -1,9 +1,12 @@
 using MediaBrowser.Controller.Net;
+using MediaBrowser.Controller.Security;
 using MediaBrowser.Controller.Session;
+using MediaBrowser.Model.Events;
 using MediaBrowser.Model.Logging;
 using MediaBrowser.Model.Serialization;
 using MediaBrowser.Model.Session;
 using System;
+using System.Collections.Specialized;
 using System.Globalization;
 using System.Linq;
 using System.Threading.Tasks;
@@ -13,7 +16,7 @@ namespace MediaBrowser.Server.Implementations.Session
     /// <summary>
     /// Class SessionWebSocketListener
     /// </summary>
-    public class SessionWebSocketListener : IWebSocketListener
+    public class SessionWebSocketListener : IWebSocketListener, IDisposable
     {
         /// <summary>
         /// The _true task result
@@ -35,17 +38,75 @@ namespace MediaBrowser.Server.Implementations.Session
         /// </summary>
         private readonly IJsonSerializer _json;
 
+        private readonly IHttpServer _httpServer;
+        private readonly IServerManager _serverManager;
+
+
         /// <summary>
         /// Initializes a new instance of the <see cref="SessionWebSocketListener" /> class.
         /// </summary>
         /// <param name="sessionManager">The session manager.</param>
         /// <param name="logManager">The log manager.</param>
         /// <param name="json">The json.</param>
-        public SessionWebSocketListener(ISessionManager sessionManager, ILogManager logManager, IJsonSerializer json)
+        /// <param name="httpServer">The HTTP server.</param>
+        /// <param name="serverManager">The server manager.</param>
+        public SessionWebSocketListener(ISessionManager sessionManager, ILogManager logManager, IJsonSerializer json, IHttpServer httpServer, IServerManager serverManager)
         {
             _sessionManager = sessionManager;
             _logger = logManager.GetLogger(GetType().Name);
             _json = json;
+            _httpServer = httpServer;
+            _serverManager = serverManager;
+            httpServer.WebSocketConnecting += _httpServer_WebSocketConnecting;
+            serverManager.WebSocketConnected += _serverManager_WebSocketConnected;
+        }
+
+        void _serverManager_WebSocketConnected(object sender, GenericEventArgs<IWebSocketConnection> e)
+        {
+            var session = GetSession(e.Argument.QueryString);
+
+            if (session != null)
+            {
+                var controller = session.SessionController as WebSocketController;
+
+                if (controller == null)
+                {
+                    controller = new WebSocketController(session, _logger, _sessionManager);
+                }
+
+                controller.AddWebSocket(e.Argument);
+
+                session.SessionController = controller;
+            }
+            else
+            {
+                _logger.Warn("Unable to determine session based on url: {0}", e.Argument.Url);
+            }
+        }
+
+        void _httpServer_WebSocketConnecting(object sender, WebSocketConnectingEventArgs e)
+        {
+            if (e.QueryString.AllKeys.Contains("api_key", StringComparer.OrdinalIgnoreCase))
+            {
+                var session = GetSession(e.QueryString);
+
+                if (session == null)
+                {
+                    e.AllowConnection = false;
+                }
+            }
+        }
+
+        private SessionInfo GetSession(NameValueCollection queryString)
+        {
+            var token = queryString["api_key"];
+            return _sessionManager.GetSessionByAuthenticationToken(token);
+        }
+
+        public void Dispose()
+        {
+            _httpServer.WebSocketConnecting -= _httpServer_WebSocketConnecting;
+            _serverManager.WebSocketConnected -= _serverManager_WebSocketConnected;
         }
 
         /// <summary>