diff --git a/src/Apache.IoTDB/SessionPool.Builder.cs b/src/Apache.IoTDB/SessionPool.Builder.cs index f943d81..9de2874 100644 --- a/src/Apache.IoTDB/SessionPool.Builder.cs +++ b/src/Apache.IoTDB/SessionPool.Builder.cs @@ -17,7 +17,6 @@ * under the License. */ -using System; using System.Collections.Generic; namespace Apache.IoTDB; @@ -35,6 +34,8 @@ public class Builder private int _poolSize = 8; private bool _enableRpcCompression = false; private int _connectionTimeoutInMs = 500; + private bool _useSsl = false; + private string _certificatePath = null; private string _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT; private string _database = ""; private List _nodeUrls = new List(); @@ -93,6 +94,18 @@ public Builder SetConnectionTimeoutInMs(int timeout) return this; } + public Builder SetUseSsl(bool useSsl) + { + _useSsl = useSsl; + return this; + } + + public Builder SetCertificatePath(string certificatePath) + { + _certificatePath = certificatePath; + return this; + } + public Builder SetNodeUrl(List nodeUrls) { _nodeUrls = nodeUrls; @@ -122,6 +135,8 @@ public Builder() _poolSize = 8; _enableRpcCompression = false; _connectionTimeoutInMs = 500; + _useSsl = false; + _certificatePath = null; _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT; _database = ""; } @@ -131,9 +146,9 @@ public SessionPool Build() // if nodeUrls is not empty, use nodeUrls to create session pool if (_nodeUrls.Count > 0) { - return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database); + return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database); } - return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database); + return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database); } } } diff --git a/src/Apache.IoTDB/SessionPool.cs b/src/Apache.IoTDB/SessionPool.cs index 135199b..fc2eca0 100644 --- a/src/Apache.IoTDB/SessionPool.cs +++ b/src/Apache.IoTDB/SessionPool.cs @@ -19,13 +19,12 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; -using System.Net.Sockets; -using System.Numerics; using System.Threading; using System.Threading.Tasks; +using System.Security.Cryptography.X509Certificates; using Apache.IoTDB.DataStructure; -using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Thrift; using Thrift.Protocol; @@ -47,6 +46,8 @@ public partial class SessionPool : IDisposable private readonly List _endPoints = new(); private readonly string _host; private readonly int _port; + private readonly bool _useSsl; + private readonly string _certificatePath; private readonly int _fetchSize; /// /// _timeout is the amount of time a Session will wait for a send operation to complete successfully. @@ -86,10 +87,10 @@ public SessionPool(string host, int port) : this(host, port, "root", "root", 102 { } public SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout) - : this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "") + : this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, false, null, IoTDBConstant.TREE_SQL_DIALECT, "") { } - protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database) + protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database) { _host = host; _port = port; @@ -101,6 +102,8 @@ protected internal SessionPool(string host, int port, string username, string pa _poolSize = poolSize; _enableRpcCompression = enableRpcCompression; _timeout = timeout; + _useSsl = useSsl; + _certificatePath = certificatePath; _sqlDialect = sqlDialect; _database = database; } @@ -126,11 +129,11 @@ public SessionPool(List nodeUrls, string username, string password, int { } public SessionPool(List nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout) - : this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "") + : this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, false, null, IoTDBConstant.TREE_SQL_DIALECT, "") { } - protected internal SessionPool(List nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database) + protected internal SessionPool(List nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database) { if (nodeUrls.Count == 0) { @@ -146,6 +149,8 @@ protected internal SessionPool(List nodeUrls, string username, string pa _poolSize = poolSize; _enableRpcCompression = enableRpcCompression; _timeout = timeout; + _useSsl = useSsl; + _certificatePath = certificatePath; _sqlDialect = sqlDialect; _database = database; } @@ -241,7 +246,7 @@ public async Task Open(CancellationToken cancellationToken = default) { try { - _clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken)); + _clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken)); } catch (Exception e) { @@ -264,7 +269,7 @@ public async Task Open(CancellationToken cancellationToken = default) var endPoint = _endPoints[endPointIndex]; try { - var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken); + var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken); _clients.Add(client); isConnected = true; startIndex = (endPointIndex + 1) % _endPoints.Count; @@ -303,7 +308,7 @@ public async Task Reconnect(Client originalClient = null, CancellationTo { try { - var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken); + var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken); return client; } catch (Exception e) @@ -330,7 +335,7 @@ public async Task Reconnect(Client originalClient = null, CancellationTo int j = (startIndex + i) % _endPoints.Count; try { - var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken); + var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken); return client; } catch (Exception e) @@ -423,12 +428,14 @@ public async Task GetTimeZone() } } - private async Task CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, string sqlDialect, string database, CancellationToken cancellationToken = default) + private async Task CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, bool useSsl, string cert, string sqlDialect, string database, CancellationToken cancellationToken = default) { - var tcpClient = new TcpClient(host, port); - tcpClient.SendTimeout = timeout; - tcpClient.ReceiveTimeout = timeout; - var transport = new TFramedTransport(new TSocketTransport(tcpClient, null)); + + TTransport socket = useSsl ? + new TTlsSocketTransport(host, port, null, timeout, new X509Certificate2(File.ReadAllBytes(cert))) : + new TSocketTransport(host, port, null, timeout); + + var transport = new TFramedTransport(socket); if (!transport.IsOpen) { diff --git a/src/Apache.IoTDB/TableSessionPool.Builder.cs b/src/Apache.IoTDB/TableSessionPool.Builder.cs index 07387b5..10e24c8 100644 --- a/src/Apache.IoTDB/TableSessionPool.Builder.cs +++ b/src/Apache.IoTDB/TableSessionPool.Builder.cs @@ -37,6 +37,8 @@ public class Builder private int _poolSize = 8; private bool _enableRpcCompression = false; private int _connectionTimeoutInMs = 500; + private bool _useSsl = false; + private string _certificatePath = null; private string _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT; private string _database = ""; private List _nodeUrls = new List(); @@ -95,6 +97,18 @@ public Builder SetConnectionTimeoutInMs(int timeout) return this; } + public Builder SetUseSsl(bool useSsl) + { + _useSsl = useSsl; + return this; + } + + public Builder SetCertificatePath(string certificatePath) + { + _certificatePath = certificatePath; + return this; + } + public Builder SetNodeUrls(List nodeUrls) { _nodeUrls = nodeUrls; @@ -134,11 +148,11 @@ public TableSessionPool Build() // if nodeUrls is not empty, use nodeUrls to create session pool if (_nodeUrls.Count > 0) { - sessionPool = new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database); + sessionPool = new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database); } else { - sessionPool = new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database); + sessionPool = new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database); } return new TableSessionPool(sessionPool); }