Created
June 8, 2021 10:57
-
-
Save JuhaKiili/c5c2603c672379809b18ad7666f516bd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from yolov3_tf2.models import YoloV3 | |
from yolov3_tf2.utils import load_darknet_weights | |
import tensorflow as tf | |
import valohai | |
params = { | |
"weights_num_classes": 80, | |
} | |
inputs = { | |
"weights": "https://pjreddie.com/media/files/yolov3.weights", | |
} | |
valohai.prepare(step="weights", default_parameters=params, default_inputs=inputs) | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
yolo = YoloV3(classes=valohai.parameters('weights_num_classes').value) | |
load_darknet_weights(yolo, valohai.inputs('weights').path(), False) | |
# Sanity check with random image | |
img = np.random.random((1, 320, 320, 3)).astype(np.float32) | |
output = yolo(img) | |
path = valohai.outputs('model').path('model.tf') | |
yolo.save_weights(path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment