Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions streams/src/main/java/org/apache/kafka/streams/Topology.java
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,10 @@ public synchronized <K, V, S extends StateStore> Topology addReadOnlyStateStore(
storeBuilder.withLoggingDisabled();
internalTopologyBuilder.connectSourceStoreAndTopic(storeBuilder.name(), topic);

// register reprocess factory so that restoration also goes through the custom processor
internalTopologyBuilder.registerReadOnlyStoreReprocessFactory(
storeBuilder.name(), stateUpdateSupplier, keyDeserializer, valueDeserializer, processorName);

return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ public Collection<StreamTask> createTasks(final Consumer<byte[], byte[]> consume
stateDirectory,
topology.storeToChangelogTopic(),
partitions,
upgradeFrom);
upgradeFrom,
topology.storeNameToReprocessOnRestore());

final InternalProcessorContext<Object, Object> context = new ProcessorContextImpl(
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ public static class ReprocessFactory<KIn, VIn, KOut, VOut> {
private final Deserializer<VIn> valueDeserializer;
private final String processorName;

private ReprocessFactory(final ProcessorSupplier<KIn, VIn, KOut, VOut> processorSupplier,
// package-private for testing
ReprocessFactory(final ProcessorSupplier<KIn, VIn, KOut, VOut> processorSupplier,
final Deserializer<KIn> key,
final Deserializer<VIn> value,
final String processorName) {
Expand Down Expand Up @@ -751,6 +752,15 @@ public String storeForChangelogTopic(final String topicName) {
return changelogTopicToStore.get(topicName);
}

public <KIn, VIn> void registerReadOnlyStoreReprocessFactory(final String storeName,
final ProcessorSupplier<KIn, VIn, Void, Void> processorSupplier,
final Deserializer<KIn> keyDeserializer,
final Deserializer<VIn> valueDeserializer,
final String processorName) {
storeNameToReprocessOnRestore.put(storeName,
Optional.of(new ReprocessFactory<>(processorSupplier, keyDeserializer, valueDeserializer, processorName)));
}

public void connectSourceStoreAndTopic(final String sourceStoreName,
final String topic) {
if (storeToChangelogTopic.containsKey(sourceStoreName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.apache.kafka.streams.processor.internals.Task.TaskType;
import org.apache.kafka.streams.state.internals.CachedStateStore;
import org.apache.kafka.streams.state.internals.LegacyCheckpointingStateStore;
import org.apache.kafka.streams.processor.api.Processor;
import org.apache.kafka.streams.processor.api.ProcessorContext;
import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.state.internals.RecordConverter;
import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer;

Expand All @@ -47,6 +50,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -183,10 +187,13 @@ public String toString() {
private final StateDirectory stateDirectory;
private final File baseDir;
private final UpgradeFromValues upgradeFrom;
private final Map<String, Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>>> storeNameToReprocessOnRestore;

private TaskType taskType;
private Logger log;
private Task.State taskState;
private InternalProcessorContext<?, ?> processorContext;
private final Map<String, Processor<?, ?, ?, ?>> reprocessorCache = new HashMap<>();

public static String storeChangelogTopic(final String prefix, final String storeName, final String namedTopology) {
if (namedTopology == null) {
Expand All @@ -207,6 +214,19 @@ public ProcessorStateManager(final TaskId taskId,
final Map<String, String> storeToChangelogTopic,
final Collection<TopicPartition> sourcePartitions,
final UpgradeFromValues upgradeFrom) throws ProcessorStateException {
this(taskId, taskType, eosEnabled, logContext, stateDirectory, storeToChangelogTopic,
sourcePartitions, upgradeFrom, Collections.emptyMap());
}

public ProcessorStateManager(final TaskId taskId,
final TaskType taskType,
final boolean eosEnabled,
final LogContext logContext,
final StateDirectory stateDirectory,
final Map<String, String> storeToChangelogTopic,
final Collection<TopicPartition> sourcePartitions,
final UpgradeFromValues upgradeFrom,
final Map<String, Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>>> storeNameToReprocessOnRestore) throws ProcessorStateException {
this.storeToChangelogTopic = storeToChangelogTopic;
this.log = logContext.logger(ProcessorStateManager.class);
this.logPrefix = logContext.logPrefix();
Expand All @@ -215,6 +235,7 @@ public ProcessorStateManager(final TaskId taskId,
this.eosEnabled = eosEnabled;
this.sourcePartitions = sourcePartitions;
this.upgradeFrom = upgradeFrom;
this.storeNameToReprocessOnRestore = storeNameToReprocessOnRestore;

this.baseDir = stateDirectory.getOrCreateDirectoryForTask(taskId);
this.stateDirectory = stateDirectory;
Expand Down Expand Up @@ -252,6 +273,7 @@ static ProcessorStateManager createStartupTaskStateManager(final TaskId taskId,
}

void registerStateStores(final List<StateStore> allStores, final InternalProcessorContext<?, ?> processorContext) {
this.processorContext = processorContext;
processorContext.uninitialize();
final Map<TopicPartition, StateStore> storesToMigrate = new HashMap<>(stores.size());
for (final StateStore store : allStores) {
Expand Down Expand Up @@ -460,6 +482,7 @@ StateStoreMetadata storeMetadata(final TopicPartition partition) {
}

// used by the changelog reader only
@SuppressWarnings({"rawtypes", "unchecked"})
void restore(final StateStoreMetadata storeMetadata, final List<ConsumerRecord<byte[], byte[]>> restoreRecords, final OptionalLong optionalLag) {
if (!stores.containsValue(storeMetadata)) {
throw new IllegalStateException("Restoring " + storeMetadata + " which is not registered in this state manager, " +
Expand All @@ -469,18 +492,27 @@ void restore(final StateStoreMetadata storeMetadata, final List<ConsumerRecord<b
if (!restoreRecords.isEmpty()) {
// restore states from changelog records and update the snapshot offset as the batch end record's offset
final Long batchEndOffset = restoreRecords.get(restoreRecords.size() - 1).offset();
final RecordBatchingStateRestoreCallback restoreCallback = adapt(storeMetadata.restoreCallback);
final List<ConsumerRecord<byte[], byte[]>> convertedRecords = restoreRecords.stream()
.map(storeMetadata.recordConverter::convert)
.collect(Collectors.toList());

try {
restoreCallback.restoreBatch(convertedRecords);
} catch (final RuntimeException e) {
throw new ProcessorStateException(
format("%sException caught while trying to restore state from %s", logPrefix, storeMetadata.changelogPartition),
e
);
final String storeName = storeMetadata.store().name();
final Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>> reprocessFactory =
storeNameToReprocessOnRestore.getOrDefault(storeName, Optional.empty());

if (reprocessFactory.isPresent() && processorContext != null) {
reprocessRestore(storeMetadata, restoreRecords, reprocessFactory.get());
} else {
final RecordBatchingStateRestoreCallback restoreCallback = adapt(storeMetadata.restoreCallback);
final List<ConsumerRecord<byte[], byte[]>> convertedRecords = restoreRecords.stream()
.map(storeMetadata.recordConverter::convert)
.collect(Collectors.toList());

try {
restoreCallback.restoreBatch(convertedRecords);
} catch (final RuntimeException e) {
throw new ProcessorStateException(
format("%sException caught while trying to restore state from %s", logPrefix, storeMetadata.changelogPartition),
e
);
}
}

storeMetadata.setOffset(batchEndOffset);
Expand All @@ -494,6 +526,44 @@ void restore(final StateStoreMetadata storeMetadata, final List<ConsumerRecord<b
}
}

@SuppressWarnings({"rawtypes", "unchecked"})
private void reprocessRestore(final StateStoreMetadata storeMetadata,
final List<ConsumerRecord<byte[], byte[]>> restoreRecords,
final InternalTopologyBuilder.ReprocessFactory reprocessFactory) {
final String storeName = storeMetadata.store().name();
final Processor processor = reprocessorCache.computeIfAbsent(storeName, k -> {
final Processor p = reprocessFactory.processorSupplier().get();
p.init((ProcessorContext) processorContext);
return p;
});

for (final ConsumerRecord<byte[], byte[]> record : restoreRecords) {
final ConsumerRecord<byte[], byte[]> converted = storeMetadata.recordConverter.convert(record);
if (converted.key() != null) {
final ProcessorRecordContext recordContext = new ProcessorRecordContext(
converted.timestamp(),
converted.offset(),
converted.partition(),
converted.topic(),
converted.headers());
processorContext.setRecordContext(recordContext);

try {
final Object key = reprocessFactory.keyDeserializer().deserialize(converted.topic(), converted.key());
final Object value = reprocessFactory.valueDeserializer().deserialize(converted.topic(), converted.value());
final long timestamp = Math.max(0L, converted.timestamp());
processor.process(new Record<>(key, value, timestamp, converted.headers()));
} catch (final RuntimeException e) {
throw new ProcessorStateException(
format("%sException caught while trying to reprocess-restore state from %s",
logPrefix, storeMetadata.changelogPartition),
e
);
}
}
}
}

/**
* @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed
* @throws StreamsException fatal error when committing the state store, for example sending changelog records failed
Expand Down Expand Up @@ -607,6 +677,16 @@ public void flushCache() {
public void close() throws ProcessorStateException {
log.debug("Closing its state manager and all the registered state stores: {}", stores);

// close any cached reprocess processors
for (final Processor<?, ?, ?, ?> processor : reprocessorCache.values()) {
try {
processor.close();
} catch (final RuntimeException e) {
log.warn("Failed to close reprocess processor: ", e);
}
}
reprocessorCache.clear();

final Map<TopicPartition, Long> allOffsets = new HashMap<>();
RuntimeException firstException = null;
// attempting to close the stores, just in case they
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ Collection<StandbyTask> createTasks(final Map<TaskId, Set<TopicPartition>> tasks
stateDirectory,
topology.storeToChangelogTopic(),
partitions,
upgradeFrom);
upgradeFrom,
topology.storeNameToReprocessOnRestore());

final InternalProcessorContext<?, ?> context = new ProcessorContextImpl(
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,111 @@ public void process(final Record<Integer, String> record) {
assertThat(output.readKeyValuesToList(), equalTo(expectedResult));
}
}

@Test
public void shouldUseCustomProcessorDuringRestorationWithTransformation() {
final java.util.concurrent.atomic.AtomicInteger processCallCount = new java.util.concurrent.atomic.AtomicInteger(0);

final Topology topology = new Topology();
topology.addReadOnlyStateStore(
Stores.keyValueStoreBuilder(
Stores.inMemoryKeyValueStore("readOnlyStore"),
new Serdes.IntegerSerde(),
new Serdes.StringSerde()
),
"readOnlySource",
new IntegerDeserializer(),
new StringDeserializer(),
"storeTopic",
"readOnlyProcessor",
() -> new Processor<>() {
KeyValueStore<Integer, String> store;

@Override
public void init(final ProcessorContext<Void, Void> context) {
store = context.getStateStore("readOnlyStore");
}
@Override
public void process(final Record<Integer, String> record) {
processCallCount.incrementAndGet();
// Custom transformation: prepend "processed-" to the value
store.put(record.key(), "processed-" + record.value());
}
}
);

try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) {
final TestInputTopic<Integer, String> readOnlyStoreTopic =
driver.createInputTopic("storeTopic", new IntegerSerializer(), new StringSerializer());

readOnlyStoreTopic.pipeInput(1, "foo");
readOnlyStoreTopic.pipeInput(2, "bar");

final KeyValueStore<Integer, String> store = driver.getKeyValueStore("readOnlyStore");

try (final KeyValueIterator<Integer, String> it = store.all()) {
final List<KeyValue<Integer, String>> storeContent = new LinkedList<>();
it.forEachRemaining(storeContent::add);

// Values should have the "processed-" prefix from the custom processor
final List<KeyValue<Integer, String>> expectedResult = new LinkedList<>();
expectedResult.add(KeyValue.pair(1, "processed-foo"));
expectedResult.add(KeyValue.pair(2, "processed-bar"));

assertThat(storeContent, equalTo(expectedResult));
}

// Verify the processor was actually called
assertThat(processCallCount.get(), equalTo(2));
}
}

@Test
public void shouldHandleNullKeyRecordsDuringReprocessRestore() {
final java.util.concurrent.atomic.AtomicInteger processCallCount = new java.util.concurrent.atomic.AtomicInteger(0);

final Topology topology = new Topology();
topology.addReadOnlyStateStore(
Stores.keyValueStoreBuilder(
Stores.inMemoryKeyValueStore("readOnlyStore"),
new Serdes.IntegerSerde(),
new Serdes.StringSerde()
),
"readOnlySource",
new IntegerDeserializer(),
new StringDeserializer(),
"storeTopic",
"readOnlyProcessor",
() -> new Processor<>() {
KeyValueStore<Integer, String> store;

@Override
public void init(final ProcessorContext<Void, Void> context) {
store = context.getStateStore("readOnlyStore");
}
@Override
public void process(final Record<Integer, String> record) {
processCallCount.incrementAndGet();
store.put(record.key(), record.value());
}
}
);

try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) {
final TestInputTopic<Integer, String> readOnlyStoreTopic =
driver.createInputTopic("storeTopic", new IntegerSerializer(), new StringSerializer());

readOnlyStoreTopic.pipeInput(1, "value1");

final KeyValueStore<Integer, String> store = driver.getKeyValueStore("readOnlyStore");

try (final KeyValueIterator<Integer, String> it = store.all()) {
final List<KeyValue<Integer, String>> storeContent = new LinkedList<>();
it.forEachRemaining(storeContent::add);

assertThat(storeContent.size(), equalTo(1));
assertThat(storeContent.get(0), equalTo(KeyValue.pair(1, "value1")));
}
}
}
}
Loading