Created
January 26, 2021 14:43
-
-
Save aferust/55bb70359fdd3148c7e920b02907084a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
@author: Ferhat Kurtulmuş | |
""" | |
from PyQt5.QtWidgets import (QWidget, QPushButton, QLineEdit, QFileDialog, | |
QInputDialog, QApplication, QMessageBox, QLabel, | |
QHBoxLayout, QVBoxLayout, QSpinBox, QDoubleSpinBox ) | |
import sys, os | |
from os.path import dirname, abspath | |
from shutil import copyfile | |
from sklearn.model_selection import train_test_split | |
class TTSplit(QWidget): | |
def __init__(self): | |
super().__init__() | |
self.initUI() | |
def initUI(self): | |
self.folder = "" | |
self.rndSeed = 0 | |
self.testRate = 0.3 | |
rootLayout = QVBoxLayout() | |
self.folderBtn = QPushButton('Choose folder', self) | |
self.folderBtn.clicked.connect(self.showDialog) | |
self.le = QLineEdit(self) | |
row1Layout = QHBoxLayout() | |
row1Layout.addWidget(self.folderBtn) | |
row1Layout.addWidget(self.le) | |
rootLayout.addLayout(row1Layout) | |
self.doItBtn = QPushButton('create', self) | |
self.doItBtn.clicked.connect(self.getSubfolders) | |
rootLayout.addWidget(self.doItBtn) | |
row3Layout = QHBoxLayout() | |
l1 = QLabel("Random seed:") | |
row3Layout.addWidget(l1) | |
self.sp = QSpinBox() | |
self.sp.setMinimum(0) | |
self.sp.setValue(0) | |
self.sp.singleStep() | |
self.sp.valueChanged.connect(self.seedChange) | |
row3Layout.addWidget(self.sp) | |
rootLayout.addLayout(row3Layout) | |
row4Layout = QHBoxLayout() | |
l2 = QLabel("Test size:") | |
row4Layout.addWidget(l2) | |
self.spf = QDoubleSpinBox() | |
self.spf.setMinimum(0.05) | |
self.spf.setMaximum(0.5) | |
self.spf.setValue(0.3) | |
self.spf.setSingleStep(0.05) | |
self.spf.valueChanged.connect(self.splitRateChange) | |
row4Layout.addWidget(self.spf) | |
rootLayout.addLayout(row4Layout) | |
self.setLayout(rootLayout) | |
self.setGeometry(300, 300, 450, 350) | |
self.setWindowTitle('TT Split') | |
self.show() | |
def seedChange(self): | |
self.rndSeed = self.sp.value() | |
def splitRateChange(self): | |
self.testRate = self.spf.value() | |
def showDialog(self): | |
self.folder = str(QFileDialog.getExistingDirectory(self, "Select Directory")) | |
self.le.setText(str(self.folder)) | |
def getSubfolders(self): | |
subfolders = [dI for dI in os.listdir(self.folder) if os.path.isdir(os.path.join(self.folder,dI))] | |
if not subfolders: | |
msg = QMessageBox() | |
msg.setIcon(QMessageBox.Information) | |
msg.setText("Unexpected input") | |
msg.setInformativeText("Empty folder selected") | |
msg.setWindowTitle("Error") | |
#msg.setDetailedText("The details are as follows:") | |
_ = msg.exec_() | |
return | |
filesWithClasses = [] | |
labels = [] | |
k = 0 | |
for sf in subfolders: | |
for root, dirs, files in os.walk(os.path.join(self.folder,sf)): | |
for filename in files: | |
filesWithClasses.append(os.path.join(self.folder, sf, filename)) | |
labels.append(k) | |
k += 1 | |
X_train, X_test, _, _ = train_test_split(filesWithClasses, labels, | |
test_size=self.testRate, random_state=self.rndSeed) | |
# create root folder | |
droot = dirname(abspath(self.folder)) | |
eroot = abspath(self.folder) | |
bname = os.path.basename(eroot) | |
newFolderPath = os.path.join(droot, bname+"_generated") | |
if not os.path.exists(newFolderPath): | |
os.makedirs(newFolderPath) | |
else: | |
msg = QMessageBox() | |
msg.setIcon(QMessageBox.Information) | |
msg.setText("Folder exist") | |
msg.setInformativeText("Please remove existing folder before continue") | |
msg.setWindowTitle("Error") | |
_ = msg.exec_() | |
return | |
for sf in subfolders: | |
os.makedirs(os.path.join(newFolderPath, sf + "_train" )) | |
#copyfile() | |
for sf in subfolders: | |
os.makedirs(os.path.join(newFolderPath, sf + "_valid" )) | |
for xtf in X_train: | |
_root = dirname(abspath(xtf)) | |
copyfile(xtf, os.path.join(newFolderPath, os.path.basename(_root) + "_train", os.path.basename(xtf))) | |
for xtf in X_test: | |
_root = dirname(abspath(xtf)) | |
copyfile(xtf, os.path.join(newFolderPath, os.path.basename(_root) + "_valid", os.path.basename(xtf))) | |
msg = QMessageBox() | |
msg.setIcon(QMessageBox.Information) | |
msg.setText("Done!") | |
msg.setInformativeText("Done!") | |
msg.setWindowTitle("Success") | |
_ = msg.exec_() | |
def terminate(ex): | |
sys.exit(ex) | |
def main(): | |
app = QApplication(sys.argv) | |
ex = TTSplit() | |
terminate(app.exec_()) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment