Created
March 3, 2024 06:26
-
-
Save asbisen/70c7c030c092558d461beeb3f25c1cf5 to your computer and use it in GitHub Desktop.
recursive text splitter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
split_recursive( data; delimiters, chunk_size ) | |
Recursively split the data using the provided delimiters (default: ["\n\n", "\n", " ", " "]) | |
in ordered manner and chunk_size (default: 4096). The function will split the data using the | |
first delimiter and then recursively split the chunks using the next delimiter if the chunk | |
size is larger than the provided chunk_size. If all delimiters are exhausted then the function | |
will split the data based on the chunk_size. The function will merge consecutive chunks if they | |
are smaller than chunk_size. | |
# Arguments | |
- data::AbstractString: The data to be split | |
- delimiters::Array{String}: The delimiters to be used for splitting the data | |
- chunk_size::Int: The size of the chunk | |
""" | |
function split_recursive( data::AbstractString; | |
delimiters=["\n\n", "\n", " ", " "], | |
chunk_size=4096 ) | |
chunked_data = [] # store chunked data | |
# split the data using the first delimiter | |
delim = delimiters[1] | |
chunks = split(data, delim; keepempty=false) | |
# individually process each chunk | |
for (idx,c) in enumerate(chunks) | |
if length(c) <= chunk_size | |
push!(chunked_data, c) | |
elseif (length(c) > chunk_size) && (length(delimiters[2:end]) > 0) # split using next delim if larger than chunk_size | |
new_delimiters = delimiters[2:end] | |
r = text_splitter(c, delimiters = new_delimiters, chunk_size=chunk_size) | |
chunked_data = cat(chunked_data, r, dims=1) | |
elseif (length(c) > chunk_size) && (length(delimiters[2:end]) == 0) # if all delims are exhausted then just split on length | |
r = [c[i:min(i+chunk_size-1,end)] for i in 1:chunk_size:length(c)] | |
chunked_data = cat(chunked_data, r, dims=1) | |
end | |
end | |
# merge consecutive chunks if they are smaller than chunk_size | |
merged_data = [] | |
for (idx,c) in enumerate(chunked_data) | |
prev_chunk_len = (length(merged_data) == 0) ? chunk_size + 1 : length(merged_data[end]) | |
if length(c) + prev_chunk_len <= chunk_size # merge if the chunks are smaller than chunk_size | |
merged_data[end] = join([merged_data[end], c], " ") | |
else | |
push!(merged_data, c) | |
end | |
end | |
merged_data | |
end | |
# Download some text | |
download("https://gutenberg.net.au/ebooks01/0100021.txt", "1984.txt") | |
data = read("1984.txt", String) | |
# Split the text | |
r = split_recursive(data, chunk_size=1024) | |
for (idx, d) in enumerate(r) | |
println("\n\nChunk: ", idx, " Length: ", length(d)) | |
println(d) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment