Skip to content

Instantly share code, notes, and snippets.

@rsomani95
Last active May 27, 2021 07:32
Show Gist options
  • Save rsomani95/44d1db1741c947aa0858a748c9b599f0 to your computer and use it in GitHub Desktop.
Save rsomani95/44d1db1741c947aa0858a748c9b599f0 to your computer and use it in GitHub Desktop.
Applying LUTs with PIL
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
img_path = "image.jpg"
lut_path = "lut.cube"
img = Image.open(img_path) # .convert("RGB")
lut = read_lut(lut_path)
# This returns a PIL Image with the LUT applied
img.filter(lut)
import os
from PIL import Image, ImageFilter
from typing import List, Union
from pathlib import Path
import random
# lut = LUT(lut_path)
# img = PIL.Image.open(img_path)
# lut(img) | lut.apply(img)
class LUT:
"""
Load in a LUT from, say, a `.cube` file. The header of files
that contain metadata are ignored. After initialising with
the path like so:
lut = LUT("lut_file.cube")
You can call the object or use `.apply` to pass in a PIL Image
to apply the lut to.
lut.apply(img <PIL.Image>)
lut(img) # also valid
"""
def __init__(self, path: os.PathLike, num_channels: int = 3):
self.path = path
self.num_channels = num_channels
self.lut = self.read(self.path)
def apply(self, img: Image.Image):
return self(img)
def __call__(self, img: Image.Image):
return img.filter(self.lut)
def __repr__(self):
return "\n".join([f"{self.lut}", f"Path: {self.path}"])
def read(self, path_lut: Union[str, Path]): # , num_channels: int = 3):
"Read LUT from raw file. Filters out rows that aren't recognised as valid values"
with open(path_lut) as f:
lut_raw = f.read().splitlines()
# some luts we need to read from line 11
# some from line 7
lut_raw = lut_raw[6:] # 11
size = round(len(lut_raw) ** (1 / 3))
row2val = lambda row: tuple([float(val) for val in row.split(" ")])
lut_table = [row2val(row) for row in lut_raw if self._is_3dlut_row(row)]
return ImageFilter.Color3DLUT(size, lut_table, self.num_channels)
@staticmethod
def _is_3dlut_row(row: str) -> bool:
"Check if one line in the file has exactly 3 values"
row_values = []
for val in row.split(" "):
try:
row_values.append(float(val))
except:
return False
if len(row_values) == 3:
return True
return False
from PIL import ImageFilter
from typing import List
import os
def is_3dlut_row(row: List) -> bool:
"Check if one line in the file has exactly 3 values"
row_values = []
for val in row:
try:
row_values.append(float(val))
except:
return False
if len(row_values) == 3:
return True
return False
def read_lut(path_lut: Union[str, os.PathLike], num_channels: int = 3):
"Read LUT from raw file. Assumes each line in a file is part of the lut table"
with open(path_lut) as f:
lut_raw = f.read().splitlines()
size = round(len(lut_raw) ** (1 / 3))
row2val = lambda row: tuple([float(val) for val in row])
lut_table = [
row2val(row.split(" ")) for row in lut_raw if is_3dlut_row(row.split(" "))
]
return ImageFilter.Color3DLUT(size, lut_table, num_channels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment