Created
June 24, 2024 06:24
-
-
Save alexisperrier/772eb91463f13758503774b042843b0f to your computer and use it in GitHub Desktop.
Embeddings d'images avec Keras et EfficientNetB0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Calculer les embeddings des images à partir de EfficientNetB0 | |
| avec Tensorflow et Keras | |
| Code généré par Claude.ai 2024.06.23 | |
| ''' | |
| import tensorflow as tf | |
| from tensorflow.keras.applications import EfficientNetB0 | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.applications.efficientnet import preprocess_input | |
| import numpy as np | |
| import os | |
| import h5py | |
| # Load pre-trained EfficientNetB0 without the top layer | |
| base_model = EfficientNetB0(weights='imagenet', include_top=False, pooling='avg') | |
| def image_generator(image_paths, batch_size=32): | |
| # function that yields batches of preprocessed images. | |
| num_images = len(image_paths) | |
| while True: | |
| for i in range(0, num_images, batch_size): | |
| batch_paths = image_paths[i:i+batch_size] | |
| batch_images = [] | |
| for path in batch_paths: | |
| img = image.load_img(path, target_size=(224, 224)) | |
| x = image.img_to_array(img) | |
| x = preprocess_input(x) | |
| batch_images.append(x) | |
| # yield np.array(batch_images) | |
| yield (np.array(batch_images),) | |
| def get_embeddings(image_paths, batch_size=32): | |
| gen = image_generator(image_paths, batch_size) | |
| num_images = len(image_paths) | |
| steps = (num_images + batch_size - 1) // batch_size # Round up division | |
| embeddings = base_model.predict(gen, steps=steps, verbose=1) | |
| return embeddings | |
| # Example usage | |
| batch_size = 32 | |
| image_dir = "./images/" | |
| image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| embeddings = get_embeddings(image_paths, batch_size) | |
| print(f"Embedding shape: {embeddings.shape}") | |
| # Save | |
| embeddings_file = './embeddings.h5' | |
| with h5py.File(embeddings_file, 'w') as f: | |
| f.create_dataset('embeddings', data=embeddings) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Le code utilise le modele EfficientNetB0 qui allie rapidité et précision.