Skip to content

Instantly share code, notes, and snippets.

@andrewharp
Last active October 24, 2024 19:35
Show Gist options
  • Save andrewharp/95763565b8797779773fd2d8e352c81f to your computer and use it in GitHub Desktop.
Save andrewharp/95763565b8797779773fd2d8e352c81f to your computer and use it in GitHub Desktop.
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
title = f" Executing prompt {Style.BRIGHT}{prompt_id:59}{Style.NORMAL}"
logging.info(Fore.GREEN + "=" * 80)
logging.info(Fore.GREEN + f"={title}=")
logging.info(Fore.GREEN + "=" * 80)
nodes.interrupt_processing(False)
logging.info(f"Extra data {extra_data.keys()}")
if "nodes_requested" in extra_data:
requested_nodes = extra_data["nodes_requested"]
logging.info(f"Nodes requested: {requested_nodes}")
assert isinstance(requested_nodes, list)
execute_outputs = [str(node) for node in requested_nodes if str(node) in prompt]
for node in execute_outputs:
logging.error(f"Node {prompt[node]}")
logging.info(f"Number of outputs to execute: {len(execute_outputs)}")
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
self.server.client_id = None
self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=True)
with torch.inference_mode():
# Delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete.append(o)
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
logging.error(f"Deleting {o} because its class type is different. {o[1]} != {p['class_type']}")
to_delete.append(o)
for o in to_delete:
del self.object_storage[o]
output_list = list(self.outputs.keys())
nodes_changed = []
for x in output_list:
assert x in self.output_hashes, f"Output {x} not in output_hashes"
if x in prompt:
is_changed = compute_node_hash(x, prompt, self.outputs)
old_is_changed = self.output_hashes[x]
if is_changed != old_is_changed:
nodes_changed.append(f" {x}: {old_is_changed} -> {is_changed}")
del self.outputs[x]
del self.output_hashes[x]
if x in self.outputs_ui:
del self.outputs_ui[x]
if nodes_changed:
logging.info("Nodes changed:")
for node_info in nodes_changed:
logging.info(node_info)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
del self.outputs_ui[x]
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.all_nodes_to_execute.clear()
from easy_nodes import config_service
early_abort = config_service.get_config_value("curie.early_abort", False)
logging.info(f"Early abort enabled: {early_abort}")
for output_node_id in execute_outputs:
logging.debug(f"Computing requirements for output node: {output_node_id}")
recursive_will_execute(prompt, self.outputs, self.output_hashes,
output_node_id, self.all_nodes_to_execute,
early_abort=early_abort)
# logging.info(f"{len(self.all_nodes_to_execute)} total nodes to execute: {self.all_nodes_to_execute}")
ram_usage.clear()
def get_node_dependencies(node_id):
dependencies = set()
for input_name, input_data in prompt[node_id]['inputs'].items():
if isinstance(input_data, list) and len(input_data) == 2:
input_node_id = input_data[0]
if input_node_id in self.all_nodes_to_execute:
dependencies.add(input_node_id)
return dependencies
def get_node_position(node_id):
workflow = extra_data.get('extra_pnginfo', {}).get('workflow', {})
nodes = workflow.get('nodes', [])
for node in nodes:
if str(node['id']) == node_id:
potential_pos = node.get('pos', [float('inf'), float('inf')])
# rgthree's ImageComparer node stores its position as a dict.
if isinstance(potential_pos, dict):
potential_pos = [potential_pos["0"], potential_pos["1"]]
return potential_pos
return [float('inf'), float('inf')]
def get_group_for_node(node_id):
workflow = extra_data.get('extra_pnginfo', {}).get('workflow', {})
groups = workflow.get('groups', [])
node_pos = get_node_position(node_id)
for idx, group in enumerate(groups):
bounding = group['bounding']
if (bounding[0] <= node_pos[0] <= bounding[0] + bounding[2] and
bounding[1] <= node_pos[1] <= bounding[1] + bounding[3]):
return idx, group['bounding'][:2], group.get('title', f"Group {idx}") # Return group index, top-left corner, and title
return None, node_pos, "Ungrouped"
executed = set()
while self.all_nodes_to_execute:
dependency_lists = {node_id: get_node_dependencies(node_id) for node_id in self.all_nodes_to_execute}
executable_nodes = [node for node, deps in dependency_lists.items() if not deps]
if not executable_nodes:
logging.error("Circular dependency detected. Cannot continue execution.")
break
# Sort executable nodes based on their group and position
executable_nodes.sort(key=lambda node: (
get_group_for_node(node)[1],
get_group_for_node(node)[0] if get_group_for_node(node)[0] is not None else float('inf'),
get_node_position(node)
))
self.node_to_execute = executable_nodes[0]
group_info = get_group_for_node(self.node_to_execute)
group_str = group_info[2] if group_info[0] is not None else "ungrouped"
node_type = prompt[self.node_to_execute]["class_type"]
self.add_message("executing", { "node": self.node_to_execute,
"prompt_id": prompt_id,
"node_type": node_type,
"node_group": group_str }, None)
self.success, error, ex = execute_node(self.server,
prompt,
self.outputs,
self.output_hashes,
self.node_to_execute,
extra_data,
executed,
prompt_id,
self.outputs_ui,
self.object_storage,
group_str)
self.add_message("executed", { "node": self.node_to_execute,
"output": self.outputs_ui.get(self.node_to_execute, None),
"prompt_id": prompt_id,
"node_type": node_type,
"node_group": group_str}, None)
executed.add(self.node_to_execute)
self.all_nodes_to_execute.remove(self.node_to_execute)
reclaim_memory(self.outputs, prompt, self.all_nodes_to_execute)
if self.success is not True:
self.node_to_execute = None
self.all_nodes_to_execute.clear()
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
return
self.node_to_execute = None
# Execution completed successfully
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=True)
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment