Skip to content

Instantly share code, notes, and snippets.

@GINK03 GINK03/seaborn_matplot.md
Last active Nov 18, 2019

Embed
What would you like to do?
seaborn-matplot

きれいで見やすい

software

imgcat

日本語化

seabornの日本語化が厄介で、この方法しかUbuntuではワークしなかった https://qiita.com/yutapon988/items/855748de65ac780c98d5

バイオリンプロット

def space_by_price_every_country():
    df = pd.read_csv('lexical_parsed.csv')
    countries = set()
    for country, subDf in df.groupby(by=['country']):
        if len(subDf) > 500:
            countries.add(country)
    df = df.sort_values(by=['country'])
    df['space_by_price'] = df['menseki'] / df['yachin']
    df = df[df['country'].apply(lambda x:x in countries)]
    dfC = df[['country', 'space_by_price']].groupby(by=['country']).mean().reset_index().sort_values(by=['space_by_price'], ascending=True)
    sns.set(font_scale=2)
    ax = sns.violinplot(x="country", y='space_by_price', data=df, order=dfC.country)
    pyplot.ylim(0, 15)
    ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=90)
    ax.set(xlabel='都道府県', ylabel='m^2/家賃(万円)')

    ax.figure.savefig('space_by_price.png')

barplot

https://seaborn.pydata.org/generated/seaborn.barplot.html

import seaborn as sns
from matplotlib import pyplot
import pandas as pd

def feature_imps():
    pyplot.figure(figsize=(30, 30))
    df = pd.read_csv('./imps.csv')
    df = df[df['features'].apply(lambda x:x != '__intercept__')]
    ax = sns.barplot(x="coefs", y="features", data=df)
    sns.set(font_scale=2)
    ax.set(xlabel='重要度', ylabel='特徴量')
    ax.figure.savefig('all_imps.png')

correlation matrix

# very heavy
import seaborn as sns
%matplotlib inline

fig, ax = plt.subplots(figsize=(15,15))
# calculate the correlation matrix
corr = dfTrain[:1000].corr()

# plot the heatmap
sns.heatmap(corr, 
        xticklabels=corr.columns,
        yticklabels=corr.columns, ax=ax)

line plit

seabornのバージョンによってはこの方法でないとxtickをrotateできないこともあったりする。

また、figインスタンスをaxから得たあと、clfしないと無限に上書きしてしまうことがある。

import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from datetime import timedelta
df = pd.read_csv('result.csv')
df['pd'] = pd.to_datetime(df.pd)
df['ad'] = pd.to_datetime(df.ad)

objs = []
for TH in [12]:
    for pd_, df_sub in df.groupby(by=['pd']):
        plt.figure(figsize=(30, 30))
        if len(df_sub) <= 30:
            continue
        print(pd)
        df_sub['access_num'] = df_sub.access_num.cumsum()
        # 50件以下はテスト等の可能性があって無視する
        df_sub = df_sub[df_sub.access_num > 50]
        print(df_sub.head(20))
        '''
        ax = sns.lineplot(x='ad', y='access_num', data=df_sub)
        #ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45)
        ax.set_title(f'{pd}')
        plt.setp(ax.get_xticklabels(), rotation=45)
        fig = ax.get_figure()
        fig.savefig(f'imgs/{pd}.png')
        fig.clf()
        '''
        obj = {'TH':TH, 'pd':pd_, 'access_num_sum':df_sub.iloc[TH].access_num}
        objs.append(obj)

df = pd.DataFrame(objs)
df.to_csv('output.csv', index=None)
@GINK03

This comment has been minimized.

Copy link
Owner Author

GINK03 commented Mar 20, 2019

image

@GINK03

This comment has been minimized.

Copy link
Owner Author

GINK03 commented Mar 21, 2019

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.