Created
February 6, 2020 12:57
-
-
Save pei223/d2738e3b646472e70613c8bb07c81e32 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict | |
from wordcloud import WordCloud | |
import matplotlib.pylab as plt | |
import math | |
class WordCloudGenerator: | |
def __init__(self): | |
self.word_cloud = None | |
def to_word_cloud(self, word_weight_dict: dict, size=(300, 300)): | |
self.word_cloud: WordCloud = WordCloud( | |
font_path="./NotoSansCJKjp-Medium.otf", | |
background_color='black', | |
max_words=4000, | |
width=size[0], height=size[1], | |
random_state=0 | |
).generate_from_frequencies(word_weight_dict) | |
return self | |
def save_img(self, output_filepath): | |
plt.imshow(self.word_cloud) | |
plt.savefig(output_filepath) | |
def show_img(self): | |
plt.imshow(self.word_cloud) | |
plt.show() | |
class MultipleWordCloudGenerator: | |
def __init__(self): | |
self.word_cloud_list = [] | |
def to_word_clouds(self, word_weight_dict_list: List[Dict[str, float]], one_img_size=(500, 500)): | |
self.word_cloud_list = [] | |
for word_weight_dict in word_weight_dict_list: | |
self.word_cloud_list.append(WordCloud( | |
font_path="./NotoSansCJKjp-Medium.otf", | |
background_color='black', | |
max_words=4000, | |
width=one_img_size[0], height=one_img_size[1], | |
random_state=0 | |
).generate_from_frequencies(word_weight_dict)) | |
return self | |
def save_img(self, output_filepath, cols: int = 2): | |
fig, axs = plt.subplots(ncols=cols, nrows=int(math.ceil(len(self.word_cloud_list) / 2)), figsize=(16, 20)) | |
axs = axs.flatten() | |
for i, word_cloud in enumerate(self.word_cloud_list): | |
axs[i].imshow(word_cloud.recolor(colormap='Paired_r', random_state=244), alpha=0.98) | |
axs[i].axis('off') | |
axs[i].set_title('Data ' + str(i + 1)) | |
plt.tight_layout() | |
plt.savefig(output_filepath) | |
def show_img(self, cols: int = 2): | |
fig, axs = plt.subplots(ncols=cols, nrows=int(math.ceil(len(self.word_cloud_list) / 2)), figsize=(16, 20)) | |
axs = axs.flatten() | |
for i, word_cloud in enumerate(self.word_cloud_list): | |
axs[i].imshow(word_cloud.recolor(colormap='Paired_r', random_state=244), alpha=0.98) | |
axs[i].axis('off') | |
axs[i].set_title('Data ' + str(i + 1)) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment