//
//              Copyright 2004 (C) by UCAR
//
// Description:
//
//

#include <atdISFF/FieldDataSocket.h>
#include <atdUtil/Logger.h>
#include <atdISFF/Time.h>

#include <unistd.h>
#include <fcntl.h>
#include <string.h>

#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <resolv.h>
#include <arpa/inet.h>
#include <netdb.h>

#include <iostream>

using namespace atdISFF;
using namespace atdUtil;
using namespace std;

FieldDataSocket::FieldDataSocket(const string& hostname,int port): 
  _fd(-1), _hostname(hostname),_port(port),_inbuf(0),_outbuf(1),
  _statisticsPeriodMsec(300000),_statisticsTime(0),
  _currStatsIndex(0),_prevStatsIndex(1),
  _socketWriteErrors(0),_socketTempUnavailable(0),
  _maxMilliSeconds(0),_minMilliSeconds(0),
  connected(false)
{
  initBuffers();
}

FieldDataSocket::~FieldDataSocket() {
  close();
}

void FieldDataSocket::open() throw (UnknownHostException,IOException) {

  if (_fd >= 0) ::close(_fd);
  zeroStatistics();
  // At first report partial statistics
  _prevStatsIndex = _currStatsIndex;

  /*
   * Open the socket. Use ARPA Internet address format and stream sockets.
   * Format described in "socket.h".
   */

  if ((_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0)
    throw IOException(getName(),"open",errno);

  memset((char *) &_hostAddr, 0,sizeof(_hostAddr));
  _hostAddr.sin_family = AF_INET;
  _hostAddr.sin_port = htons(_port);

  struct hostent *svr = gethostbyname(_hostname.c_str());

  if (!svr) throw UnknownHostException(_hostname);

  bool success = false;
  for (int i = 0; svr->h_addr_list[i] != 0; i++) {
    memcpy(&_hostAddr.sin_addr.s_addr,
	  svr->h_addr_list[i],sizeof(_hostAddr.sin_addr.s_addr));
    if (::connect(_fd, (struct sockaddr *) &_hostAddr, sizeof(_hostAddr)) == 0) {
      success = true;
      break;
    }
  }

  if (!success)
    throw IOException(getName(), "connect",errno);

  inet_ntop(AF_INET,&_hostAddr.sin_addr, _hostStringAddr,sizeof _hostStringAddr);
  createBuffers();
  connected = true;
}

void FieldDataSocket::close() throw(IOException) {
  connected = false;
  if (_fd >= 0) ::close(_fd);
  _fd = -1;
  deleteBuffers();
}

std::string FieldDataSocket::getName() {
  char tmp[12];
  sprintf(tmp,"%d",_port);
  return std::string("socket to ") +
      _hostname + " port " + std::string(tmp);
}

void FieldDataSocket::initBuffers() {
  _bufs[0] = 0;
  _bufs[1] = 0;
  _bufPtrs[0] = 0;
  _bufPtrs[1] = 0;
  _bufN[0] = 0;
  _bufN[1] = 0;
}

void FieldDataSocket::deleteBuffers() {
  delete [] _bufs[0];
  delete [] _bufs[1];
  initBuffers();
}

void FieldDataSocket::createBuffers() throw (IOException) {

  // setting buffer sizes on sockets seems to be a bad idea.
  // They are already pretty big, and setting them to
  // something like 16384 or 32768 just seems to slow things down.
#ifdef NDAQ_SET_SOCKET_BUF_SIZE
  setSendBufferSize(16384);
#endif
  Logger::getInstance()->log(LOG_INFO,"%s: getSendBufferSize()=%d",
  	getName().c_str(),getSendBufferSize());
  _writelen = 16384;
  _bufsize = _writelen * 2;
  delete [] _bufs[0];
  delete [] _bufs[1];

  _bufs[0] = new char[_bufsize];
  _bufs[1] = new char[_bufsize];
  _bufPtrs[0] = _bufs[0];
  _bufPtrs[1] = _bufs[1];
  _bufN[0] = 0;
  _bufN[1] = 0;
}

void FieldDataSocket::setNonBlocking() throw(IOException) {
  if (_fd >= 0) {
    int flags;
    /* set io to non-blocking, so network jams don't hang us up */
    if ((flags = fcntl(_fd, F_GETFL, 0)) < 0)
      throw IOException(getName(),"fcntl(...,F_GETFL,...)",errno);
    flags |= O_NONBLOCK;
    if (fcntl(_fd, F_SETFL, flags) < 0)
	throw IOException(getName(),"fcntl(...,F_SETFL,O_NONBLOCK)",errno);
  }
}

void FieldDataSocket::setSendBufferSize(int size) throw (IOException) {
  int i = sizeof size;
  Logger::getInstance()->log(LOG_INFO,"%s: doing setSendBufferSize()=%d",
	getName().c_str(),size);
  if (setsockopt(_fd,SOL_SOCKET,SO_SNDBUF,(char *)&size,i) < 0)
    throw IOException(getName(),"setsockopt",errno);
  if (size < 512) {
      Logger::getInstance()->log(LOG_INFO,"%s: doing TCP_NODELAY",
	    getName().c_str());
      int opt = 1;
      socklen_t len = sizeof(opt);
      if (setsockopt(_fd,SOL_TCP,TCP_NODELAY,(char *)&opt,len) < 0)
	throw IOException(getName(),"setsockopt",errno);
  }
}

int FieldDataSocket::getSendBufferSize() throw (IOException) {
  int size;
  socklen_t i = sizeof size;
  if (getsockopt(_fd,SOL_SOCKET,SO_SNDBUF,(char *)&size,&i) < 0)
    throw IOException(getName(),"getsockopt",errno);
  return size;
}

void FieldDataSocket::setMaxMilliSecondsBetweenWrites(int i) { _maxMilliSeconds = i; }
void FieldDataSocket::setMinMilliSecondsBetweenWrites(int i) { _minMilliSeconds = i; }
  
/**
 * write data to socket.
 * Return true: OK
 *	false: socket blocked, write not performed
 */
bool FieldDataSocket::write(const char * const*buf,int *len,int nbuf) throw (IOException) {

  int i,l;
  bool status = true;

  /* compute total size of write
   * We want to either write all these buffers or toss them all.
   * It is all or nothing so that the output stream doesn't
   * contain half samples.
   */
    
  int blen = 0;
  for (i = 0; i < nbuf; i++) blen += len[i];

  isff_sys_time_t tnow = getCurrentTimeInMillis();

  if (tnow > _statisticsTime) resetStatistics(tnow);
  int tdiff = (int)(tnow - _lastWrite);

  /* are we not yet finished writing output buffer? */
  int olen;
  if ((olen = (_bufs[_outbuf] + _bufN[_outbuf] - _bufPtrs[_outbuf])) > 0) {
    if (tdiff >= _minMilliSeconds) {
      _partialWrites[_currStatsIndex]++;
#ifdef DEBUG
	cerr << " unfinished write, len=" << olen << endl;
#endif
      if ((l = ::write(_fd,_bufPtrs[_outbuf],olen)) < 0) {
        if (errno == EAGAIN) {
	  _socketTempUnavailable++;
	  l = 0;
	}
        else if (errno == EINTR) {
	  _socketInterrupted++;
	  l = 0;
	}
	else {
	  _socketWriteErrors++;
	  throw IOException(getName(),"write",errno);
	}
      }
      _socketBytes[_currStatsIndex] += l;
      _bufPtrs[_outbuf] += l;
      if (l < _minWriteLength[_currStatsIndex])
	  _minWriteLength[_currStatsIndex] = l;
      if (l > _maxWriteLength[_currStatsIndex])
	  _maxWriteLength[_currStatsIndex] = l;
      olen -= l;
      _lastWrite = tnow;
      tdiff = 0;
    }
  }
  int wlen = _bufN[_inbuf];
  // if this write will exceed _writelen, then switch buffers and write
  if (wlen + blen > _writelen || tdiff >= _maxMilliSeconds) {
    if (wlen > 0 && olen == 0) {
      /* switch buffers if input filled and output empty */
      i = _inbuf; _inbuf = _outbuf; _outbuf = i;

      /* reset input buffer */
      _bufPtrs[_inbuf] = _bufs[_inbuf];
      _bufN[_inbuf] = 0;

      _bufPtrs[_outbuf] = _bufs[_outbuf];
      _bufN[_outbuf] = wlen;

      if (wlen > _writelen) wlen = _writelen;
#ifdef DEBUG
      cerr << "writing " << wlen << " bytes" << endl;
#endif
      if ((l = ::write(_fd,_bufPtrs[_outbuf],wlen)) < 0) {
        if (errno == EAGAIN) {
	  _socketTempUnavailable++;
	  l = 0;
	}
        else if (errno == EINTR) {
	  _socketInterrupted++;
	  l = 0;
	}
	else {
	  _socketWriteErrors++;
	  throw IOException(getName(),"write",errno);
	}
      }
      _socketBytes[_currStatsIndex] += l;
      if (l < _minWriteLength[_currStatsIndex])
	  _minWriteLength[_currStatsIndex] = l;
      if (l > _maxWriteLength[_currStatsIndex])
	  _maxWriteLength[_currStatsIndex] = l;
      _bufPtrs[_outbuf] += l;
      _lastWrite = tnow;
    }
  }

#ifdef DEBUG
  cerr << "_inbuf=" << _inbuf << " _outbuf=" << _outbuf <<
	" _bufsize=" << _bufsize << endl;
  cerr << "_bufN[_inbuf]=" << _bufN[_inbuf] <<
	" _bufPtrs[_inbuf]-_bufs[_inbuf]=" << _bufPtrs[_inbuf]-_bufs[_inbuf] << endl;
#endif

  /* put sample in input buffer if there is room for it */
  if ((_bufsize - _bufN[_inbuf]) >= blen) {
    for (i = 0; i < nbuf; i++) {
      l = len[i];
      memcpy(_bufPtrs[_inbuf],buf[i],l);
      _bufPtrs[_inbuf] += l;
      _bufN[_inbuf] += l;
    }
    _numWrites[_currStatsIndex]++;
  }
  /* if there isn't room we have to toss it */
  else {
    status = false;  /* lost it */
  }

#ifdef DEBUG
  cerr << "_bufN[_outbuf]=" << _bufN[_outbuf] <<
	" _bufPtrs[_outbuf]-_bufs[_outbuf]=" << _bufPtrs[_outbuf]-_bufs[_outbuf] << endl;
#endif

  return status;
}
void  FieldDataSocket::zeroStatistics() {
  time(&_startupTime);
  _lastWrite = getCurrentTimeInMillis();
  _statisticsTime = ((_lastWrite / _statisticsPeriodMsec) + 1) * _statisticsPeriodMsec;
  _maxWriteLength[0] = _maxWriteLength[1] = 0;
  _minWriteLength[0] = _minWriteLength[1] = 999999999;
  _socketBytes[0] = _socketBytes[1] = 0;
  _numWrites[0] = _numWrites[1] = 0;
  _partialWrites[0] = _partialWrites[1] = 0;
}
void  FieldDataSocket::resetStatistics(isff_sys_time_t tnow) {
  _prevStatsIndex = _currStatsIndex;
  _currStatsIndex = (_currStatsIndex + 1) % 2;
  _statisticsTime += _statisticsPeriodMsec;
  if (_statisticsTime < tnow)
    _statisticsTime = ((tnow / _statisticsPeriodMsec) + 1) * _statisticsPeriodMsec;
  _maxWriteLength[_currStatsIndex] = 0;
  _minWriteLength[_currStatsIndex] = 999999999;
  _socketBytes[_currStatsIndex] = _numWrites[_currStatsIndex] = 0;
  _partialWrites[_currStatsIndex] = 0;
}

/**
 * What has been my throughput.
 */
int FieldDataSocket::getBytesPerSec() const {
  if (_prevStatsIndex == _currStatsIndex) { // partial statistics
    time_t tnow;
    time(&tnow);
    int nsec = tnow - _startupTime;
    if (nsec == 0) nsec = 1;
    return _socketBytes[_prevStatsIndex] / nsec;
  }
  return _socketBytes[_prevStatsIndex] / (_statisticsPeriodMsec / 1000);
}

/**
 * return number of FieldDataSocket::writes (not number of socket writes)
 * per second.
 */
float FieldDataSocket::getWritesPerSec() const {
  if (_prevStatsIndex == _currStatsIndex) { // partial statistics
    time_t tnow;
    time(&tnow);
    int nsec = tnow - _startupTime;
    if (nsec == 0) nsec = 1;
    return (float)_numWrites[_prevStatsIndex] / nsec;
  }
  else return (float)_numWrites[_prevStatsIndex] * 1000. / _statisticsPeriodMsec;
}
