Skip to content

Instantly share code, notes, and snippets.

@Geson-anko
Created December 16, 2023 08:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Geson-anko/eb67f7285c78b3e5f4c2d2268df50d80 to your computer and use it in GitHub Desktop.
Save Geson-anko/eb67f7285c78b3e5f4c2d2268df50d80 to your computer and use it in GitHub Desktop.
学習と推論を非同期に行う処理のモックアップです。
import copy
import threading
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class Inference:
def __init__(self, device: torch.device) -> None:
self.device = device
self.lock = threading.RLock()
self._model = None
@property
def model(self) -> nn.Module:
if self._model is not None:
return self._model
else:
raise RuntimeError("Not attached model.")
@model.setter
def model(self, m: nn.Module) -> None:
with self.lock:
self._model = m
self._model.eval()
self._model.to(self.device)
def attach_model_from_neural_nets(self, neural_nets: dict[str, nn.Module]) -> None:
self.model = copy.deepcopy(neural_nets["encoder"])
@torch.inference_mode()
def infer(self, x: torch.Tensor) -> torch.Tensor:
with self.lock:
return self.model(x.to(self.device))
class AutoEncoder(nn.Module):
def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.decoder(self.encoder(x))
class Trainer:
def __init__(self, device: torch.device) -> None:
self._inference_model = None
self.device = device
def build(self, neural_nets: nn.ModuleDict) -> None:
self.net = AutoEncoder(neural_nets["encoder"], neural_nets["decoder"])
self.net.to(self.device)
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.001)
self.data_loader = DataLoader(TensorDataset(torch.zeros(16, 28 * 28)), 8)
@property
def neural_nets(self) -> nn.ModuleDict:
if hasattr(self, "_neural_nets"):
return self._neural_nets
raise RuntimeError("Neural Nets is not attached to trainer!")
@neural_nets.setter
def neural_nets(self, n: nn.ModuleDict) -> None:
self._neural_nets = n
@property
def inference_model(self) -> Inference:
if self._inference_model is None:
raise RuntimeError("Inference model is not attached to trainer!")
return self._inference_model
@inference_model.setter
def inference_model(self, m: Inference) -> None:
self._inference_model = m
def train(self):
for data in self.data_loader:
x = data[0].to(self.device)
out = self.net(x)
loss = F.mse_loss(x, out)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
print("Loss:", loss.item())
def sync(self) -> None:
trained_model = self.neural_nets["encoder"]
trained_model.eval()
trained_param = trained_model.state_dict()
untrained = self.inference_model.model
self.inference_model.model = trained_model
untrained.load_state_dict(trained_param)
untrained.train()
self.neural_nets["encoder"] = untrained
def __call__(self):
self.build(self.neural_nets)
self.train()
self.sync()
class System:
def __init__(self, device="cpu"):
encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 32),
nn.ReLU(),
nn.Linear(32, 8),
)
decoder = nn.Sequential(
nn.Linear(8, 32),
nn.ReLU(),
nn.Linear(32, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
)
neural_nets = nn.ModuleDict(
{
"encoder": encoder,
"decoder": decoder,
}
)
self.inference_model = Inference(device)
self.inference_model.attach_model_from_neural_nets(neural_nets)
self.trainer = Trainer(device)
self.trainer.inference_model = self.inference_model
self.trainer.neural_nets = neural_nets
def inference_loop(self):
while True:
obs = torch.zeros(28 * 28)
out = self.inference_model.infer(obs)
print("Infered on ", id(self.inference_model.model))
time.sleep(0.5)
def training_loop(self):
while True:
self.trainer()
print("Trained")
time.sleep(2.0)
def main(self):
inference_thread = threading.Thread(target=self.inference_loop)
trainning_thread = threading.Thread(target=self.training_loop)
inference_thread.start()
trainning_thread.start()
inference_thread.join()
trainning_thread.join()
if __name__ == "__main__":
system = System(device="mps")
system.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment