Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,28 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
version="1.0.1",
version="1.0.2",
)
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""

collection: list[ImageField] = InputField(description="The collection of image values")
collection: Optional[list[ImageField]] = InputField(
default=None,
description="An optional image collection to append to",
input=Input.Connection,
title="Collection",
ui_order=0,
)
images: Optional[list[ImageField]] = InputField(
default=None,
description="The images to append to the collection",
input=Input.Direct,
title="Images",
ui_order=1,
)

def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
return ImageCollectionOutput(collection=[*(self.collection or []), *(self.images or [])])


# endregion
Expand Down
33 changes: 28 additions & 5 deletions invokeai/frontend/web/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -32518,11 +32518,34 @@
}
],
"default": null,
"description": "The collection of image values",
"description": "An optional image collection to append to",
"field_kind": "input",
"input": "any",
"orig_required": true,
"title": "Collection"
"input": "connection",
"orig_default": null,
"orig_required": false,
"title": "Collection",
"ui_order": 0
},
"images": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ImageField"
},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"description": "The images to append to the collection",
"field_kind": "input",
"input": "direct",
"orig_default": null,
"orig_required": false,
"title": "Images",
"ui_order": 1
},
"type": {
"const": "image_collection",
Expand All @@ -32536,7 +32559,7 @@
"tags": ["primitives", "image", "collection"],
"title": "Image Collection Primitive",
"type": "object",
"version": "1.0.1",
"version": "1.0.2",
"output": {
"$ref": "#/components/schemas/ImageCollectionOutput"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/storeHooks';
import { useGetNodesNeedUpdate } from 'features/nodes/hooks/useGetNodesNeedUpdate';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { selectNodes } from 'features/nodes/store/selectors';
import { NodeUpdateError } from 'features/nodes/types/error';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
import { getConnectedInputNames, getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
Expand All @@ -20,6 +19,7 @@ const useUpdateNodes = () => {

const updateNodes = useCallback(() => {
const nodes = selectNodes(store.getState());
const edges = selectEdges(store.getState());
const templates = $templates.get();

let unableToUpdateCount = 0;
Expand All @@ -35,17 +35,16 @@ const useUpdateNodes = () => {
return;
}
try {
const updatedNode = updateNode(node, template);
const connectedInputNames = getConnectedInputNames(node.id, edges);
const updatedNode = updateNode(node, template, { connectedInputNames });
store.dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
{ type: 'add', item: updatedNode },
])
);
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;
}
} catch {
unableToUpdateCount++;
}
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { updateNode } from 'features/nodes/util/node/nodeUpdate';
import { describe, expect, it } from 'vitest';

const imageCollectionOutput = {
collection: {
fieldKind: 'output',
name: 'collection',
title: 'Collection',
description: 'The output images',
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
ui_hidden: false,
},
} satisfies InvocationTemplate['outputs'];

const oldImageCollectionTemplate = {
title: 'Image Collection Primitive',
type: 'image_collection',
version: '1.0.1',
tags: ['primitives', 'image', 'collection'],
description: 'A collection of image primitive values',
outputType: 'image_collection_output',
inputs: {
collection: {
name: 'collection',
title: 'Collection',
required: false,
description: 'The collection of image values',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: undefined,
},
},
outputs: imageCollectionOutput,
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'primitives',
} satisfies InvocationTemplate;

const oldestImageCollectionTemplate = {
...oldImageCollectionTemplate,
version: '1.0.0',
} satisfies InvocationTemplate;

const currentImageCollectionTemplate = {
...oldImageCollectionTemplate,
version: '1.0.2',
inputs: {
collection: {
name: 'collection',
title: 'Collection',
required: false,
description: 'An optional image collection to append to',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: undefined,
},
images: {
name: 'images',
title: 'Images',
required: false,
description: 'The images to append to the collection',
fieldKind: 'input',
input: 'direct',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: undefined,
},
},
} satisfies InvocationTemplate;

describe('updateNode', () => {
it('moves old image_collection direct collection values to the new images field', () => {
const node = buildInvocationNode({ x: 0, y: 0 }, oldImageCollectionTemplate);
const images = [{ image_name: 'first' }, { image_name: 'second' }];
const collectionInput = node.data.inputs.collection;
if (!collectionInput) {
throw new Error('Expected collection input');
}
collectionInput.value = images;

const updated = updateNode(node, currentImageCollectionTemplate, { connectedInputNames: new Set() });

expect(updated.data.version).toBe('1.0.2');
expect(updated.data.inputs.images?.value).toEqual(images);
expect(updated.data.inputs.collection?.value).toEqual([]);
});

it('moves 1.0.0 image_collection direct collection values to the new images field', () => {
const node = buildInvocationNode({ x: 0, y: 0 }, oldestImageCollectionTemplate);
const images = [{ image_name: 'first' }];
const collectionInput = node.data.inputs.collection;
if (!collectionInput) {
throw new Error('Expected collection input');
}
collectionInput.value = images;

const updated = updateNode(node, currentImageCollectionTemplate, { connectedInputNames: new Set() });

expect(updated.data.version).toBe('1.0.2');
expect(updated.data.inputs.images?.value).toEqual(images);
expect(updated.data.inputs.collection?.value).toEqual([]);
});

it('preserves old image_collection direct collection values when collection is connected', () => {
const node = buildInvocationNode({ x: 0, y: 0 }, oldImageCollectionTemplate);
const images = [{ image_name: 'stale' }];
const collectionInput = node.data.inputs.collection;
if (!collectionInput) {
throw new Error('Expected collection input');
}
collectionInput.value = images;

const updated = updateNode(node, currentImageCollectionTemplate, {
connectedInputNames: new Set(['collection']),
});

expect(updated.data.inputs.images?.value).toBeUndefined();
expect(updated.data.inputs.collection?.value).toEqual(images);
});
});
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import { deepClone } from 'common/util/deepClone';
import { satisfies } from 'compare-versions';
import { compare, satisfies } from 'compare-versions';
import { defaultsDeep, keys, pick } from 'es-toolkit/compat';
import { NodeUpdateError } from 'features/nodes/types/error';
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
import { zParsedSemver } from 'features/nodes/types/semver';

import { buildInvocationNode } from './buildInvocationNode';

type ConnectedInputEdge = { type?: string; target: string; targetHandle?: string | null };

type UpdateNodeOptions = {
connectedInputNames: Set<string>;
};

export const getConnectedInputNames = (nodeId: string, edges: ConnectedInputEdge[]): Set<string> =>
new Set(
edges.flatMap((edge) =>
edge.type === 'default' && edge.target === nodeId && edge.targetHandle ? [edge.targetHandle] : []
)
);

export const getUpdatedFieldName = (node: InvocationNode, fieldName: string): string => {
if (node.data.type === 'image_collection' && fieldName === 'collection' && node.data.inputs.images) {
return 'images';
}
return fieldName;
};

export const getNeedsUpdate = (data: InvocationNodeData, template: InvocationTemplate): boolean => {
if (data.type !== template.type) {
return true;
Expand All @@ -29,6 +49,34 @@ const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): b
return satisfies(node.data.version, `^${templateMajor}`);
};

export const migrateImageCollectionInputValues = (
node: InvocationNode,
options: UpdateNodeOptions & { sourceVersion?: string }
) => {
if (node.data.type !== 'image_collection') {
return;
}
if (options.sourceVersion && compare(options.sourceVersion, '1.0.2', '>=')) {
return;
}

const collection = node.data.inputs.collection;
const images = node.data.inputs.images;
if (!collection || !images || !Array.isArray(collection.value)) {
return;
}
if (Array.isArray(images.value) && images.value.length > 0) {
return;
}

if (options.connectedInputNames.has('collection')) {
return;
}

images.value = collection.value;
collection.value = [];
};

/**
* Updates a node to the latest version of its template:
* - Create a new node data object with the latest version of the template.
Expand All @@ -40,7 +88,11 @@ const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): b
* @param template The invocation template to update to.
* @throws {NodeUpdateError} If the node is not an invocation node.
*/
export const updateNode = (node: InvocationNode, template: InvocationTemplate): InvocationNode => {
export const updateNode = (
node: InvocationNode,
template: InvocationTemplate,
options: UpdateNodeOptions
): InvocationNode => {
const mayUpdate = getMayUpdateNode(node, template);

if (!mayUpdate || node.data.type !== template.type) {
Expand All @@ -54,8 +106,10 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
// being valid. We rely on the template's major version to be majorly incremented if this kind of
// merge would result in an invalid node.
const clone = deepClone(node);
const sourceVersion = clone.data.version;
clone.data.version = template.version;
defaultsDeep(clone, defaults); // mutates!
migrateImageCollectionInputValues(clone, { ...options, sourceVersion });

// Remove any fields that are not in the template
clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs));
Expand Down
Loading
Loading