-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
63 lines (46 loc) · 2.06 KB
/
__init__.py
File metadata and controls
63 lines (46 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
@title: EternalKernel PyTorch Nodes
@nickname: EternalKernel
@description: Comprehensive PyTorch nodes for ComfyUI - Neural network training, inference, and ML workflows with 35+ specialized nodes
"""
# Original content follows
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
WEB_DIRECTORY = "./web"
import sys
from types import ModuleType
def patch_validate_inputs():
# Get the execution module
execution_module = sys.modules.get('execution')
if not execution_module:
print("Warning: 'execution' module not found. Patch not applied.")
return
# Store the original validate_inputs function
original_validate_inputs = execution_module.validate_inputs
def wrapper_validate_inputs(prompt, item, validated):
original_result = original_validate_inputs(prompt, item, validated)
if not original_result[0]: # If validation failed
errors = original_result[1]
filtered_errors = [
error for error in errors
if error['type'] != 'return_type_mismatch'
]
# If the only errors were type mismatches, consider it valid
if not filtered_errors:
return (True, [], original_result[2])
else:
return (False, filtered_errors, original_result[2])
return original_result
# Replace the original function with our wrapper
execution_module.validate_inputs = wrapper_validate_inputs
print("validate_inputs function has been patched by ETK extension.")
# Apply the patch when this module is imported
patch_validate_inputs()
# Import PyTorch nodes to populate node mappings
from . import pytorch_nodes
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
# Expose package under custom_nodes namespace for tests
if 'custom_nodes' not in sys.modules:
sys.modules['custom_nodes'] = ModuleType('custom_nodes')
sys.modules['custom_nodes'].EternalKernelLiteGraphNodes = sys.modules[__name__]
sys.modules['custom_nodes.EternalKernelLiteGraphNodes'] = sys.modules[__name__]