Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.clients.producer;
package org.apache.kafka.common;

import kafka.network.SocketServer;
import kafka.server.KafkaBroker;

import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.message.ProduceRequestData;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.RequestUtils;
import org.apache.kafka.common.test.ClusterInstance;
import org.apache.kafka.common.test.api.ClusterConfigProperty;
import org.apache.kafka.common.test.api.ClusterTest;
import org.apache.kafka.common.test.api.ClusterTestDefaults;
import org.apache.kafka.common.test.api.TestKitDefaults;
import org.apache.kafka.common.test.api.Type;
import org.apache.kafka.network.SocketServerConfigs;
import org.apache.kafka.server.IntegrationTestUtils;

import java.io.EOFException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -50,42 +51,31 @@
})
public class SocketServerMemoryPoolTest {
@ClusterTest
public void testProduceRequestWithUnsupportedVersion(ClusterInstance clusterInstance) throws Exception {
short unsupportedVersion = Short.MAX_VALUE;
byte[] rawRequestBytes = buildRawRequest(
ApiKeys.PRODUCE.id,
unsupportedVersion,
/* correlationId */ 1,
/* clientId */ "test-unsupported-version",
new byte[10000]
);
public void testRequestWithUnsupportedVersion(ClusterInstance clusterInstance) throws Exception {
RequestHeader header = IntegrationTestUtils.nextRequestHeader(ApiKeys.PRODUCE, Short.MAX_VALUE);
ByteBuffer buffer = RequestUtils.serialize(header.data(), header.headerVersion(), new ProduceRequestData(), header.apiVersion());
byte[] rawRequestBytes = buffer.array();

sendAndAssert(clusterInstance, rawRequestBytes);
}

@ClusterTest
public void testProduceRequestWithCorruptBody(ClusterInstance clusterInstance) throws Exception {
short validVersion = 3;
byte[] corruptBody = new byte[10000];
for (int i = 0; i < corruptBody.length; i++) {
corruptBody[i] = (byte) 0xFF; // The corrupt body (0xFF ... 0xFF) makes Schema.read() throw SchemaException.
public void testRequestWithCorruptBody(ClusterInstance clusterInstance) throws Exception {
RequestHeader header = IntegrationTestUtils.nextRequestHeader(ApiKeys.PRODUCE, ApiKeys.PRODUCE.latestVersion());
ByteBuffer buffer = RequestUtils.serialize(header.data(), header.headerVersion(), new ProduceRequestData(), header.apiVersion());
byte[] rawRequestBytes = buffer.array();

// corrupt body but leave header valid
for (int i = header.size(); i < rawRequestBytes.length; i++) {
rawRequestBytes[i] = (byte) 0xFF;
}

byte[] rawRequestBytes = buildRawRequest(
ApiKeys.PRODUCE.id,
validVersion,
/* correlationId */ 2,
/* clientId */ "test-corrupt-body",
corruptBody
);

sendAndAssert(clusterInstance, rawRequestBytes);
}

private void sendAndAssert(ClusterInstance clusterInstance, byte[] rawRequestBytes) throws Exception {
long initialMemoryPoolAvailable = getMemoryPoolAvailable(clusterInstance);

try (Socket socket = IntegrationTestUtils.connect(clusterInstance.brokerBoundPorts().get(0))) {
try (Socket socket = IntegrationTestUtils.connect(getBrokerBoundPort(clusterInstance))) {
socket.setSoTimeout(/* readTimeoutMs */ 5_000);
IntegrationTestUtils.sendRequest(socket, rawRequestBytes);
assertTrue(drainUntilClosed(socket.getInputStream()), "expected connection closed");
Expand All @@ -96,48 +86,16 @@ private void sendAndAssert(ClusterInstance clusterInstance, byte[] rawRequestByt
assertEquals(initialMemoryPoolAvailable, finalMemoryPoolAvailable);
}

// This test uses reflection to read the SocketServer memoryPool availableMemory.
// The metric "MemoryPoolAvailable" from Yammer Metrics default registry
// can be overwritten in a @ClusterTest as the registry is a singleton.
long getMemoryPoolAvailable(ClusterInstance clusterInstance) throws Exception {
KafkaBroker broker = clusterInstance.aliveBrokers().values().iterator().next();
SocketServer socketServer = broker.socketServer();
Field memoryPoolField = socketServer.getClass().getDeclaredField("memoryPool");
memoryPoolField.setAccessible(true);
MemoryPool memoryPool = (MemoryPool) memoryPoolField.get(socketServer);
return memoryPool.availableMemory();
private SocketServer getSocketServer(ClusterInstance clusterInstance) {
return clusterInstance.brokers().get(TestKitDefaults.BROKER_ID_OFFSET).socketServer();
}

/**
* Builds a raw Kafka request excluding the frame length
*
* <p>Wire layout:
* <pre>
* 4 bytes – frame length (payload size, not including these 4 bytes)
*
* 2 bytes – api_key
* 2 bytes – api_version
* 4 bytes – correlation_id
* 2 bytes – client_id string length
* N bytes – client_id (UTF-8)
* X bytes - request body
* </pre>
*/
private static byte[] buildRawRequest(short apiKey, short apiVersion, int correlationId, String clientId, byte[] body) {
byte[] clientIdBytes = clientId.getBytes(StandardCharsets.UTF_8);

// Header: api_key(2) + api_version(2) + correlation_id(4) + client_id_len(2) + client_id
int headerSize = 2 + 2 + 4 + 2 + clientIdBytes.length;
int payloadSize = headerSize + body.length;
private int getBrokerBoundPort(ClusterInstance clusterInstance) {
return getSocketServer(clusterInstance).boundPort(ListenerName.normalised(TestKitDefaults.DEFAULT_BROKER_LISTENER_NAME));
}

ByteBuffer buf = ByteBuffer.allocate(payloadSize);
buf.putShort(apiKey); // api_key
buf.putShort(apiVersion); // api_version
buf.putInt(correlationId); // correlation_id
buf.putShort((short) clientIdBytes.length); // client_id string length
buf.put(clientIdBytes); // client_id bytes
buf.put(body); // request body (possibly empty / corrupt)
return buf.array();
private long getMemoryPoolAvailable(ClusterInstance clusterInstance) {
return getSocketServer(clusterInstance).memoryPool().availableMemory();
}

/*
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/kafka/network/SocketServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class SocketServer(
private val memoryPoolDepletedPercentMetricName = metrics.metricName("MemoryPoolAvgDepletedPercent", JSocketServer.METRICS_GROUP)
private val memoryPoolDepletedTimeMetricName = metrics.metricName("MemoryPoolDepletedTimeTotal", JSocketServer.METRICS_GROUP)
memoryPoolSensor.add(new Meter(TimeUnit.MILLISECONDS, memoryPoolDepletedPercentMetricName, memoryPoolDepletedTimeMetricName))
private val memoryPool = if (config.queuedMaxBytes > 0) new SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, memoryPoolSensor) else MemoryPool.NONE
// accessible for testing
val memoryPool = if (config.queuedMaxBytes > 0) new SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, memoryPoolSensor) else MemoryPool.NONE
// data-plane
private[network] val dataPlaneAcceptors = new ConcurrentHashMap[Endpoint, DataPlaneAcceptor]()
val dataPlaneRequestChannel = new RequestChannel(maxQueuedRequests, time, apiVersionManager.newRequestMetrics)
Expand Down
Loading