prompt_1 = "A Balrog in Moria high defination image"
prompt_2 = "A still life DSLR photo of Sauron in Mount Doom"
prompt_3 = "high quality painting of Gondor with fire raining from sky"
prompt_4 = "Elves from Rivendell with Hobbits high quality image"

interpolation_steps = 6
batch_size = 3
batches = (interpolation_steps**2) // batch_size

encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))
encoding_3 = tf.squeeze(model.encode_text(prompt_3))
encoding_4 = tf.squeeze(model.encode_text(prompt_4))

interpolated_encodings = tf.linspace(
    tf.linspace(encoding_1, encoding_2, interpolation_steps),
    tf.linspace(encoding_3, encoding_4, interpolation_steps),
    interpolation_steps,
)
interpolated_encodings = tf.reshape(
    interpolated_encodings, (interpolation_steps**2, 77, 768)
)
batched_encodings = tf.split(interpolated_encodings, batches)

outputs = []
for batch in range(batches):
    images.append(
        model.generate_image(
            batched_encodings[batch],
            batch_size=batch_size,
        )
    )

images = np.concatenate(outputs)
plot_grid(images, "lotr.jpg", interpolation_steps)