diff --git a/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs b/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs
new file mode 100644
index 0000000..bf2eff0
--- /dev/null
+++ b/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs
@@ -0,0 +1,397 @@
+using System;
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+namespace ManagedSSPI
+{
+ [Flags]
+ public enum CredentialsUse
+ {
+ Inbound = 0x1,
+ Outbound = 0x2,
+ InboundAndOutbound = Inbound | Outbound
+ }
+
+ [Flags]
+ public enum SecurityCapabilities
+ {
+ SupportsIntegrity = 0x1,
+ SupportsPrivacy = 0x2,
+ SupportsTokenOnly = 0x4,
+ SupportsDatagram = 0x8,
+ SupportsConnections = 0x10,
+ MultipleLegsRequired = 0x20,
+ ClientOnly = 0x40,
+ ExtendedErrorSupport = 0x80,
+ SupportsImpersonation = 0x100,
+ AccepsWin32Names = 0x200,
+ SupportsStreams = 0x400,
+ Negotiable = 0x800,
+ GSSAPICompatible = 0x1000,
+ SupportsLogon = 0x2000,
+ BuffersAreASCII = 0x4000,
+ SupportsTokenFragmentation = 0x8000,
+ SupportsMutualAuthentication = 0x10000,
+ SupportsDelegation = 0x20000,
+ SupportsChecksumOnly = 0x40000,
+ SupportsRestrictedTokens = 0x80000,
+ ExtendsNegotiate = 0x100000,
+ NegotiableByExtendedNegotiate = 0x200000,
+ AppContainerPassThrough = 0x400000,
+ AppContainerChecks = 0x800000,
+ CredentialIsolationEnabled = 0x1000000
+ }
+
+ [Flags]
+ public enum ContextRequirements
+ {
+ Delegation = 0x00000001,
+ MutualAuthentication = 0x00000002,
+ ReplayDetection = 0x00000004,
+ SequenceDetection = 0x00000008,
+ Confidentiality = 0x00000010,
+ UseSessionKey = 0x00000020,
+ PromptForCredentials = 0x00000040,
+ UseSuppliedCredentials = 0x00000080,
+ AllocateMemory = 0x00000100,
+ UseDceStyle = 0x00000200,
+ DatagramCommunications = 0x00000400,
+ ConnectionCommunications = 0x00000800,
+ CallLevel = 0x00001000,
+ FragmentSupplied = 0x00002000,
+ ExtendedError = 0x00004000,
+ StreamCommunications = 0x00008000,
+ Integrity = 0x00010000,
+ Identity = 0x00020000,
+ NullSession = 0x00040000,
+ ManualCredValidation = 0x00080000,
+ Reserved = 0x00100000,
+ FragmentToFit = 0x00200000,
+ ForwardCredentials = 0x00400000,
+ NoIntegrity = 0x00800000,
+ UseHttpStyle = 0x01000000,
+ UnverifiedTargetName = 0x20000000,
+ ConfidentialityOnly = 0x40000000
+ }
+
+ public enum SecBufferType
+ {
+ Empty = 0,
+ Data = 1,
+ Token = 2,
+ PackageParameters = 3,
+ MissingBuffer = 4,
+ ExtraData = 5,
+ StreamTrailer = 6,
+ StreamHeader = 7,
+ NegotiationInfo = 8,
+ Padding = 9,
+ Stream = 10,
+ ObjectIdList = 11,
+ OidListSignature = 12,
+ Target = 13,
+ ChannelBindings = 14,
+ ChangePassResp = 15,
+ TargetHost = 16,
+ Alert = 17,
+ AppProtocolIds = 18,
+ StrpProtProfiles = 19,
+ StrpMasterKeyId = 20,
+ TokenBinding = 21,
+ PresharedKey = 22,
+ PresharedKeyId = 23,
+ DtlsMtu = 24
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SecurityHandle
+ {
+ public IntPtr LowPart;
+ public IntPtr HighPart;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SecurityPackageInfo
+ {
+ [MarshalAs(UnmanagedType.U4)]
+ public SecurityCapabilities Capabilities;
+ public UInt16 Version;
+ public UInt16 RpcId;
+ public UInt32 MaxTokenSize;
+ [MarshalAs(UnmanagedType.LPWStr)]
+ public string Name;
+ [MarshalAs(UnmanagedType.LPWStr)]
+ public string Comment;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SecBufferDescription
+ {
+ public UInt32 version;
+ public UInt32 numOfBuffers;
+ public IntPtr buffersPtr;
+ }
+
+ [StructLayout(LayoutKind.Sequential)]
+ public struct SecBuffer
+ {
+ public UInt32 size;
+ [MarshalAs(UnmanagedType.U4)]
+ public SecBufferType type;
+ public IntPtr bufferPtr;
+ }
+
+ public class SSPIClient : IDisposable
+ {
+ private const int NoError = 0;
+ private const int ContinueNeeded = 0x90312;
+ private const int NativeDataRepresentation = 0x10;
+
+ private SecurityHandle _credHandle;
+ private SecurityHandle _contextHandle;
+ private DateTime _credExpiration;
+ private DateTime _contextExpiration;
+
+ public SSPIClient(string packageName)
+ {
+ _credHandle = new SecurityHandle()
+ {
+ HighPart = IntPtr.Zero,
+ LowPart = IntPtr.Zero
+ };
+ _contextHandle = new SecurityHandle()
+ {
+ HighPart = IntPtr.Zero,
+ LowPart = IntPtr.Zero
+ };
+ _contextExpiration = DateTime.MinValue;
+ _credExpiration = DateTime.MinValue;
+
+ UInt64 expiration = 0;
+ var retCode = AcquireCredentialsHandle(
+ null,
+ packageName,
+ CredentialsUse.Outbound,
+ IntPtr.Zero,
+ IntPtr.Zero,
+ IntPtr.Zero,
+ IntPtr.Zero,
+ ref _credHandle,
+ ref expiration
+ );
+ try
+ {
+ _credExpiration = DateTime.FromFileTime((Int64)expiration);
+ }
+ catch(ArgumentException)
+ {
+ // no expiration
+ _credExpiration = DateTime.MaxValue;
+ }
+ }
+
+ public byte[] GetClientToken(byte[] serverToken)
+ {
+ var pinnedServerToken = GCHandle.Alloc(serverToken, GCHandleType.Pinned);
+ var serverBuffer = new SecBuffer()
+ {
+ type = SecBufferType.Token,
+ size = null == serverToken ? 0 : (UInt32)serverToken.Length,
+ bufferPtr = pinnedServerToken.AddrOfPinnedObject()
+ };
+ var pinnedServerBuffer = GCHandle.Alloc(serverBuffer, GCHandleType.Pinned);
+ var clientBuffer = new SecBuffer()
+ {
+ type = SecBufferType.Token,
+ size = 0,
+ bufferPtr = IntPtr.Zero
+ };
+ var pinnedClientBuffer = GCHandle.Alloc(clientBuffer, GCHandleType.Pinned);
+ byte[] clientToken = null;
+
+ try
+ {
+ var inBuffDesc = new SecBufferDescription()
+ {
+ version = 0,
+ numOfBuffers = 1,
+ buffersPtr = pinnedServerBuffer.AddrOfPinnedObject()
+ };
+
+ var outBuffDesc = new SecBufferDescription()
+ {
+ version = 0,
+ numOfBuffers = 1,
+ buffersPtr = pinnedClientBuffer.AddrOfPinnedObject()
+ };
+
+ UInt64 expiration = 0;
+ ContextRequirements availableCapabilities = default(ContextRequirements);
+ int retCode = NoError;
+
+ if (serverToken == null)
+ {
+ // first leg - no server token
+ retCode = InitializeSecurityContext(
+ ref _credHandle,
+ IntPtr.Zero,
+ null,
+ ContextRequirements.AllocateMemory | ContextRequirements.ConnectionCommunications,
+ 0,
+ NativeDataRepresentation,
+ IntPtr.Zero,
+ 0,
+ ref _contextHandle,
+ ref outBuffDesc,
+ ref availableCapabilities,
+ ref expiration
+ );
+ }
+ else
+ {
+ retCode = InitializeSecurityContext(
+ ref _credHandle,
+ ref _contextHandle,
+ null,
+ ContextRequirements.AllocateMemory | ContextRequirements.ConnectionCommunications,
+ 0,
+ NativeDataRepresentation,
+ ref inBuffDesc,
+ 0,
+ ref _contextHandle,
+ ref outBuffDesc,
+ ref availableCapabilities,
+ ref expiration
+ );
+ }
+
+ if (retCode != NoError && retCode != ContinueNeeded)
+ throw new Win32Exception(retCode);
+
+ var newClientBuff = (SecBuffer)Marshal.PtrToStructure(outBuffDesc.buffersPtr, typeof(SecBuffer));
+ clientToken = new byte[newClientBuff.size];
+ Marshal.Copy(newClientBuff.bufferPtr, clientToken, 0, (int)newClientBuff.size);
+ FreeContextBuffer(newClientBuff.bufferPtr);
+
+ try
+ {
+ _contextExpiration = DateTime.FromFileTimeUtc((Int64)expiration);
+ }
+ catch (ArgumentException)
+ {
+ // no expiration
+ _contextExpiration = DateTime.MaxValue;
+ }
+
+ }
+ finally
+ {
+ pinnedClientBuffer.Free();
+ pinnedServerBuffer.Free();
+ pinnedServerToken.Free();
+ }
+ return clientToken;
+ }
+
+ public static SecurityPackageInfo[] EnumerateSecurityPackages()
+ {
+ UInt32 numOfPackges = 0;
+ IntPtr packgeInfosPtr = IntPtr.Zero;
+ int retCode = EnumerateSecurityPackagesW(ref numOfPackges, ref packgeInfosPtr);
+ if (retCode != NoError)
+ {
+ throw new Win32Exception(retCode);
+ }
+ try
+ {
+ var infos = new SecurityPackageInfo[numOfPackges];
+ var infoSize = Marshal.SizeOf(typeof(SecurityPackageInfo));
+ var currentPtr = packgeInfosPtr;
+ for(int i = 0; i < numOfPackges; i++)
+ {
+ infos[i] = (SecurityPackageInfo)Marshal.PtrToStructure(currentPtr, typeof(SecurityPackageInfo));
+ currentPtr = IntPtr.Add(currentPtr, infoSize);
+ }
+ return infos;
+ }
+ finally
+ {
+ FreeContextBuffer(packgeInfosPtr);
+ }
+ }
+
+ public DateTime TokenExpiration => _contextExpiration;
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int AcquireCredentialsHandle(
+ string principal,
+ string package,
+ [MarshalAs(UnmanagedType.U4)]
+ CredentialsUse credentialUse,
+ IntPtr authenticationID,
+ IntPtr authData,
+ IntPtr getKeyFn,
+ IntPtr getKeyArgument,
+ ref SecurityHandle credential,
+ ref UInt64 expiration
+ );
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int InitializeSecurityContext(
+ ref SecurityHandle credential,
+ ref SecurityHandle context,
+ string pszTargetName,
+ [MarshalAs(UnmanagedType.U4)]
+ ContextRequirements requirements,
+ int Reserved1,
+ int TargetDataRep,
+ ref SecBufferDescription inBuffDesc,
+ int Reserved2,
+ ref SecurityHandle newContext,
+ ref SecBufferDescription outBuffDesc,
+ ref ContextRequirements contextAttributes,
+ ref UInt64 expiration
+ );
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int InitializeSecurityContext(
+ ref SecurityHandle credential,
+ IntPtr context,
+ string pszTargetName,
+ [MarshalAs(UnmanagedType.U4)]
+ ContextRequirements requirements,
+ int Reserved1,
+ int TargetDataRep,
+ IntPtr inBuffDesc,
+ int Reserved2,
+ ref SecurityHandle newContext,
+ ref SecBufferDescription outBuffDesc,
+ ref ContextRequirements contextAttributes,
+ ref UInt64 expiration
+ );
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int FreeCredentialsHandle(ref SecurityHandle credential);
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int DeleteSecurityContext(ref SecurityHandle context);
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int FreeContextBuffer(IntPtr buffer);
+
+ [DllImport("secur32", CharSet = CharSet.Unicode)]
+ private static extern int EnumerateSecurityPackagesW(
+ ref UInt32 numOfPackages,
+ ref IntPtr packageInfosPtr
+ );
+
+ public void Dispose()
+ {
+ FreeCredentialsHandle(ref _credHandle);
+ DeleteSecurityContext(ref _contextHandle);
+ }
+ }
+
+
+
+}
diff --git a/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs b/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs
index 5f48621..82fb806 100644
--- a/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs
+++ b/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs
@@ -2,11 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using ManagedSSPI;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
-using System.Net.Http;
-using System.Net.Http.Headers;
+using System.Linq;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.ExceptionServices;
@@ -19,6 +19,15 @@ namespace System.Net.WebSockets.Managed
{
internal sealed class WebSocketHandle
{
+ /// Per-thread cached StringBuilder for building of strings to send on the connection.
+ [ThreadStatic]
+ private static StringBuilder t_cachedStringBuilder;
+
+ /// Default encoding for HTTP requests. Latin alphabeta no 1, ISO/IEC 8859-1.
+ private static readonly Encoding s_defaultHttpEncoding = Encoding.GetEncoding(28591);
+
+ /// Size of the receive buffer to use.
+ private const int DefaultReceiveBufferSize = 0x1000;
/// GUID appended by the server as part of the security key response. Defined in the RFC.
private const string WSServerGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -64,166 +73,291 @@ public Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescriptio
public Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) =>
_webSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
- private sealed class DirectManagedHttpClientHandler : HttpClientHandler
+ public async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options)
{
- private const string ManagedHandlerEnvVar = "COMPlus_UseManagedHttpClientHandler";
- private static readonly LocalDataStoreSlot s_managedHandlerSlot = GetSlot();
- private static readonly object s_true = true;
+ // Establish connection to the server
+ CancellationTokenRegistration registration = cancellationToken.Register(s => ((WebSocketHandle)s).Abort(), this);
+ try
+ {
+ var httpUri = new UriBuilder(uri) { Scheme = (uri.Scheme == UriScheme.Ws) ? UriScheme.Http : UriScheme.Https }.Uri;
+ var connectUri = httpUri;
+ bool useProxy = false;
+ if(options.Proxy != null && !options.Proxy.IsBypassed(httpUri))
+ {
+ useProxy = true;
+ connectUri = options.Proxy.GetProxy(httpUri);
+ }
+
+ // Connect to the remote server
+ Socket connectedSocket = await ConnectSocketAsync(connectUri.Host, connectUri.Port, cancellationToken).ConfigureAwait(false);
+ Stream stream = new NetworkStream(connectedSocket, ownsSocket: true);
+
+ // establish a tunnel if needed
+ if(useProxy)
+ {
+ stream = await EstablishTunnelTrhoughWebProxy(stream, httpUri, connectUri, cancellationToken).ConfigureAwait(false);
+ }
- private static LocalDataStoreSlot GetSlot()
+ // Upgrade to SSL if needed
+ if (httpUri.Scheme == UriScheme.Https)
+ {
+ var sslStream = new SslStream(stream);
+ await sslStream.AuthenticateAsClientAsync(
+ httpUri.Host,
+ options.ClientCertificates,
+ SecurityProtocol.AllowedSecurityProtocols,
+ checkCertificateRevocation: false).ConfigureAwait(false);
+ stream = sslStream;
+ }
+
+ // Create the security key and expected response, then build all of the request headers
+ KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept();
+ byte[] requestHeader = BuildRequestHeader(uri, options, secKeyAndSecWebSocketAccept.Key);
+
+ // Write out the header to the connection
+ await stream.WriteAsync(requestHeader, 0, requestHeader.Length, cancellationToken).ConfigureAwait(false);
+
+ // Parse the response and store our state for the remainder of the connection
+ string subprotocol = await ParseAndValidateConnectResponseAsync(stream, options, secKeyAndSecWebSocketAccept.Value, cancellationToken).ConfigureAwait(false);
+
+ _webSocket = WebSocketUtil.CreateClientWebSocket(
+ stream, subprotocol, options.ReceiveBufferSize, options.SendBufferSize, options.KeepAliveInterval, false, options.Buffer.GetValueOrDefault());
+
+ // If a concurrent Abort or Dispose came in before we set _webSocket, make sure to update it appropriately
+ if (_state == WebSocketState.Aborted)
+ {
+ _webSocket.Abort();
+ }
+ else if (_state == WebSocketState.Closed)
+ {
+ _webSocket.Dispose();
+ }
+ }
+ catch (Exception exc)
{
- LocalDataStoreSlot slot = Thread.GetNamedDataSlot(ManagedHandlerEnvVar);
- if (slot != null)
+ if (_state < WebSocketState.Closed)
{
- return slot;
+ _state = WebSocketState.Closed;
}
- try
+ Abort();
+
+ if (exc is WebSocketException)
{
- return Thread.AllocateNamedDataSlot(ManagedHandlerEnvVar);
+ throw;
}
- catch (ArgumentException) // in case of a race condition where multiple threads all try to allocate the slot concurrently
+ throw new WebSocketException(SR.net_webstatus_ConnectFailure, exc);
+ }
+ finally
+ {
+ registration.Dispose();
+ }
+ }
+
+ private async Task EstablishTunnelTrhoughWebProxy(Stream stream, Uri httpUri, Uri proxyUri, CancellationToken cancellationToken)
+ {
+ var proxyHeder = s_defaultHttpEncoding.GetBytes($"CONNECT {httpUri.Host}:{httpUri.Port} HTTP/1.1\r\nHost: {httpUri.Host}:{httpUri.Port}\r\n\r\n");
+ await stream.WriteAsync(proxyHeder, 0, proxyHeder.Length, cancellationToken).ConfigureAwait(false);
+ string statusline = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false);
+ if (statusline.StartsWith("HTTP/1.1 407"))
+ {
+ var authPackages = new List();
+ string line;
+ while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false)))
{
- return Thread.GetNamedDataSlot(ManagedHandlerEnvVar);
+ if (line.ToLowerInvariant().StartsWith("proxy-authenticate: "))
+ {
+ authPackages.Add(line.ToLowerInvariant().Substring("proxy-authenticate: ".Length));
+ }
}
+ // close annonymous connection
+ stream.Close();
+ // re-establish a new one and authenticate it.
+ var connectedSocket = await ConnectSocketAsync(proxyUri.Host, proxyUri.Port, cancellationToken).ConfigureAwait(false);
+ stream = new NetworkStream(connectedSocket, ownsSocket: true);
+ await AuthenticateProxyStream(stream, authPackages, proxyUri.Host, httpUri, cancellationToken);
}
+ else if(statusline.StartsWith("HTTP/1.1 200"))
+ {
+ // no authentication needed read the rest of the proxy response
+ string line;
+ while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false))) { }
+ }
+ return stream;
+ }
- public static DirectManagedHttpClientHandler CreateHandler()
+ private async Task AuthenticateProxyStream(Stream stream, List proxyAuthPackages, string proxyHost, Uri targetUrl, CancellationToken cancellationToken)
+ {
+ var localAuthPackages = SSPIClient.EnumerateSecurityPackages();
+ var packageToUse = "NTLM"; // a good safe default
+ foreach(var package in proxyAuthPackages)
{
- Thread.SetData(s_managedHandlerSlot, s_true);
try
{
- return new DirectManagedHttpClientHandler();
+ var localPackage = localAuthPackages.FirstOrDefault(p => p.Name.ToLowerInvariant() == package);
+ var requiredCapabilities = SecurityCapabilities.AccepsWin32Names | SecurityCapabilities.SupportsConnections;
+ if ((localPackage.Capabilities & requiredCapabilities) != 0)
+ {
+ packageToUse = localPackage.Name;
+ break; // found our package
+ }
+ }
+ catch(Exception)
+ {
+ // cat't use this particular package
+ }
+ }
+ using (var sspi = new SSPIClient(packageToUse))
+ {
+ byte[] serverToken = null;
+ bool authSucceeded = false;
+ var clientToken = Convert.ToBase64String(sspi.GetClientToken(serverToken));
+ while (!authSucceeded)
+ {
+ var authHeader = $"Proxy-Authorization: {packageToUse} {clientToken}";
+ var proxyHeder = s_defaultHttpEncoding.GetBytes($"CONNECT {targetUrl.Host}:{targetUrl.Port} HTTP/1.1\r\nHost: {targetUrl.Host}:{targetUrl.Port}\r\n{authHeader}\r\n\r\n");
+ await stream.WriteAsync(proxyHeder, 0, proxyHeder.Length, cancellationToken).ConfigureAwait(false);
+ string statusline = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false);
+ string line;
+ while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false)))
+ {
+ if (line.ToLowerInvariant().StartsWith($"proxy-authenticate: {packageToUse.ToLowerInvariant()}"))
+ {
+ serverToken = Convert.FromBase64String(line.Substring($"proxy-authenticate: {packageToUse.ToLowerInvariant()}".Length));
+ }
+ }
+ if (statusline.StartsWith("HTTP/1.1 200"))
+ authSucceeded = true;
+ else
+ clientToken = Convert.ToBase64String(sspi.GetClientToken(serverToken));
}
- finally { Thread.SetData(s_managedHandlerSlot, null); }
}
-
- public new Task SendAsync(
- HttpRequestMessage request, CancellationToken cancellationToken) =>
- base.SendAsync(request, cancellationToken);
}
- public async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options)
+ /// Connects a socket to the specified host and port, subject to cancellation and aborting.
+ /// The host to which to connect.
+ /// The port to which to connect on the host.
+ /// The CancellationToken to use to cancel the websocket.
+ /// The connected Socket.
+ private async Task ConnectSocketAsync(string host, int port, CancellationToken cancellationToken)
{
- HttpResponseMessage response = null;
- try
+ IPAddress[] addresses = await Dns.GetHostAddressesAsync(host).ConfigureAwait(false);
+
+ ExceptionDispatchInfo lastException = null;
+ foreach (IPAddress address in addresses)
{
- // Create the request message, including a uri with ws{s} switched to http{s}.
- uri = new UriBuilder(uri) { Scheme = (uri.Scheme == UriScheme.Ws) ? UriScheme.Http : UriScheme.Https }.Uri;
- var request = new HttpRequestMessage(HttpMethod.Get, uri);
- if (options._requestHeaders?.Count > 0) // use field to avoid lazily initializing the collection
+ var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+ try
{
- foreach (string key in options.RequestHeaders)
+ using (cancellationToken.Register(s => ((Socket)s).Dispose(), socket))
+ using (_abortSource.Token.Register(s => ((Socket)s).Dispose(), socket))
{
- request.Headers.Add(key, options.RequestHeaders[key]);
+ try
+ {
+ await socket.ConnectAsync(address, port).ConfigureAwait(false);
+ }
+ catch (ObjectDisposedException ode)
+ {
+ // If the socket was disposed because cancellation was requested, translate the exception
+ // into a new OperationCanceledException. Otherwise, let the original ObjectDisposedexception propagate.
+ CancellationToken token = cancellationToken.IsCancellationRequested ? cancellationToken : _abortSource.Token;
+ if (token.IsCancellationRequested)
+ {
+ throw new OperationCanceledException(new OperationCanceledException().Message, ode, token);
+ }
+ }
}
+ cancellationToken.ThrowIfCancellationRequested(); // in case of a race and socket was disposed after the await
+ _abortSource.Token.ThrowIfCancellationRequested();
+ return socket;
}
-
- // Create the security key and expected response, then build all of the request headers
- KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept();
- AddWebSocketHeaders(request, secKeyAndSecWebSocketAccept.Key, options);
-
- // Create the handler for this request and populate it with all of the options.
- DirectManagedHttpClientHandler handler = DirectManagedHttpClientHandler.CreateHandler();
- handler.UseDefaultCredentials = options.UseDefaultCredentials;
- handler.Credentials = options.Credentials;
- handler.Proxy = options.Proxy;
- handler.CookieContainer = options.Cookies;
- if (options._clientCertificates?.Count > 0) // use field to avoid lazily initializing the collection
+ catch (Exception exc)
{
- throw new NotImplementedException();
- handler.ClientCertificateOptions = ClientCertificateOption.Manual;
- //handler.ClientCertificates.AddRange(options.ClientCertificates);
+ socket.Dispose();
+ lastException = ExceptionDispatchInfo.Capture(exc);
}
- CancellationTokenSource linkedCancellation, externalAndAbortCancellation;
- if (cancellationToken.CanBeCanceled) // avoid allocating linked source if external token is not cancelable
+ }
+
+ lastException?.Throw();
+
+ Debug.Fail("We should never get here. We should have already returned or an exception should have been thrown.");
+ throw new WebSocketException(SR.net_webstatus_ConnectFailure);
+ }
+
+ /// Creates a byte[] containing the headers to send to the server.
+ /// The Uri of the server.
+ /// The options used to configure the websocket.
+ /// The generated security key to send in the Sec-WebSocket-Key header.
+ /// The byte[] containing the encoded headers ready to send to the network.
+ private static byte[] BuildRequestHeader(Uri uri, ClientWebSocketOptions options, string secKey)
+ {
+ StringBuilder builder = t_cachedStringBuilder ?? (t_cachedStringBuilder = new StringBuilder());
+ Debug.Assert(builder.Length == 0, $"Expected builder to be empty, got one of length {builder.Length}");
+ try
+ {
+ builder.Append("GET ").Append(uri.PathAndQuery).Append(" HTTP/1.1\r\n");
+
+ // Add all of the required headers, honoring Host header if set.
+ string hostHeader = options.RequestHeaders[HttpKnownHeaderNames.Host];
+ builder.Append("Host: ");
+ if (string.IsNullOrEmpty(hostHeader))
{
- linkedCancellation =
- externalAndAbortCancellation =
- CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _abortSource.Token);
+ builder.Append(uri.GetIdnHost()).Append(':').Append(uri.Port).Append("\r\n");
}
else
{
- linkedCancellation = null;
- externalAndAbortCancellation = _abortSource;
+ builder.Append(hostHeader).Append("\r\n");
}
- using (linkedCancellation)
- {
- response = await handler.SendAsync(request, externalAndAbortCancellation.Token).ConfigureAwait(false);
- externalAndAbortCancellation.Token.ThrowIfCancellationRequested();
- }
+ builder.Append("Connection: Upgrade\r\n");
+ builder.Append("Upgrade: websocket\r\n");
+ builder.Append("Sec-WebSocket-Version: 13\r\n");
+ builder.Append("Sec-WebSocket-Key: ").Append(secKey).Append("\r\n");
- // Issue the request. The response must be status code 101.
- if (response.StatusCode != HttpStatusCode.SwitchingProtocols)
- {
- throw new WebSocketException(SR.net_webstatus_ConnectFailure);
- }
+ // Add all of the additionally requested headers
+ foreach (string key in options.RequestHeaders.AllKeys)
+ {
+ if (string.Equals(key, HttpKnownHeaderNames.Host, StringComparison.OrdinalIgnoreCase))
+ {
+ // Host header handled above
+ continue;
+ }
- // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values.
- ValidateHeader(response.Headers, HttpKnownHeaderNames.Connection, "Upgrade");
- ValidateHeader(response.Headers, HttpKnownHeaderNames.Upgrade, "websocket");
- ValidateHeader(response.Headers, HttpKnownHeaderNames.SecWebSocketAccept, secKeyAndSecWebSocketAccept.Value);
+ builder.Append(key).Append(": ").Append(options.RequestHeaders[key]).Append("\r\n");
+ }
- // The SecWebSocketProtocol header is optional. We should only get it with a non-empty value if we requested subprotocols,
- // and then it must only be one of the ones we requested. If we got a subprotocol other than one we requested (or if we
- // already got one in a previous header), fail. Otherwise, track which one we got.
- string subprotocol = null;
- IEnumerable subprotocolEnumerableValues;
- if (response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketProtocol, out subprotocolEnumerableValues))
+ // Add the optional subprotocols header
+ if (options.RequestedSubProtocols.Count > 0)
{
- Debug.Assert(subprotocolEnumerableValues is string[]);
- string[] subprotocolArray = (string[])subprotocolEnumerableValues;
- if (subprotocolArray.Length != 1 ||
- (subprotocol = options.RequestedSubProtocols.Find(requested => string.Equals(requested, subprotocolArray[0], StringComparison.OrdinalIgnoreCase))) == null)
+ builder.Append(HttpKnownHeaderNames.SecWebSocketProtocol).Append(": ");
+ builder.Append(options.RequestedSubProtocols[0]);
+ for (int i = 1; i < options.RequestedSubProtocols.Count; i++)
{
- throw new WebSocketException(
- WebSocketError.UnsupportedProtocol,
- SR.Format(SR.net_WebSockets_AcceptUnsupportedProtocol, string.Join(", ", options.RequestedSubProtocols), subprotocol));
+ builder.Append(", ").Append(options.RequestedSubProtocols[i]);
}
+ builder.Append("\r\n");
}
- // Get the response stream and wrap it in a web socket.
- Stream connectedStream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
- Debug.Assert(connectedStream.CanWrite);
- Debug.Assert(connectedStream.CanRead);
- _webSocket = WebSocket.CreateClientWebSocket( // TODO https://github.com/dotnet/corefx/issues/21537: Use new API when available
- connectedStream,
- subprotocol,
- options.ReceiveBufferSize,
- options.SendBufferSize,
- options.KeepAliveInterval,
- useZeroMaskingKey: false,
- internalBuffer: options.Buffer.GetValueOrDefault());
- }
- catch (Exception exc)
- {
- if (_state < WebSocketState.Closed)
+ // Add an optional cookies header
+ if (options.Cookies != null)
{
- _state = WebSocketState.Closed;
+ string header = options.Cookies.GetCookieHeader(uri);
+ if (!string.IsNullOrWhiteSpace(header))
+ {
+ builder.Append(HttpKnownHeaderNames.Cookie).Append(": ").Append(header).Append("\r\n");
+ }
}
- Abort();
- response?.Dispose();
+ // End the headers
+ builder.Append("\r\n");
- if (exc is WebSocketException)
- {
- throw;
- }
- throw new WebSocketException(SR.net_webstatus_ConnectFailure, exc);
+ // Return the bytes for the built up header
+ return s_defaultHttpEncoding.GetBytes(builder.ToString());
}
- }
-
- /// The generated security key to send in the Sec-WebSocket-Key header.
- private static void AddWebSocketHeaders(HttpRequestMessage request, string secKey, ClientWebSocketOptions options)
- {
- request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade);
- request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Upgrade, "websocket");
- request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketVersion, "13");
- request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketKey, secKey);
- if (options._requestedSubProtocols?.Count > 0)
+ finally
{
- request.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols));
+ // Make sure we clear the builder
+ builder.Clear();
}
}
@@ -244,21 +378,174 @@ private static KeyValuePair CreateSecKeyAndSecWebSocketAccept()
}
}
- private static void ValidateHeader(HttpHeaders headers, string name, string expectedValue)
+ /// Read and validate the connect response headers from the server.
+ /// The stream from which to read the response headers.
+ /// The options used to configure the websocket.
+ /// The expected value of the Sec-WebSocket-Accept header.
+ /// The CancellationToken to use to cancel the websocket.
+ /// The agreed upon subprotocol with the server, or null if there was none.
+ private async Task ParseAndValidateConnectResponseAsync(
+ Stream stream, ClientWebSocketOptions options, string expectedSecWebSocketAccept, CancellationToken cancellationToken)
{
- if (!headers.TryGetValues(name, out IEnumerable values))
+ // Read the first line of the response
+ string statusLine = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false);
+
+ // Depending on the underlying sockets implementation and timing, connecting to a server that then
+ // immediately closes the connection may either result in an exception getting thrown from the connect
+ // earlier, or it may result in getting to here but reading 0 bytes. If we read 0 bytes and thus have
+ // an empty status line, treat it as a connect failure.
+ if (string.IsNullOrEmpty(statusLine))
+ {
+ throw new WebSocketException(SR.Format(SR.net_webstatus_ConnectFailure));
+ }
+
+ const string ExpectedStatusStart = "HTTP/1.1 ";
+ const string ExpectedStatusStatWithCode = "HTTP/1.1 101"; // 101 == SwitchingProtocols
+
+ // If the status line doesn't begin with "HTTP/1.1" or isn't long enough to contain a status code, fail.
+ if (!statusLine.StartsWith(ExpectedStatusStart, StringComparison.Ordinal) || statusLine.Length < ExpectedStatusStatWithCode.Length)
{
- ThrowConnectFailure();
+ throw new WebSocketException(WebSocketError.HeaderError);
}
- Debug.Assert(values is string[]);
- string[] array = (string[])values;
- if (array.Length != 1 || !string.Equals(array[0], expectedValue, StringComparison.OrdinalIgnoreCase))
+ // If the status line doesn't contain a status code 101, or if it's long enough to have a status description
+ // but doesn't contain whitespace after the 101, fail.
+ if (!statusLine.StartsWith(ExpectedStatusStatWithCode, StringComparison.Ordinal) ||
+ (statusLine.Length > ExpectedStatusStatWithCode.Length && !char.IsWhiteSpace(statusLine[ExpectedStatusStatWithCode.Length])))
{
- throw new WebSocketException(SR.Format(SR.net_WebSockets_InvalidResponseHeader, name, string.Join(", ", array)));
+ throw new WebSocketException(SR.net_webstatus_ConnectFailure);
+ }
+
+ // Read each response header. Be liberal in parsing the response header, treating
+ // everything to the left of the colon as the key and everything to the right as the value, trimming both.
+ // For each header, validate that we got the expected value.
+ bool foundUpgrade = false, foundConnection = false, foundSecWebSocketAccept = false;
+ string subprotocol = null;
+ string line;
+ while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false)))
+ {
+ int colonIndex = line.IndexOf(':');
+ if (colonIndex == -1)
+ {
+ throw new WebSocketException(WebSocketError.HeaderError);
+ }
+
+ string headerName = line.SubstringTrim(0, colonIndex);
+ string headerValue = line.SubstringTrim(colonIndex + 1);
+
+ // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values.
+ ValidateAndTrackHeader(HttpKnownHeaderNames.Connection, "Upgrade", headerName, headerValue, ref foundConnection);
+ ValidateAndTrackHeader(HttpKnownHeaderNames.Upgrade, "websocket", headerName, headerValue, ref foundUpgrade);
+ ValidateAndTrackHeader(HttpKnownHeaderNames.SecWebSocketAccept, expectedSecWebSocketAccept, headerName, headerValue, ref foundSecWebSocketAccept);
+
+ // The SecWebSocketProtocol header is optional. We should only get it with a non-empty value if we requested subprotocols,
+ // and then it must only be one of the ones we requested. If we got a subprotocol other than one we requested (or if we
+ // already got one in a previous header), fail. Otherwise, track which one we got.
+ if (string.Equals(HttpKnownHeaderNames.SecWebSocketProtocol, headerName, StringComparison.OrdinalIgnoreCase) &&
+ !string.IsNullOrWhiteSpace(headerValue))
+ {
+ string newSubprotocol = options.RequestedSubProtocols.Find(requested => string.Equals(requested, headerValue, StringComparison.OrdinalIgnoreCase));
+ if (newSubprotocol == null || subprotocol != null)
+ {
+ throw new WebSocketException(
+ WebSocketError.UnsupportedProtocol,
+ SR.Format(SR.net_WebSockets_AcceptUnsupportedProtocol, string.Join(", ", options.RequestedSubProtocols), subprotocol));
+ }
+ subprotocol = newSubprotocol;
+ }
+ }
+ if (!foundUpgrade || !foundConnection || !foundSecWebSocketAccept)
+ {
+ throw new WebSocketException(SR.net_webstatus_ConnectFailure);
+ }
+
+ return subprotocol;
+ }
+
+ /// Validates a received header against expected values and tracks that we've received it.
+ /// The header name against which we're comparing.
+ /// The header value against which we're comparing.
+ /// The actual header name received.
+ /// The actual header value received.
+ /// A bool tracking whether this header has been seen.
+ private static void ValidateAndTrackHeader(
+ string targetHeaderName, string targetHeaderValue,
+ string foundHeaderName, string foundHeaderValue,
+ ref bool foundHeader)
+ {
+ bool isTargetHeader = string.Equals(targetHeaderName, foundHeaderName, StringComparison.OrdinalIgnoreCase);
+ if (!foundHeader)
+ {
+ if (isTargetHeader)
+ {
+ if (!string.Equals(targetHeaderValue, foundHeaderValue, StringComparison.OrdinalIgnoreCase))
+ {
+ throw new WebSocketException(SR.Format(SR.net_WebSockets_InvalidResponseHeader, targetHeaderName, foundHeaderValue));
+ }
+ foundHeader = true;
+ }
+ }
+ else
+ {
+ if (isTargetHeader)
+ {
+ throw new WebSocketException(SR.Format(SR.net_webstatus_ConnectFailure));
+ }
}
}
- private static void ThrowConnectFailure() => throw new WebSocketException(SR.net_webstatus_ConnectFailure);
+ /// Reads a line from the stream.
+ /// The stream from which to read.
+ /// The CancellationToken used to cancel the websocket.
+ /// The read line, or null if none could be read.
+ private static async Task ReadResponseHeaderLineAsync(Stream stream, CancellationToken cancellationToken)
+ {
+ StringBuilder sb = t_cachedStringBuilder;
+ if (sb != null)
+ {
+ t_cachedStringBuilder = null;
+ Debug.Assert(sb.Length == 0, $"Expected empty StringBuilder");
+ }
+ else
+ {
+ sb = new StringBuilder();
+ }
+
+ var arr = new byte[1];
+ char prevChar = '\0';
+ try
+ {
+ // TODO: Reading one byte is extremely inefficient. The problem, however,
+ // is that if we read multiple bytes, we could end up reading bytes post-headers
+ // that are part of messages meant to be read by the managed websocket after
+ // the connection. The likely solution here is to wrap the stream in a BufferedStream,
+ // though a) that comes at the expense of an extra set of virtual calls, b)
+ // it adds a buffer when the managed websocket will already be using a buffer, and
+ // c) it's not exposed on the version of the System.IO contract we're currently using.
+ while (await stream.ReadAsync(arr, 0, 1, cancellationToken).ConfigureAwait(false) == 1)
+ {
+ // Process the next char
+ char curChar = (char)arr[0];
+ if (prevChar == '\r' && curChar == '\n')
+ {
+ break;
+ }
+ sb.Append(curChar);
+ prevChar = curChar;
+ }
+
+ if (sb.Length > 0 && sb[sb.Length - 1] == '\r')
+ {
+ sb.Length = sb.Length - 1;
+ }
+
+ return sb.ToString();
+ }
+ finally
+ {
+ sb.Clear();
+ t_cachedStringBuilder = sb;
+ }
+ }
}
}
\ No newline at end of file