Skip to content

Instantly share code, notes, and snippets.

@trojblue
Last active March 12, 2024 20:41
Show Gist options
  • Save trojblue/85b06a132d29757b9ef2032c29ae111e to your computer and use it in GitHub Desktop.
Save trojblue/85b06a132d29757b9ef2032c29ae111e to your computer and use it in GitHub Desktop.

用来数出df里某列 tag counts数量, 然后可视化的代码:

def safe_split_tag_str(tag_str, separator=","):
    """
    Splits a tag string into a list of non-empty, whitespace-stripped tag strings.
    """
    if not tag_str:
        return []
    # Split the tag string using the separator, strip whitespace, and filter out empty strings
    tags = [tag.strip() for tag in tag_str.split(separator) if tag.strip()]
    return tags
    
def visualize(df, columns, range=(0, 100)):
    """
    Visualizes the distribution of a specified column in the DataFrame.
    """
    start, end = range
    alpha = 1 / len(columns)
    plt.figure(figsize=(14, 7))
    
    for column in columns:
        plt.hist(df[column], bins=np.arange(start, end, 1), alpha=alpha)

    plt.legend(columns)
    plt.title(f'Distribution of {len(columns)} columns: {", ".join(columns)}')
    plt.xlabel(columns)
    plt.ylabel('Frequency')
    plt.show()


def add_tag_counts(df):

    df['adjusted_tag_count'] = df['general_tags_all'].apply(lambda x: len(x.split(',')))
    df['wd_tag_count'] = df['wd_general'].apply(lambda x: len(safe_split_tag_str(x)))
    df['mld_tag_count'] = df['mld_tags'].apply(lambda x: len(safe_split_tag_str(x)))
    df["pixiv_tag_count"] = df["tags"].apply(lambda x: len(safe_split_tag_str(x)))
    df["danbooru_tag_count"] = df["tag_string_general"].apply(lambda x: len(safe_split_tag_str(x)))

    return df


vis_cols = ["adjusted_tag_count", "danbooru_tag_count"]
df_danbooru = df_sample[df_sample["tag_string_general"].notna()].copy()
df_danbooru = add_tag_counts(df_danbooru)
visualize(df_danbooru, vis_cols, range=(0, 100))


# vis_cols = ["adjusted_tag_count", "wd_tag_count", "mld_tag_count", "pixiv_tag_count", "danbooru_tag_count"]
# df_sample = add_tag_counts(df_sample)
# visualize(df_sample, vis_cols, range=(0, 100))

图: Imgur

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment