#!/usr/bin/env ruby
require 'json'
require 'socket'

class SendReceive
  COMMANDS = %w[skel base incremental transfer cancel].freeze

  def initialize(pool, key_name, cmdline, connection)
    @key_pool = pool
    @key_name = key_name
    @client_ip, = connection
    error! unless cmdline

    args = cmdline.split
    error! if args.count < 2 || args[0] != 'receive' || !COMMANDS.include?(args[1])

    @args = args[2..]
    connect
    method(args[1].to_sym).call
  end

  protected

  attr_reader :key_pool, :key_name, :client_ip, :args, :client

  def skel
    to_pool =
      case args[0]
      when nil, '-'
        nil
      else
        args[0]
      end

    send_cmd(
      :receive_skel,
      pool: to_pool,
      passphrase: args[1],
      client_ip:,
      key_pool:,
      key_name:
    )

    if recv_resp! != 'continue'
      warn 'Error: invalid response'
      exit(false)
    end

    send_stdin

    # osctld will generate a unique token for this send/receive, which we pass
    # to the sending side for further identification
    puts recv_resp!
  end

  def base
    error! if args.count < 2
    send_cmd(
      :receive_base,
      key_pool:,
      key_name:,
      token: parse_token,
      dataset: args[1],
      snapshot: args[2]
    )

    if recv_resp! != 'continue'
      warn 'Error: invalid response'
      exit(false)
    end

    send_stdin
    recv_resp!
  end

  def incremental
    error! if args.count < 2
    send_cmd(
      :receive_incremental,
      key_pool:,
      key_name:,
      token: parse_token,
      dataset: args[1],
      snapshot: args[2]
    )

    if recv_resp! != 'continue'
      warn 'Error: invalid response'
      exit(false)
    end

    send_stdin
    recv_resp!
  end

  def transfer
    send_cmd(
      :receive_transfer,
      key_pool:,
      key_name:,
      token: parse_token,
      start: args[1] == 'start'
    )
    recv_resp!
  end

  def cancel
    send_cmd(
      :receive_cancel,
      key_pool:,
      key_name:,
      token: parse_token
    )
    recv_resp!
  end

  def connect
    @client = UNIXSocket.new('/run/osctl/send-receive/control.sock')
  end

  def send_cmd(cmd, opts = {})
    client.puts({ cmd:, opts: }.to_json)
  end

  def send_stdin
    client.send_io($stdin)
  end

  def recv_msg
    JSON.parse(client.readline, symbolize_names: true)
  end

  def recv_resp!
    msg = recv_msg

    unless msg[:status]
      warn "Error: #{msg[:message]}"
      exit(false)
    end

    msg[:response]
  end

  def parse_token
    error! unless args[0]
    args[0]
  end

  def usage
    warn <<~END
      Usage:
        receive skel [pool|- [passphrase]]
        receive base <token> <dataset> [snapshot]
        receive incremental <token> <dataset> [snapshot]
        receive transfer <token> [start]
        receive cancel <token>
    END
  end

  def error!
    usage
    exit(false)
  end
end

if ARGV.length != 2
  warn 'Usage: $0 <pool> <key name>'
  exit(false)
end

SendReceive.new(
  ARGV[0],
  ARGV[1],
  ENV.fetch('SSH_ORIGINAL_COMMAND', nil),
  ENV['SSH_CONNECTION'].split
)
