-
-
Save andrewharp/95763565b8797779773fd2d8e352c81f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): | |
title = f" Executing prompt {Style.BRIGHT}{prompt_id:59}{Style.NORMAL}" | |
logging.info(Fore.GREEN + "=" * 80) | |
logging.info(Fore.GREEN + f"={title}=") | |
logging.info(Fore.GREEN + "=" * 80) | |
nodes.interrupt_processing(False) | |
logging.info(f"Extra data {extra_data.keys()}") | |
if "nodes_requested" in extra_data: | |
requested_nodes = extra_data["nodes_requested"] | |
logging.info(f"Nodes requested: {requested_nodes}") | |
assert isinstance(requested_nodes, list) | |
execute_outputs = [str(node) for node in requested_nodes if str(node) in prompt] | |
for node in execute_outputs: | |
logging.error(f"Node {prompt[node]}") | |
logging.info(f"Number of outputs to execute: {len(execute_outputs)}") | |
if "client_id" in extra_data: | |
self.server.client_id = extra_data["client_id"] | |
else: | |
self.server.client_id = None | |
self.status_messages = [] | |
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=True) | |
with torch.inference_mode(): | |
# Delete cached outputs if nodes don't exist for them | |
to_delete = [] | |
for o in self.object_storage: | |
if o[0] not in prompt: | |
to_delete.append(o) | |
else: | |
p = prompt[o[0]] | |
if o[1] != p['class_type']: | |
logging.error(f"Deleting {o} because its class type is different. {o[1]} != {p['class_type']}") | |
to_delete.append(o) | |
for o in to_delete: | |
del self.object_storage[o] | |
output_list = list(self.outputs.keys()) | |
nodes_changed = [] | |
for x in output_list: | |
assert x in self.output_hashes, f"Output {x} not in output_hashes" | |
if x in prompt: | |
is_changed = compute_node_hash(x, prompt, self.outputs) | |
old_is_changed = self.output_hashes[x] | |
if is_changed != old_is_changed: | |
nodes_changed.append(f" {x}: {old_is_changed} -> {is_changed}") | |
del self.outputs[x] | |
del self.output_hashes[x] | |
if x in self.outputs_ui: | |
del self.outputs_ui[x] | |
if nodes_changed: | |
logging.info("Nodes changed:") | |
for node_info in nodes_changed: | |
logging.info(node_info) | |
current_outputs = set(self.outputs.keys()) | |
for x in list(self.outputs_ui.keys()): | |
if x not in current_outputs: | |
del self.outputs_ui[x] | |
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) | |
self.all_nodes_to_execute.clear() | |
from easy_nodes import config_service | |
early_abort = config_service.get_config_value("curie.early_abort", False) | |
logging.info(f"Early abort enabled: {early_abort}") | |
for output_node_id in execute_outputs: | |
logging.debug(f"Computing requirements for output node: {output_node_id}") | |
recursive_will_execute(prompt, self.outputs, self.output_hashes, | |
output_node_id, self.all_nodes_to_execute, | |
early_abort=early_abort) | |
# logging.info(f"{len(self.all_nodes_to_execute)} total nodes to execute: {self.all_nodes_to_execute}") | |
ram_usage.clear() | |
def get_node_dependencies(node_id): | |
dependencies = set() | |
for input_name, input_data in prompt[node_id]['inputs'].items(): | |
if isinstance(input_data, list) and len(input_data) == 2: | |
input_node_id = input_data[0] | |
if input_node_id in self.all_nodes_to_execute: | |
dependencies.add(input_node_id) | |
return dependencies | |
def get_node_position(node_id): | |
workflow = extra_data.get('extra_pnginfo', {}).get('workflow', {}) | |
nodes = workflow.get('nodes', []) | |
for node in nodes: | |
if str(node['id']) == node_id: | |
potential_pos = node.get('pos', [float('inf'), float('inf')]) | |
# rgthree's ImageComparer node stores its position as a dict. | |
if isinstance(potential_pos, dict): | |
potential_pos = [potential_pos["0"], potential_pos["1"]] | |
return potential_pos | |
return [float('inf'), float('inf')] | |
def get_group_for_node(node_id): | |
workflow = extra_data.get('extra_pnginfo', {}).get('workflow', {}) | |
groups = workflow.get('groups', []) | |
node_pos = get_node_position(node_id) | |
for idx, group in enumerate(groups): | |
bounding = group['bounding'] | |
if (bounding[0] <= node_pos[0] <= bounding[0] + bounding[2] and | |
bounding[1] <= node_pos[1] <= bounding[1] + bounding[3]): | |
return idx, group['bounding'][:2], group.get('title', f"Group {idx}") # Return group index, top-left corner, and title | |
return None, node_pos, "Ungrouped" | |
executed = set() | |
while self.all_nodes_to_execute: | |
dependency_lists = {node_id: get_node_dependencies(node_id) for node_id in self.all_nodes_to_execute} | |
executable_nodes = [node for node, deps in dependency_lists.items() if not deps] | |
if not executable_nodes: | |
logging.error("Circular dependency detected. Cannot continue execution.") | |
break | |
# Sort executable nodes based on their group and position | |
executable_nodes.sort(key=lambda node: ( | |
get_group_for_node(node)[1], | |
get_group_for_node(node)[0] if get_group_for_node(node)[0] is not None else float('inf'), | |
get_node_position(node) | |
)) | |
self.node_to_execute = executable_nodes[0] | |
group_info = get_group_for_node(self.node_to_execute) | |
group_str = group_info[2] if group_info[0] is not None else "ungrouped" | |
node_type = prompt[self.node_to_execute]["class_type"] | |
self.add_message("executing", { "node": self.node_to_execute, | |
"prompt_id": prompt_id, | |
"node_type": node_type, | |
"node_group": group_str }, None) | |
self.success, error, ex = execute_node(self.server, | |
prompt, | |
self.outputs, | |
self.output_hashes, | |
self.node_to_execute, | |
extra_data, | |
executed, | |
prompt_id, | |
self.outputs_ui, | |
self.object_storage, | |
group_str) | |
self.add_message("executed", { "node": self.node_to_execute, | |
"output": self.outputs_ui.get(self.node_to_execute, None), | |
"prompt_id": prompt_id, | |
"node_type": node_type, | |
"node_group": group_str}, None) | |
executed.add(self.node_to_execute) | |
self.all_nodes_to_execute.remove(self.node_to_execute) | |
reclaim_memory(self.outputs, prompt, self.all_nodes_to_execute) | |
if self.success is not True: | |
self.node_to_execute = None | |
self.all_nodes_to_execute.clear() | |
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) | |
return | |
self.node_to_execute = None | |
# Execution completed successfully | |
# Only execute when the while-loop ends without break | |
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=True) | |
for x in executed: | |
self.old_prompt[x] = copy.deepcopy(prompt[x]) | |
self.server.last_node_id = None | |
if comfy.model_management.DISABLE_SMART_MEMORY: | |
comfy.model_management.unload_all_models() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment