This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = create_model() | |
print("\n[INFO] Ready to train. Training is starting!\n") | |
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) # limitation to stop somewhere early | |
hist = model.fit(train_generator, validation_data=val_generator, epochs=30, callbacks=[callback]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create a model | |
def create_model(): | |
METRICS = [ | |
'accuracy', | |
tf.keras.metrics.Precision(name='precision'), | |
tf.keras.metrics.Recall(name='recall'), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create training and validation data generators | |
IMG_SIZE = (224,224) | |
BATCH_SIZE = 8 | |
def train_val_generators(train_dir, val_dir): | |
train_gen = ImageDataGenerator(rescale=1/255., | |
horizontal_flip=True, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# paths to training and validation data | |
train_dir = os.path.join(path/to/save/bach_data, "/bach-train/training") | |
val_dir = os.path.join(path/to/save/bach_data, "/bach-train/validation") | |
print(train_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import class weights function | |
from sklearn.utils.class_weight import compute_class_weight | |
# Calculate class weights | |
class_weights = compute_class_weight('balanced', np.unique(y_train), y_train) | |
# Create a dictionary to map class indices to their respective weights | |
class_weight_dict = dict(zip(np.unique(y_train), class_weights)) | |
# Define and train a classifier with class weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
import os | |
from tensorflow.keras.preprocessing import image | |
import matplotlib.pyplot as plt | |
# Set the path to your dataset directory | |
dataset_directory = r"Path to your folder" | |
# Create an ImageDataGenerator with augmentation parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM picsellia/cuda:11.7.1-cudnn8-ubuntu20.04 | |
COPY ./requirements.txt . | |
ARG REBUILD_ALL | |
RUN python3.10 -m pip install -r requirements.txt --no-cache-dir | |
ARG REBUILD_PICSELLIA | |
RUN python3.10 -m pip install picsellia --upgrade | |
WORKDIR /picsellia | |
COPY . ./ | |
ENTRYPOINT ["run", "train.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get the model that will receive this new version | |
my_model = client.get_model( | |
name=model_name, | |
) | |
experiment.export_in_existing_model(my_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_name = "detr-resnet-50_" | |
model_description = "Finetuned DETR model" | |
if 'api_token' not in os.environ: | |
raise Exception("You must set an api_token to run this image") | |
api_token = os.environ["api_token"] | |
if "host" not in os.environ: | |
host = "https://trial.picsellia.com" | |
else: | |
host = os.environ["host"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
experiment.compute_evaluations_metrics(inference_type=InferenceType.OBJECT_DETECTION) |
NewerOlder