#!/usr/bin/env ruby
# Run a container within a PTY and forward its input/output via UNIX socket.
#
# Usage: $0 <title> <socket> lxc-start...
#
# `socket` is a path where a UNIX server socket will be created. This wrapper
# accepts one client on the UNIX server, accepts commands and forwards data
# to/from the wrapped process.
#
# Data read from the client are expected to be in JSON, one command on every
# line.
#
#   {
#     "keys": base64 encoded input data,
#     "rows": terminal height,
#     "cols": terminal width
#   }
#
# The command can contain just `keys`, `rows` and `cols` together, or all three
# keys.
#
# Data sent to the client are in a raw form, just as the wrapped process writes
# them.
require 'base64'
require 'json'
require 'pty'
require 'socket'
require 'timeout'

class Wrapper
  def initialize(args)
    ctid = args.shift
    @socket = args.shift
    @cmd = args
    @current_rows = 25
    @current_cols = 80
    @cmd_buf = ''

    Process.setproctitle("osctld: CT #{ctid}")
  end

  def run
    @server = UNIXServer.new(socket)
    r_pty, w_pty, @pid = PTY.spawn(*@cmd)

    catch(:stop) do
      loop do
        work(r_pty, w_pty)
      end
    end

    @server.close
    File.unlink(socket)
    terminate(pid)
  end

  protected

  attr_reader :server, :socket, :client, :pid, :cmd_buf, :current_rows, :current_cols

  def work(r_pty, w_pty)
    rs, = IO.select([server, client, r_pty].compact)

    rs.each do |r|
      case r
      when server # new client connects
        c = server.accept

        if client_valid?(c)
          client.close if client
          @client = c

          # Send buffer contents, if any
          send_to_client('')

        else
          c.close
        end

      when client # command from client
        str = read_from_client
        stop if str.nil?

        cmd_buf << str

        while (i = cmd_buf.index("\n"))
          t = cmd_buf[0..i]
          @cmd_buf = cmd_buf[i + 1..]

          begin
            cmd = JSON.parse(t, symbolize_names: true)
          rescue JSON::ParserError
            next
          end

          if cmd[:keys]
            begin
              keys = Base64.strict_decode64(cmd[:keys])
            rescue ArgumentError
              next
            end

            w_pty.write(keys)
            w_pty.flush
          end

          next unless cmd[:rows] && cmd[:cols]

          new_rows = cmd[:rows].to_i
          new_cols = cmd[:cols].to_i

          next unless new_rows > 0 && new_cols > 0 \
            && (new_rows != current_rows || new_cols != current_cols)

          @current_rows = new_rows
          @current_cols = new_cols

          `stty -F #{r_pty.path} rows #{current_rows} cols #{current_cols}`
          Process.kill('WINCH', pid)
        end

      when r_pty # the container has written something to the console
        buf = read_nonblock(r_pty)
        stop if buf.nil?
        send_to_client(buf)
      end
    end
  end

  def stop
    throw(:stop)
  end

  def terminate(pid)
    status = nil

    begin
      Timeout.timeout(3) do
        _, status = Process.wait2(pid)
      end
    rescue Timeout::Error
      Process.kill('KILL', pid)
      _, status = Process.wait2(pid)
    end
  end

  def read_nonblock(io)
    io.read_nonblock(4096)
  rescue IO::WaitReadable
    ''
  rescue Errno::EIO
    nil
  end

  # Only connections from root are accepted
  def client_valid?(sock)
    cred = sock.getsockopt(Socket::SOL_SOCKET, Socket::SO_PEERCRED)
    pid, uid, gid = cred.unpack('LLL')
    uid == 0
  end

  def read_from_client
    read_nonblock(client)
  rescue SystemCallError, EOFError
    # osctld has crashed/exited
    @client.close
    @client = nil
    ''
  end

  def send_to_client(data)
    client.send(data, 0) if client && !data.empty?
  rescue SystemCallError
    # osctld has crashed/exited
    @client.close
    @client = nil
  end
end

begin
  w = Wrapper.new(ARGV)
  w.run
rescue Exception => e # rubocop:disable Lint/RescueException
  File.open('/tmp/dafuq', 'w') do |f|
    f.puts(e.class)
    f.puts(e.message)
    f.puts(e.backtrace.join("\n"))
  end
end
