Skip to content

Instantly share code, notes, and snippets.

def resume_train(self, model):
if self.args.resume:
logger.info("resume training")
if self.args.ckpt is None:
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth")
else:
ckpt_file = self.args.ckpt
ckpt = torch.load(ckpt_file, map_location=self.device)
# resume the model/optimizer state dict
def _cache_images(self):
logger.warning("\n********************************************************************************\n"
"You are using cached images in RAM to accelerate training.\n"
"This requires large system RAM.\n"
"Make sure you have 200G+ RAM and 136G available disk space for training COCO.\n"
"********************************************************************************\n")
max_h = self.img_size[0]
max_w = self.img_size[1]
cache_file = self.data_dir + "/img_resized_cache_" + self.name + ".array"
if not os.path.exists(cache_file):
"""
https://github.com/z-bingo/FastDVDNet/tree/master/arch
Reimplementation of 4 channel FastDVDNet in PyTorch
"""
import torch
import torch.nn as nn
import numpy as np
from thop import profile
from backtesting import Backtest, Strategy
from backtesting.lib import crossover
from FinMind.data import DataLoader
import pandas as pd
import talib
from talib import abstract
## 取得資料
import pandas as pd
from twstock import Stock
import argparse
def parse():
parser = argparse.ArgumentParser()
parser.add_argument(
"--etf_code", type=str, default="00733",
)
"""
RetinaNet model with the MobileNetV3 backbone from
Torchvision classification models.
Reference: https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py#L377-L405
"""
import torchvision
import torch
from torchvision.models.detection import RetinaNet
import numpy as np
import cv2
import matplotlib.image
def gamma_compression(image):
"""Converts from linear to gamma space."""
return np.maximum(image, 1e-8) ** (1.0 / 2.2)
def tonemap(image):
"""Simple S-curved global tonemap"""
"""
有时我们需要载入之前训练好的模型来训练当前网络,但之前模型和现在网络结构又存在一些不同(相同就更简单了,直接载入就行)
可使用如下代码来迁移模型参数:
https://zhuanlan.zhihu.com/p/393586665
"""
def transfer_model(pretrain_file, model):
pretrain_dict = torch.load(pretrain_file)
model_dict = model.state_dict()
pretrain_dict = transfer_state_dict(pretrain_dict, model_dict)
@e96031413
e96031413 / TensorFlowISP.py
Last active June 13, 2023 07:08
Convert rawRGB to sRGB with TensorFlow
# https://github.com/mv-lab/AISP/blob/main/mai22-learnedisp/mai-learned-isp-dev.ipynb
import numpy as np
import pandas as pd
import gc
import time
from glob import glob
from tqdm import tqdm
from matplotlib import pyplot as plt
import pickle
import imageio
import tensorflow as tf
from tensorboard.backend.event_processing import event_accumulator
# Specify the path to the existing event file
original_event_file = "old_events.out.tfevents.1686288913.51ec3d9beacb.141764.0"
# Specify the path for the new event file
new_event_file = "new_events.out.tfevents.1686288913.51ec3d9beacb.141764.0"
# Load the existing event file using EventAccumulator
event_acc = event_accumulator.EventAccumulator(original_event_file)