Skip to content

Instantly share code, notes, and snippets.

@aferust
Created January 26, 2021 14:43
Show Gist options
  • Save aferust/55bb70359fdd3148c7e920b02907084a to your computer and use it in GitHub Desktop.
Save aferust/55bb70359fdd3148c7e920b02907084a to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
@author: Ferhat Kurtulmuş
"""
from PyQt5.QtWidgets import (QWidget, QPushButton, QLineEdit, QFileDialog,
QInputDialog, QApplication, QMessageBox, QLabel,
QHBoxLayout, QVBoxLayout, QSpinBox, QDoubleSpinBox )
import sys, os
from os.path import dirname, abspath
from shutil import copyfile
from sklearn.model_selection import train_test_split
class TTSplit(QWidget):
def __init__(self):
super().__init__()
self.initUI()
def initUI(self):
self.folder = ""
self.rndSeed = 0
self.testRate = 0.3
rootLayout = QVBoxLayout()
self.folderBtn = QPushButton('Choose folder', self)
self.folderBtn.clicked.connect(self.showDialog)
self.le = QLineEdit(self)
row1Layout = QHBoxLayout()
row1Layout.addWidget(self.folderBtn)
row1Layout.addWidget(self.le)
rootLayout.addLayout(row1Layout)
self.doItBtn = QPushButton('create', self)
self.doItBtn.clicked.connect(self.getSubfolders)
rootLayout.addWidget(self.doItBtn)
row3Layout = QHBoxLayout()
l1 = QLabel("Random seed:")
row3Layout.addWidget(l1)
self.sp = QSpinBox()
self.sp.setMinimum(0)
self.sp.setValue(0)
self.sp.singleStep()
self.sp.valueChanged.connect(self.seedChange)
row3Layout.addWidget(self.sp)
rootLayout.addLayout(row3Layout)
row4Layout = QHBoxLayout()
l2 = QLabel("Test size:")
row4Layout.addWidget(l2)
self.spf = QDoubleSpinBox()
self.spf.setMinimum(0.05)
self.spf.setMaximum(0.5)
self.spf.setValue(0.3)
self.spf.setSingleStep(0.05)
self.spf.valueChanged.connect(self.splitRateChange)
row4Layout.addWidget(self.spf)
rootLayout.addLayout(row4Layout)
self.setLayout(rootLayout)
self.setGeometry(300, 300, 450, 350)
self.setWindowTitle('TT Split')
self.show()
def seedChange(self):
self.rndSeed = self.sp.value()
def splitRateChange(self):
self.testRate = self.spf.value()
def showDialog(self):
self.folder = str(QFileDialog.getExistingDirectory(self, "Select Directory"))
self.le.setText(str(self.folder))
def getSubfolders(self):
subfolders = [dI for dI in os.listdir(self.folder) if os.path.isdir(os.path.join(self.folder,dI))]
if not subfolders:
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setText("Unexpected input")
msg.setInformativeText("Empty folder selected")
msg.setWindowTitle("Error")
#msg.setDetailedText("The details are as follows:")
_ = msg.exec_()
return
filesWithClasses = []
labels = []
k = 0
for sf in subfolders:
for root, dirs, files in os.walk(os.path.join(self.folder,sf)):
for filename in files:
filesWithClasses.append(os.path.join(self.folder, sf, filename))
labels.append(k)
k += 1
X_train, X_test, _, _ = train_test_split(filesWithClasses, labels,
test_size=self.testRate, random_state=self.rndSeed)
# create root folder
droot = dirname(abspath(self.folder))
eroot = abspath(self.folder)
bname = os.path.basename(eroot)
newFolderPath = os.path.join(droot, bname+"_generated")
if not os.path.exists(newFolderPath):
os.makedirs(newFolderPath)
else:
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setText("Folder exist")
msg.setInformativeText("Please remove existing folder before continue")
msg.setWindowTitle("Error")
_ = msg.exec_()
return
for sf in subfolders:
os.makedirs(os.path.join(newFolderPath, sf + "_train" ))
#copyfile()
for sf in subfolders:
os.makedirs(os.path.join(newFolderPath, sf + "_valid" ))
for xtf in X_train:
_root = dirname(abspath(xtf))
copyfile(xtf, os.path.join(newFolderPath, os.path.basename(_root) + "_train", os.path.basename(xtf)))
for xtf in X_test:
_root = dirname(abspath(xtf))
copyfile(xtf, os.path.join(newFolderPath, os.path.basename(_root) + "_valid", os.path.basename(xtf)))
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setText("Done!")
msg.setInformativeText("Done!")
msg.setWindowTitle("Success")
_ = msg.exec_()
def terminate(ex):
sys.exit(ex)
def main():
app = QApplication(sys.argv)
ex = TTSplit()
terminate(app.exec_())
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment