Skip to content

Instantly share code, notes, and snippets.

@xiaodaigh
Created June 4, 2021 12:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xiaodaigh/5e5710d6e4233c9d2334aa7773b83b0e to your computer and use it in GitHub Desktop.
Save xiaodaigh/5e5710d6e4233c9d2334aa7773b83b0e to your computer and use it in GitHub Desktop.
Tang Dynasty poetry
using Gumbo, Cascadia, HTTP
using Serialization
urls= ["https://www.shicimingju.com/shicimark/tangshisanbaishou.html"]
urls = vcat(urls, ["https://www.shicimingju.com/shicimark/tangshisanbaishou_$(i)_0__0.html" for i in 2:16])
function get_chars(poem::Vector{<:AbstractString})::Set{Char}
mapreduce(Set, union, poem)
end
function download_poems(url, i)
response = response = HTTP.get(url)
# the body is the html content
parsed_html = parsehtml(String(response.body))
poems = eachmatch(sel"div.shici_content", parsed_html.root) |> collect .|> nodeText
poems_cleaned = split.(strip.(poems), Ref([',','。','!',';','?',' ',',','?','\n']))
serialize("c:/data/poems/$i.jls", poems_cleaned)
#mapreduce(get_chars, union, poems_cleaned)
end
@time for (i, url) in enumerate(urls)
download_poems(url, i)
end
function get_chars_from_serialized_poems(i)
poems_cleaned = deserialize("c:/data/poems/$i.jls")
mapreduce(get_chars, union, poems_cleaned)
end
const UNIQUE_CHARS = mapreduce(get_chars_from_serialized_poems, union, 1:16) |> collect |> sort!
serialize("UNIQUE_CHARS", UNIQUE_CHARS)
function make_stanza_training(stanza)
cs = Int16.(indexin(collect(stanza), UNIQUE_CHARS))
end
function make_poem_training(poem)
chars = filter(x -> length(x)>0, map(make_stanza_training, poem))
mapreduce(chars1->chars1[1:end-1], vcat, chars), mapreduce(chars1->chars1[2:end], vcat, chars)
end
function make_poems_training(poems)
x = map(make_poem_training, poems)
mapreduce(x->x[1], vcat, x), mapreduce(x->x[2], vcat, x)
end
function make_data(i)
poems = deserialize("c:/data/poems/$i.jls")
make_poems_training(poems)
end
tmp = map(make_data, 1:16)
x = mapreduce(x->x[1], vcat, tmp)
y = mapreduce(x->x[2], vcat, tmp)
serialize("x", x)
serialize("y", y)
()->println("training $(loss(xmc, ymc))"), 10))
using Serialization
using Flux
using Flux: logitbinarycrossentropy, throttle, binarycrossentropy
using CUDA
CUDA.allowscalar(false)
x = deserialize("x")
y = deserialize("y")
using SparseArrays
xm = sparse(x, 1:length(x), 1.0, length(x), length(x));
ym = sparse(y, 1:length(y), Int32(1), length(y), length(y));
xmc=cu(xm |> collect)
ymc=cu(ym |> collect)
model = Chain(
Dense(length(x), 32),
Dense(32, length(x)),
) |> gpu
model(xmc)
loss(xmc, ymc) = logitbinarycrossentropy(model(xmc), ymc)
CUDA.@time meh = loss(xmc, ymc)
opt = ADAM()
using Flux.Data: DataLoader
dl = DataLoader((xmc, ymc), batchsize=256, shuffle=true)
# @time Flux.train!(loss, params(model), dl, opt, cb = throttle(()->print(loss(xmc, ymc)), 10))
# @time Flux.@epochs 2 Flux.train!(loss, params(model), dl, opt, cb = throttle(()->print(loss(xmc, ymc)), 10))
# @time Flux.@epochs 8 Flux.train!(loss, params(model), dl, opt, cb = throttle(()->print(loss(xmc, ymc)), 10))
# @time Flux.@epochs 88 Flux.train!(loss, params(model), dl, opt, cb = throttle(()->print(loss(xmc, ymc)), 10))
@time Flux.@epochs 888 Flux.train!(loss, params(model), dl, opt, cb = throttle(()->print(loss(xmc, ymc)), 10))
serialize("model", model)
UNIQUE_CHARS = deserialize("UNIQUE_CHARS")
using CSV, DataFrames
CSV.write("ok.csv", DataFrame(ok = UNIQUE_CHARS))
x = zeros(Float64, length(x))
x[2] = 1.0
cmodel = cpu(model)
findmax(p)[2]
const L = length(x)
using StatsBase
function write_a_stanza(char::Char, upto=1, jue=7)
print(char)
id = indexin([char], UNIQUE_CHARS)[1]
x = zeros(L)
x[id] = 1.0
ecx = exp.(cmodel(x))
p = ecx ./ sum(ecx)
next_id = sample(1:L, Weights(p))
next_char = UNIQUE_CHARS[next_id]
if upto == jue
println()
return
else
return write_a_stanza(next_char, upto+1, jue)
end
end
write_a_stanza(char)
begin
write_a_stanza('老')
write_a_stanza('坡')
write_a_stanza('真')
write_a_stanza('好')
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment