Explorar o código

Merge pull request #5012 from jellyfin/ws

Improve WebSocket Message Deserialization
dkanada %!s(int64=4) %!d(string=hai) anos
pai
achega
7f1243978c

+ 22 - 30
Emby.Server.Implementations/HttpServer/WebSocketConnection.cs

@@ -5,6 +5,7 @@ using System.Buffers;
 using System.IO.Pipelines;
 using System.Net;
 using System.Net.WebSockets;
+using System.Text;
 using System.Text.Json;
 using System.Threading;
 using System.Threading.Tasks;
@@ -138,7 +139,7 @@ namespace Emby.Server.Implementations.HttpServer
                 writer.Advance(bytesRead);
 
                 // Make the data available to the PipeReader
-                FlushResult flushResult = await writer.FlushAsync().ConfigureAwait(false);
+                FlushResult flushResult = await writer.FlushAsync(cancellationToken).ConfigureAwait(false);
                 if (flushResult.IsCompleted)
                 {
                     // The PipeReader stopped reading
@@ -181,32 +182,16 @@ namespace Emby.Server.Implementations.HttpServer
             }
 
             WebSocketMessage<object>? stub;
+            long bytesConsumed = 0;
             try
             {
-
-                if (buffer.IsSingleSegment)
-                {
-                    stub = JsonSerializer.Deserialize<WebSocketMessage<object>>(buffer.FirstSpan, _jsonOptions);
-                }
-                else
-                {
-                    var buf = ArrayPool<byte>.Shared.Rent(Convert.ToInt32(buffer.Length));
-                    try
-                    {
-                        buffer.CopyTo(buf);
-                        stub = JsonSerializer.Deserialize<WebSocketMessage<object>>(buf, _jsonOptions);
-                    }
-                    finally
-                    {
-                        ArrayPool<byte>.Shared.Return(buf);
-                    }
-                }
+                stub = DeserializeWebSocketMessage(buffer, out bytesConsumed);
             }
             catch (JsonException ex)
             {
                 // Tell the PipeReader how much of the buffer we have consumed
                 reader.AdvanceTo(buffer.End);
-                _logger.LogError(ex, "Error processing web socket message");
+                _logger.LogError(ex, "Error processing web socket message: {Data}", Encoding.UTF8.GetString(buffer));
                 return;
             }
 
@@ -217,27 +202,34 @@ namespace Emby.Server.Implementations.HttpServer
             }
 
             // Tell the PipeReader how much of the buffer we have consumed
-            reader.AdvanceTo(buffer.End);
+            reader.AdvanceTo(buffer.GetPosition(bytesConsumed));
 
             _logger.LogDebug("WS {IP} received message: {@Message}", RemoteEndPoint, stub);
 
-            var info = new WebSocketMessageInfo
-            {
-                MessageType = stub.MessageType,
-                Data = stub.Data?.ToString(), // Data can be null
-                Connection = this
-            };
-
-            if (info.MessageType == SessionMessageType.KeepAlive)
+            if (stub.MessageType == SessionMessageType.KeepAlive)
             {
                 await SendKeepAliveResponse().ConfigureAwait(false);
             }
             else
             {
-                await OnReceive(info).ConfigureAwait(false);
+                await OnReceive(
+                    new WebSocketMessageInfo
+                    {
+                        MessageType = stub.MessageType,
+                        Data = stub.Data?.ToString(), // Data can be null
+                        Connection = this
+                    }).ConfigureAwait(false);
             }
         }
 
+        internal WebSocketMessage<object>? DeserializeWebSocketMessage(ReadOnlySequence<byte> bytes, out long bytesConsumed)
+        {
+            var jsonReader = new Utf8JsonReader(bytes);
+            var ret = JsonSerializer.Deserialize<WebSocketMessage<object>>(ref jsonReader, _jsonOptions);
+            bytesConsumed = jsonReader.BytesConsumed;
+            return ret;
+        }
+
         private Task SendKeepAliveResponse()
         {
             LastKeepAliveDate = DateTime.UtcNow;

+ 69 - 0
tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs

@@ -0,0 +1,69 @@
+using System;
+using System.Buffers;
+using System.IO;
+using System.Text.Json;
+using Emby.Server.Implementations.HttpServer;
+using Microsoft.Extensions.Logging.Abstractions;
+using Xunit;
+
+namespace Jellyfin.Server.Implementations.Tests.HttpServer
+{
+    public class WebSocketConnectionTests
+    {
+        [Fact]
+        public void DeserializeWebSocketMessage_SingleSegment_Success()
+        {
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
+            con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
+            Assert.Equal(109, bytesConsumed);
+        }
+
+        [Fact]
+        public void DeserializeWebSocketMessage_MultipleSegments_Success()
+        {
+            const int SplitPos = 64;
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
+            var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
+            var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
+            con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(seg1, 0, seg2, seg2.Memory.Length - 1), out var bytesConsumed);
+            Assert.Equal(109, bytesConsumed);
+        }
+
+        [Fact]
+        public void DeserializeWebSocketMessage_ValidPartial_Success()
+        {
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
+            con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
+            Assert.Equal(109, bytesConsumed);
+        }
+
+        [Fact]
+        public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
+        {
+            var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
+            var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
+            Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
+        }
+
+        internal class BufferSegment : ReadOnlySequenceSegment<byte>
+        {
+            public BufferSegment(Memory<byte> memory)
+            {
+                Memory = memory;
+            }
+
+            public BufferSegment Append(Memory<byte> memory)
+            {
+                var segment = new BufferSegment(memory)
+                {
+                    RunningIndex = RunningIndex + Memory.Length
+                };
+                Next = segment;
+                return segment;
+            }
+        }
+    }
+}

+ 6 - 5
tests/Jellyfin.Server.Implementations.Tests/Jellyfin.Server.Implementations.Tests.csproj

@@ -13,6 +13,12 @@
     <RootNamespace>Jellyfin.Server.Implementations.Tests</RootNamespace>
   </PropertyGroup>
 
+  <ItemGroup>
+    <None Include="Test Data\**\*.*">
+      <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+    </None>
+  </ItemGroup>
+
   <ItemGroup>
     <PackageReference Include="AutoFixture" Version="4.15.0" />
     <PackageReference Include="AutoFixture.AutoMoq" Version="4.15.0" />
@@ -35,11 +41,6 @@
     <ProjectReference Include="..\..\Emby.Server.Implementations\Emby.Server.Implementations.csproj" />
   </ItemGroup>
 
-  <ItemGroup>
-    <EmbeddedResource Include="LiveTv\discover.json" />
-    <EmbeddedResource Include="LiveTv\lineup.json" />
-  </ItemGroup>
-
   <PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
     <CodeAnalysisRuleSet>../jellyfin-tests.ruleset</CodeAnalysisRuleSet>
   </PropertyGroup>

+ 2 - 10
tests/Jellyfin.Server.Implementations.Tests/LiveTv/HdHomerunHostTests.cs

@@ -1,4 +1,5 @@
 using System;
+using System.IO;
 using System.Net.Http;
 using System.Threading;
 using System.Threading.Tasks;
@@ -21,24 +22,15 @@ namespace Jellyfin.Server.Implementations.Tests.LiveTv
 
         public HdHomerunHostTests()
         {
-            const string BaseResourcePath = "Jellyfin.Server.Implementations.Tests.LiveTv.";
-
             var messageHandler = new Mock<HttpMessageHandler>();
             messageHandler.Protected()
                 .Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
                 .Returns<HttpRequestMessage, CancellationToken>(
                     (m, _) =>
                     {
-                        var resource = BaseResourcePath + m.RequestUri?.Segments[^1];
-                        var stream = typeof(HdHomerunHostTests).Assembly.GetManifestResourceStream(resource);
-                        if (stream == null)
-                        {
-                            throw new NullReferenceException("Resource doesn't exist: " + resource);
-                        }
-
                         return Task.FromResult(new HttpResponseMessage()
                         {
-                            Content = new StreamContent(stream)
+                            Content = new StreamContent(File.OpenRead("Test Data/LiveTv/" + m.RequestUri?.Segments[^1]))
                         });
                     });
 

+ 1 - 0
tests/Jellyfin.Server.Implementations.Tests/Test Data/HttpServer/ForceKeepAlive.json

@@ -0,0 +1 @@
+{"MessageType":"ForceKeepAlive","MessageId":"00000000-0000-0000-0000-000000000000","ServerId":null,"Data":60}

+ 1 - 0
tests/Jellyfin.Server.Implementations.Tests/Test Data/HttpServer/Partial.json

@@ -0,0 +1 @@
+{"MessageType":"KeepAlive","MessageId":"d29ef449-6965-4000

+ 1 - 0
tests/Jellyfin.Server.Implementations.Tests/Test Data/HttpServer/ValidPartial.json

@@ -0,0 +1 @@
+{"MessageType":"ForceKeepAlive","MessageId":"00000000-0000-0000-0000-000000000000","ServerId":null,"Data":60}{"MessageType":"KeepAlive","MessageId":"d29ef449-6965-4000

+ 0 - 0
tests/Jellyfin.Server.Implementations.Tests/LiveTv/discover.json → tests/Jellyfin.Server.Implementations.Tests/Test Data/LiveTv/discover.json


+ 0 - 0
tests/Jellyfin.Server.Implementations.Tests/LiveTv/lineup.json → tests/Jellyfin.Server.Implementations.Tests/Test Data/LiveTv/lineup.json