diff --git a/streams/src/main/java/org/apache/kafka/streams/Topology.java b/streams/src/main/java/org/apache/kafka/streams/Topology.java index e032abc346fd4..830e342735521 100644 --- a/streams/src/main/java/org/apache/kafka/streams/Topology.java +++ b/streams/src/main/java/org/apache/kafka/streams/Topology.java @@ -829,6 +829,10 @@ public synchronized Topology addReadOnlyStateStore( storeBuilder.withLoggingDisabled(); internalTopologyBuilder.connectSourceStoreAndTopic(storeBuilder.name(), topic); + // register reprocess factory so that restoration also goes through the custom processor + internalTopologyBuilder.registerReadOnlyStoreReprocessFactory( + storeBuilder.name(), stateUpdateSupplier, keyDeserializer, valueDeserializer, processorName); + return this; } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java index 1d5b7fbf7ed01..bb57945d62adf 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java @@ -153,7 +153,8 @@ public Collection createTasks(final Consumer consume stateDirectory, topology.storeToChangelogTopic(), partitions, - upgradeFrom); + upgradeFrom, + topology.storeNameToReprocessOnRestore()); final InternalProcessorContext context = new ProcessorContextImpl( taskId, diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java index 86f1ebbe7a60d..1c125f3620687 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java @@ -209,7 +209,8 @@ public static class ReprocessFactory { private final Deserializer valueDeserializer; private final String processorName; - private ReprocessFactory(final ProcessorSupplier processorSupplier, + // package-private for testing + ReprocessFactory(final ProcessorSupplier processorSupplier, final Deserializer key, final Deserializer value, final String processorName) { @@ -751,6 +752,15 @@ public String storeForChangelogTopic(final String topicName) { return changelogTopicToStore.get(topicName); } + public void registerReadOnlyStoreReprocessFactory(final String storeName, + final ProcessorSupplier processorSupplier, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String processorName) { + storeNameToReprocessOnRestore.put(storeName, + Optional.of(new ReprocessFactory<>(processorSupplier, keyDeserializer, valueDeserializer, processorName))); + } + public void connectSourceStoreAndTopic(final String sourceStoreName, final String topic) { if (storeToChangelogTopic.containsKey(sourceStoreName)) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java index db03332079596..79d4aed9adc76 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java @@ -34,6 +34,9 @@ import org.apache.kafka.streams.processor.internals.Task.TaskType; import org.apache.kafka.streams.state.internals.CachedStateStore; import org.apache.kafka.streams.state.internals.LegacyCheckpointingStateStore; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; import org.apache.kafka.streams.state.internals.RecordConverter; import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer; @@ -47,6 +50,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.OptionalLong; import java.util.Set; import java.util.stream.Collectors; @@ -183,10 +187,13 @@ public String toString() { private final StateDirectory stateDirectory; private final File baseDir; private final UpgradeFromValues upgradeFrom; + private final Map>> storeNameToReprocessOnRestore; private TaskType taskType; private Logger log; private Task.State taskState; + private InternalProcessorContext processorContext; + private final Map> reprocessorCache = new HashMap<>(); public static String storeChangelogTopic(final String prefix, final String storeName, final String namedTopology) { if (namedTopology == null) { @@ -207,6 +214,19 @@ public ProcessorStateManager(final TaskId taskId, final Map storeToChangelogTopic, final Collection sourcePartitions, final UpgradeFromValues upgradeFrom) throws ProcessorStateException { + this(taskId, taskType, eosEnabled, logContext, stateDirectory, storeToChangelogTopic, + sourcePartitions, upgradeFrom, Collections.emptyMap()); + } + + public ProcessorStateManager(final TaskId taskId, + final TaskType taskType, + final boolean eosEnabled, + final LogContext logContext, + final StateDirectory stateDirectory, + final Map storeToChangelogTopic, + final Collection sourcePartitions, + final UpgradeFromValues upgradeFrom, + final Map>> storeNameToReprocessOnRestore) throws ProcessorStateException { this.storeToChangelogTopic = storeToChangelogTopic; this.log = logContext.logger(ProcessorStateManager.class); this.logPrefix = logContext.logPrefix(); @@ -215,6 +235,7 @@ public ProcessorStateManager(final TaskId taskId, this.eosEnabled = eosEnabled; this.sourcePartitions = sourcePartitions; this.upgradeFrom = upgradeFrom; + this.storeNameToReprocessOnRestore = storeNameToReprocessOnRestore; this.baseDir = stateDirectory.getOrCreateDirectoryForTask(taskId); this.stateDirectory = stateDirectory; @@ -252,6 +273,7 @@ static ProcessorStateManager createStartupTaskStateManager(final TaskId taskId, } void registerStateStores(final List allStores, final InternalProcessorContext processorContext) { + this.processorContext = processorContext; processorContext.uninitialize(); final Map storesToMigrate = new HashMap<>(stores.size()); for (final StateStore store : allStores) { @@ -460,6 +482,7 @@ StateStoreMetadata storeMetadata(final TopicPartition partition) { } // used by the changelog reader only + @SuppressWarnings({"rawtypes", "unchecked"}) void restore(final StateStoreMetadata storeMetadata, final List> restoreRecords, final OptionalLong optionalLag) { if (!stores.containsValue(storeMetadata)) { throw new IllegalStateException("Restoring " + storeMetadata + " which is not registered in this state manager, " + @@ -469,18 +492,27 @@ void restore(final StateStoreMetadata storeMetadata, final List> convertedRecords = restoreRecords.stream() - .map(storeMetadata.recordConverter::convert) - .collect(Collectors.toList()); - try { - restoreCallback.restoreBatch(convertedRecords); - } catch (final RuntimeException e) { - throw new ProcessorStateException( - format("%sException caught while trying to restore state from %s", logPrefix, storeMetadata.changelogPartition), - e - ); + final String storeName = storeMetadata.store().name(); + final Optional> reprocessFactory = + storeNameToReprocessOnRestore.getOrDefault(storeName, Optional.empty()); + + if (reprocessFactory.isPresent() && processorContext != null) { + reprocessRestore(storeMetadata, restoreRecords, reprocessFactory.get()); + } else { + final RecordBatchingStateRestoreCallback restoreCallback = adapt(storeMetadata.restoreCallback); + final List> convertedRecords = restoreRecords.stream() + .map(storeMetadata.recordConverter::convert) + .collect(Collectors.toList()); + + try { + restoreCallback.restoreBatch(convertedRecords); + } catch (final RuntimeException e) { + throw new ProcessorStateException( + format("%sException caught while trying to restore state from %s", logPrefix, storeMetadata.changelogPartition), + e + ); + } } storeMetadata.setOffset(batchEndOffset); @@ -494,6 +526,44 @@ void restore(final StateStoreMetadata storeMetadata, final List> restoreRecords, + final InternalTopologyBuilder.ReprocessFactory reprocessFactory) { + final String storeName = storeMetadata.store().name(); + final Processor processor = reprocessorCache.computeIfAbsent(storeName, k -> { + final Processor p = reprocessFactory.processorSupplier().get(); + p.init((ProcessorContext) processorContext); + return p; + }); + + for (final ConsumerRecord record : restoreRecords) { + final ConsumerRecord converted = storeMetadata.recordConverter.convert(record); + if (converted.key() != null) { + final ProcessorRecordContext recordContext = new ProcessorRecordContext( + converted.timestamp(), + converted.offset(), + converted.partition(), + converted.topic(), + converted.headers()); + processorContext.setRecordContext(recordContext); + + try { + final Object key = reprocessFactory.keyDeserializer().deserialize(converted.topic(), converted.key()); + final Object value = reprocessFactory.valueDeserializer().deserialize(converted.topic(), converted.value()); + final long timestamp = Math.max(0L, converted.timestamp()); + processor.process(new Record<>(key, value, timestamp, converted.headers())); + } catch (final RuntimeException e) { + throw new ProcessorStateException( + format("%sException caught while trying to reprocess-restore state from %s", + logPrefix, storeMetadata.changelogPartition), + e + ); + } + } + } + } + /** * @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed * @throws StreamsException fatal error when committing the state store, for example sending changelog records failed @@ -607,6 +677,16 @@ public void flushCache() { public void close() throws ProcessorStateException { log.debug("Closing its state manager and all the registered state stores: {}", stores); + // close any cached reprocess processors + for (final Processor processor : reprocessorCache.values()) { + try { + processor.close(); + } catch (final RuntimeException e) { + log.warn("Failed to close reprocess processor: ", e); + } + } + reprocessorCache.clear(); + final Map allOffsets = new HashMap<>(); RuntimeException firstException = null; // attempting to close the stores, just in case they diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java index 139efbd63de9f..78e0e336a3317 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java @@ -86,7 +86,8 @@ Collection createTasks(final Map> tasks stateDirectory, topology.storeToChangelogTopic(), partitions, - upgradeFrom); + upgradeFrom, + topology.storeNameToReprocessOnRestore()); final InternalProcessorContext context = new ProcessorContextImpl( taskId, diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/ReadOnlyStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/ReadOnlyStoreTest.java index 3b6393399024b..0a92a82f92e5d 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/ReadOnlyStoreTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/ReadOnlyStoreTest.java @@ -130,4 +130,111 @@ public void process(final Record record) { assertThat(output.readKeyValuesToList(), equalTo(expectedResult)); } } + + @Test + public void shouldUseCustomProcessorDuringRestorationWithTransformation() { + final java.util.concurrent.atomic.AtomicInteger processCallCount = new java.util.concurrent.atomic.AtomicInteger(0); + + final Topology topology = new Topology(); + topology.addReadOnlyStateStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("readOnlyStore"), + new Serdes.IntegerSerde(), + new Serdes.StringSerde() + ), + "readOnlySource", + new IntegerDeserializer(), + new StringDeserializer(), + "storeTopic", + "readOnlyProcessor", + () -> new Processor<>() { + KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore("readOnlyStore"); + } + @Override + public void process(final Record record) { + processCallCount.incrementAndGet(); + // Custom transformation: prepend "processed-" to the value + store.put(record.key(), "processed-" + record.value()); + } + } + ); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) { + final TestInputTopic readOnlyStoreTopic = + driver.createInputTopic("storeTopic", new IntegerSerializer(), new StringSerializer()); + + readOnlyStoreTopic.pipeInput(1, "foo"); + readOnlyStoreTopic.pipeInput(2, "bar"); + + final KeyValueStore store = driver.getKeyValueStore("readOnlyStore"); + + try (final KeyValueIterator it = store.all()) { + final List> storeContent = new LinkedList<>(); + it.forEachRemaining(storeContent::add); + + // Values should have the "processed-" prefix from the custom processor + final List> expectedResult = new LinkedList<>(); + expectedResult.add(KeyValue.pair(1, "processed-foo")); + expectedResult.add(KeyValue.pair(2, "processed-bar")); + + assertThat(storeContent, equalTo(expectedResult)); + } + + // Verify the processor was actually called + assertThat(processCallCount.get(), equalTo(2)); + } + } + + @Test + public void shouldHandleNullKeyRecordsDuringReprocessRestore() { + final java.util.concurrent.atomic.AtomicInteger processCallCount = new java.util.concurrent.atomic.AtomicInteger(0); + + final Topology topology = new Topology(); + topology.addReadOnlyStateStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("readOnlyStore"), + new Serdes.IntegerSerde(), + new Serdes.StringSerde() + ), + "readOnlySource", + new IntegerDeserializer(), + new StringDeserializer(), + "storeTopic", + "readOnlyProcessor", + () -> new Processor<>() { + KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore("readOnlyStore"); + } + @Override + public void process(final Record record) { + processCallCount.incrementAndGet(); + store.put(record.key(), record.value()); + } + } + ); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) { + final TestInputTopic readOnlyStoreTopic = + driver.createInputTopic("storeTopic", new IntegerSerializer(), new StringSerializer()); + + readOnlyStoreTopic.pipeInput(1, "value1"); + + final KeyValueStore store = driver.getKeyValueStore("readOnlyStore"); + + try (final KeyValueIterator it = store.all()) { + final List> storeContent = new LinkedList<>(); + it.forEachRemaining(storeContent::add); + + assertThat(storeContent.size(), equalTo(1)); + assertThat(storeContent.get(0), equalTo(KeyValue.pair(1, "value1"))); + } + } + } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java index 76a10cf192a8f..116ebd93f8bd4 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java @@ -868,6 +868,109 @@ public void shouldThrowIfRestoreCallbackThrows() { } } + @Test + public void shouldRestoreViaReprocessFactoryWhenPresent() { + final java.util.concurrent.atomic.AtomicInteger processedCount = new java.util.concurrent.atomic.AtomicInteger(0); + final java.util.List processedKeys = new java.util.ArrayList<>(); + final MockKeyValueStore store = new MockKeyValueStore(persistentStoreName, true); + + final org.apache.kafka.streams.processor.api.ProcessorSupplier processorSupplier = + () -> new org.apache.kafka.streams.processor.api.Processor<>() { + @Override + public void init(final org.apache.kafka.streams.processor.api.ProcessorContext context) { + // no-op: we'll write to the store directly + } + + @Override + public void process(final org.apache.kafka.streams.processor.api.Record record) { + processedCount.incrementAndGet(); + processedKeys.add(record.key()); + } + }; + + final org.apache.kafka.common.serialization.StringDeserializer stringDeserializer = + new org.apache.kafka.common.serialization.StringDeserializer(); + + final InternalTopologyBuilder.ReprocessFactory reprocessFactory = + new InternalTopologyBuilder.ReprocessFactory<>(processorSupplier, stringDeserializer, stringDeserializer, "testProcessor"); + + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.ACTIVE, + false, + logContext, + stateDirectory, + mkMap( + mkEntry(persistentStoreName, persistentStoreTopicName), + mkEntry(persistentStoreTwoName, persistentStoreTwoTopicName), + mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName) + ), + emptySet(), + null, + mkMap(mkEntry(persistentStoreName, java.util.Optional.of(reprocessFactory))) + ); + + try { + // Register store directly (like other tests) and set context + stateMgr.registerStore(store, store.stateRestoreCallback, null); + // set the processorContext so that reprocessRestore can use it + stateMgr.registerStateStores(java.util.Collections.emptyList(), context); + + final StateStoreMetadata storeMetadataObj = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadataObj, notNullValue()); + + final byte[] testKey = "myKey".getBytes(StandardCharsets.UTF_8); + final byte[] testValue = "myValue".getBytes(StandardCharsets.UTF_8); + final ConsumerRecord record = + new ConsumerRecord<>(persistentStoreTopicName, 1, 100L, 1000L, + org.apache.kafka.common.record.TimestampType.CREATE_TIME, + testKey.length, testValue.length, testKey, testValue, + new org.apache.kafka.common.header.internals.RecordHeaders(), + java.util.Optional.empty()); + + stateMgr.restore(storeMetadataObj, singletonList(record), OptionalLong.of(2L)); + + // verify the processor was called instead of the callback + assertEquals(1, processedCount.get()); + assertEquals("myKey", processedKeys.get(0)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldFallbackToCallbackWhenNoReprocessFactory() { + final MockRestoreCallback restoreCallback = new MockRestoreCallback(); + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.ACTIVE, + false, + logContext, + stateDirectory, + mkMap( + mkEntry(persistentStoreName, persistentStoreTopicName), + mkEntry(persistentStoreTwoName, persistentStoreTwoTopicName), + mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName) + ), + emptySet(), + null, + java.util.Collections.emptyMap() + ); + + try { + stateMgr.registerStore(persistentStore, restoreCallback, null); + final StateStoreMetadata storeMetadataObj = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadataObj, notNullValue()); + + stateMgr.restore(storeMetadataObj, singletonList(consumerRecord), OptionalLong.of(2L)); + + // verify the restore callback was used (not the processor) + assertThat(restoreCallback.restored.size(), is(1)); + } finally { + stateMgr.close(); + } + } + @Test public void shouldCommitGoodStoresEvenSomeThrowsException() { final AtomicBoolean committedStore = new AtomicBoolean(false);