diff --git a/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/producer/SocketServerMemoryPoolTest.java b/clients/clients-integration-tests/src/test/java/org/apache/kafka/common/SocketServerMemoryPoolTest.java similarity index 53% rename from clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/producer/SocketServerMemoryPoolTest.java rename to clients/clients-integration-tests/src/test/java/org/apache/kafka/common/SocketServerMemoryPoolTest.java index 2278718a1cc11..8b5ee4297deb4 100644 --- a/clients/clients-integration-tests/src/test/java/org/apache/kafka/clients/producer/SocketServerMemoryPoolTest.java +++ b/clients/clients-integration-tests/src/test/java/org/apache/kafka/common/SocketServerMemoryPoolTest.java @@ -14,28 +14,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.kafka.clients.producer; +package org.apache.kafka.common; import kafka.network.SocketServer; -import kafka.server.KafkaBroker; -import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.RequestUtils; import org.apache.kafka.common.test.ClusterInstance; import org.apache.kafka.common.test.api.ClusterConfigProperty; import org.apache.kafka.common.test.api.ClusterTest; import org.apache.kafka.common.test.api.ClusterTestDefaults; +import org.apache.kafka.common.test.api.TestKitDefaults; import org.apache.kafka.common.test.api.Type; import org.apache.kafka.network.SocketServerConfigs; import org.apache.kafka.server.IntegrationTestUtils; import java.io.EOFException; import java.io.InputStream; -import java.lang.reflect.Field; import java.net.Socket; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -50,42 +51,31 @@ }) public class SocketServerMemoryPoolTest { @ClusterTest - public void testProduceRequestWithUnsupportedVersion(ClusterInstance clusterInstance) throws Exception { - short unsupportedVersion = Short.MAX_VALUE; - byte[] rawRequestBytes = buildRawRequest( - ApiKeys.PRODUCE.id, - unsupportedVersion, - /* correlationId */ 1, - /* clientId */ "test-unsupported-version", - new byte[10000] - ); + public void testRequestWithUnsupportedVersion(ClusterInstance clusterInstance) throws Exception { + RequestHeader header = IntegrationTestUtils.nextRequestHeader(ApiKeys.PRODUCE, Short.MAX_VALUE); + ByteBuffer buffer = RequestUtils.serialize(header.data(), header.headerVersion(), new ProduceRequestData(), header.apiVersion()); + byte[] rawRequestBytes = buffer.array(); sendAndAssert(clusterInstance, rawRequestBytes); } @ClusterTest - public void testProduceRequestWithCorruptBody(ClusterInstance clusterInstance) throws Exception { - short validVersion = 3; - byte[] corruptBody = new byte[10000]; - for (int i = 0; i < corruptBody.length; i++) { - corruptBody[i] = (byte) 0xFF; // The corrupt body (0xFF ... 0xFF) makes Schema.read() throw SchemaException. + public void testRequestWithCorruptBody(ClusterInstance clusterInstance) throws Exception { + RequestHeader header = IntegrationTestUtils.nextRequestHeader(ApiKeys.PRODUCE, ApiKeys.PRODUCE.latestVersion()); + ByteBuffer buffer = RequestUtils.serialize(header.data(), header.headerVersion(), new ProduceRequestData(), header.apiVersion()); + byte[] rawRequestBytes = buffer.array(); + + // corrupt body but leave header valid + for (int i = header.size(); i < rawRequestBytes.length; i++) { + rawRequestBytes[i] = (byte) 0xFF; } - - byte[] rawRequestBytes = buildRawRequest( - ApiKeys.PRODUCE.id, - validVersion, - /* correlationId */ 2, - /* clientId */ "test-corrupt-body", - corruptBody - ); - sendAndAssert(clusterInstance, rawRequestBytes); } private void sendAndAssert(ClusterInstance clusterInstance, byte[] rawRequestBytes) throws Exception { long initialMemoryPoolAvailable = getMemoryPoolAvailable(clusterInstance); - try (Socket socket = IntegrationTestUtils.connect(clusterInstance.brokerBoundPorts().get(0))) { + try (Socket socket = IntegrationTestUtils.connect(getBrokerBoundPort(clusterInstance))) { socket.setSoTimeout(/* readTimeoutMs */ 5_000); IntegrationTestUtils.sendRequest(socket, rawRequestBytes); assertTrue(drainUntilClosed(socket.getInputStream()), "expected connection closed"); @@ -96,48 +86,16 @@ private void sendAndAssert(ClusterInstance clusterInstance, byte[] rawRequestByt assertEquals(initialMemoryPoolAvailable, finalMemoryPoolAvailable); } - // This test uses reflection to read the SocketServer memoryPool availableMemory. - // The metric "MemoryPoolAvailable" from Yammer Metrics default registry - // can be overwritten in a @ClusterTest as the registry is a singleton. - long getMemoryPoolAvailable(ClusterInstance clusterInstance) throws Exception { - KafkaBroker broker = clusterInstance.aliveBrokers().values().iterator().next(); - SocketServer socketServer = broker.socketServer(); - Field memoryPoolField = socketServer.getClass().getDeclaredField("memoryPool"); - memoryPoolField.setAccessible(true); - MemoryPool memoryPool = (MemoryPool) memoryPoolField.get(socketServer); - return memoryPool.availableMemory(); + private SocketServer getSocketServer(ClusterInstance clusterInstance) { + return clusterInstance.brokers().get(TestKitDefaults.BROKER_ID_OFFSET).socketServer(); } - /** - * Builds a raw Kafka request excluding the frame length - * - *
Wire layout: - *
- * 4 bytes – frame length (payload size, not including these 4 bytes) - * - * 2 bytes – api_key - * 2 bytes – api_version - * 4 bytes – correlation_id - * 2 bytes – client_id string length - * N bytes – client_id (UTF-8) - * X bytes - request body - *- */ - private static byte[] buildRawRequest(short apiKey, short apiVersion, int correlationId, String clientId, byte[] body) { - byte[] clientIdBytes = clientId.getBytes(StandardCharsets.UTF_8); - - // Header: api_key(2) + api_version(2) + correlation_id(4) + client_id_len(2) + client_id - int headerSize = 2 + 2 + 4 + 2 + clientIdBytes.length; - int payloadSize = headerSize + body.length; + private int getBrokerBoundPort(ClusterInstance clusterInstance) { + return getSocketServer(clusterInstance).boundPort(ListenerName.normalised(TestKitDefaults.DEFAULT_BROKER_LISTENER_NAME)); + } - ByteBuffer buf = ByteBuffer.allocate(payloadSize); - buf.putShort(apiKey); // api_key - buf.putShort(apiVersion); // api_version - buf.putInt(correlationId); // correlation_id - buf.putShort((short) clientIdBytes.length); // client_id string length - buf.put(clientIdBytes); // client_id bytes - buf.put(body); // request body (possibly empty / corrupt) - return buf.array(); + private long getMemoryPoolAvailable(ClusterInstance clusterInstance) { + return getSocketServer(clusterInstance).memoryPool().availableMemory(); } /* diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index cb7dc5ad92a11..664874d54e5c1 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -94,7 +94,8 @@ class SocketServer( private val memoryPoolDepletedPercentMetricName = metrics.metricName("MemoryPoolAvgDepletedPercent", JSocketServer.METRICS_GROUP) private val memoryPoolDepletedTimeMetricName = metrics.metricName("MemoryPoolDepletedTimeTotal", JSocketServer.METRICS_GROUP) memoryPoolSensor.add(new Meter(TimeUnit.MILLISECONDS, memoryPoolDepletedPercentMetricName, memoryPoolDepletedTimeMetricName)) - private val memoryPool = if (config.queuedMaxBytes > 0) new SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, memoryPoolSensor) else MemoryPool.NONE + // accessible for testing + val memoryPool = if (config.queuedMaxBytes > 0) new SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, memoryPoolSensor) else MemoryPool.NONE // data-plane private[network] val dataPlaneAcceptors = new ConcurrentHashMap[Endpoint, DataPlaneAcceptor]() val dataPlaneRequestChannel = new RequestChannel(maxQueuedRequests, time, apiVersionManager.newRequestMetrics)