/////////////////////////////////////////////////////////////////////
//
// See mod9-asr.h for comments.
//

#include "mod9-asr.h"

#include <mod9-io.h>

// It should be fairly easy to drop in another json library by
// changing the next two lines and changing the various json::
// calls in the code (e.g. json::parse, json::parse_error).

#include <json.hpp>
using json = nlohmann::json;

#include <cstdio>

#include <arpa/inet.h>
#include <cerrno>
#include <netdb.h>
#include <poll.h>
#include <sys/socket.h>
#include <unistd.h>

// Local functions

static std::string clean_nlohmann_parse_exception_message(std::string msg);


// Pointer-to-implementation (pimpl) pattern that hides various details
// of the implementation. This would allow us to e.g. change to Boost::asio
// internally without changing the API.

struct mod9_asr::Connection::impl {
  impl(std::string hostname, std::string port) :
      hostname_(std::move(hostname)),
      port_(std::move(port)),
      sockfd_(-1) {

    struct addrinfo hints;
    struct addrinfo *ais;
    struct addrinfo *ai;

    memset(&hints, 0, sizeof(struct addrinfo));
    hints.ai_family = AF_UNSPEC; // IPv4 or IPv6
    hints.ai_socktype = SOCK_STREAM;

    int err = getaddrinfo(hostname_.c_str(), port_.c_str(), &hints, &ais);
    if (err != 0) {
      throw mod9_asr::ConnectError(
          std::string("Unable to get address information: ") + gai_strerror(err),
          hostname_, port_);
    }

    // getaddrinfo can return more than one addrinfo. Use the first
    // successful.

    for (ai = ais; ai != nullptr; ai = ai->ai_next) {
      sockfd_ = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
      if (sockfd_ == -1) {
        // Unable to create socket with given addrinfo. Try the next.
        continue;
      }
      if (connect(sockfd_, ai->ai_addr, ai->ai_addrlen) != -1) {
        // Success.
        break;
      }

      // If we got here, connect() failed but socket() succeeded. Need
      // to clean up.
      close(sockfd_);
      sockfd_ = -1;
    }

    // No longer need the addrinfos.
    freeaddrinfo(ais);

    // If we got to the end and no call to connect() succeeded, then
    // no socket will have been opened and ai will be nullptr and
    // sockfd_ will be -1.
    if (ai == nullptr || sockfd_ == -1) {
      throw mod9_asr::ConnectError("Unable to connect socket", hostname_, port_);
    }

    // Pass the socket to a mod9_asr::IO object. This allows use
    // of getline(). NOTE: Do not attempt to read from sockfd_ after
    // this call, as mod9_asr::IO does buffering and things will get
    // desynchronized. Only write to sockfd_, not to mod9_io_.

    mod9_io_.setfd(sockfd_);

    // Indicate that the next read will be the first read. This is
    // used to detect if we're connected to a non-Engine host/port.
    first_read_ = true;
  }  // impl()

  virtual ~impl() {
    close(sockfd_);
  }

  // Low level abort. Just shuts down the socket. The caller
  // must handle the Engine protocol.
  int abort() {
    return shutdown(sockfd_, SHUT_RDWR);
  }

  // Return true if the socket has hung up (specifically, if half- or
  // full- hangup). This will occur if the server disconnected.
  //
  // NOTE: This function will generally only be able to detect a
  // disconnect on Linux systems. See comment below for details.
  bool server_disconnected() {
    struct pollfd pfd;
    pfd.fd = sockfd_;

  // Detecting a half-disconnect with POLLRDHUP is not POSIX
  // compliant. To the best of my knowledge, there's no way to do it
  // portably. If we need to support detecting half-disconnect on
  // other platforms, it will at a minimum require additional #ifdefs,
  // and it may not be possible. Since a full disconnect won't happen
  // until the client disconnects, server_disconnected() isn't too
  // useful on non-Linux.

    int events;
#ifdef POLLRDHUP
    events = POLLRDHUP | POLLHUP;
#else
    events = POLLHUP;
#endif

    pfd.events = events;

    int pollretval = poll(&pfd, 1, 0);
    if (pollretval == -1) {
      throw mod9_asr::ConnectionError("Error checking status of the server",
                                      hostname_, port_);
    }

    return pfd.revents & events;
  } // server_disconnected()

private:

  std::string hostname_;
  std::string port_;
  int sockfd_;

  // The mod9_asr::IO class wraps the socket and allows getline with timeout on it.
  mod9_asr::IO mod9_io_;

  // The very first read is required to be valid json with no error.
  // This boolean is true before the first read, and false after.
  bool first_read_;

  friend class mod9_asr::Connection::Connection;
};  // struct mod9_asr::Connection::impl

mod9_asr::Connection::Connection()  = default;

mod9_asr::Connection::Connection(const std::string& hostname, const std::string& port) {
  open(hostname, port);
}  // mod9_asr::Connection::Connection()

// Note: RAII handles closing the socket and freeing impl_.
mod9_asr::Connection::~Connection()  = default;

void mod9_asr::Connection::abort() {
  std::lock_guard<std::mutex> lock(abort_mutex_);

  if (!is_open_) {
    throw mod9_asr::AbortError("Called abort() on a Connection that is not open",
                               "unknown", "unknown");
  }

  try {
    write_data("END-OF-FILE");
  } catch (const mod9_asr::WriteError &e) {
    throw mod9_asr::AbortError(
        std::string("Error writing end of file marker to Engine while calling abort(): ") + e.what(),
        impl_->hostname_, impl_->port_);
  }

  errno = 0;
  if (impl_->abort() != 0) {
    int err = errno; // errno can be a macro, so copy it.
    throw mod9_asr::AbortError(std::string("Error while calling abort(): ") + ::strerror(err),
                               impl_->hostname_, impl_->port_);
  }
  is_aborted_ = true;
} // mod9_asr::Connection::abort()

void mod9_asr::Connection::open(const std::string& hostname, const std::string& port) {
  if (is_open_) {
    throw mod9_asr::ConnectError("Attempted to open an already open connection",
                                 hostname, port);
  }
  impl_ = std::unique_ptr<mod9_asr::Connection::impl>(
      new mod9_asr::Connection::impl(hostname, port));

  is_open_ = true;
} // mod9_asr::Connection::open()

// Utility routines for read methods.

// This should be called before a read. It throws appropriate
// exceptions as needed.
void mod9_asr::Connection::preread() {
  if (!is_open_) {
    throw mod9_asr::ReadError("Attempted to read from a Connection that is not open",
                              "unknown", "unknown");
  }
  if (is_aborted_) {
    throw mod9_asr::ReadAbort(impl_->hostname_, impl_->port_);
  }
} //  mod9_asr::Connection::preread()

// This should be called if a read returned EOF. It throws appropriate
// exceptions as needed.

void mod9_asr::Connection::after_read_eof() {
  if (is_aborted_) {
    throw mod9_asr::ReadAbort(impl_->hostname_, impl_->port_);
  }
  if (impl_->first_read_) {
    throw mod9_asr::ReadError("Engine failed to return expected first reply",
                              impl_->hostname_, impl_->port_);
  }
} // mod9_asr::Connection::after_read_eof()

// This should be called after a successful read (that is, not EOF).

void mod9_asr::Connection::postread(std::string* reply) {
  // Something is seriously wrong if the Engine returns a string that
  // can't be parsed as json. So we attempt to parse here, and throw
  // an exception if it fails. This is slightly wasteful, since the
  // caller will almost certainly immediately parse the string, but
  // allows the caller to use any json library.
  json replyjson;
  try {
    replyjson = json::parse(*reply);
  } catch (const json::parse_error& e) {
    // Should this return more details? It's pretty exceptional.
    throw mod9_asr::ReadError("Engine returned JSON that failed to parse",
                              impl_->hostname_, impl_->port_);
  }
  if (!replyjson.is_object()) {
    throw mod9_asr::ReadError("Engine returned a non-object JSON",
                              impl_->hostname_, impl_->port_);
  }
  impl_->first_read_ = false;
} // mod9_asr::Connection::postread()

// User level read() methods.

// Read with no timeout.
bool mod9_asr::Connection::read(std::string* reply) {
  preread();
  // Note this will block until it sees a newline or the socket closes.
  if (!impl_->mod9_io_.getline(*reply)) {
    // It was EOF. Send appropriate exceptions if needed.
    after_read_eof();
    return false;
  }
  postread(reply);
  return true;
}  // mod9_asr::Connection::read(std::string*)

// Read with timeout, throws exception.
bool mod9_asr::Connection::read(std::string* reply, float timeout) {
  if (timeout < 0) {
    throw mod9_asr::ReadError("Timeout cannot be negative.", impl_->hostname_, impl_->port_);
  }

  preread();

  try {
    if (!impl_->mod9_io_.getline(*reply, static_cast<int>(timeout*1000))) {
      after_read_eof();
      return false;
    }
  } catch (const mod9_asr::IO::TimeoutException& e) {
    throw mod9_asr::ReadTimeout("Call to read() timed out.", impl_->hostname_, impl_->port_);
  }
  postread(reply);
  return true;
}  // mod9_asr::Connection::read(std::string*, float)

// Read with timeout, sets passed boolean.
bool mod9_asr::Connection::read(std::string* reply, float timeout, bool* timedout) {
  *timedout = false;

  preread();

  if (!impl_->mod9_io_.getline(*reply, static_cast<int>(timeout*1000), *timedout)) {
    after_read_eof();
    return false;
  }
  postread(reply);
  return true;
}  // mod9_asr::Connection::read(std::string*, float, bool*)

// Write data to the connection. Note that write_json() and the other
// overload of write_data() call this method.
void mod9_asr::Connection::write_data(const void* buf, std::size_t count) {
  if (!is_open_) {
    throw mod9_asr::WriteError("Attempted to write to a Connection that is not open",
                               "unknown", "unknown");
  }
  if (is_aborted_) {
    throw mod9_asr::WriteAbort(impl_->hostname_, impl_->port_);
  }
  if (impl_->server_disconnected()) {
    throw mod9_asr::WriteError("The server disconnected during writing",
                               impl_->hostname_, impl_->port_);
  }

  // Do the actual write.

  std::size_t nwritten;
  {
    std::lock_guard<std::mutex> lock(write_mutex_);
    nwritten = ::write(impl_->sockfd_, buf, count);
  }

  if (is_aborted_) {
    throw mod9_asr::WriteAbort(impl_->hostname_, impl_->port_);
  }
  if (nwritten != count) {
    throw mod9_asr::WriteError(std::string("Writing ") + std::to_string(count) + " bytes failed",
                               impl_->hostname_, impl_->port_);
  }
} // mod9_asr::Connection::write_data(const void* buf, std::size_t count)

void mod9_asr::Connection::write_data(const std::string& str) {
  write_data(str.c_str(), str.length());
}  // mod9_asr::Connection::write_data(const std::string& str)

void mod9_asr::Connection::write_json(const std::string& req) {
  // Convert to json and then back to a string. This ensures that the
  // string will be one line of valid json. Also, report if passed a
  // json string that fails to parse.
  json reqjson;
  try {
    reqjson = json::parse(req);
  } catch (const json::parse_error &e) {
    throw mod9_asr::WriteError(clean_nlohmann_parse_exception_message(e.what()),
                               impl_->hostname_, impl_->port_);
  }

  // Add a newline
  std::string buf = reqjson.dump() + "\n";

  write_data(buf.c_str(), buf.length());
}  // mod9_asr::Connection::write_json()

// Local utility function to rewrite the parse exception from nlohmann
// json. The message nlohmann json returns starts something like:
//
// [json.exception.parse_error.101] parse error at ...
//
// I don't like the first part, as it's too specific to nlohmann
// json. Remove everything up to and including the first space and
// uppercase the first letter so the message will start with:
//
// Parse error at ...

std::string clean_nlohmann_parse_exception_message(std::string msg) {
  msg.erase(0, msg.find_first_of(' ')+1);
  msg[0] = std::toupper(msg[0]);
  return msg;
}
