Torch.rbを使ってみた


Ruby

PyTorchのRuby版?のTorch.rbを使って関連記事を推測できるようなのをなんとなく作ってみた

train.rb

# frozen_string_literal: true

require "torch"
require "natto"

class DocumentRepository
  VOCAB_LIMIT = 50_000

  attr_reader :db, :vocab, :embed_weight

  def initialize(model_file)
    @model_file   = model_file
    checkpoint    = File.exist?(model_file) ? (Torch.load(model_file) rescue {}) : {} # rubocop:disable Style/RescueModifier
    @db           = checkpoint["db"] || {}
    @vocab        = checkpoint["vocab"] || {}
    @embed_weight = checkpoint["embed_weight"]
  end

  def add_article(title, body)
    tokens = tokenize(body)
    update_vocab(tokens)

    new_id = ((@db.keys.map(&:to_i).max || 0) + 1).to_s
    @db[new_id] = { "title" => title, "content" => body, "tokens" => tokens }
    new_id
  end

  def build_jaccard_matrix
    v_size = [vocab.size, 1].max
    word_docs = Hash.new {|h, k| h[k] = [] }
    @db.each {|post_id, data| data["tokens"].each {|w| word_docs[w] << post_id if @vocab[w] } }
    word_docs.each_value(&:uniq!)

    raw_matrix = Array.new(v_size) { Array.new(v_size, 0.0) }
    @vocab.each do |w1, idx1|
      @vocab.each do |w2, idx2|
        next if idx1 > idx2

        union = (word_docs[w1] | word_docs[w2]).size.to_f
        jaccard = union > 0 ? (word_docs[w1] & word_docs[w2]).size.to_f / union : 0.0
        raw_matrix[idx1][idx2] = jaccard
        raw_matrix[idx2][idx1] = jaccard
      end
    end
    raw_matrix
  end

  def update_relations(normalized_weight)
    @db.each_value do |data|
      token_ids = data["tokens"].filter_map {|w| @vocab[w] }
      data["vector"] = token_ids.empty? ? nil : normalized_weight[Torch.tensor(token_ids)].mean(0)
      data["related"] = {}
    end

    ids = @db.keys
    ids.each_with_index do |post_id, idx|
      vec1 = @db[post_id]["vector"] or next

      ((idx + 1)...ids.size).each do |j|
        other_id = ids[j]
        vec2 = @db[other_id]["vector"] or next

        cos_sim = Torch.dot(vec1, vec2).item
        @db[post_id]["related"][other_id] = cos_sim
        @db[other_id]["related"][post_id] = cos_sim
      end
    end
  end

  def save(final_embed_weight)
    Torch.save({ "db" => @db, "vocab" => @vocab, "embed_weight" => final_embed_weight }, @model_file)
  end

  private

  def tokenize(body)
    tokens = []
    Natto::MeCab.new.parse(body).each_line do |line|
      break if line.strip == "EOS"

      tokens << Regexp.last_match(1).downcase if line =~ /\A([[:alnum:]]{4,}|[^\t\n]{2,})\t名詞,(?:固有名詞|一般|サ変接続),[^\n]*/
    end
    tokens
  end

  def update_vocab(tokens)
    tokens.each do |word|
      @vocab[word] = @vocab.size if @vocab[word].nil? && @vocab.size < VOCAB_LIMIT
    end
  end
end

class EmbeddingTrainer # rubocop:disable Style/OneClassPerFile
  DIM = 32

  def initialize(v_size, old_weight = nil)
    @v_size = v_size
    @weight_tensor = prepare_weight_tensor(old_weight)
    @embed = Torch::NN::Embedding.new(@v_size, DIM)
    @embed.load_state_dict({ "weight" => @weight_tensor })
  end

  def train(jaccard_matrix, epochs: 50)
    target_tensor = Torch.tensor(jaccard_matrix, dtype: :float32)
    optimizer     = Torch::Optim::Adam.new(@embed.parameters, lr: 0.05)
    criterion     = Torch::NN::MSELoss.new

    epochs.times do
      weights    = @embed.weight
      norms      = weights.norm(2, 1, true) + 1e-8
      normalized = weights / norms
      pred       = Torch.mm(normalized, normalized.t)
      loss       = criterion.call(pred, target_tensor)

      optimizer.zero_grad
      loss.backward
      optimizer.step
    end
  end

  def embed_weight
    @embed.weight.data
  end

  def normalized_weight
    @embed.weight / (@embed.weight.norm(2, 1, true) + 1e-8)
  end

  private

  def prepare_weight_tensor(old_weight)
    return Torch.randn(@v_size, DIM) * 0.1 unless old_weight

    old_v_size = old_weight.size(0)
    return old_weight if @v_size <= old_v_size

    new_weight = Torch.randn(@v_size - old_v_size, DIM) * 0.1
    Torch.cat([old_weight, new_weight], 0)
  end
end

title = ARGV[0] || ""
body  = ARGV[1] or raise "Not found: data"

repo = DocumentRepository.new("model.pth")
new_id = repo.add_article(title, body)
jaccard_matrix = repo.build_jaccard_matrix
v_size = [repo.vocab.size, 1].max

trainer = EmbeddingTrainer.new(v_size, repo.embed_weight)
trainer.train(jaccard_matrix, epochs: 50)

repo.update_relations(trainer.normalized_weight)
repo.save(trainer.embed_weight)

puts "最新追加ID: #{new_id}(タイトル: #{title})"

データを形態素で解析して単語の相関行列を生成、Torchで学習処理を行った後コサイン類似度(または内積)で関連度を取得し保存

てな感じで

read.rb

train.rbで作ったmodel.pthを使ってそれを読み取って取得するだけ

#!/usr/bin/env ruby
# frozen_string_literal: true

require "torch"

class Recommender
  def initialize(model_file)
    raise "Not found: #{model_file}" unless File.exist?(model_file)

    checkpoint = Torch.load(model_file)
    @db = checkpoint["db"] || {}
  end

  def find_post(id)
    @db[id.to_s] or raise "Error: Not found #{id}"
  end

  def related_posts(id)
    target = find_post(id)

    posts = (target["related"] || {}).filter_map do |post_id, score|
      next if @db[post_id].nil?

      { id: post_id, score: score, title: @db[post_id]["title"] }
    end

    posts.max_by(10) {|r| r[:score] }.sort_by {|r| -r[:score] }
  end
end

target_id = ARGV[0] or raise "Not found: id"

engine = Recommender.new("model.pth")

target = engine.find_post(target_id)
puts "ID:#{target_id} #{target['title']}\n\n"

related_list = engine.related_posts(target_id)
related_list.each do |res|
  puts "[ID:#{res[:id]}] (#{res[:score]})\t#{res[:title]}"
end

結果

image

という感じでTorchの機械学習を使ってこういうような事もできるよってことでw

Table of Contents
train.rb read.rb 結果