Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
TextGeneratorをPython3系で動かす
# -*- coding: utf-8 -*-
u"""
マルコフ連鎖を用いて適当な文章を自動生成するファイル
"""
import os.path
import sqlite3
import random
from PrepareChain import PrepareChain
class GenerateText(object):
u"""
文章生成用クラス
"""
def __init__(self, n=5):
u"""
初期化メソッド
@param n いくつの文章を生成するか
"""
self.n = n
def generate(self):
u"""
実際に生成する
@return 生成された文章
"""
# DBが存在しないときは例外をあげる
if not os.path.exists(PrepareChain.DB_PATH):
raise IOError(u"DBファイルが存在しません")
# DBオープン
con = sqlite3.connect(PrepareChain.DB_PATH)
con.row_factory = sqlite3.Row
# 最終的にできる文章
generated_text = u""
# 指定の数だけ作成する
# for i in xrange(self.n):
for i in range(self.n):
text = self._generate_sentence(con)
generated_text += text
# DBクローズ
con.close()
return generated_text
def _generate_sentence(self, con):
u"""
ランダムに一文を生成する
@param con DBコネクション
@return 生成された1つの文章
"""
# 生成文章のリスト
morphemes = []
# はじまりを取得
first_triplet = self._get_first_triplet(con)
morphemes.append(first_triplet[1])
morphemes.append(first_triplet[2])
# 文章を紡いでいく
while morphemes[-1] != PrepareChain.END:
prefix1 = morphemes[-2]
prefix2 = morphemes[-1]
triplet = self._get_triplet(con, prefix1, prefix2)
morphemes.append(triplet[2])
# 連結
result = "".join(morphemes[:-1])
return result
def _get_chain_from_DB(self, con, prefixes):
u"""
チェーンの情報をDBから取得する
@param con DBコネクション
@param prefixes チェーンを取得するprefixの条件 tupleかlist
@return チェーンの情報の配列
"""
# ベースとなるSQL
sql = u"select prefix1, prefix2, suffix, freq from chain_freqs where prefix1 = ?"
# prefixが2つなら条件に加える
if len(prefixes) == 2:
sql += u" and prefix2 = ?"
# 結果
result = []
# DBから取得
cursor = con.execute(sql, prefixes)
for row in cursor:
result.append(dict(row))
return result
def _get_first_triplet(self, con):
u"""
文章のはじまりの3つ組をランダムに取得する
@param con DBコネクション
@return 文章のはじまりの3つ組のタプル
"""
# BEGINをprefix1としてチェーンを取得
prefixes = (PrepareChain.BEGIN,)
# チェーン情報を取得
chains = self._get_chain_from_DB(con, prefixes)
# 取得したチェーンから、確率的に1つ選ぶ
triplet = self._get_probable_triplet(chains)
return (triplet["prefix1"], triplet["prefix2"], triplet["suffix"])
def _get_triplet(self, con, prefix1, prefix2):
u"""
prefix1とprefix2からsuffixをランダムに取得する
@param con DBコネクション
@param prefix1 1つ目のprefix
@param prefix2 2つ目のprefix
@return 3つ組のタプル
"""
# BEGINをprefix1としてチェーンを取得
prefixes = (prefix1, prefix2)
# チェーン情報を取得
chains = self._get_chain_from_DB(con, prefixes)
# 取得したチェーンから、確率的に1つ選ぶ
triplet = self._get_probable_triplet(chains)
return (triplet["prefix1"], triplet["prefix2"], triplet["suffix"])
def _get_probable_triplet(self, chains):
u"""
チェーンの配列の中から確率的に1つを返す
@param chains チェーンの配列
@return 確率的に選んだ3つ組
"""
# 確率配列
probability = []
# 確率に合うように、インデックスを入れる
for (index, chain) in enumerate(chains):
# for j in xrange(chain["freq"]):
for j in range(chain["freq"]):
probability.append(index)
# ランダムに1つを選ぶ
chain_index = random.choice(probability)
return chains[chain_index]
if __name__ == '__main__':
generator = GenerateText()
print(generator.generate())
# -*- coding: utf-8 -*-
u"""
与えられた文書からマルコフ連鎖のためのチェーン(連鎖)を作成して、DBに保存するファイル
"""
import unittest
import re
import MeCab
import sqlite3
from collections import defaultdict
class PrepareChain(object):
u"""
チェーンを作成してDBに保存するクラス
"""
BEGIN = u"__BEGIN_SENTENCE__"
END = u"__END_SENTENCE__"
DB_PATH = "chain.db"
DB_SCHEMA_PATH = "schema.sql"
def __init__(self, text):
u"""
初期化メソッド
@param text チェーンを生成するための文章
"""
# if isinstance(text, str):
# text = text.decode("utf-8")
self.text = text
# 形態素解析用タガー
self.tagger = MeCab.Tagger('-Ochasen')
def make_triplet_freqs(self):
u"""
形態素解析から3つ組の出現回数まで
@return 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数
"""
# 長い文章をセンテンス毎に分割
sentences = self._divide(self.text)
# 3つ組の出現回数
triplet_freqs = defaultdict(int)
# センテンス毎に3つ組にする
for sentence in sentences:
# 形態素解析
morphemes = self._morphological_analysis(sentence)
# 3つ組をつくる
triplets = self._make_triplet(morphemes)
# 出現回数を加算
for (triplet, n) in triplets.items():
triplet_freqs[triplet] += n
return triplet_freqs
def _divide(self, text):
u"""
「。」や改行などで区切られた長い文章を一文ずつに分ける
@param text 分割前の文章
@return 一文ずつの配列
"""
# 改行文字以外の分割文字(正規表現表記)
delimiter = u"。|.|\."
# 全ての分割文字を改行文字に置換(splitしたときに「。」などの情報を無くさないため)
text = re.sub(r"({0})".format(delimiter), r"\1\n", text)
# 改行文字で分割
sentences = text.splitlines()
# 前後の空白文字を削除
sentences = [sentence.strip() for sentence in sentences]
return sentences
def _morphological_analysis(self, sentence):
u"""
一文を形態素解析する
@param sentence 一文
@return 形態素で分割された配列
"""
morphemes = []
# sentence = sentence.encode("utf-8")
node = self.tagger.parseToNode(sentence)
while node:
if node.posid != 0:
# morpheme = node.surface.decode("utf-8")
morpheme = node.surface
morphemes.append(morpheme)
node = node.next
return morphemes
def _make_triplet(self, morphemes):
u"""
形態素解析で分割された配列を、形態素毎に3つ組にしてその出現回数を数える
@param morphemes 形態素配列
@return 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数
"""
# 3つ組をつくれない場合は終える
if len(morphemes) < 3:
return {}
# 出現回数の辞書
triplet_freqs = defaultdict(int)
# 繰り返し
# for i in xrange(len(morphemes)-2):
for i in range(len(morphemes)-2):
triplet = tuple(morphemes[i:i+3])
triplet_freqs[triplet] += 1
# beginを追加
triplet = (PrepareChain.BEGIN, morphemes[0], morphemes[1])
triplet_freqs[triplet] = 1
# endを追加
triplet = (morphemes[-2], morphemes[-1], PrepareChain.END)
triplet_freqs[triplet] = 1
return triplet_freqs
def save(self, triplet_freqs, init=False):
u"""
3つ組毎に出現回数をDBに保存
@param triplet_freqs 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数
"""
# DBオープン
con = sqlite3.connect(PrepareChain.DB_PATH)
# 初期化から始める場合
if init:
# DBの初期化
with open(PrepareChain.DB_SCHEMA_PATH, "r") as f:
schema = f.read()
con.executescript(schema)
# データ整形
datas = [(triplet[0], triplet[1], triplet[2], freq) for (triplet, freq) in triplet_freqs.items()]
# データ挿入
p_statement = u"insert into chain_freqs (prefix1, prefix2, suffix, freq) values (?, ?, ?, ?)"
con.executemany(p_statement, datas)
# コミットしてクローズ
con.commit()
con.close()
def show(self, triplet_freqs):
u"""
3つ組毎の出現回数を出力する
@param triplet_freqs 3つ組とその出現回数の辞書 key: 3つ組(タプル) val: 出現回数
"""
for triplet in triplet_freqs:
print("|".join(triplet), "\t", triplet_freqs[triplet])
class TestFunctions(unittest.TestCase):
u"""
テスト用クラス
"""
def setUp(self):
u"""
テストが実行される前に実行される
"""
self.text = u"こんにちは。 今日は、楽しい運動会です。hello world.我輩は猫である\n 名前はまだない。我輩は犬である\r\n名前は決まってるよ"
self.chain = PrepareChain(self.text)
def test_make_triplet_freqs(self):
u"""
全体のテスト
"""
triplet_freqs = self.chain.make_triplet_freqs()
answer = {(u"__BEGIN_SENTENCE__", u"今日", u"は"): 1, (u"今日", u"は", u"、"): 1, (u"は", u"、", u"楽しい"): 1, (u"、", u"楽しい", u"運動会"): 1, (u"楽しい", u"運動会", u"です"): 1, (u"運動会", u"です", u"。"): 1, (u"です", u"。", u"__END_SENTENCE__"): 1, (u"__BEGIN_SENTENCE__", u"hello", u"world"): 1, (u"hello", u"world", u"."): 1, (u"world", u".", u"__END_SENTENCE__"): 1, (u"__BEGIN_SENTENCE__", u"我輩", u"は"): 2, (u"我輩", u"は", u"猫"): 1, (u"は", u"猫", u"で"): 1, (u"猫", u"で", u"ある"): 1, (u"で", u"ある", u"__END_SENTENCE__"): 2, (u"__BEGIN_SENTENCE__", u"名前", u"は"): 2, (u"名前", u"は", u"まだ"): 1, (u"は", u"まだ", u"ない"): 1, (u"まだ", u"ない", u"。"): 1, (u"ない", u"。", u"__END_SENTENCE__"): 1, (u"我輩", u"は", u"犬"): 1, (u"は", u"犬", u"で"): 1, (u"犬", u"で", u"ある"): 1, (u"名前", u"は", u"決まっ"): 1, (u"は", u"決まっ", u"てる"): 1, (u"決まっ", u"てる", u"よ"): 1, (u"てる", u"よ", u"__END_SENTENCE__"): 1}
self.assertEqual(triplet_freqs, answer)
def test_divide(self):
u"""
一文ずつに分割するテスト
"""
sentences = self.chain._divide(self.text)
answer = [u"こんにちは。", u"今日は、楽しい運動会です。", u"hello world.", u"我輩は猫である", u"名前はまだない。", u"我輩は犬である", u"名前は決まってるよ"]
self.assertEqual(sentences.sort(), answer.sort())
def test_morphological_analysis(self):
u"""
形態素解析用のテスト
"""
sentence = u"今日は、楽しい運動会です。"
morphemes = self.chain._morphological_analysis(sentence)
answer = [u"今日", u"は", u"、", u"楽しい", u"運動会", u"です", u"。"]
self.assertEqual(morphemes.sort(), answer.sort())
def test_make_triplet(self):
u"""
形態素毎に3つ組にしてその出現回数を数えるテスト
"""
morphemes = [u"今日", u"は", u"、", u"楽しい", u"運動会", u"です", u"。"]
triplet_freqs = self.chain._make_triplet(morphemes)
answer = {(u"__BEGIN_SENTENCE__", u"今日", u"は"): 1, (u"今日", u"は", u"、"): 1, (u"は", u"、", u"楽しい"): 1, (u"、", u"楽しい", u"運動会"): 1, (u"楽しい", u"運動会", u"です"): 1, (u"運動会", u"です", u"。"): 1, (u"です", u"。", u"__END_SENTENCE__"): 1}
self.assertEqual(triplet_freqs, answer)
def test_make_triplet_too_short(self):
u"""
形態素毎に3つ組にしてその出現回数を数えるテスト
ただし、形態素が少なすぎる場合
"""
morphemes = [u"こんにちは", u"。"]
triplet_freqs = self.chain._make_triplet(morphemes)
answer = {}
self.assertEqual(triplet_freqs, answer)
def test_make_triplet_3morphemes(self):
u"""
形態素毎に3つ組にしてその出現回数を数えるテスト
ただし、形態素がちょうど3つの場合
"""
morphemes = [u"hello", u"world", u"."]
triplet_freqs = self.chain._make_triplet(morphemes)
answer = {(u"__BEGIN_SENTENCE__", u"hello", u"world"): 1, (u"hello", u"world", u"."): 1, (u"world", u".", u"__END_SENTENCE__"): 1}
self.assertEqual(triplet_freqs, answer)
def tearDown(self):
u"""
テストが実行された後に実行される
"""
pass
if __name__ == '__main__':
# unittest.main()
# テキストをファイルから読み込む
ero_file = open('../ero.txt', encoding='utf-8')
text = ero_file.read()
chain = PrepareChain(text)
triplet_freqs = chain.make_triplet_freqs()
chain.save(triplet_freqs, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment