This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from botorch.utils.sampling import manual_seed | |
import warnings | |
from ax.utils.common.typeutils import not_none | |
from ax.modelbridge.modelbridge_utils import get_pending_observation_features | |
from ax.utils.common.logger import _round_floats_for_logging, get_logger | |
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from ax.core.generator_run import GeneratorRun | |
from ax.core.types import TParameterization | |
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE | |
# 1. 5 dropouts inside the network | |
# 2. The `raw_data` feeding into the `ax.complete_trial` here are 10 accs + 1 average acc + 1 loss, I actually use 4 accs, 4 F1s, 4 unweithged F1s, 3 averages (of acc, F1 and unweighted F1) and 1 loss instead. | |
import torch | |
import numpy as np | |
from ax.service.ax_client import AxClient | |
from ax.plot.contour import plot_contour | |
from ax.plot.trace import optimization_trace_single_method |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(serialized_example, sess): | |
''' | |
Given a serialized example in which the frames are stored as | |
compressed JPG images 'frames/0001', 'frames/0002' etc., this | |
function samples SEQ_NUM_FRAMES from the frame list, decodes them from | |
JPG into a tensor and packs them to obtain a tensor of shape (N,H,W,3). | |
Returns the the tuple (frames, class_label (tf.int64) | |
:param serialized_example: serialized example from tf.data.TFRecordDataset | |
:return: tuple: (frames (tf.uint8), class_label (tf.int64) |