diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java index 0523ed2fe79cf..1d9860c72cee5 100644 --- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java @@ -17,6 +17,8 @@ package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.common.serialization.Serializer; @@ -64,13 +66,18 @@ public void setIfUnset(final SerdeGetter getter) { @Override public byte[] serialize(final String topic, final SubscriptionResponseWrapper data) { + return serialize(topic, new RecordHeaders(), data); + } + + @Override + public byte[] serialize(final String topic, final Headers headers, final SubscriptionResponseWrapper data) { //{1-bit-isHashNull}{7-bits-version}{Optional-16-byte-Hash}{n-bytes serialized data} if (data.version() < 0) { throw new UnsupportedVersionException("SubscriptionResponseWrapper version cannot be negative"); } - final byte[] serializedData = data.foreignValue() == null ? null : serializer.serialize(topic, data.foreignValue()); + final byte[] serializedData = data.foreignValue() == null ? null : serializer.serialize(topic, headers, data.foreignValue()); final int serializedDataLength = serializedData == null ? 0 : serializedData.length; final long[] originalHash = data.originalValueHash(); final int hashLength = originalHash == null ? 0 : 2 * Long.BYTES; @@ -111,6 +118,11 @@ public void setIfUnset(final SerdeGetter getter) { @Override public SubscriptionResponseWrapper deserialize(final String topic, final byte[] data) { + return deserialize(topic, new RecordHeaders(), data); + } + + @Override + public SubscriptionResponseWrapper deserialize(final String topic, final Headers headers, final byte[] data) { //{1-bit-isHashNull}{7-bits-version}{Optional-16-byte-Hash}{n-bytes serialized data} final ByteBuffer buf = ByteBuffer.wrap(data); @@ -134,7 +146,7 @@ public SubscriptionResponseWrapper deserialize(final String topic, final byte final byte[] serializedValue; serializedValue = new byte[data.length - lengthSum]; buf.get(serializedValue, 0, serializedValue.length); - value = deserializer.deserialize(topic, serializedValue); + value = deserializer.deserialize(topic, headers, serializedValue); } else { value = null; } diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java index df45bc683ddbe..acf2458c1dea3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; import org.apache.kafka.streams.processor.internals.SerdeGetter; @@ -57,13 +59,18 @@ public void configure(final Map configs, @Override public LeftOrRightValue deserialize(final String topic, final byte[] data) { + return deserialize(topic, new RecordHeaders(), data); + } + + @Override + public LeftOrRightValue deserialize(final String topic, final Headers headers, final byte[] data) { if (data == null || data.length == 0) { return null; } return (data[0] == 1) - ? LeftOrRightValue.makeLeftValue(leftDeserializer.deserialize(topic, rawValue(data))) - : LeftOrRightValue.makeRightValue(rightDeserializer.deserialize(topic, rawValue(data))); + ? LeftOrRightValue.makeLeftValue(leftDeserializer.deserialize(topic, headers, rawValue(data))) + : LeftOrRightValue.makeRightValue(rightDeserializer.deserialize(topic, headers, rawValue(data))); } private byte[] rawValue(final byte[] data) { diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java index 1c64c29fd5ac8..03ec857278a0c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; import org.apache.kafka.streams.processor.internals.SerdeGetter; @@ -61,13 +63,18 @@ public void configure(final Map configs, final boolean isKey) { @Override public byte[] serialize(final String topic, final LeftOrRightValue data) { + return serialize(topic, new RecordHeaders(), data); + } + + @Override + public byte[] serialize(final String topic, final Headers headers, final LeftOrRightValue data) { if (data == null) { return null; } final byte[] rawValue = (data.leftValue() != null) - ? leftSerializer.serialize(topic, data.leftValue()) - : rightSerializer.serialize(topic, data.rightValue()); + ? leftSerializer.serialize(topic, headers, data.leftValue()) + : rightSerializer.serialize(topic, headers, data.rightValue()); if (rawValue == null) { return null; diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java index 4285dd7bf166b..4955c8c864902 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; @@ -56,9 +58,14 @@ public void configure(final Map configs, final boolean isKey) { @Override public TimestampedKeyAndJoinSide deserialize(final String topic, final byte[] data) { + return deserialize(topic, new RecordHeaders(), data); + } + + @Override + public TimestampedKeyAndJoinSide deserialize(final String topic, final Headers headers, final byte[] data) { final boolean isLeft = data[StateSerdes.TIMESTAMP_SIZE] == 1; - final K key = keyDeserializer.deserialize(topic, rawKey(data)); - final long timestamp = timestampDeserializer.deserialize(topic, rawTimestamp(data)); + final K key = keyDeserializer.deserialize(topic, headers, rawKey(data)); + final long timestamp = timestampDeserializer.deserialize(topic, headers, rawTimestamp(data)); return isLeft ? TimestampedKeyAndJoinSide.makeLeft(key, timestamp) : TimestampedKeyAndJoinSide.makeRight(key, timestamp); diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java index d94cc486357bb..6e488d87cd606 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java @@ -16,6 +16,8 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.LongSerializer; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; @@ -56,9 +58,14 @@ public void configure(final Map configs, final boolean isKey) { @Override public byte[] serialize(final String topic, final TimestampedKeyAndJoinSide data) { + return serialize(topic, new RecordHeaders(), data); + } + + @Override + public byte[] serialize(final String topic, final Headers headers, final TimestampedKeyAndJoinSide data) { final byte boolByte = (byte) (data.isLeftSide() ? 1 : 0); - final byte[] keyBytes = keySerializer.serialize(topic, data.key()); - final byte[] timestampBytes = timestampSerializer.serialize(topic, data.timestamp()); + final byte[] keyBytes = keySerializer.serialize(topic, headers, data.key()); + final byte[] timestampBytes = timestampSerializer.serialize(topic, headers, data.timestamp()); return ByteBuffer .allocate(timestampBytes.length + 1 + keyBytes.length) diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java index 276600fd106e2..854c3bcb6fbc0 100644 --- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java @@ -17,10 +17,14 @@ package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; import org.apache.kafka.streams.state.internals.Murmur3; import org.junit.jupiter.api.Test; @@ -32,6 +36,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class SubscriptionResponseWrapperSerdeTest { private static final class NonNullableSerde implements Serde, Serializer, Deserializer { @@ -147,6 +155,47 @@ public void shouldThrowExceptionOnSerializeWhenDataVersionUnknown() { } } + @Test + public void shouldPassHeadersToUnderlyingSerializer() { + final Serializer mockSerializer = mock(StringSerializer.class); + final Serde mockSerde = mock(Serdes.StringSerde.class); + when(mockSerde.serializer()).thenReturn(mockSerializer); + + final String topic = "dummy"; + final String foreignValue = "foreignValue"; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final SubscriptionResponseWrapper data = new SubscriptionResponseWrapper<>(null, foreignValue, 1); + + final SubscriptionResponseWrapperSerde testSerde = new SubscriptionResponseWrapperSerde<>(mockSerde); + + testSerde.serializer().serialize(topic, headers, data); + + verify(mockSerializer).serialize(topic, headers, foreignValue); + verify(mockSerializer, never()).serialize(topic, foreignValue); + } + + @Test + public void shouldPassHeadersToUnderlyingDeserializer() { + final Deserializer mockDeserializer = mock(StringDeserializer.class); + final Serde mockSerde = mock(Serdes.StringSerde.class); + when(mockSerde.deserializer()).thenReturn(mockDeserializer); + when(mockSerde.serializer()).thenReturn(Serdes.String().serializer()); + + final String topic = "dummy"; + final String foreignValue = "foreignValue"; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final SubscriptionResponseWrapper data = new SubscriptionResponseWrapper<>(null, foreignValue, 1); + + final SubscriptionResponseWrapperSerde testSerde = new SubscriptionResponseWrapperSerde<>(mockSerde); + + final byte[] serializedData = testSerde.serializer().serialize(topic, headers, data); + + testSerde.deserializer().deserialize(topic, headers, serializedData); + + verify(mockDeserializer).deserialize(topic, headers, foreignValue.getBytes()); + verify(mockDeserializer, never()).deserialize(topic, foreignValue.getBytes()); + } + public static class InvalidSubscriptionResponseWrapper extends SubscriptionResponseWrapper { public InvalidSubscriptionResponseWrapper(final long[] originalValueHash, diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializerTest.java new file mode 100644 index 0000000000000..bc9eca2545f2b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializerTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; + +import org.junit.jupiter.api.Test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class LeftOrRightValueDeserializerTest { + + @Test + public void shouldPassHeadersToUnderlyingDeserializer() { + final Deserializer mockDeserializer = mock(StringDeserializer.class); + + final String topic = "dummy"; + final String value = "some-string"; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final LeftOrRightValue data = LeftOrRightValue.makeLeftValue(value); + final byte[] serializedBytes = new LeftOrRightValueSerializer<>(Serdes.String().serializer(), null).serialize(topic, headers, data); + + when(mockDeserializer.deserialize(topic, headers, value.getBytes())).thenReturn("dummy-value"); + + final LeftOrRightValueDeserializer testDeserializer = new LeftOrRightValueDeserializer<>(mockDeserializer, null); + + testDeserializer.deserialize(topic, headers, serializedBytes); + + verify(mockDeserializer).deserialize(topic, headers, value.getBytes()); + verify(mockDeserializer, never()).deserialize(topic, value.getBytes()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java index 2a5aa5c891c67..fc177d2c7a0ba 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java @@ -16,7 +16,11 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; import org.junit.jupiter.api.Test; @@ -24,12 +28,14 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; public class LeftOrRightValueSerializerTest { private static final String TOPIC = "some-topic"; - private static final LeftOrRightValueSerde STRING_OR_INTEGER_SERDE = - new LeftOrRightValueSerde<>(Serdes.String(), Serdes.Integer()); + private static final LeftOrRightValueSerde STRING_OR_INTEGER_SERDE = new LeftOrRightValueSerde<>(Serdes.String(), Serdes.Integer()); @Test public void shouldSerializeStringValue() { @@ -37,13 +43,11 @@ public void shouldSerializeStringValue() { final LeftOrRightValue leftOrRightValue = LeftOrRightValue.makeLeftValue(value); - final byte[] serialized = - STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); + final byte[] serialized = STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); assertThat(serialized, is(notNullValue())); - final LeftOrRightValue deserialized = - STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); + final LeftOrRightValue deserialized = STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); assertThat(deserialized, is(leftOrRightValue)); } @@ -54,13 +58,11 @@ public void shouldSerializeIntegerValue() { final LeftOrRightValue leftOrRightValue = LeftOrRightValue.makeRightValue(value); - final byte[] serialized = - STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); + final byte[] serialized = STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); assertThat(serialized, is(notNullValue())); - final LeftOrRightValue deserialized = - STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); + final LeftOrRightValue deserialized = STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); assertThat(deserialized, is(leftOrRightValue)); } @@ -76,4 +78,21 @@ public void shouldThrowIfSerializeOtherValueAsNull() { assertThrows(NullPointerException.class, () -> STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, LeftOrRightValue.makeRightValue(null))); } + + @Test + public void shouldPassHeadersToUnderlyingSerializer() { + final Serializer mockSerializer = mock(StringSerializer.class); + + final String topic = "dummy"; + final String value = "some-string"; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final LeftOrRightValue data = LeftOrRightValue.makeLeftValue(value); + + final LeftOrRightValueSerializer testSerializer = new LeftOrRightValueSerializer<>(mockSerializer, null); + + testSerializer.serialize(topic, headers, data); + + verify(mockSerializer).serialize(topic, headers, value); + verify(mockSerializer, never()).serialize(topic, value); + } } diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializerTest.java new file mode 100644 index 0000000000000..059d482f8ad9a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializerTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; + +import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TimestampedKeyAndJoinSideDeserializerTest { + + @Test + public void shouldPassHeadersToUnderlyingDeserializer() { + try (MockedConstruction timestampSerializer = mockConstruction(LongDeserializer.class)) { + final Deserializer mockDeserializer = mock(StringDeserializer.class); + final TimestampedKeyAndJoinSideDeserializer testDeserializer = new TimestampedKeyAndJoinSideDeserializer<>(mockDeserializer); + final Deserializer innerTimestampDeserializer = timestampSerializer.constructed().get(0); + + final String topic = "dummy"; + final String key = "some-key"; + final long timestamp = 10; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final TimestampedKeyAndJoinSide data = TimestampedKeyAndJoinSide.makeLeft(key, timestamp); + final byte[] serializedValue = new TimestampedKeyAndJoinSideSerializer<>(Serdes.String().serializer()).serialize(topic, headers, data); + + when(mockDeserializer.deserialize(topic, headers, key.getBytes())).thenReturn(key); + when(innerTimestampDeserializer.deserialize(eq(topic), eq(headers), any(byte[].class))).thenReturn(timestamp); + + testDeserializer.deserialize(topic, headers, serializedValue); + + verify(mockDeserializer).deserialize(topic, headers, key.getBytes()); + verify(mockDeserializer, never()).deserialize(topic, key.getBytes()); + + verify(innerTimestampDeserializer).deserialize(eq(topic), eq(headers), any(byte[].class)); + verify(innerTimestampDeserializer, never()).deserialize(eq(topic), any(byte[].class)); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java index 81d5736015ac8..f544f0a04a396 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java @@ -16,20 +16,30 @@ */ package org.apache.kafka.streams.state.internals; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.LongSerializer; import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TimestampedKeyAndJoinSideSerializerTest { private static final String TOPIC = "some-topic"; - private static final TimestampedKeyAndJoinSideSerde STRING_SERDE = - new TimestampedKeyAndJoinSideSerde<>(Serdes.String()); + private static final TimestampedKeyAndJoinSideSerde STRING_SERDE = new TimestampedKeyAndJoinSideSerde<>(Serdes.String()); @Test public void shouldSerializeKeyWithJoinSideAsTrue() { @@ -37,13 +47,11 @@ public void shouldSerializeKeyWithJoinSideAsTrue() { final TimestampedKeyAndJoinSide timestampedKeyAndJoinSide = TimestampedKeyAndJoinSide.makeLeft(value, 10); - final byte[] serialized = - STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); + final byte[] serialized = STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); assertThat(serialized, is(notNullValue())); - final TimestampedKeyAndJoinSide deserialized = - STRING_SERDE.deserializer().deserialize(TOPIC, serialized); + final TimestampedKeyAndJoinSide deserialized = STRING_SERDE.deserializer().deserialize(TOPIC, serialized); assertThat(deserialized, is(timestampedKeyAndJoinSide)); } @@ -54,13 +62,11 @@ public void shouldSerializeKeyWithJoinSideAsFalse() { final TimestampedKeyAndJoinSide timestampedKeyAndJoinSide = TimestampedKeyAndJoinSide.makeRight(value, 20); - final byte[] serialized = - STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); + final byte[] serialized = STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); assertThat(serialized, is(notNullValue())); - final TimestampedKeyAndJoinSide deserialized = - STRING_SERDE.deserializer().deserialize(TOPIC, serialized); + final TimestampedKeyAndJoinSide deserialized = STRING_SERDE.deserializer().deserialize(TOPIC, serialized); assertThat(deserialized, is(timestampedKeyAndJoinSide)); } @@ -70,4 +76,30 @@ public void shouldThrowIfSerializeNullData() { assertThrows(NullPointerException.class, () -> STRING_SERDE.serializer().serialize(TOPIC, TimestampedKeyAndJoinSide.makeLeft(null, 0))); } + + @Test + public void shouldPassHeadersToUnderlyingSerializer() { + try (MockedConstruction timestampSerializer = mockConstruction(LongSerializer.class)) { + final Serializer mockSerializer = mock(StringSerializer.class); + final TimestampedKeyAndJoinSideSerializer testSerializer = new TimestampedKeyAndJoinSideSerializer<>(mockSerializer); + final Serializer innerTimestampSerializer = timestampSerializer.constructed().get(0); + + final String topic = "dummy"; + final String key = "some-key"; + final long timestamp = 10; + final Headers headers = new RecordHeaders().add("key", "value".getBytes()); + final TimestampedKeyAndJoinSide data = TimestampedKeyAndJoinSide.makeLeft(key, timestamp); + + when(mockSerializer.serialize(topic, headers, data.key())).thenReturn(key.getBytes()); + when(innerTimestampSerializer.serialize(topic, headers, data.timestamp())).thenReturn(new byte[]{Byte.MAX_VALUE}); + + testSerializer.serialize(topic, headers, data); + + verify(mockSerializer).serialize(topic, headers, key); + verify(mockSerializer, never()).serialize(topic, key); + + verify(innerTimestampSerializer).serialize(topic, headers, timestamp); + verify(innerTimestampSerializer, never()).serialize(topic, timestamp); + } + } }