Skip to content

Instantly share code, notes, and snippets.

@mahmoudimus
Last active April 7, 2024 04:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mahmoudimus/909490152f1dc192faa73aa9db8d8981 to your computer and use it in GitHub Desktop.
Save mahmoudimus/909490152f1dc192faa73aa9db8d8981 to your computer and use it in GitHub Desktop.
A stupidly simple approach for text non-text separation for printed document images.
import argparse
import os
import torch
import pandas as pd
from PIL import Image
def LTPN_rotation_cal(mat, th):
"""
Calculates the Local Ternary Pattern (LTP) with rotation invariance for a given matrix.
Args:
mat (torch.Tensor): The input matrix.
th (float): The threshold value for LTP calculation.
Returns:
torch.Tensor: The LTP value.
"""
# Get pixel order for LTP calculation
pixel_order = torch.cat((mat[0, :3], mat[1, 2], mat[2, 1::-1], mat[1, 0]))
# Generate binary string based on threshold comparison
bin_val_l = "".join(
["1" if pixel + th < mat[1, 1] else "0" for pixel in pixel_order]
)
# Convert binary string to decimal
x = int(bin_val_l, 2)
# Find minimum LTP value across all rotations
str_x = format(x, "08b")
for i in range(8):
x = min(x, int(str_x[i:] + str_x[:i], 2))
return torch.tensor(x)
def LTPP_rotation_cal(mat, th):
"""
Calculates the Upper Local Ternary Pattern (LTPP) with rotation invariance.
Args:
mat (torch.Tensor): The input matrix.
th (float): The threshold value for LTPP calculation.
Returns:
torch.Tensor: The LTPP value.
"""
# Get pixel order for LTPP calculation (same as LTPN)
pixel_order = torch.cat((mat[0, :3], mat[1, 2], mat[2, 1::-1], mat[1, 0]))
# Generate binary string based on threshold comparison
bin_val_u = "".join(
["1" if pixel >= mat[1, 1] + th else "0" for pixel in pixel_order]
)
# Convert binary string to decimal
x = int(bin_val_u, 2)
# Find minimum LTPP value across all rotations
str_x = format(x, "08b") # Ensure 8-bit representation
for i in range(8):
x = min(x, int(str_x[i:] + str_x[:i], 2))
return torch.tensor(x)
def write_features_to_csv(features, filename="data.csv"):
"""
Writes PyTorch tensor features to a CSV file.
Args:
features (torch.Tensor): The feature tensor to be written.
filename (str, optional): The name of the output CSV file. Defaults to "data.csv".
"""
# Convert PyTorch tensor to Pandas DataFrame
df = pd.DataFrame(features.numpy())
# Write DataFrame to CSV file
df.to_csv(filename, index=False)
print("Feature set successfully generated")
def filter_features(feature_tensor, feature_selector):
"""
Filters features from a PyTorch tensor based on a binary selector array.
Args:
feature_tensor (torch.Tensor): The feature tensor to be filtered.
feature_selector (torch.Tensor): A binary array indicating which features to keep (1) or discard (0).
Returns:
torch.Tensor: The filtered feature tensor.
"""
if feature_tensor.shape[1] != len(feature_selector):
raise ValueError("Feature dimensions do not match. Please check input tensors.")
keep_indices = torch.nonzero(feature_selector).squeeze(1)
filtered_features = feature_tensor[:, keep_indices]
return filtered_features
def hist_LTP(img, th):
"""
Calculates the LTP histogram of an image.
Args:
img (torch.Tensor): The input image.
th (float): The threshold value for LTP calculation.
Returns:
torch.Tensor: The LTP histogram (72 elements).
"""
h, l = img.shape # Get image dimensions
trans_img = torch.zeros((h + 2, l + 2)) # Padded image
trans_img[1:-1, 1:-1] = img
pos_val = torch.zeros(h * l)
neg_val = torch.zeros(h * l)
k = 0
pos_hist = torch.zeros(36)
neg_hist = torch.zeros(36)
histogram = torch.zeros(72)
elements = torch.arange(256) # Assuming possible LTP values are 0-255
index = torch.arange(256)
for i in range(1, h + 1):
for j in range(1, l + 1):
pos_code = LTPP_rotation_cal(trans_img[i - 1 : i + 2, j - 1 : j + 2], th)
neg_code = LTPN_rotation_cal(trans_img[i - 1 : i + 2, j - 1 : j + 2], th)
pos_val[k] = pos_code
neg_val[k] = neg_code
k += 1
pos_hist[pos_code] += 1 # Directly use LTP code as index
neg_hist[neg_code] += 1
histogram[:36] = pos_hist
histogram[36:] = neg_hist
return histogram
def hist_LTP(img, th):
"""
Calculates the LTP histogram of an image.
Args:
img (torch.Tensor): The input image.
th (float): The threshold value for LTP calculation.
Returns:
torch.Tensor: The LTP histogram (72 elements).
"""
h, l = img.shape
trans_img = torch.zeros((h + 2, l + 2))
trans_img[1:-1, 1:-1] = img
pos_val = torch.zeros(h * l)
neg_val = torch.zeros(h * l)
k = 0
pos_hist = torch.zeros(36)
neg_hist = torch.zeros(36)
histogram = torch.zeros(72)
elements = torch.arange(256) # Assuming LTP values 0-255
index = torch.arange(256)
# Create mapping from LTP value to histogram index (assuming unique values)
for i in range(256):
if i in elements:
index[i] = elements.tolist().index(i) + 1 # +1 for 1-based indexing
for i in range(1, h + 1):
for j in range(1, l + 1):
pos_code = LTPP_rotation_cal(trans_img[i - 1 : i + 2, j - 1 : j + 2], th)
neg_code = LTPN_rotation_cal(trans_img[i - 1 : i + 2, j - 1 : j + 2], th)
pos_val[k] = pos_code
neg_val[k] = neg_code
k += 1
pos_hist[index[pos_code]] += 1
neg_hist[index[neg_code]] += 1
histogram[:36] = pos_hist
histogram[36:] = neg_hist
return histogram
def rotate(x):
"""
Finds the rotation-invariant minimum LTP value.
Args:
x (int): The LTP value as an integer.
Returns:
int: The rotation-invariant minimum LTP value.
"""
bin_val = format(x, "08b") # Convert to 8-bit binary string
min_val = x
for i in range(1, 8):
rotated = int(bin_val[i:] + bin_val[:i], 2) # Circular shift
min_val = min(min_val, rotated)
return min_val
def unique_vals():
"""
Generates a list of unique rotation-invariant LTP values.
Returns:
torch.Tensor: A 1D tensor containing the unique LTP values.
"""
val = torch.zeros(256, dtype=torch.int)
for i in range(256):
val[i] = rotate(i)
unique_val = torch.unique(val)
return unique_val
def main():
"""
Calculates LTP features for images and saves them to a CSV file.
"""
parser = argparse.ArgumentParser(description="LTP Feature Extraction")
parser.add_argument(
"data_dir", help="Path to the directory containing image classes."
)
parser.add_argument("output_csv", help="Path to the output CSV file.")
parser.add_argument(
"--threshold", type=float, default=5.0, help="LTP calculation threshold."
)
args = parser.parse_args()
# Create a list to store feature vectors and labels
data = []
# Iterate through classes (folders)
for class_id in range(1, 3): # Assuming two classes: 1 and 2
class_dir = os.path.join(args.data_dir, str(class_id))
for filename in os.listdir(class_dir):
if filename.endswith(".bmp"):
img_path = os.path.join(class_dir, filename)
img = Image.open(img_path).convert("L")
img = img.resize((128, 128))
features = hist_LTP(torch.tensor(img), args.threshold)
data.append(torch.cat((features, torch.tensor([class_id]))))
# Create DataFrame and save to CSV
df = pd.DataFrame(torch.stack(data).numpy())
df.to_csv(
args.output_csv,
index=False,
header=["attr_{}".format(i + 1) for i in range(72)] + ["class"],
)
print("Feature set successfully generated and saved to", args.output_csv)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment