2
0
Эх сурвалжийг харах

Merge pull request #847 from Bond-009/async

Make websockets code async
Vasily 6 жил өмнө
parent
commit
8ef41020d9

+ 4 - 3
Jellyfin.Server/SocketSharp/SharpWebSocket.cs

@@ -44,10 +44,11 @@ namespace Jellyfin.Server.SocketSharp
             socket.OnMessage += OnSocketMessage;
             socket.OnClose += OnSocketClose;
             socket.OnError += OnSocketError;
-
-            WebSocket.ConnectAsServer();
         }
 
+        public Task ConnectAsServerAsync()
+            => WebSocket.ConnectAsServer();
+
         public Task StartReceive()
         {
             return _taskCompletionSource.Task;
@@ -133,7 +134,7 @@ namespace Jellyfin.Server.SocketSharp
 
                 _cancellationTokenSource.Cancel();
 
-                WebSocket.Close();
+                WebSocket.CloseAsync().GetAwaiter().GetResult();
             }
 
             _disposed = true;

+ 8 - 18
Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs

@@ -69,7 +69,7 @@ namespace Jellyfin.Server.SocketSharp
         {
             if (_listener == null)
             {
-                _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _networkManager, _streamHelper, _fileSystem, _environment);
+                _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _streamHelper, _fileSystem, _environment);
             }
 
             _listener.EnableDualMode = _enableDualMode;
@@ -79,22 +79,14 @@ namespace Jellyfin.Server.SocketSharp
                 _listener.LoadCert(_certificate);
             }
 
-            foreach (var prefix in urlPrefixes)
-            {
-                _logger.LogInformation("Adding HttpListener prefix " + prefix);
-                _listener.Prefixes.Add(prefix);
-            }
+            _logger.LogInformation("Adding HttpListener prefixes {Prefixes}", urlPrefixes);
+            _listener.Prefixes.AddRange(urlPrefixes);
 
-            _listener.OnContext = ProcessContext;
+            _listener.OnContext = async c => await InitTask(c, _disposeCancellationToken).ConfigureAwait(false);
 
             _listener.Start();
         }
 
-        private void ProcessContext(HttpListenerContext context)
-        {
-            _ = Task.Run(async () => await InitTask(context, _disposeCancellationToken).ConfigureAwait(false));
-        }
-
         private static void LogRequest(ILogger logger, HttpListenerRequest request)
         {
             var url = request.Url.ToString();
@@ -151,10 +143,7 @@ namespace Jellyfin.Server.SocketSharp
                     Endpoint = endpoint
                 };
 
-                if (WebSocketConnecting != null)
-                {
-                    WebSocketConnecting(connectingArgs);
-                }
+                WebSocketConnecting?.Invoke(connectingArgs);
 
                 if (connectingArgs.AllowConnection)
                 {
@@ -165,6 +154,7 @@ namespace Jellyfin.Server.SocketSharp
                     if (WebSocketConnected != null)
                     {
                         var socket = new SharpWebSocket(webSocketContext.WebSocket, _logger);
+                        await socket.ConnectAsServerAsync().ConfigureAwait(false);
 
                         WebSocketConnected(new WebSocketConnectEventArgs
                         {
@@ -174,7 +164,7 @@ namespace Jellyfin.Server.SocketSharp
                             Endpoint = endpoint
                         });
 
-                        await ReceiveWebSocket(ctx, socket).ConfigureAwait(false);
+                        await ReceiveWebSocketAsync(ctx, socket).ConfigureAwait(false);
                     }
                 }
                 else
@@ -192,7 +182,7 @@ namespace Jellyfin.Server.SocketSharp
             }
         }
 
-        private async Task ReceiveWebSocket(HttpListenerContext ctx, SharpWebSocket socket)
+        private async Task ReceiveWebSocketAsync(HttpListenerContext ctx, SharpWebSocket socket)
         {
             try
             {

+ 28 - 41
SocketHttpListener/Ext.cs

@@ -74,18 +74,20 @@ namespace SocketHttpListener
             }
         }
 
-        private static byte[] readBytes(this Stream stream, byte[] buffer, int offset, int length)
+        private static async Task<byte[]> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length)
         {
-            var len = stream.Read(buffer, offset, length);
+            var len = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false);
             if (len < 1)
                 return buffer.SubArray(0, offset);
 
             var tmp = 0;
             while (len < length)
             {
-                tmp = stream.Read(buffer, offset + len, length - len);
+                tmp = await stream.ReadAsync(buffer, offset + len, length - len).ConfigureAwait(false);
                 if (tmp < 1)
+                {
                     break;
+                }
 
                 len += tmp;
             }
@@ -95,10 +97,9 @@ namespace SocketHttpListener
                    : buffer;
         }
 
-        private static bool readBytes(
-          this Stream stream, byte[] buffer, int offset, int length, Stream dest)
+        private static async Task<bool> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length, Stream dest)
         {
-            var bytes = stream.readBytes(buffer, offset, length);
+            var bytes = await stream.ReadBytesAsync(buffer, offset, length).ConfigureAwait(false);
             var len = bytes.Length;
             dest.Write(bytes, 0, len);
 
@@ -109,16 +110,16 @@ namespace SocketHttpListener
 
         #region Internal Methods
 
-        internal static byte[] Append(this ushort code, string reason)
+        internal static async Task<byte[]> AppendAsync(this ushort code, string reason)
         {
             using (var buffer = new MemoryStream())
             {
                 var tmp = code.ToByteArrayInternally(ByteOrder.Big);
-                buffer.Write(tmp, 0, 2);
+                await buffer.WriteAsync(tmp, 0, 2).ConfigureAwait(false);
                 if (reason != null && reason.Length > 0)
                 {
                     tmp = Encoding.UTF8.GetBytes(reason);
-                    buffer.Write(tmp, 0, tmp.Length);
+                    await buffer.WriteAsync(tmp, 0, tmp.Length).ConfigureAwait(false);
                 }
 
                 return buffer.ToArray();
@@ -331,12 +332,10 @@ namespace SocketHttpListener
                    : string.Format("\"{0}\"", value.Replace("\"", "\\\""));
         }
 
-        internal static byte[] ReadBytes(this Stream stream, int length)
-        {
-            return stream.readBytes(new byte[length], 0, length);
-        }
+        internal static Task<byte[]> ReadBytesAsync(this Stream stream, int length)
+            => stream.ReadBytesAsync(new byte[length], 0, length);
 
-        internal static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
+        internal static async Task<byte[]> ReadBytesAsync(this Stream stream, long length, int bufferLength)
         {
             using (var result = new MemoryStream())
             {
@@ -347,7 +346,7 @@ namespace SocketHttpListener
                 var end = false;
                 for (long i = 0; i < count; i++)
                 {
-                    if (!stream.readBytes(buffer, 0, bufferLength, result))
+                    if (!await stream.ReadBytesAsync(buffer, 0, bufferLength, result).ConfigureAwait(false))
                     {
                         end = true;
                         break;
@@ -355,26 +354,14 @@ namespace SocketHttpListener
                 }
 
                 if (!end && rem > 0)
-                    stream.readBytes(new byte[rem], 0, rem, result);
+                {
+                    await stream.ReadBytesAsync(new byte[rem], 0, rem, result).ConfigureAwait(false);
+                }
 
                 return result.ToArray();
             }
         }
 
-        internal static async Task<byte[]> ReadBytesAsync(this Stream stream, int length)
-        {
-            var buffer = new byte[length];
-
-            var len = await stream.ReadAsync(buffer, 0, length).ConfigureAwait(false);
-            var bytes = len < 1
-                ? new byte[0]
-                : len < length
-                  ? stream.readBytes(buffer, len, length - len)
-                  : buffer;
-
-            return bytes;
-        }
-
         internal static string RemovePrefix(this string value, params string[] prefixes)
         {
             var i = 0;
@@ -493,19 +480,16 @@ namespace SocketHttpListener
             return string.Format("{0}; {1}", m, parameters.ToString("; "));
         }
 
-        internal static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
-        {
-            return new List<TSource>(source);
-        }
-
         internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder)
         {
-            return BitConverter.ToUInt16(src.ToHostOrder(srcOrder), 0);
+            src.ToHostOrder(srcOrder);
+            return BitConverter.ToUInt16(src, 0);
         }
 
         internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder)
         {
-            return BitConverter.ToUInt64(src.ToHostOrder(srcOrder), 0);
+            src.ToHostOrder(srcOrder);
+            return BitConverter.ToUInt64(src, 0);
         }
 
         internal static string TrimEndSlash(this string value)
@@ -852,14 +836,17 @@ namespace SocketHttpListener
         /// <exception cref="ArgumentNullException">
         /// <paramref name="src"/> is <see langword="null"/>.
         /// </exception>
-        public static byte[] ToHostOrder(this byte[] src, ByteOrder srcOrder)
+        public static void ToHostOrder(this byte[] src, ByteOrder srcOrder)
         {
             if (src == null)
+            {
                 throw new ArgumentNullException(nameof(src));
+            }
 
-            return src.Length > 1 && !srcOrder.IsHostOrder()
-                   ? src.Reverse()
-                   : src;
+            if (src.Length > 1 && !srcOrder.IsHostOrder())
+            {
+                Array.Reverse(src);
+            }
         }
 
         /// <summary>

+ 25 - 18
SocketHttpListener/Net/HttpListener.cs

@@ -3,7 +3,6 @@ using System.Collections;
 using System.Collections.Generic;
 using System.Net;
 using System.Security.Cryptography.X509Certificates;
-using MediaBrowser.Common.Net;
 using MediaBrowser.Model.Cryptography;
 using MediaBrowser.Model.IO;
 using MediaBrowser.Model.Net;
@@ -18,47 +17,55 @@ namespace SocketHttpListener.Net
         internal ISocketFactory SocketFactory { get; private set; }
         internal IFileSystem FileSystem { get; private set; }
         internal IStreamHelper StreamHelper { get; private set; }
-        internal INetworkManager NetworkManager { get; private set; }
         internal IEnvironmentInfo EnvironmentInfo { get; private set; }
 
         public bool EnableDualMode { get; set; }
 
-        AuthenticationSchemes auth_schemes;
-        HttpListenerPrefixCollection prefixes;
-        AuthenticationSchemeSelector auth_selector;
-        string realm;
-        bool unsafe_ntlm_auth;
-        bool listening;
-        bool disposed;
+        private AuthenticationSchemes auth_schemes;
+        private HttpListenerPrefixCollection prefixes;
+        private AuthenticationSchemeSelector auth_selector;
+        private string realm;
+        private bool unsafe_ntlm_auth;
+        private bool listening;
+        private bool disposed;
 
-        Dictionary<HttpListenerContext, HttpListenerContext> registry;   // Dictionary<HttpListenerContext,HttpListenerContext>
-        Dictionary<HttpConnection, HttpConnection> connections;
+        private Dictionary<HttpListenerContext, HttpListenerContext> registry;
+        private Dictionary<HttpConnection, HttpConnection> connections;
         private ILogger _logger;
         private X509Certificate _certificate;
 
         public Action<HttpListenerContext> OnContext { get; set; }
 
-        public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory,
-            INetworkManager networkManager, IStreamHelper streamHelper, IFileSystem fileSystem,
+        public HttpListener(
+            ILogger logger,
+            ICryptoProvider cryptoProvider,
+            ISocketFactory socketFactory,
+            IStreamHelper streamHelper,
+            IFileSystem fileSystem,
             IEnvironmentInfo environmentInfo)
         {
             _logger = logger;
             CryptoProvider = cryptoProvider;
             SocketFactory = socketFactory;
-            NetworkManager = networkManager;
             StreamHelper = streamHelper;
             FileSystem = fileSystem;
             EnvironmentInfo = environmentInfo;
+
             prefixes = new HttpListenerPrefixCollection(logger, this);
             registry = new Dictionary<HttpListenerContext, HttpListenerContext>();
             connections = new Dictionary<HttpConnection, HttpConnection>();
             auth_schemes = AuthenticationSchemes.Anonymous;
         }
 
-        public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider,
-            ISocketFactory socketFactory, INetworkManager networkManager, IStreamHelper streamHelper,
-            IFileSystem fileSystem, IEnvironmentInfo environmentInfo)
-            : this(logger, cryptoProvider, socketFactory, networkManager, streamHelper, fileSystem, environmentInfo)
+        public HttpListener(
+            ILogger logger,
+            X509Certificate certificate,
+            ICryptoProvider cryptoProvider,
+            ISocketFactory socketFactory,
+            IStreamHelper streamHelper,
+            IFileSystem fileSystem,
+            IEnvironmentInfo environmentInfo)
+            : this(logger, cryptoProvider, socketFactory, streamHelper, fileSystem, environmentInfo)
         {
             _certificate = certificate;
         }

+ 54 - 25
SocketHttpListener/Net/HttpListenerPrefixCollection.cs

@@ -7,18 +7,18 @@ namespace SocketHttpListener.Net
 {
     public class HttpListenerPrefixCollection : ICollection<string>, IEnumerable<string>, IEnumerable
     {
-        List<string> prefixes = new List<string>();
-        HttpListener listener;
+        private List<string> _prefixes = new List<string>();
+        private HttpListener _listener;
 
         private ILogger _logger;
 
         internal HttpListenerPrefixCollection(ILogger logger, HttpListener listener)
         {
             _logger = logger;
-            this.listener = listener;
+            _listener = listener;
         }
 
-        public int Count => prefixes.Count;
+        public int Count => _prefixes.Count;
 
         public bool IsReadOnly => false;
 
@@ -26,61 +26,90 @@ namespace SocketHttpListener.Net
 
         public void Add(string uriPrefix)
         {
-            listener.CheckDisposed();
+            _listener.CheckDisposed();
             //ListenerPrefix.CheckUri(uriPrefix);
-            if (prefixes.Contains(uriPrefix))
+            if (_prefixes.Contains(uriPrefix))
+            {
                 return;
+            }
 
-            prefixes.Add(uriPrefix);
-            if (listener.IsListening)
-                HttpEndPointManager.AddPrefix(_logger, uriPrefix, listener);
+            _prefixes.Add(uriPrefix);
+            if (_listener.IsListening)
+            {
+                HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
+            }
+        }
+
+        public void AddRange(IEnumerable<string> uriPrefixes)
+        {
+            _listener.CheckDisposed();
+
+            foreach (var uriPrefix in uriPrefixes)
+            {
+                if (_prefixes.Contains(uriPrefix))
+                {
+                    continue;
+                }
+
+                _prefixes.Add(uriPrefix);
+                if (_listener.IsListening)
+                {
+                    HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
+                }
+            }
         }
 
         public void Clear()
         {
-            listener.CheckDisposed();
-            prefixes.Clear();
-            if (listener.IsListening)
-                HttpEndPointManager.RemoveListener(_logger, listener);
+            _listener.CheckDisposed();
+            _prefixes.Clear();
+            if (_listener.IsListening)
+            {
+                HttpEndPointManager.RemoveListener(_logger, _listener);
+            }
         }
 
         public bool Contains(string uriPrefix)
         {
-            listener.CheckDisposed();
-            return prefixes.Contains(uriPrefix);
+            _listener.CheckDisposed();
+            return _prefixes.Contains(uriPrefix);
         }
 
         public void CopyTo(string[] array, int offset)
         {
-            listener.CheckDisposed();
-            prefixes.CopyTo(array, offset);
+            _listener.CheckDisposed();
+            _prefixes.CopyTo(array, offset);
         }
 
         public void CopyTo(Array array, int offset)
         {
-            listener.CheckDisposed();
-            ((ICollection)prefixes).CopyTo(array, offset);
+            _listener.CheckDisposed();
+            ((ICollection)_prefixes).CopyTo(array, offset);
         }
 
         public IEnumerator<string> GetEnumerator()
         {
-            return prefixes.GetEnumerator();
+            return _prefixes.GetEnumerator();
         }
 
         IEnumerator IEnumerable.GetEnumerator()
         {
-            return prefixes.GetEnumerator();
+            return _prefixes.GetEnumerator();
         }
 
         public bool Remove(string uriPrefix)
         {
-            listener.CheckDisposed();
+            _listener.CheckDisposed();
             if (uriPrefix == null)
+            {
                 throw new ArgumentNullException(nameof(uriPrefix));
+            }
 
-            bool result = prefixes.Remove(uriPrefix);
-            if (result && listener.IsListening)
-                HttpEndPointManager.RemovePrefix(_logger, uriPrefix, listener);
+            bool result = _prefixes.Remove(uriPrefix);
+            if (result && _listener.IsListening)
+            {
+                HttpEndPointManager.RemovePrefix(_logger, uriPrefix, _listener);
+            }
 
             return result;
         }

+ 148 - 146
SocketHttpListener/WebSocket.cs

@@ -30,9 +30,9 @@ namespace SocketHttpListener
         private CookieCollection _cookies;
         private AutoResetEvent _exitReceiving;
         private object _forConn;
-        private object _forEvent;
+        private readonly SemaphoreSlim _forEvent = new SemaphoreSlim(1, 1);
         private object _forMessageEventQueue;
-        private object _forSend;
+        private readonly SemaphoreSlim _forSend = new SemaphoreSlim(1, 1);
         private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
         private Queue<MessageEventArgs> _messageEventQueue;
         private string _protocol;
@@ -109,12 +109,15 @@ namespace SocketHttpListener
 
         #region Private Methods
 
-        private void close(CloseStatusCode code, string reason, bool wait)
+        private async Task CloseAsync(CloseStatusCode code, string reason, bool wait)
         {
-            close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait);
+            await CloseAsync(new PayloadData(
+                await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)),
+                !code.IsReserved(),
+                wait).ConfigureAwait(false);
         }
 
-        private void close(PayloadData payload, bool send, bool wait)
+        private async Task CloseAsync(PayloadData payload, bool send, bool wait)
         {
             lock (_forConn)
             {
@@ -126,11 +129,12 @@ namespace SocketHttpListener
                 _readyState = WebSocketState.CloseSent;
             }
 
-            var e = new CloseEventArgs(payload);
-            e.WasClean =
-              closeHandshake(
+            var e = new CloseEventArgs(payload)
+            {
+                WasClean = await CloseHandshakeAsync(
                   send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null,
-                  wait ? 1000 : 0);
+                  wait ? 1000 : 0).ConfigureAwait(false)
+            };
 
             _readyState = WebSocketState.Closed;
             try
@@ -143,9 +147,9 @@ namespace SocketHttpListener
             }
         }
 
-        private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout)
+        private async Task<bool> CloseHandshakeAsync(byte[] frameAsBytes, int millisecondsTimeout)
         {
-            var sent = frameAsBytes != null && writeBytes(frameAsBytes);
+            var sent = frameAsBytes != null && await WriteBytesAsync(frameAsBytes).ConfigureAwait(false);
             var received =
               millisecondsTimeout == 0 ||
               (sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout));
@@ -189,11 +193,11 @@ namespace SocketHttpListener
             _context = null;
         }
 
-        private bool concatenateFragmentsInto(Stream dest)
+        private async Task<bool> ConcatenateFragmentsIntoAsync(Stream dest)
         {
             while (true)
             {
-                var frame = WebSocketFrame.Read(_stream, true);
+                var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false);
                 if (frame.IsFinal)
                 {
                     /* FINAL */
@@ -221,7 +225,7 @@ namespace SocketHttpListener
 
                     // CLOSE
                     if (frame.IsClose)
-                        return processCloseFrame(frame);
+                        return await ProcessCloseFrameAsync(frame).ConfigureAwait(false);
                 }
                 else
                 {
@@ -236,10 +240,10 @@ namespace SocketHttpListener
                 }
 
                 // ?
-                return processUnsupportedFrame(
+                return await ProcessUnsupportedFrameAsync(
                   frame,
                   CloseStatusCode.IncorrectData,
-                  "An incorrect data has been received while receiving fragmented data.");
+                  "An incorrect data has been received while receiving fragmented data.").ConfigureAwait(false);
             }
 
             return true;
@@ -299,44 +303,42 @@ namespace SocketHttpListener
             _compression = CompressionMethod.None;
             _cookies = new CookieCollection();
             _forConn = new object();
-            _forEvent = new object();
-            _forSend = new object();
             _messageEventQueue = new Queue<MessageEventArgs>();
             _forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot;
             _readyState = WebSocketState.Connecting;
         }
 
-        private void open()
+        private async Task OpenAsync()
         {
             try
             {
                 startReceiving();
 
-                lock (_forEvent)
-                {
-                    try
-                    {
-                        if (OnOpen != null)
-                        {
-                            OnOpen(this, EventArgs.Empty);
-                        }
-                    }
-                    catch (Exception ex)
-                    {
-                        processException(ex, "An exception has occurred while OnOpen.");
-                    }
-                }
             }
             catch (Exception ex)
             {
-                processException(ex, "An exception has occurred while opening.");
+                await ProcessExceptionAsync(ex, "An exception has occurred while opening.").ConfigureAwait(false);
+            }
+
+            await _forEvent.WaitAsync().ConfigureAwait(false);
+            try
+            {
+                OnOpen?.Invoke(this, EventArgs.Empty);
+            }
+            catch (Exception ex)
+            {
+                await ProcessExceptionAsync(ex, "An exception has occurred while OnOpen.").ConfigureAwait(false);
+            }
+            finally
+            {
+                _forEvent.Release();
             }
         }
 
-        private bool processCloseFrame(WebSocketFrame frame)
+        private async Task<bool> ProcessCloseFrameAsync(WebSocketFrame frame)
         {
             var payload = frame.PayloadData;
-            close(payload, !payload.ContainsReservedCloseStatusCode, false);
+            await CloseAsync(payload, !payload.ContainsReservedCloseStatusCode, false).ConfigureAwait(false);
 
             return false;
         }
@@ -352,7 +354,7 @@ namespace SocketHttpListener
             return true;
         }
 
-        private void processException(Exception exception, string message)
+        private async Task ProcessExceptionAsync(Exception exception, string message)
         {
             var code = CloseStatusCode.Abnormal;
             var reason = message;
@@ -365,25 +367,31 @@ namespace SocketHttpListener
 
             error(message ?? code.GetMessage(), exception);
             if (_readyState == WebSocketState.Connecting)
-                Close(HttpStatusCode.BadRequest);
+            {
+                await CloseAsync(HttpStatusCode.BadRequest).ConfigureAwait(false);
+            }
             else
-                close(code, reason ?? code.GetMessage(), false);
+            {
+                await CloseAsync(code, reason ?? code.GetMessage(), false).ConfigureAwait(false);
+            }
         }
 
-        private bool processFragmentedFrame(WebSocketFrame frame)
+        private Task<bool> ProcessFragmentedFrameAsync(WebSocketFrame frame)
         {
             return frame.IsContinuation // Not first fragment
-                   ? true
-                   : processFragments(frame);
+                   ? Task.FromResult(true)
+                   : ProcessFragmentsAsync(frame);
         }
 
-        private bool processFragments(WebSocketFrame first)
+        private async Task<bool> ProcessFragmentsAsync(WebSocketFrame first)
         {
             using (var buff = new MemoryStream())
             {
                 buff.WriteBytes(first.PayloadData.ApplicationData);
-                if (!concatenateFragmentsInto(buff))
+                if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false))
+                {
                     return false;
+                }
 
                 byte[] data;
                 if (_compression != CompressionMethod.None)
@@ -412,36 +420,38 @@ namespace SocketHttpListener
             return true;
         }
 
-        private bool processUnsupportedFrame(WebSocketFrame frame, CloseStatusCode code, string reason)
+        private async Task<bool> ProcessUnsupportedFrameAsync(WebSocketFrame frame, CloseStatusCode code, string reason)
         {
-            processException(new WebSocketException(code, reason), null);
+            await ProcessExceptionAsync(new WebSocketException(code, reason), null).ConfigureAwait(false);
 
             return false;
         }
 
-        private bool processWebSocketFrame(WebSocketFrame frame)
+        private Task<bool> ProcessWebSocketFrameAsync(WebSocketFrame frame)
         {
+            // TODO: @bond change to if/else chain
             return frame.IsCompressed && _compression == CompressionMethod.None
-                   ? processUnsupportedFrame(
+                   ? ProcessUnsupportedFrameAsync(
                        frame,
                        CloseStatusCode.IncorrectData,
                        "A compressed data has been received without available decompression method.")
                    : frame.IsFragmented
-                     ? processFragmentedFrame(frame)
+                     ? ProcessFragmentedFrameAsync(frame)
                      : frame.IsData
-                       ? processDataFrame(frame)
+                       ? Task.FromResult(processDataFrame(frame))
                        : frame.IsPing
-                         ? processPingFrame(frame)
+                         ? Task.FromResult(processPingFrame(frame))
                          : frame.IsPong
-                           ? processPongFrame(frame)
+                           ? Task.FromResult(processPongFrame(frame))
                            : frame.IsClose
-                             ? processCloseFrame(frame)
-                             : processUnsupportedFrame(frame, CloseStatusCode.PolicyViolation, null);
+                             ? ProcessCloseFrameAsync(frame)
+                             : ProcessUnsupportedFrameAsync(frame, CloseStatusCode.PolicyViolation, null);
         }
 
-        private bool send(Opcode opcode, Stream stream)
+        private async Task<bool> SendAsync(Opcode opcode, Stream stream)
         {
-            lock (_forSend)
+            await _forSend.WaitAsync().ConfigureAwait(false);
+            try
             {
                 var src = stream;
                 var compressed = false;
@@ -454,7 +464,7 @@ namespace SocketHttpListener
                         compressed = true;
                     }
 
-                    sent = send(opcode, Mask.Unmask, stream, compressed);
+                    sent = await SendAsync(opcode, Mask.Unmask, stream, compressed).ConfigureAwait(false);
                     if (!sent)
                         error("Sending a data has been interrupted.");
                 }
@@ -472,16 +482,20 @@ namespace SocketHttpListener
 
                 return sent;
             }
+            finally
+            {
+                _forSend.Release();
+            }
         }
 
-        private bool send(Opcode opcode, Mask mask, Stream stream, bool compressed)
+        private async Task<bool> SendAsync(Opcode opcode, Mask mask, Stream stream, bool compressed)
         {
             var len = stream.Length;
 
             /* Not fragmented */
 
             if (len == 0)
-                return send(Fin.Final, opcode, mask, new byte[0], compressed);
+                return await SendAsync(Fin.Final, opcode, mask, new byte[0], compressed).ConfigureAwait(false);
 
             var quo = len / FragmentLength;
             var rem = (int)(len % FragmentLength);
@@ -490,26 +504,26 @@ namespace SocketHttpListener
             if (quo == 0)
             {
                 buff = new byte[rem];
-                return stream.Read(buff, 0, rem) == rem &&
-                       send(Fin.Final, opcode, mask, buff, compressed);
+                return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
+                       await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
             }
 
             buff = new byte[FragmentLength];
             if (quo == 1 && rem == 0)
-                return stream.Read(buff, 0, FragmentLength) == FragmentLength &&
-                       send(Fin.Final, opcode, mask, buff, compressed);
+                return await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) == FragmentLength &&
+                       await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
 
             /* Send fragmented */
 
             // Begin
-            if (stream.Read(buff, 0, FragmentLength) != FragmentLength ||
-                !send(Fin.More, opcode, mask, buff, compressed))
+            if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
+                !await SendAsync(Fin.More, opcode, mask, buff, compressed).ConfigureAwait(false))
                 return false;
 
             var n = rem == 0 ? quo - 2 : quo - 1;
             for (long i = 0; i < n; i++)
-                if (stream.Read(buff, 0, FragmentLength) != FragmentLength ||
-                    !send(Fin.More, Opcode.Cont, mask, buff, compressed))
+                if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
+                    !await SendAsync(Fin.More, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false))
                     return false;
 
             // End
@@ -518,98 +532,88 @@ namespace SocketHttpListener
             else
                 buff = new byte[rem];
 
-            return stream.Read(buff, 0, rem) == rem &&
-                   send(Fin.Final, Opcode.Cont, mask, buff, compressed);
+            return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
+                   await SendAsync(Fin.Final, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false);
         }
 
-        private bool send(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed)
+        private Task<bool> SendAsync(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed)
         {
             lock (_forConn)
             {
                 if (_readyState != WebSocketState.Open)
                 {
-                    return false;
+                    return Task.FromResult(false);
                 }
 
-                return writeBytes(
+                return WriteBytesAsync(
                   WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray());
             }
         }
 
-        private Task sendAsync(Opcode opcode, Stream stream)
-        {
-            var completionSource = new TaskCompletionSource<bool>();
-            Task.Run(() =>
-           {
-               try
-               {
-                   send(opcode, stream);
-                   completionSource.TrySetResult(true);
-               }
-               catch (Exception ex)
-               {
-                   completionSource.TrySetException(ex);
-               }
-           });
-            return completionSource.Task;
-        }
-
         // As server
-        private bool sendHttpResponse(HttpResponse response)
-        {
-            return writeBytes(response.ToByteArray());
-        }
+        private Task<bool> SendHttpResponseAsync(HttpResponse response)
+            => WriteBytesAsync(response.ToByteArray());
 
         private void startReceiving()
         {
             if (_messageEventQueue.Count > 0)
+            {
                 _messageEventQueue.Clear();
+            }
 
             _exitReceiving = new AutoResetEvent(false);
             _receivePong = new AutoResetEvent(false);
 
             Action receive = null;
-            receive = () => WebSocketFrame.ReadAsync(
-              _stream,
-              true,
-              frame =>
-              {
-                  if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed)
-                  {
-                      receive();
-
-                      if (!frame.IsData)
-                          return;
-
-                      lock (_forEvent)
-                      {
-                          try
-                          {
-                              var e = dequeueFromMessageEventQueue();
-                              if (e != null && _readyState == WebSocketState.Open)
-                                  OnMessage.Emit(this, e);
-                          }
-                          catch (Exception ex)
-                          {
-                              processException(ex, "An exception has occurred while OnMessage.");
-                          }
-                      }
-                  }
-                  else if (_exitReceiving != null)
-                  {
-                      _exitReceiving.Set();
-                  }
-              },
-              ex => processException(ex, "An exception has occurred while receiving a message."));
+            receive = async () => await WebSocketFrame.ReadAsync(
+                _stream,
+                true,
+                async frame =>
+                {
+                    if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed)
+                    {
+                        receive();
+
+                        if (!frame.IsData)
+                        {
+                            return;
+                        }
+
+                        await _forEvent.WaitAsync().ConfigureAwait(false);
+
+                        try
+                        {
+                            var e = dequeueFromMessageEventQueue();
+                            if (e != null && _readyState == WebSocketState.Open)
+                            {
+                                OnMessage.Emit(this, e);
+                            }
+                        }
+                        catch (Exception ex)
+                        {
+                            await ProcessExceptionAsync(ex, "An exception has occurred while OnMessage.").ConfigureAwait(false);
+                        }
+                        finally
+                        {
+                            _forEvent.Release();
+                        }
+
+                    }
+                    else if (_exitReceiving != null)
+                    {
+                        _exitReceiving.Set();
+                    }
+                },
+                async ex => await ProcessExceptionAsync(ex, "An exception has occurred while receiving a message.")).ConfigureAwait(false);
 
             receive();
         }
 
-        private bool writeBytes(byte[] data)
+        private async Task<bool> WriteBytesAsync(byte[] data)
         {
             try
             {
-                _stream.Write(data, 0, data.Length);
+                await _stream.WriteAsync(data, 0, data.Length).ConfigureAwait(false);
                 return true;
             }
             catch (Exception)
@@ -623,10 +627,10 @@ namespace SocketHttpListener
         #region Internal Methods
 
         // As server
-        internal void Close(HttpResponse response)
+        internal async Task CloseAsync(HttpResponse response)
         {
             _readyState = WebSocketState.CloseSent;
-            sendHttpResponse(response);
+            await SendHttpResponseAsync(response).ConfigureAwait(false);
 
             closeServerResources();
 
@@ -634,22 +638,20 @@ namespace SocketHttpListener
         }
 
         // As server
-        internal void Close(HttpStatusCode code)
-        {
-            Close(createHandshakeCloseResponse(code));
-        }
+        internal Task CloseAsync(HttpStatusCode code)
+            => CloseAsync(createHandshakeCloseResponse(code));
 
         // As server
-        public void ConnectAsServer()
+        public async Task ConnectAsServer()
         {
             try
             {
                 _readyState = WebSocketState.Open;
-                open();
+                await OpenAsync().ConfigureAwait(false);
             }
             catch (Exception ex)
             {
-                processException(ex, "An exception has occurred while connecting.");
+                await ProcessExceptionAsync(ex, "An exception has occurred while connecting.").ConfigureAwait(false);
             }
         }
 
@@ -660,18 +662,18 @@ namespace SocketHttpListener
         /// <summary>
         /// Closes the WebSocket connection, and releases all associated resources.
         /// </summary>
-        public void Close()
+        public Task CloseAsync()
         {
             var msg = _readyState.CheckIfClosable();
             if (msg != null)
             {
                 error(msg);
 
-                return;
+                return Task.CompletedTask;
             }
 
             var send = _readyState == WebSocketState.Open;
-            close(new PayloadData(), send, send);
+            return CloseAsync(new PayloadData(), send, send);
         }
 
         /// <summary>
@@ -689,11 +691,11 @@ namespace SocketHttpListener
         /// <param name="reason">
         /// A <see cref="string"/> that represents the reason for the close.
         /// </param>
-        public void Close(CloseStatusCode code, string reason)
+        public async Task CloseAsync(CloseStatusCode code, string reason)
         {
             byte[] data = null;
             var msg = _readyState.CheckIfClosable() ??
-                      (data = ((ushort)code).Append(reason)).CheckIfValidControlData("reason");
+                      (data = await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)).CheckIfValidControlData("reason");
 
             if (msg != null)
             {
@@ -703,7 +705,7 @@ namespace SocketHttpListener
             }
 
             var send = _readyState == WebSocketState.Open && !code.IsReserved();
-            close(new PayloadData(data), send, send);
+            await CloseAsync(new PayloadData(data), send, send).ConfigureAwait(false);
         }
 
         /// <summary>
@@ -728,7 +730,7 @@ namespace SocketHttpListener
                 throw new Exception(msg);
             }
 
-            return sendAsync(Opcode.Binary, new MemoryStream(data));
+            return SendAsync(Opcode.Binary, new MemoryStream(data));
         }
 
         /// <summary>
@@ -753,7 +755,7 @@ namespace SocketHttpListener
                 throw new Exception(msg);
             }
 
-            return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data)));
+            return SendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data)));
         }
 
         #endregion
@@ -768,7 +770,7 @@ namespace SocketHttpListener
         /// </remarks>
         void IDisposable.Dispose()
         {
-            Close(CloseStatusCode.Away, null);
+            CloseAsync(CloseStatusCode.Away, null).GetAwaiter().GetResult();
         }
 
         #endregion

+ 23 - 24
SocketHttpListener/WebSocketFrame.cs

@@ -2,6 +2,7 @@ using System;
 using System.Collections;
 using System.Collections.Generic;
 using System.IO;
+using System.Threading.Tasks;
 
 namespace SocketHttpListener
 {
@@ -177,7 +178,7 @@ namespace SocketHttpListener
             return opcode == Opcode.Text || opcode == Opcode.Binary;
         }
 
-        private static WebSocketFrame read(byte[] header, Stream stream, bool unmask)
+        private static async Task<WebSocketFrame> ReadAsync(byte[] header, Stream stream, bool unmask)
         {
             /* Header */
 
@@ -229,7 +230,7 @@ namespace SocketHttpListener
                          ? 2
                          : 8;
 
-            var extPayloadLen = size > 0 ? stream.ReadBytes(size) : new byte[0];
+            var extPayloadLen = size > 0 ? await stream.ReadBytesAsync(size).ConfigureAwait(false) : Array.Empty<byte>();
             if (size > 0 && extPayloadLen.Length != size)
                 throw new WebSocketException(
                   "The 'Extended Payload Length' of a frame cannot be read from the data source.");
@@ -239,7 +240,7 @@ namespace SocketHttpListener
             /* Masking Key */
 
             var masked = mask == Mask.Mask;
-            var maskingKey = masked ? stream.ReadBytes(4) : new byte[0];
+            var maskingKey = masked ? await stream.ReadBytesAsync(4).ConfigureAwait(false) : Array.Empty<byte>();
             if (masked && maskingKey.Length != 4)
                 throw new WebSocketException(
                   "The 'Masking Key' of a frame cannot be read from the data source.");
@@ -264,8 +265,8 @@ namespace SocketHttpListener
                       "The length of 'Payload Data' of a frame is greater than the allowable length.");
 
                 data = payloadLen > 126
-                       ? stream.ReadBytes((long)len, 1024)
-                       : stream.ReadBytes((int)len);
+                       ? await stream.ReadBytesAsync((long)len, 1024).ConfigureAwait(false)
+                       : await stream.ReadBytesAsync((int)len).ConfigureAwait(false);
 
                 //if (data.LongLength != (long)len)
                 //    throw new WebSocketException(
@@ -273,7 +274,7 @@ namespace SocketHttpListener
             }
             else
             {
-                data = new byte[0];
+                data = Array.Empty<byte>();
             }
 
             var payload = new PayloadData(data, masked);
@@ -281,7 +282,7 @@ namespace SocketHttpListener
             {
                 payload.Mask(maskingKey);
                 frame._mask = Mask.Unmask;
-                frame._maskingKey = new byte[0];
+                frame._maskingKey = Array.Empty<byte>();
             }
 
             frame._payloadData = payload;
@@ -302,10 +303,10 @@ namespace SocketHttpListener
             return new WebSocketFrame(Opcode.Close, mask, payload);
         }
 
-        internal static WebSocketFrame CreateCloseFrame(Mask mask, CloseStatusCode code, string reason)
+        internal static async Task<WebSocketFrame> CreateCloseFrameAsync(Mask mask, CloseStatusCode code, string reason)
         {
             return new WebSocketFrame(
-              Opcode.Close, mask, new PayloadData(((ushort)code).Append(reason)));
+              Opcode.Close, mask, new PayloadData(await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)));
         }
 
         internal static WebSocketFrame CreatePingFrame(Mask mask)
@@ -329,41 +330,39 @@ namespace SocketHttpListener
             return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed);
         }
 
-        internal static WebSocketFrame Read(Stream stream)
-        {
-            return Read(stream, true);
-        }
+        internal static Task<WebSocketFrame> ReadAsync(Stream stream)
+            => ReadAsync(stream, true);
 
-        internal static WebSocketFrame Read(Stream stream, bool unmask)
+        internal static async Task<WebSocketFrame> ReadAsync(Stream stream, bool unmask)
         {
-            var header = stream.ReadBytes(2);
+            var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
             if (header.Length != 2)
+            {
                 throw new WebSocketException(
                   "The header part of a frame cannot be read from the data source.");
+            }
 
-            return read(header, stream, unmask);
+            return await ReadAsync(header, stream, unmask).ConfigureAwait(false);
         }
 
-        internal static async void ReadAsync(
+        internal static async Task ReadAsync(
           Stream stream, bool unmask, Action<WebSocketFrame> completed, Action<Exception> error)
         {
             try
             {
                 var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
                 if (header.Length != 2)
+                {
                     throw new WebSocketException(
                       "The header part of a frame cannot be read from the data source.");
+                }
 
-                var frame = read(header, stream, unmask);
-                if (completed != null)
-                    completed(frame);
+                var frame = await ReadAsync(header, stream, unmask).ConfigureAwait(false);
+                completed?.Invoke(frame);
             }
             catch (Exception ex)
             {
-                if (error != null)
-                {
-                    error(ex);
-                }
+                error.Invoke(ex);
             }
         }