Torch.rbを使ってみた
PyTorchのRuby版?のTorch.rbを使って関連記事を推測できるようなのをなんとなく作ってみた
train.rb
# frozen_string_literal: true
require "torch"
require "natto"
class DocumentRepository
VOCAB_LIMIT = 50_000
attr_reader :db, :vocab, :embed_weight
def initialize(model_file)
@model_file = model_file
checkpoint = File.exist?(model_file) ? (Torch.load(model_file) rescue {}) : {} # rubocop:disable Style/RescueModifier
@db = checkpoint["db"] || {}
@vocab = checkpoint["vocab"] || {}
@embed_weight = checkpoint["embed_weight"]
end
def add_article(title, body)
tokens = tokenize(body)
update_vocab(tokens)
new_id = ((@db.keys.map(&:to_i).max || 0) + 1).to_s
@db[new_id] = { "title" => title, "content" => body, "tokens" => tokens }
new_id
end
def build_jaccard_matrix
v_size = [vocab.size, 1].max
word_docs = Hash.new {|h, k| h[k] = [] }
@db.each {|post_id, data| data["tokens"].each {|w| word_docs[w] << post_id if @vocab[w] } }
word_docs.each_value(&:uniq!)
raw_matrix = Array.new(v_size) { Array.new(v_size, 0.0) }
@vocab.each do |w1, idx1|
@vocab.each do |w2, idx2|
next if idx1 > idx2
union = (word_docs[w1] | word_docs[w2]).size.to_f
jaccard = union > 0 ? (word_docs[w1] & word_docs[w2]).size.to_f / union : 0.0
raw_matrix[idx1][idx2] = jaccard
raw_matrix[idx2][idx1] = jaccard
end
end
raw_matrix
end
def update_relations(normalized_weight)
@db.each_value do |data|
token_ids = data["tokens"].filter_map {|w| @vocab[w] }
data["vector"] = token_ids.empty? ? nil : normalized_weight[Torch.tensor(token_ids)].mean(0)
data["related"] = {}
end
ids = @db.keys
ids.each_with_index do |post_id, idx|
vec1 = @db[post_id]["vector"] or next
((idx + 1)...ids.size).each do |j|
other_id = ids[j]
vec2 = @db[other_id]["vector"] or next
cos_sim = Torch.dot(vec1, vec2).item
@db[post_id]["related"][other_id] = cos_sim
@db[other_id]["related"][post_id] = cos_sim
end
end
end
def save(final_embed_weight)
Torch.save({ "db" => @db, "vocab" => @vocab, "embed_weight" => final_embed_weight }, @model_file)
end
private
def tokenize(body)
tokens = []
Natto::MeCab.new.parse(body).each_line do |line|
break if line.strip == "EOS"
tokens << Regexp.last_match(1).downcase if line =~ /\A([[:alnum:]]{4,}|[^\t\n]{2,})\t名詞,(?:固有名詞|一般|サ変接続),[^\n]*/
end
tokens
end
def update_vocab(tokens)
tokens.each do |word|
@vocab[word] = @vocab.size if @vocab[word].nil? && @vocab.size < VOCAB_LIMIT
end
end
end
class EmbeddingTrainer # rubocop:disable Style/OneClassPerFile
DIM = 32
def initialize(v_size, old_weight = nil)
@v_size = v_size
@weight_tensor = prepare_weight_tensor(old_weight)
@embed = Torch::NN::Embedding.new(@v_size, DIM)
@embed.load_state_dict({ "weight" => @weight_tensor })
end
def train(jaccard_matrix, epochs: 50)
target_tensor = Torch.tensor(jaccard_matrix, dtype: :float32)
optimizer = Torch::Optim::Adam.new(@embed.parameters, lr: 0.05)
criterion = Torch::NN::MSELoss.new
epochs.times do
weights = @embed.weight
norms = weights.norm(2, 1, true) + 1e-8
normalized = weights / norms
pred = Torch.mm(normalized, normalized.t)
loss = criterion.call(pred, target_tensor)
optimizer.zero_grad
loss.backward
optimizer.step
end
end
def embed_weight
@embed.weight.data
end
def normalized_weight
@embed.weight / (@embed.weight.norm(2, 1, true) + 1e-8)
end
private
def prepare_weight_tensor(old_weight)
return Torch.randn(@v_size, DIM) * 0.1 unless old_weight
old_v_size = old_weight.size(0)
return old_weight if @v_size <= old_v_size
new_weight = Torch.randn(@v_size - old_v_size, DIM) * 0.1
Torch.cat([old_weight, new_weight], 0)
end
end
title = ARGV[0] || ""
body = ARGV[1] or raise "Not found: data"
repo = DocumentRepository.new("model.pth")
new_id = repo.add_article(title, body)
jaccard_matrix = repo.build_jaccard_matrix
v_size = [repo.vocab.size, 1].max
trainer = EmbeddingTrainer.new(v_size, repo.embed_weight)
trainer.train(jaccard_matrix, epochs: 50)
repo.update_relations(trainer.normalized_weight)
repo.save(trainer.embed_weight)
puts "最新追加ID: #{new_id}(タイトル: #{title})"
データを形態素で解析して単語の相関行列を生成、Torchで学習処理を行った後コサイン類似度(または内積)で関連度を取得し保存
てな感じで
read.rb
train.rbで作ったmodel.pthを使ってそれを読み取って取得するだけ
#!/usr/bin/env ruby
# frozen_string_literal: true
require "torch"
class Recommender
def initialize(model_file)
raise "Not found: #{model_file}" unless File.exist?(model_file)
checkpoint = Torch.load(model_file)
@db = checkpoint["db"] || {}
end
def find_post(id)
@db[id.to_s] or raise "Error: Not found #{id}"
end
def related_posts(id)
target = find_post(id)
posts = (target["related"] || {}).filter_map do |post_id, score|
next if @db[post_id].nil?
{ id: post_id, score: score, title: @db[post_id]["title"] }
end
posts.max_by(10) {|r| r[:score] }.sort_by {|r| -r[:score] }
end
end
target_id = ARGV[0] or raise "Not found: id"
engine = Recommender.new("model.pth")
target = engine.find_post(target_id)
puts "ID:#{target_id} #{target['title']}\n\n"
related_list = engine.related_posts(target_id)
related_list.each do |res|
puts "[ID:#{res[:id]}] (#{res[:score]})\t#{res[:title]}"
end
結果

という感じでTorchの機械学習を使ってこういうような事もできるよってことでw