Skip to content

Instantly share code, notes, and snippets.

@komiya-atsushi
Created September 22, 2012 09:43
Show Gist options
  • Save komiya-atsushi/3765693 to your computer and use it in GitHub Desktop.
Save komiya-atsushi/3765693 to your computer and use it in GitHub Desktop.
Double Array Trie の Java 実装を Map インタフェースでラップしたもの。
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
/**
* Double Array Trie の Java 実装を Map インタフェースでラップしたクラスです。
* <p>
* Java 実装は以下を利用しています。<br />
* http://nlp.ist.i.kyoto-u.ac.jp/member/murawaki/misc/index.html
* </p>
*
* <p>
* Map インタフェースのうち、以下のメソッドが実装されています。
* <ul>
* <li><code>{@link Map#containsKey(Object)}</code></li>
* <li><code>{@link Map#get(Object)}</code></li>
* <li><code>{@link Map#isEmpty()}</code></li>
* <li><code>{@link Map#size()}</code></li>
* </ul>
* </p>
*
* <p>
* また、Trie 特有の操作として、以下のメソッドが実装されています。
* <ul>
* <li><code>{@link #commonPrefixSearch(String)}</code></li>
* </ul>
* </p>
*
* @author KOMIYA Atsushi
*/
public class DartsMap<V> implements Map<String, V> {
/** Trie の各種操作を提供するオブジェクト */
private DoubleArrayTrie trie;
/** Trie に格納されている各キーに対応する値。キーとは要素番号で紐付けされています。 */
private List<V> values;
/**
* DartsMap オブジェクトを生成して返却します。
*
* @param keys
* キーとなる文字列を格納したリスト。このリストの内容をもとに Trie が構築されます。
* @param values
* キーに対応する値。keys の順序に対応しています。keys と同じ長さである必要があります。
* @return 生成された DartsMap オブジェクト
*/
public static <V> DartsMap<V> create(List<String> keys, List<V> values) {
if (keys == null) {
throw new IllegalArgumentException("keys に null を指定することはできません。");
}
if (values == null) {
throw new IllegalArgumentException("values に null を指定することはできません。");
}
if (keys.size() != values.size()) {
throw new IllegalArgumentException(
"keys と values の長さは同じである必要があります。");
}
DoubleArrayTrie trie = new DoubleArrayTrie();
char[][] charsArray = new char[keys.size()][];
int[] keyLengthes = new int[keys.size()];
for (int i = 0; i < keys.size(); i++) {
String key = keys.get(i);
charsArray[i] = key.toCharArray();
keyLengthes[i] = key.length();
}
int[] indexes = new int[values.size()];
for (int i = 0; i < indexes.length; i++) {
indexes[i] = i;
}
trie.build(charsArray, keyLengthes, indexes);
return new DartsMap<V>(trie, values);
}
private DartsMap(DoubleArrayTrie trie, List<V> values) {
this.trie = trie;
this.values = values;
}
/**
* 指定されたキーで共通接頭辞探索を行い、得られた値をリストに格納して返却します。
*
* @param key
* @return
*/
public List<V> commonPrefixSearch(String key) {
int[] indexes = new int[values.size()];
int indexCount = trie.commonPrefixSearch(key.toCharArray(), indexes,
indexes.length, 0);
if (indexCount == 0) {
return Collections.emptyList();
}
List<V> result = new ArrayList<V>(indexCount);
for (int i = 0; i < indexCount; i++) {
result.add(values.get(indexes[i]));
}
return result;
}
@Override
public int size() {
return values.size();
}
@Override
public boolean isEmpty() {
return size() == 0;
}
@Override
public boolean containsKey(Object key) {
char[] chars = ((String) key).toCharArray();
int index = trie.exactMatchSearch(chars, 0);
return index >= 0;
}
@Override
public V get(Object key) {
char[] chars = ((String) key).toCharArray();
int index = trie.exactMatchSearch(chars, 0);
if (index < 0) {
return null;
}
return values.get(index);
}
// 以下は提供されないメソッド
@Override
public V remove(Object key) {
throw createUnsupportedOperationException();
}
@Override
public void clear() {
throw createUnsupportedOperationException();
}
@Override
public Set<String> keySet() {
throw createUnsupportedOperationException();
}
@Override
public Collection<V> values() {
throw createUnsupportedOperationException();
}
@Override
public boolean containsValue(Object value) {
throw createUnsupportedOperationException();
}
@Override
public Set<Entry<String, V>> entrySet() {
throw createUnsupportedOperationException();
}
@Override
public V put(String key, V value) {
throw createUnsupportedOperationException();
}
@Override
public void putAll(Map<? extends String, ? extends V> m) {
throw createUnsupportedOperationException();
}
/**
* DartsMap として実装を提供しない Map インタフェースのメソッドが呼ばれた時に
* 投げる実行時例外オブジェクトを生成して返却します。
* <p>
* 返却される例外オブジェクトには、呼ばれたメソッドの名前が含まれます。
* </p>
*
* @return 例外オブジェクト
*/
private RuntimeException createUnsupportedOperationException() {
StackTraceElement elem = new Exception().getStackTrace()[1];
String message = String.format("DartsMap#%s() は実装されていません。",
elem.getMethodName());
return new UnsupportedOperationException(message);
}
}
/**
* DoubleArrayTrie: Java implementation of Darts (Double-ARray Trie System)
*
* <p>
* Copyright(C) 2001-2007 Taku Kudo &lt;taku@chasen.org&gt;<br />
* Copyright(C) 2009 MURAWAKI Yugo &lt;murawaki@nlp.kuee.kyoto-u.ac.jp&gt;
* </p>
*
* <p>
* The contents of this file may be used under the terms of either of
* the GNU Lesser General Public License Version 2.1 or later (the
* "LGPL"), or the BSD License (the "BSD").
* </p>
*/
class DoubleArrayTrie {
private final static int BUF_SIZE = 16384;
private final static int UNIT_SIZE = 8; // size of int + int
private class Node {
int code;
int depth;
int left;
int right;
};
private class Unit {
int base;
int check;
};
private Unit array[];
private int used[];
private int size;
private int allocSize;
private char key[][];
private int keySize;
private int length[];
private int value[];
private int progress;
private int nextCheckPos;
// boolean no_delete_;
int error_;
// int (*progressfunc_) (size_t, size_t);
// inline _resize expanded
private int resize(int newSize) {
Unit array2[] = new Unit[newSize];
for (int i = 0; i < allocSize; i++) {
array2[i] = array[i];
}
for (int i = allocSize; i < newSize; i++) {
Unit tmp = new Unit();
tmp.base = 0;
tmp.check = 0;
array2[i] = tmp;
}
array = array2;
int used2[] = new int[newSize];
for (int i = 0; i < allocSize; i++) {
used2[i] = used[i];
}
for (int i = allocSize; i < newSize; i++) {
used2[i] = 0;
}
used = used2;
return allocSize = newSize;
}
private int fetch(Node parent, Vector<Node> siblings) {
if (error_ < 0)
return 0;
int prev = 0;
for (int i = parent.left; i < parent.right; i++) {
if ((length != null ? length[i] : key[i].length) < parent.depth)
continue;
char tmp[] = key[i];
int cur = 0;
if ((length != null ? length[i] : key[i].length) != parent.depth)
cur = (int) tmp[parent.depth] + 1;
if (prev > cur) {
error_ = -3;
return 0;
}
if (cur != prev || siblings.size() == 0) {
Node tmp_node = new Node();
tmp_node.depth = parent.depth + 1;
tmp_node.code = cur;
tmp_node.left = i;
if (siblings.size() != 0)
siblings.get(siblings.size() - 1).right = i;
siblings.add(tmp_node);
}
prev = cur;
}
if (siblings.size() != 0)
siblings.get(siblings.size() - 1).right = parent.right;
return siblings.size();
}
private int insert(Vector<Node> siblings) {
if (error_ < 0)
return 0;
int begin = 0;
int pos = ((siblings.get(0).code + 1 > nextCheckPos) ? siblings.get(0).code + 1
: nextCheckPos) - 1;
int nonzero_num = 0;
int first = 0;
if (allocSize <= pos)
resize(pos + 1);
outer: while (true) {
pos++;
if (allocSize <= pos)
resize(pos + 1);
if (array[pos].check != 0) {
nonzero_num++;
continue;
} else if (first == 0) {
nextCheckPos = pos;
first = 1;
}
begin = pos - siblings.get(0).code;
if (allocSize <= (begin + siblings.get(siblings.size() - 1).code)) {
// progress can be zero
double l = (1.05 > 1.0 * keySize / (progress + 1)) ? 1.05 : 1.0
* keySize / (progress + 1);
resize((int) (allocSize * l));
}
if (used[begin] != 0)
continue;
for (int i = 1; i < siblings.size(); i++)
if (array[begin + siblings.get(i).code].check != 0)
continue outer;
break;
}
// -- Simple heuristics --
// if the percentage of non-empty contents in check between the
// index
// 'next_check_pos' and 'check' is greater than some constant value
// (e.g. 0.9),
// new 'next_check_pos' index is written by 'check'.
if (1.0 * nonzero_num / (pos - nextCheckPos + 1) >= 0.95)
nextCheckPos = pos;
used[begin] = 1;
size = (size > begin + siblings.get(siblings.size() - 1).code + 1) ? size
: begin + siblings.get(siblings.size() - 1).code + 1;
for (int i = 0; i < siblings.size(); i++)
array[begin + siblings.get(i).code].check = begin;
for (int i = 0; i < siblings.size(); i++) {
Vector<Node> new_siblings = new Vector<Node>();
if (fetch(siblings.get(i), new_siblings) == 0) {
array[begin + siblings.get(i).code].base = (value != null) ? (-value[siblings
.get(i).left] - 1) : (-siblings.get(i).left - 1);
if (value != null && (-value[siblings.get(i).left] - 1) >= 0) {
error_ = -2;
return 0;
}
progress++;
// if (progress_func_) (*progress_func_) (progress,
// keySize);
} else {
int h = insert(new_siblings);
array[begin + siblings.get(i).code].base = h;
}
}
return begin;
}
public DoubleArrayTrie() {
array = null;
used = null;
size = 0;
allocSize = 0;
// no_delete_ = false;
error_ = 0;
}
// no deconstructor
// set_result omitted
// the search methods returns (the list of) the value(s) instead
// of (the list of) the pair(s) of value(s) and length(s)
// set_array omitted
// array omitted
void clear() {
// if (! no_delete_)
array = null;
used = null;
allocSize = 0;
size = 0;
// no_delete_ = false;
}
public int getUnitSize() {
return UNIT_SIZE;
}
public int getSize() {
return size;
}
public int getTotalSize() {
return size * UNIT_SIZE;
}
public int getNonzeroSize() {
int result = 0;
for (int i = 0; i < size; i++)
if (array[i].check != 0)
result++;
return result;
}
public int build(char key[][], int length[], int value[]) {
return build(key, length, value, key.length);
}
public int build(char _key[][], int _length[], int _value[], int _keySize) {
if (_keySize > _key.length || _key == null)
return 0;
// progress_func_ = progress_func;
key = _key;
length = _length;
keySize = _keySize;
value = _value;
progress = 0;
resize(8192);
array[0].base = 1;
nextCheckPos = 0;
Node root_node = new Node();
root_node.left = 0;
root_node.right = keySize;
root_node.depth = 0;
Vector<Node> siblings = new Vector<Node>();
fetch(root_node, siblings);
insert(siblings);
// size += (1 << 8 * 2) + 1; // ???
// if (size >= allocSize) resize (size);
used = null;
return error_;
}
public void open(String fileName) throws IOException {
File file = new File(fileName);
size = (int) file.length() / UNIT_SIZE;
array = new Unit[size];
DataInputStream is = null;
try {
is = new DataInputStream(new BufferedInputStream(
new FileInputStream(file), BUF_SIZE));
for (int i = 0; i < array.length; i++) {
Unit tmp = new Unit();
tmp.base = is.readInt();
tmp.check = is.readInt();
array[i] = tmp;
}
} finally {
if (is != null)
is.close();
}
}
public void save(String fileName) throws IOException {
DataOutputStream out = null;
try {
out = new DataOutputStream(new BufferedOutputStream(
new FileOutputStream(fileName)));
for (int i = 0; i < size; i++) {
out.writeInt(array[i].base);
out.writeInt(array[i].check);
}
out.close();
} finally {
if (out != null)
out.close();
}
}
public int exactMatchSearch(char key[], int pos) {
return exactMatchSearch(key, pos, 0, 0);
}
public int exactMatchSearch(char key[], int pos, int len, int nodePos) {
if (len <= 0)
len = key.length;
if (nodePos <= 0)
nodePos = 0;
int result = -1;
int b = array[nodePos].base;
int p;
for (int i = pos; i < len; i++) {
p = b + (int) (key[i]) + 1;
if (b == array[p].check)
b = array[p].base;
else
return result;
}
p = b;
int n = array[p].base;
if (b == array[p].check && n < 0) {
result = -n - 1;
}
return result;
}
public int commonPrefixSearch(char key[], int result[], int resultLen,
int pos) {
return commonPrefixSearch(key, result, resultLen, pos, 0, 0);
}
public int commonPrefixSearch(char key[], int result[], int resultLen,
int pos, int len, int nodePos) {
if (len <= 0)
len = key.length;
if (nodePos <= 0)
nodePos = 0;
int b = array[nodePos].base;
int num = 0;
int n;
int p;
for (int i = pos; i < len; i++) {
p = b;
n = array[p].base;
if (b == array[p].check && n < 0) {
if (num < resultLen)
result[num] = -n - 1;
num++;
}
p = b + (int) (key[i]) + 1;
if (b == array[p].check)
b = array[p].base;
else
return num;
}
p = b;
n = array[p].base;
if (b == array[p].check && n < 0) {
if (num < resultLen)
result[num] = -n - 1;
num++;
}
return num;
}
// debug
public void dump() {
for (int i = 0; i < size; i++) {
System.err.println("i: " + i + " [" + array[i].base + ", "
+ array[i].check + "]");
}
}
}
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
public class DartsMapTest {
private static final String[] KEYS = { "ALGOL", "ANSI", "ARCO", "ARPA",
"ARPANET", "ASCII" };
private static final String[] VALUES;
static {
String[] values = new String[KEYS.length];
for (int i = 0; i < KEYS.length; i++) {
values[i] = KEYS[i].toLowerCase();
}
VALUES = values;
}
@Test
public void キーとして存在しない値を指定してgetしてみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
assertNull(map.get("APPARE"));
}
@Test
public void キーとして存在する値を指定してgetしてみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
assertEquals("arpa", map.get("ARPA"));
assertEquals("arpanet", map.get("ARPANET"));
assertEquals("ascii", map.get("ASCII"));
assertEquals("algol", map.get("ALGOL"));
assertEquals("ansi", map.get("ANSI"));
assertEquals("arco", map.get("ARCO"));
}
@Test
public void 接頭辞として存在しない文字列で共通接頭辞検索してみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
List<String> result = map.commonPrefixSearch("APPARE");
assertEquals(0, result.size());
}
@Test
public void 接頭辞として1つ存在する文字列で共通接頭辞検索してみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
List<String> result = map.commonPrefixSearch("ALGOLOGIC");
assertEquals(1, result.size());
assertEquals("algol", result.get(0));
}
@Test
public void 接頭辞として2つ存在する文字列で共通接頭辞検索してみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
List<String> result = map.commonPrefixSearch("ARPANET, Internet");
assertEquals(2, result.size());
assertTrue(result.contains("arpa"));
assertTrue(result.contains("arpanet"));
}
@Test
public void キーに合致する文字列で共通接頭辞検索してみる() {
DartsMap<String> map = DartsMap.<String> create(Arrays.asList(KEYS),
Arrays.asList(VALUES));
List<String> result = map.commonPrefixSearch("ASCII");
assertEquals(1, result.size());
assertEquals("ascii", result.get(0));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment