diff --git a/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs b/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs new file mode 100644 index 0000000..bf2eff0 --- /dev/null +++ b/System.Net.WebSockets.Client.Managed/SSPIWrapper.cs @@ -0,0 +1,397 @@ +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace ManagedSSPI +{ + [Flags] + public enum CredentialsUse + { + Inbound = 0x1, + Outbound = 0x2, + InboundAndOutbound = Inbound | Outbound + } + + [Flags] + public enum SecurityCapabilities + { + SupportsIntegrity = 0x1, + SupportsPrivacy = 0x2, + SupportsTokenOnly = 0x4, + SupportsDatagram = 0x8, + SupportsConnections = 0x10, + MultipleLegsRequired = 0x20, + ClientOnly = 0x40, + ExtendedErrorSupport = 0x80, + SupportsImpersonation = 0x100, + AccepsWin32Names = 0x200, + SupportsStreams = 0x400, + Negotiable = 0x800, + GSSAPICompatible = 0x1000, + SupportsLogon = 0x2000, + BuffersAreASCII = 0x4000, + SupportsTokenFragmentation = 0x8000, + SupportsMutualAuthentication = 0x10000, + SupportsDelegation = 0x20000, + SupportsChecksumOnly = 0x40000, + SupportsRestrictedTokens = 0x80000, + ExtendsNegotiate = 0x100000, + NegotiableByExtendedNegotiate = 0x200000, + AppContainerPassThrough = 0x400000, + AppContainerChecks = 0x800000, + CredentialIsolationEnabled = 0x1000000 + } + + [Flags] + public enum ContextRequirements + { + Delegation = 0x00000001, + MutualAuthentication = 0x00000002, + ReplayDetection = 0x00000004, + SequenceDetection = 0x00000008, + Confidentiality = 0x00000010, + UseSessionKey = 0x00000020, + PromptForCredentials = 0x00000040, + UseSuppliedCredentials = 0x00000080, + AllocateMemory = 0x00000100, + UseDceStyle = 0x00000200, + DatagramCommunications = 0x00000400, + ConnectionCommunications = 0x00000800, + CallLevel = 0x00001000, + FragmentSupplied = 0x00002000, + ExtendedError = 0x00004000, + StreamCommunications = 0x00008000, + Integrity = 0x00010000, + Identity = 0x00020000, + NullSession = 0x00040000, + ManualCredValidation = 0x00080000, + Reserved = 0x00100000, + FragmentToFit = 0x00200000, + ForwardCredentials = 0x00400000, + NoIntegrity = 0x00800000, + UseHttpStyle = 0x01000000, + UnverifiedTargetName = 0x20000000, + ConfidentialityOnly = 0x40000000 + } + + public enum SecBufferType + { + Empty = 0, + Data = 1, + Token = 2, + PackageParameters = 3, + MissingBuffer = 4, + ExtraData = 5, + StreamTrailer = 6, + StreamHeader = 7, + NegotiationInfo = 8, + Padding = 9, + Stream = 10, + ObjectIdList = 11, + OidListSignature = 12, + Target = 13, + ChannelBindings = 14, + ChangePassResp = 15, + TargetHost = 16, + Alert = 17, + AppProtocolIds = 18, + StrpProtProfiles = 19, + StrpMasterKeyId = 20, + TokenBinding = 21, + PresharedKey = 22, + PresharedKeyId = 23, + DtlsMtu = 24 + } + + [StructLayout(LayoutKind.Sequential)] + public struct SecurityHandle + { + public IntPtr LowPart; + public IntPtr HighPart; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SecurityPackageInfo + { + [MarshalAs(UnmanagedType.U4)] + public SecurityCapabilities Capabilities; + public UInt16 Version; + public UInt16 RpcId; + public UInt32 MaxTokenSize; + [MarshalAs(UnmanagedType.LPWStr)] + public string Name; + [MarshalAs(UnmanagedType.LPWStr)] + public string Comment; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SecBufferDescription + { + public UInt32 version; + public UInt32 numOfBuffers; + public IntPtr buffersPtr; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SecBuffer + { + public UInt32 size; + [MarshalAs(UnmanagedType.U4)] + public SecBufferType type; + public IntPtr bufferPtr; + } + + public class SSPIClient : IDisposable + { + private const int NoError = 0; + private const int ContinueNeeded = 0x90312; + private const int NativeDataRepresentation = 0x10; + + private SecurityHandle _credHandle; + private SecurityHandle _contextHandle; + private DateTime _credExpiration; + private DateTime _contextExpiration; + + public SSPIClient(string packageName) + { + _credHandle = new SecurityHandle() + { + HighPart = IntPtr.Zero, + LowPart = IntPtr.Zero + }; + _contextHandle = new SecurityHandle() + { + HighPart = IntPtr.Zero, + LowPart = IntPtr.Zero + }; + _contextExpiration = DateTime.MinValue; + _credExpiration = DateTime.MinValue; + + UInt64 expiration = 0; + var retCode = AcquireCredentialsHandle( + null, + packageName, + CredentialsUse.Outbound, + IntPtr.Zero, + IntPtr.Zero, + IntPtr.Zero, + IntPtr.Zero, + ref _credHandle, + ref expiration + ); + try + { + _credExpiration = DateTime.FromFileTime((Int64)expiration); + } + catch(ArgumentException) + { + // no expiration + _credExpiration = DateTime.MaxValue; + } + } + + public byte[] GetClientToken(byte[] serverToken) + { + var pinnedServerToken = GCHandle.Alloc(serverToken, GCHandleType.Pinned); + var serverBuffer = new SecBuffer() + { + type = SecBufferType.Token, + size = null == serverToken ? 0 : (UInt32)serverToken.Length, + bufferPtr = pinnedServerToken.AddrOfPinnedObject() + }; + var pinnedServerBuffer = GCHandle.Alloc(serverBuffer, GCHandleType.Pinned); + var clientBuffer = new SecBuffer() + { + type = SecBufferType.Token, + size = 0, + bufferPtr = IntPtr.Zero + }; + var pinnedClientBuffer = GCHandle.Alloc(clientBuffer, GCHandleType.Pinned); + byte[] clientToken = null; + + try + { + var inBuffDesc = new SecBufferDescription() + { + version = 0, + numOfBuffers = 1, + buffersPtr = pinnedServerBuffer.AddrOfPinnedObject() + }; + + var outBuffDesc = new SecBufferDescription() + { + version = 0, + numOfBuffers = 1, + buffersPtr = pinnedClientBuffer.AddrOfPinnedObject() + }; + + UInt64 expiration = 0; + ContextRequirements availableCapabilities = default(ContextRequirements); + int retCode = NoError; + + if (serverToken == null) + { + // first leg - no server token + retCode = InitializeSecurityContext( + ref _credHandle, + IntPtr.Zero, + null, + ContextRequirements.AllocateMemory | ContextRequirements.ConnectionCommunications, + 0, + NativeDataRepresentation, + IntPtr.Zero, + 0, + ref _contextHandle, + ref outBuffDesc, + ref availableCapabilities, + ref expiration + ); + } + else + { + retCode = InitializeSecurityContext( + ref _credHandle, + ref _contextHandle, + null, + ContextRequirements.AllocateMemory | ContextRequirements.ConnectionCommunications, + 0, + NativeDataRepresentation, + ref inBuffDesc, + 0, + ref _contextHandle, + ref outBuffDesc, + ref availableCapabilities, + ref expiration + ); + } + + if (retCode != NoError && retCode != ContinueNeeded) + throw new Win32Exception(retCode); + + var newClientBuff = (SecBuffer)Marshal.PtrToStructure(outBuffDesc.buffersPtr, typeof(SecBuffer)); + clientToken = new byte[newClientBuff.size]; + Marshal.Copy(newClientBuff.bufferPtr, clientToken, 0, (int)newClientBuff.size); + FreeContextBuffer(newClientBuff.bufferPtr); + + try + { + _contextExpiration = DateTime.FromFileTimeUtc((Int64)expiration); + } + catch (ArgumentException) + { + // no expiration + _contextExpiration = DateTime.MaxValue; + } + + } + finally + { + pinnedClientBuffer.Free(); + pinnedServerBuffer.Free(); + pinnedServerToken.Free(); + } + return clientToken; + } + + public static SecurityPackageInfo[] EnumerateSecurityPackages() + { + UInt32 numOfPackges = 0; + IntPtr packgeInfosPtr = IntPtr.Zero; + int retCode = EnumerateSecurityPackagesW(ref numOfPackges, ref packgeInfosPtr); + if (retCode != NoError) + { + throw new Win32Exception(retCode); + } + try + { + var infos = new SecurityPackageInfo[numOfPackges]; + var infoSize = Marshal.SizeOf(typeof(SecurityPackageInfo)); + var currentPtr = packgeInfosPtr; + for(int i = 0; i < numOfPackges; i++) + { + infos[i] = (SecurityPackageInfo)Marshal.PtrToStructure(currentPtr, typeof(SecurityPackageInfo)); + currentPtr = IntPtr.Add(currentPtr, infoSize); + } + return infos; + } + finally + { + FreeContextBuffer(packgeInfosPtr); + } + } + + public DateTime TokenExpiration => _contextExpiration; + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int AcquireCredentialsHandle( + string principal, + string package, + [MarshalAs(UnmanagedType.U4)] + CredentialsUse credentialUse, + IntPtr authenticationID, + IntPtr authData, + IntPtr getKeyFn, + IntPtr getKeyArgument, + ref SecurityHandle credential, + ref UInt64 expiration + ); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int InitializeSecurityContext( + ref SecurityHandle credential, + ref SecurityHandle context, + string pszTargetName, + [MarshalAs(UnmanagedType.U4)] + ContextRequirements requirements, + int Reserved1, + int TargetDataRep, + ref SecBufferDescription inBuffDesc, + int Reserved2, + ref SecurityHandle newContext, + ref SecBufferDescription outBuffDesc, + ref ContextRequirements contextAttributes, + ref UInt64 expiration + ); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int InitializeSecurityContext( + ref SecurityHandle credential, + IntPtr context, + string pszTargetName, + [MarshalAs(UnmanagedType.U4)] + ContextRequirements requirements, + int Reserved1, + int TargetDataRep, + IntPtr inBuffDesc, + int Reserved2, + ref SecurityHandle newContext, + ref SecBufferDescription outBuffDesc, + ref ContextRequirements contextAttributes, + ref UInt64 expiration + ); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int FreeCredentialsHandle(ref SecurityHandle credential); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int DeleteSecurityContext(ref SecurityHandle context); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int FreeContextBuffer(IntPtr buffer); + + [DllImport("secur32", CharSet = CharSet.Unicode)] + private static extern int EnumerateSecurityPackagesW( + ref UInt32 numOfPackages, + ref IntPtr packageInfosPtr + ); + + public void Dispose() + { + FreeCredentialsHandle(ref _credHandle); + DeleteSecurityContext(ref _contextHandle); + } + } + + + +} diff --git a/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs b/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs index 5f48621..82fb806 100644 --- a/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs +++ b/System.Net.WebSockets.Client.Managed/WebSocketHandle.Managed.cs @@ -2,11 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using ManagedSSPI; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Net.Http; -using System.Net.Http.Headers; +using System.Linq; using System.Net.Security; using System.Net.Sockets; using System.Runtime.ExceptionServices; @@ -19,6 +19,15 @@ namespace System.Net.WebSockets.Managed { internal sealed class WebSocketHandle { + /// Per-thread cached StringBuilder for building of strings to send on the connection. + [ThreadStatic] + private static StringBuilder t_cachedStringBuilder; + + /// Default encoding for HTTP requests. Latin alphabeta no 1, ISO/IEC 8859-1. + private static readonly Encoding s_defaultHttpEncoding = Encoding.GetEncoding(28591); + + /// Size of the receive buffer to use. + private const int DefaultReceiveBufferSize = 0x1000; /// GUID appended by the server as part of the security key response. Defined in the RFC. private const string WSServerGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -64,166 +73,291 @@ public Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescriptio public Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) => _webSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); - private sealed class DirectManagedHttpClientHandler : HttpClientHandler + public async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options) { - private const string ManagedHandlerEnvVar = "COMPlus_UseManagedHttpClientHandler"; - private static readonly LocalDataStoreSlot s_managedHandlerSlot = GetSlot(); - private static readonly object s_true = true; + // Establish connection to the server + CancellationTokenRegistration registration = cancellationToken.Register(s => ((WebSocketHandle)s).Abort(), this); + try + { + var httpUri = new UriBuilder(uri) { Scheme = (uri.Scheme == UriScheme.Ws) ? UriScheme.Http : UriScheme.Https }.Uri; + var connectUri = httpUri; + bool useProxy = false; + if(options.Proxy != null && !options.Proxy.IsBypassed(httpUri)) + { + useProxy = true; + connectUri = options.Proxy.GetProxy(httpUri); + } + + // Connect to the remote server + Socket connectedSocket = await ConnectSocketAsync(connectUri.Host, connectUri.Port, cancellationToken).ConfigureAwait(false); + Stream stream = new NetworkStream(connectedSocket, ownsSocket: true); + + // establish a tunnel if needed + if(useProxy) + { + stream = await EstablishTunnelTrhoughWebProxy(stream, httpUri, connectUri, cancellationToken).ConfigureAwait(false); + } - private static LocalDataStoreSlot GetSlot() + // Upgrade to SSL if needed + if (httpUri.Scheme == UriScheme.Https) + { + var sslStream = new SslStream(stream); + await sslStream.AuthenticateAsClientAsync( + httpUri.Host, + options.ClientCertificates, + SecurityProtocol.AllowedSecurityProtocols, + checkCertificateRevocation: false).ConfigureAwait(false); + stream = sslStream; + } + + // Create the security key and expected response, then build all of the request headers + KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept(); + byte[] requestHeader = BuildRequestHeader(uri, options, secKeyAndSecWebSocketAccept.Key); + + // Write out the header to the connection + await stream.WriteAsync(requestHeader, 0, requestHeader.Length, cancellationToken).ConfigureAwait(false); + + // Parse the response and store our state for the remainder of the connection + string subprotocol = await ParseAndValidateConnectResponseAsync(stream, options, secKeyAndSecWebSocketAccept.Value, cancellationToken).ConfigureAwait(false); + + _webSocket = WebSocketUtil.CreateClientWebSocket( + stream, subprotocol, options.ReceiveBufferSize, options.SendBufferSize, options.KeepAliveInterval, false, options.Buffer.GetValueOrDefault()); + + // If a concurrent Abort or Dispose came in before we set _webSocket, make sure to update it appropriately + if (_state == WebSocketState.Aborted) + { + _webSocket.Abort(); + } + else if (_state == WebSocketState.Closed) + { + _webSocket.Dispose(); + } + } + catch (Exception exc) { - LocalDataStoreSlot slot = Thread.GetNamedDataSlot(ManagedHandlerEnvVar); - if (slot != null) + if (_state < WebSocketState.Closed) { - return slot; + _state = WebSocketState.Closed; } - try + Abort(); + + if (exc is WebSocketException) { - return Thread.AllocateNamedDataSlot(ManagedHandlerEnvVar); + throw; } - catch (ArgumentException) // in case of a race condition where multiple threads all try to allocate the slot concurrently + throw new WebSocketException(SR.net_webstatus_ConnectFailure, exc); + } + finally + { + registration.Dispose(); + } + } + + private async Task EstablishTunnelTrhoughWebProxy(Stream stream, Uri httpUri, Uri proxyUri, CancellationToken cancellationToken) + { + var proxyHeder = s_defaultHttpEncoding.GetBytes($"CONNECT {httpUri.Host}:{httpUri.Port} HTTP/1.1\r\nHost: {httpUri.Host}:{httpUri.Port}\r\n\r\n"); + await stream.WriteAsync(proxyHeder, 0, proxyHeder.Length, cancellationToken).ConfigureAwait(false); + string statusline = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false); + if (statusline.StartsWith("HTTP/1.1 407")) + { + var authPackages = new List(); + string line; + while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false))) { - return Thread.GetNamedDataSlot(ManagedHandlerEnvVar); + if (line.ToLowerInvariant().StartsWith("proxy-authenticate: ")) + { + authPackages.Add(line.ToLowerInvariant().Substring("proxy-authenticate: ".Length)); + } } + // close annonymous connection + stream.Close(); + // re-establish a new one and authenticate it. + var connectedSocket = await ConnectSocketAsync(proxyUri.Host, proxyUri.Port, cancellationToken).ConfigureAwait(false); + stream = new NetworkStream(connectedSocket, ownsSocket: true); + await AuthenticateProxyStream(stream, authPackages, proxyUri.Host, httpUri, cancellationToken); } + else if(statusline.StartsWith("HTTP/1.1 200")) + { + // no authentication needed read the rest of the proxy response + string line; + while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false))) { } + } + return stream; + } - public static DirectManagedHttpClientHandler CreateHandler() + private async Task AuthenticateProxyStream(Stream stream, List proxyAuthPackages, string proxyHost, Uri targetUrl, CancellationToken cancellationToken) + { + var localAuthPackages = SSPIClient.EnumerateSecurityPackages(); + var packageToUse = "NTLM"; // a good safe default + foreach(var package in proxyAuthPackages) { - Thread.SetData(s_managedHandlerSlot, s_true); try { - return new DirectManagedHttpClientHandler(); + var localPackage = localAuthPackages.FirstOrDefault(p => p.Name.ToLowerInvariant() == package); + var requiredCapabilities = SecurityCapabilities.AccepsWin32Names | SecurityCapabilities.SupportsConnections; + if ((localPackage.Capabilities & requiredCapabilities) != 0) + { + packageToUse = localPackage.Name; + break; // found our package + } + } + catch(Exception) + { + // cat't use this particular package + } + } + using (var sspi = new SSPIClient(packageToUse)) + { + byte[] serverToken = null; + bool authSucceeded = false; + var clientToken = Convert.ToBase64String(sspi.GetClientToken(serverToken)); + while (!authSucceeded) + { + var authHeader = $"Proxy-Authorization: {packageToUse} {clientToken}"; + var proxyHeder = s_defaultHttpEncoding.GetBytes($"CONNECT {targetUrl.Host}:{targetUrl.Port} HTTP/1.1\r\nHost: {targetUrl.Host}:{targetUrl.Port}\r\n{authHeader}\r\n\r\n"); + await stream.WriteAsync(proxyHeder, 0, proxyHeder.Length, cancellationToken).ConfigureAwait(false); + string statusline = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false); + string line; + while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false))) + { + if (line.ToLowerInvariant().StartsWith($"proxy-authenticate: {packageToUse.ToLowerInvariant()}")) + { + serverToken = Convert.FromBase64String(line.Substring($"proxy-authenticate: {packageToUse.ToLowerInvariant()}".Length)); + } + } + if (statusline.StartsWith("HTTP/1.1 200")) + authSucceeded = true; + else + clientToken = Convert.ToBase64String(sspi.GetClientToken(serverToken)); } - finally { Thread.SetData(s_managedHandlerSlot, null); } } - - public new Task SendAsync( - HttpRequestMessage request, CancellationToken cancellationToken) => - base.SendAsync(request, cancellationToken); } - public async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken, ClientWebSocketOptions options) + /// Connects a socket to the specified host and port, subject to cancellation and aborting. + /// The host to which to connect. + /// The port to which to connect on the host. + /// The CancellationToken to use to cancel the websocket. + /// The connected Socket. + private async Task ConnectSocketAsync(string host, int port, CancellationToken cancellationToken) { - HttpResponseMessage response = null; - try + IPAddress[] addresses = await Dns.GetHostAddressesAsync(host).ConfigureAwait(false); + + ExceptionDispatchInfo lastException = null; + foreach (IPAddress address in addresses) { - // Create the request message, including a uri with ws{s} switched to http{s}. - uri = new UriBuilder(uri) { Scheme = (uri.Scheme == UriScheme.Ws) ? UriScheme.Http : UriScheme.Https }.Uri; - var request = new HttpRequestMessage(HttpMethod.Get, uri); - if (options._requestHeaders?.Count > 0) // use field to avoid lazily initializing the collection + var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + try { - foreach (string key in options.RequestHeaders) + using (cancellationToken.Register(s => ((Socket)s).Dispose(), socket)) + using (_abortSource.Token.Register(s => ((Socket)s).Dispose(), socket)) { - request.Headers.Add(key, options.RequestHeaders[key]); + try + { + await socket.ConnectAsync(address, port).ConfigureAwait(false); + } + catch (ObjectDisposedException ode) + { + // If the socket was disposed because cancellation was requested, translate the exception + // into a new OperationCanceledException. Otherwise, let the original ObjectDisposedexception propagate. + CancellationToken token = cancellationToken.IsCancellationRequested ? cancellationToken : _abortSource.Token; + if (token.IsCancellationRequested) + { + throw new OperationCanceledException(new OperationCanceledException().Message, ode, token); + } + } } + cancellationToken.ThrowIfCancellationRequested(); // in case of a race and socket was disposed after the await + _abortSource.Token.ThrowIfCancellationRequested(); + return socket; } - - // Create the security key and expected response, then build all of the request headers - KeyValuePair secKeyAndSecWebSocketAccept = CreateSecKeyAndSecWebSocketAccept(); - AddWebSocketHeaders(request, secKeyAndSecWebSocketAccept.Key, options); - - // Create the handler for this request and populate it with all of the options. - DirectManagedHttpClientHandler handler = DirectManagedHttpClientHandler.CreateHandler(); - handler.UseDefaultCredentials = options.UseDefaultCredentials; - handler.Credentials = options.Credentials; - handler.Proxy = options.Proxy; - handler.CookieContainer = options.Cookies; - if (options._clientCertificates?.Count > 0) // use field to avoid lazily initializing the collection + catch (Exception exc) { - throw new NotImplementedException(); - handler.ClientCertificateOptions = ClientCertificateOption.Manual; - //handler.ClientCertificates.AddRange(options.ClientCertificates); + socket.Dispose(); + lastException = ExceptionDispatchInfo.Capture(exc); } - CancellationTokenSource linkedCancellation, externalAndAbortCancellation; - if (cancellationToken.CanBeCanceled) // avoid allocating linked source if external token is not cancelable + } + + lastException?.Throw(); + + Debug.Fail("We should never get here. We should have already returned or an exception should have been thrown."); + throw new WebSocketException(SR.net_webstatus_ConnectFailure); + } + + /// Creates a byte[] containing the headers to send to the server. + /// The Uri of the server. + /// The options used to configure the websocket. + /// The generated security key to send in the Sec-WebSocket-Key header. + /// The byte[] containing the encoded headers ready to send to the network. + private static byte[] BuildRequestHeader(Uri uri, ClientWebSocketOptions options, string secKey) + { + StringBuilder builder = t_cachedStringBuilder ?? (t_cachedStringBuilder = new StringBuilder()); + Debug.Assert(builder.Length == 0, $"Expected builder to be empty, got one of length {builder.Length}"); + try + { + builder.Append("GET ").Append(uri.PathAndQuery).Append(" HTTP/1.1\r\n"); + + // Add all of the required headers, honoring Host header if set. + string hostHeader = options.RequestHeaders[HttpKnownHeaderNames.Host]; + builder.Append("Host: "); + if (string.IsNullOrEmpty(hostHeader)) { - linkedCancellation = - externalAndAbortCancellation = - CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _abortSource.Token); + builder.Append(uri.GetIdnHost()).Append(':').Append(uri.Port).Append("\r\n"); } else { - linkedCancellation = null; - externalAndAbortCancellation = _abortSource; + builder.Append(hostHeader).Append("\r\n"); } - using (linkedCancellation) - { - response = await handler.SendAsync(request, externalAndAbortCancellation.Token).ConfigureAwait(false); - externalAndAbortCancellation.Token.ThrowIfCancellationRequested(); - } + builder.Append("Connection: Upgrade\r\n"); + builder.Append("Upgrade: websocket\r\n"); + builder.Append("Sec-WebSocket-Version: 13\r\n"); + builder.Append("Sec-WebSocket-Key: ").Append(secKey).Append("\r\n"); - // Issue the request. The response must be status code 101. - if (response.StatusCode != HttpStatusCode.SwitchingProtocols) - { - throw new WebSocketException(SR.net_webstatus_ConnectFailure); - } + // Add all of the additionally requested headers + foreach (string key in options.RequestHeaders.AllKeys) + { + if (string.Equals(key, HttpKnownHeaderNames.Host, StringComparison.OrdinalIgnoreCase)) + { + // Host header handled above + continue; + } - // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values. - ValidateHeader(response.Headers, HttpKnownHeaderNames.Connection, "Upgrade"); - ValidateHeader(response.Headers, HttpKnownHeaderNames.Upgrade, "websocket"); - ValidateHeader(response.Headers, HttpKnownHeaderNames.SecWebSocketAccept, secKeyAndSecWebSocketAccept.Value); + builder.Append(key).Append(": ").Append(options.RequestHeaders[key]).Append("\r\n"); + } - // The SecWebSocketProtocol header is optional. We should only get it with a non-empty value if we requested subprotocols, - // and then it must only be one of the ones we requested. If we got a subprotocol other than one we requested (or if we - // already got one in a previous header), fail. Otherwise, track which one we got. - string subprotocol = null; - IEnumerable subprotocolEnumerableValues; - if (response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketProtocol, out subprotocolEnumerableValues)) + // Add the optional subprotocols header + if (options.RequestedSubProtocols.Count > 0) { - Debug.Assert(subprotocolEnumerableValues is string[]); - string[] subprotocolArray = (string[])subprotocolEnumerableValues; - if (subprotocolArray.Length != 1 || - (subprotocol = options.RequestedSubProtocols.Find(requested => string.Equals(requested, subprotocolArray[0], StringComparison.OrdinalIgnoreCase))) == null) + builder.Append(HttpKnownHeaderNames.SecWebSocketProtocol).Append(": "); + builder.Append(options.RequestedSubProtocols[0]); + for (int i = 1; i < options.RequestedSubProtocols.Count; i++) { - throw new WebSocketException( - WebSocketError.UnsupportedProtocol, - SR.Format(SR.net_WebSockets_AcceptUnsupportedProtocol, string.Join(", ", options.RequestedSubProtocols), subprotocol)); + builder.Append(", ").Append(options.RequestedSubProtocols[i]); } + builder.Append("\r\n"); } - // Get the response stream and wrap it in a web socket. - Stream connectedStream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); - Debug.Assert(connectedStream.CanWrite); - Debug.Assert(connectedStream.CanRead); - _webSocket = WebSocket.CreateClientWebSocket( // TODO https://github.com/dotnet/corefx/issues/21537: Use new API when available - connectedStream, - subprotocol, - options.ReceiveBufferSize, - options.SendBufferSize, - options.KeepAliveInterval, - useZeroMaskingKey: false, - internalBuffer: options.Buffer.GetValueOrDefault()); - } - catch (Exception exc) - { - if (_state < WebSocketState.Closed) + // Add an optional cookies header + if (options.Cookies != null) { - _state = WebSocketState.Closed; + string header = options.Cookies.GetCookieHeader(uri); + if (!string.IsNullOrWhiteSpace(header)) + { + builder.Append(HttpKnownHeaderNames.Cookie).Append(": ").Append(header).Append("\r\n"); + } } - Abort(); - response?.Dispose(); + // End the headers + builder.Append("\r\n"); - if (exc is WebSocketException) - { - throw; - } - throw new WebSocketException(SR.net_webstatus_ConnectFailure, exc); + // Return the bytes for the built up header + return s_defaultHttpEncoding.GetBytes(builder.ToString()); } - } - - /// The generated security key to send in the Sec-WebSocket-Key header. - private static void AddWebSocketHeaders(HttpRequestMessage request, string secKey, ClientWebSocketOptions options) - { - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.Upgrade, "websocket"); - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketVersion, "13"); - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketKey, secKey); - if (options._requestedSubProtocols?.Count > 0) + finally { - request.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); + // Make sure we clear the builder + builder.Clear(); } } @@ -244,21 +378,174 @@ private static KeyValuePair CreateSecKeyAndSecWebSocketAccept() } } - private static void ValidateHeader(HttpHeaders headers, string name, string expectedValue) + /// Read and validate the connect response headers from the server. + /// The stream from which to read the response headers. + /// The options used to configure the websocket. + /// The expected value of the Sec-WebSocket-Accept header. + /// The CancellationToken to use to cancel the websocket. + /// The agreed upon subprotocol with the server, or null if there was none. + private async Task ParseAndValidateConnectResponseAsync( + Stream stream, ClientWebSocketOptions options, string expectedSecWebSocketAccept, CancellationToken cancellationToken) { - if (!headers.TryGetValues(name, out IEnumerable values)) + // Read the first line of the response + string statusLine = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false); + + // Depending on the underlying sockets implementation and timing, connecting to a server that then + // immediately closes the connection may either result in an exception getting thrown from the connect + // earlier, or it may result in getting to here but reading 0 bytes. If we read 0 bytes and thus have + // an empty status line, treat it as a connect failure. + if (string.IsNullOrEmpty(statusLine)) + { + throw new WebSocketException(SR.Format(SR.net_webstatus_ConnectFailure)); + } + + const string ExpectedStatusStart = "HTTP/1.1 "; + const string ExpectedStatusStatWithCode = "HTTP/1.1 101"; // 101 == SwitchingProtocols + + // If the status line doesn't begin with "HTTP/1.1" or isn't long enough to contain a status code, fail. + if (!statusLine.StartsWith(ExpectedStatusStart, StringComparison.Ordinal) || statusLine.Length < ExpectedStatusStatWithCode.Length) { - ThrowConnectFailure(); + throw new WebSocketException(WebSocketError.HeaderError); } - Debug.Assert(values is string[]); - string[] array = (string[])values; - if (array.Length != 1 || !string.Equals(array[0], expectedValue, StringComparison.OrdinalIgnoreCase)) + // If the status line doesn't contain a status code 101, or if it's long enough to have a status description + // but doesn't contain whitespace after the 101, fail. + if (!statusLine.StartsWith(ExpectedStatusStatWithCode, StringComparison.Ordinal) || + (statusLine.Length > ExpectedStatusStatWithCode.Length && !char.IsWhiteSpace(statusLine[ExpectedStatusStatWithCode.Length]))) { - throw new WebSocketException(SR.Format(SR.net_WebSockets_InvalidResponseHeader, name, string.Join(", ", array))); + throw new WebSocketException(SR.net_webstatus_ConnectFailure); + } + + // Read each response header. Be liberal in parsing the response header, treating + // everything to the left of the colon as the key and everything to the right as the value, trimming both. + // For each header, validate that we got the expected value. + bool foundUpgrade = false, foundConnection = false, foundSecWebSocketAccept = false; + string subprotocol = null; + string line; + while (!string.IsNullOrEmpty(line = await ReadResponseHeaderLineAsync(stream, cancellationToken).ConfigureAwait(false))) + { + int colonIndex = line.IndexOf(':'); + if (colonIndex == -1) + { + throw new WebSocketException(WebSocketError.HeaderError); + } + + string headerName = line.SubstringTrim(0, colonIndex); + string headerValue = line.SubstringTrim(colonIndex + 1); + + // The Connection, Upgrade, and SecWebSocketAccept headers are required and with specific values. + ValidateAndTrackHeader(HttpKnownHeaderNames.Connection, "Upgrade", headerName, headerValue, ref foundConnection); + ValidateAndTrackHeader(HttpKnownHeaderNames.Upgrade, "websocket", headerName, headerValue, ref foundUpgrade); + ValidateAndTrackHeader(HttpKnownHeaderNames.SecWebSocketAccept, expectedSecWebSocketAccept, headerName, headerValue, ref foundSecWebSocketAccept); + + // The SecWebSocketProtocol header is optional. We should only get it with a non-empty value if we requested subprotocols, + // and then it must only be one of the ones we requested. If we got a subprotocol other than one we requested (or if we + // already got one in a previous header), fail. Otherwise, track which one we got. + if (string.Equals(HttpKnownHeaderNames.SecWebSocketProtocol, headerName, StringComparison.OrdinalIgnoreCase) && + !string.IsNullOrWhiteSpace(headerValue)) + { + string newSubprotocol = options.RequestedSubProtocols.Find(requested => string.Equals(requested, headerValue, StringComparison.OrdinalIgnoreCase)); + if (newSubprotocol == null || subprotocol != null) + { + throw new WebSocketException( + WebSocketError.UnsupportedProtocol, + SR.Format(SR.net_WebSockets_AcceptUnsupportedProtocol, string.Join(", ", options.RequestedSubProtocols), subprotocol)); + } + subprotocol = newSubprotocol; + } + } + if (!foundUpgrade || !foundConnection || !foundSecWebSocketAccept) + { + throw new WebSocketException(SR.net_webstatus_ConnectFailure); + } + + return subprotocol; + } + + /// Validates a received header against expected values and tracks that we've received it. + /// The header name against which we're comparing. + /// The header value against which we're comparing. + /// The actual header name received. + /// The actual header value received. + /// A bool tracking whether this header has been seen. + private static void ValidateAndTrackHeader( + string targetHeaderName, string targetHeaderValue, + string foundHeaderName, string foundHeaderValue, + ref bool foundHeader) + { + bool isTargetHeader = string.Equals(targetHeaderName, foundHeaderName, StringComparison.OrdinalIgnoreCase); + if (!foundHeader) + { + if (isTargetHeader) + { + if (!string.Equals(targetHeaderValue, foundHeaderValue, StringComparison.OrdinalIgnoreCase)) + { + throw new WebSocketException(SR.Format(SR.net_WebSockets_InvalidResponseHeader, targetHeaderName, foundHeaderValue)); + } + foundHeader = true; + } + } + else + { + if (isTargetHeader) + { + throw new WebSocketException(SR.Format(SR.net_webstatus_ConnectFailure)); + } } } - private static void ThrowConnectFailure() => throw new WebSocketException(SR.net_webstatus_ConnectFailure); + /// Reads a line from the stream. + /// The stream from which to read. + /// The CancellationToken used to cancel the websocket. + /// The read line, or null if none could be read. + private static async Task ReadResponseHeaderLineAsync(Stream stream, CancellationToken cancellationToken) + { + StringBuilder sb = t_cachedStringBuilder; + if (sb != null) + { + t_cachedStringBuilder = null; + Debug.Assert(sb.Length == 0, $"Expected empty StringBuilder"); + } + else + { + sb = new StringBuilder(); + } + + var arr = new byte[1]; + char prevChar = '\0'; + try + { + // TODO: Reading one byte is extremely inefficient. The problem, however, + // is that if we read multiple bytes, we could end up reading bytes post-headers + // that are part of messages meant to be read by the managed websocket after + // the connection. The likely solution here is to wrap the stream in a BufferedStream, + // though a) that comes at the expense of an extra set of virtual calls, b) + // it adds a buffer when the managed websocket will already be using a buffer, and + // c) it's not exposed on the version of the System.IO contract we're currently using. + while (await stream.ReadAsync(arr, 0, 1, cancellationToken).ConfigureAwait(false) == 1) + { + // Process the next char + char curChar = (char)arr[0]; + if (prevChar == '\r' && curChar == '\n') + { + break; + } + sb.Append(curChar); + prevChar = curChar; + } + + if (sb.Length > 0 && sb[sb.Length - 1] == '\r') + { + sb.Length = sb.Length - 1; + } + + return sb.ToString(); + } + finally + { + sb.Clear(); + t_cachedStringBuilder = sb; + } + } } } \ No newline at end of file