Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/Apache.IoTDB/SessionPool.Builder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
* under the License.
*/

using System;
using System.Collections.Generic;

namespace Apache.IoTDB;
Expand All @@ -35,6 +34,8 @@ public class Builder
private int _poolSize = 8;
private bool _enableRpcCompression = false;
private int _connectionTimeoutInMs = 500;
private bool _useSsl = false;
private string _certificatePath = null;
private string _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT;
private string _database = "";
private List<string> _nodeUrls = new List<string>();
Expand Down Expand Up @@ -93,6 +94,18 @@ public Builder SetConnectionTimeoutInMs(int timeout)
return this;
}

public Builder SetUseSsl(bool useSsl)
{
_useSsl = useSsl;
return this;
}

public Builder SetCertificatePath(string certificatePath)
{
_certificatePath = certificatePath;
return this;
}

public Builder SetNodeUrl(List<string> nodeUrls)
{
_nodeUrls = nodeUrls;
Expand Down Expand Up @@ -122,6 +135,8 @@ public Builder()
_poolSize = 8;
_enableRpcCompression = false;
_connectionTimeoutInMs = 500;
_useSsl = false;
_certificatePath = null;
_sqlDialect = IoTDBConstant.TREE_SQL_DIALECT;
_database = "";
}
Expand All @@ -131,9 +146,9 @@ public SessionPool Build()
// if nodeUrls is not empty, use nodeUrls to create session pool
if (_nodeUrls.Count > 0)
{
return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
return new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database);
}
return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
return new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database);
}
}
}
39 changes: 23 additions & 16 deletions src/Apache.IoTDB/SessionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Sockets;
using System.Numerics;
using System.Threading;
using System.Threading.Tasks;
using System.Security.Cryptography.X509Certificates;
using Apache.IoTDB.DataStructure;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Thrift;
using Thrift.Protocol;
Expand All @@ -47,6 +46,8 @@ public partial class SessionPool : IDisposable
private readonly List<TEndPoint> _endPoints = new();
private readonly string _host;
private readonly int _port;
private readonly bool _useSsl;
private readonly string _certificatePath;
private readonly int _fetchSize;
/// <summary>
/// _timeout is the amount of time a Session will wait for a send operation to complete successfully.
Expand Down Expand Up @@ -86,10 +87,10 @@ public SessionPool(string host, int port) : this(host, port, "root", "root", 102
{
}
public SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout)
: this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "")
: this(host, port, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, false, null, IoTDBConstant.TREE_SQL_DIALECT, "")
{
}
protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database)
protected internal SessionPool(string host, int port, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database)
{
_host = host;
_port = port;
Expand All @@ -101,6 +102,8 @@ protected internal SessionPool(string host, int port, string username, string pa
_poolSize = poolSize;
_enableRpcCompression = enableRpcCompression;
_timeout = timeout;
_useSsl = useSsl;
_certificatePath = certificatePath;
_sqlDialect = sqlDialect;
_database = database;
}
Expand All @@ -126,11 +129,11 @@ public SessionPool(List<string> nodeUrls, string username, string password, int
{
}
public SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout)
: this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, IoTDBConstant.TREE_SQL_DIALECT, "")
: this(nodeUrls, username, password, fetchSize, zoneId, poolSize, enableRpcCompression, timeout, false, null, IoTDBConstant.TREE_SQL_DIALECT, "")
{

}
protected internal SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, string sqlDialect, string database)
protected internal SessionPool(List<string> nodeUrls, string username, string password, int fetchSize, string zoneId, int poolSize, bool enableRpcCompression, int timeout, bool useSsl, string certificatePath, string sqlDialect, string database)
{
if (nodeUrls.Count == 0)
{
Expand All @@ -146,6 +149,8 @@ protected internal SessionPool(List<string> nodeUrls, string username, string pa
_poolSize = poolSize;
_enableRpcCompression = enableRpcCompression;
_timeout = timeout;
_useSsl = useSsl;
_certificatePath = certificatePath;
_sqlDialect = sqlDialect;
_database = database;
}
Expand Down Expand Up @@ -241,7 +246,7 @@ public async Task Open(CancellationToken cancellationToken = default)
{
try
{
_clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken));
_clients.Add(await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken));
}
catch (Exception e)
{
Expand All @@ -264,7 +269,7 @@ public async Task Open(CancellationToken cancellationToken = default)
var endPoint = _endPoints[endPointIndex];
try
{
var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
var client = await CreateAndOpen(endPoint.Ip, endPoint.Port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken);
_clients.Add(client);
isConnected = true;
startIndex = (endPointIndex + 1) % _endPoints.Count;
Expand Down Expand Up @@ -303,7 +308,7 @@ public async Task<Client> Reconnect(Client originalClient = null, CancellationTo
{
try
{
var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
var client = await CreateAndOpen(_host, _port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken);
return client;
}
catch (Exception e)
Expand All @@ -330,7 +335,7 @@ public async Task<Client> Reconnect(Client originalClient = null, CancellationTo
int j = (startIndex + i) % _endPoints.Count;
try
{
var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _sqlDialect, _database, cancellationToken);
var client = await CreateAndOpen(_endPoints[j].Ip, _endPoints[j].Port, _enableRpcCompression, _timeout, _useSsl, _certificatePath, _sqlDialect, _database, cancellationToken);
return client;
}
catch (Exception e)
Expand Down Expand Up @@ -423,12 +428,14 @@ public async Task<string> GetTimeZone()
}
}

private async Task<Client> CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, string sqlDialect, string database, CancellationToken cancellationToken = default)
private async Task<Client> CreateAndOpen(string host, int port, bool enableRpcCompression, int timeout, bool useSsl, string cert, string sqlDialect, string database, CancellationToken cancellationToken = default)
{
var tcpClient = new TcpClient(host, port);
tcpClient.SendTimeout = timeout;
tcpClient.ReceiveTimeout = timeout;
var transport = new TFramedTransport(new TSocketTransport(tcpClient, null));

TTransport socket = useSsl ?
new TTlsSocketTransport(host, port, null, timeout, new X509Certificate2(File.ReadAllBytes(cert))) :
new TSocketTransport(host, port, null, timeout);
Comment on lines +434 to +436
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new SSL functionality lacks test coverage. Since the repository has test infrastructure in place, consider adding tests that verify SSL connections work correctly when SSL is enabled, and that appropriate error handling occurs when certificate paths are invalid or SSL configuration is incorrect.

Copilot uses AI. Check for mistakes.

Comment on lines +434 to +437
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The certificate file is loaded synchronously using File.ReadAllBytes without any validation. This could cause runtime exceptions if the certificate path is invalid, the file doesn't exist, is not accessible, or contains invalid certificate data. Consider validating that the certificate path is provided when SSL is enabled and that the file exists before attempting to load it.

Suggested change
TTransport socket = useSsl ?
new TTlsSocketTransport(host, port, null, timeout, new X509Certificate2(File.ReadAllBytes(cert))) :
new TSocketTransport(host, port, null, timeout);
TTransport socket;
if (useSsl)
{
if (string.IsNullOrWhiteSpace(cert))
{
throw new ArgumentException("Certificate path must be provided when SSL is enabled.", nameof(cert));
}
if (!File.Exists(cert))
{
throw new FileNotFoundException($"Certificate file not found at path '{cert}'.", cert);
}
X509Certificate2 certificate;
try
{
var certificateBytes = File.ReadAllBytes(cert);
certificate = new X509Certificate2(certificateBytes);
}
catch (Exception ex)
{
throw new InvalidOperationException($"Failed to load SSL certificate from path '{cert}'.", ex);
}
socket = new TTlsSocketTransport(host, port, null, timeout, certificate);
}
else
{
socket = new TSocketTransport(host, port, null, timeout);
}

Copilot uses AI. Check for mistakes.
var transport = new TFramedTransport(socket);

if (!transport.IsOpen)
{
Expand Down
18 changes: 16 additions & 2 deletions src/Apache.IoTDB/TableSessionPool.Builder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class Builder
private int _poolSize = 8;
private bool _enableRpcCompression = false;
private int _connectionTimeoutInMs = 500;
private bool _useSsl = false;
private string _certificatePath = null;
private string _sqlDialect = IoTDBConstant.TREE_SQL_DIALECT;
private string _database = "";
private List<string> _nodeUrls = new List<string>();
Expand Down Expand Up @@ -95,6 +97,18 @@ public Builder SetConnectionTimeoutInMs(int timeout)
return this;
}

public Builder SetUseSsl(bool useSsl)
{
_useSsl = useSsl;
return this;
}

public Builder SetCertificatePath(string certificatePath)
{
_certificatePath = certificatePath;
return this;
}

public Builder SetNodeUrls(List<string> nodeUrls)
{
_nodeUrls = nodeUrls;
Expand Down Expand Up @@ -134,11 +148,11 @@ public TableSessionPool Build()
// if nodeUrls is not empty, use nodeUrls to create session pool
if (_nodeUrls.Count > 0)
{
sessionPool = new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
sessionPool = new SessionPool(_nodeUrls, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database);
}
else
{
sessionPool = new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _sqlDialect, _database);
sessionPool = new SessionPool(_host, _port, _username, _password, _fetchSize, _zoneId, _poolSize, _enableRpcCompression, _connectionTimeoutInMs, _useSsl, _certificatePath, _sqlDialect, _database);
}
return new TableSessionPool(sessionPool);
}
Expand Down
Loading