Last active
August 29, 2015 14:26
-
-
Save jeevanel/089e3829408963ccc617 to your computer and use it in GitHub Desktop.
Modified version of sklearn load_files to accept additional parameter 'ignore_list' which accepts list of file names to be omitted from loading. This is useful especially on OS X and POSIX machine where files beginning with '.' For instance on OS X load_files would load even '.DS_Store' file which is not desirable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os.path import dirname | |
from os.path import join | |
from os.path import exists | |
from os.path import expanduser | |
from os.path import isdir | |
from os import listdir | |
from os import makedirs | |
from sklearn.datasets.base import Bunch | |
from sklearn.utils import check_random_state | |
def load_files(container_path, description=None, categories=None, | |
load_content=True, shuffle=True, encoding=None, | |
decode_error='strict', random_state=0, ignore_files=None): | |
target = [] | |
target_names = [] | |
filenames = [] | |
folders = [f for f in sorted(listdir(container_path)) | |
if isdir(join(container_path, f))] | |
if categories is not None: | |
folders = [f for f in folders if f in categories] | |
for label, folder in enumerate(folders): | |
target_names.append(folder) | |
folder_path = join(container_path, folder) | |
documents = [ join(folder_path, d) | |
for d in sorted(listdir(folder_path)) if ignore_files is not None and d not in ignore_files ] | |
target.extend(len(documents) * [label]) | |
filenames.extend(documents) | |
# convert to array for fancy indexing | |
filenames = np.array(filenames) | |
target = np.array(target) | |
if shuffle: | |
random_state = check_random_state(random_state) | |
indices = np.arange(filenames.shape[0]) | |
random_state.shuffle(indices) | |
filenames = filenames[indices] | |
target = target[indices] | |
if load_content: | |
data = [] | |
for filename in filenames: | |
with open(filename, 'rb') as f: | |
data.append(f.read()) | |
if encoding is not None: | |
data = [d.decode(encoding, decode_error) for d in data] | |
return Bunch(data=data, | |
filenames=filenames, | |
target_names=target_names, | |
target=target, | |
DESCR=description) | |
return Bunch(filenames=filenames, | |
target_names=target_names, | |
target=target, | |
DESCR=description) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment