
local segments = {}

require("compat52")

local javaserialize = require("javaserialize")

local FLAG = bit32.btest
local AND  = bit32.band
local OR   = bit32.bor
local SHL  = bit32.lshift
local SHR  = bit32.rshift

local RUDP_VERSION = 1

local RUDP_HEADER_LEN = 6
segments.RUDP_HEADER_LEN = RUDP_HEADER_LEN

local seg_flag = {
   SYN = 0x80,
   NUL = 0x08,
   UID = 0x08, -- same as NUL
   EAK = 0x20,
   RST = 0x10,
   FIN = 0x02,
   ACK = 0x40,
   DAT = 0x40, -- same as ACK
}

local header_len = {
   SYN = RUDP_HEADER_LEN + 16,
   NUL = RUDP_HEADER_LEN,
   UID = RUDP_HEADER_LEN + 16,
-- EAK = variable length,
   RST = RUDP_HEADER_LEN,
   FIN = RUDP_HEADER_LEN,
   ACK = RUDP_HEADER_LEN,
-- DAT = variable length,
}

local function check_value(var, min, max)
   return var >= min and var <= max
end

local init_seg = {
   SYN = function(seg, max_ostand, max_seg_size, retx_to, cack_to,
            nilseg_to, max_retx, max_cack, max_outseq, max_autorst)
            assert(check_value(max_ostand, 1, 255), "invalid max outstanding segments")
            assert(check_value(max_seg_size, 22, 7990), "invalid max segment size")
            assert(check_value(retx_to, 0.1, 65.535), "invalid retransmission timeout")
            assert(check_value(cack_to, 0.1, 65.535), "invalid cumulative ack timeout")
            assert(check_value(nilseg_to, 0, 65.535), "invalid null segment timeout")
            assert(check_value(max_retx, 0, 255), "invalid max retransmissions")
            assert(check_value(max_cack, 0, 255), "invalid max cumulative acks")
            assert(check_value(max_outseq, 0, 255), "invalid max out of sequence")
            assert(check_value(max_autorst, 0, 255), "invalid max auto reset")
            seg.version = RUDP_VERSION
            seg.max_ostand = max_ostand
            seg.opt_flags = 0x01
            seg.max_seg_size = max_seg_size
            seg.retx_to = retx_to
            seg.cack_to = cack_to
            seg.nilseg_to = nilseg_to
            seg.max_retx = max_retx
            seg.max_cack = max_cack
            seg.max_outseq = max_outseq
            seg.max_autorst = max_autorst
         end,
-- NUL = none,
   UID = function(seg, uuid)
            assert(uuid, "missing UUID")
            assert(#uuid == 16, "invalid UUID")
            seg.uuid = uuid
         end,
   EAK = function(seg, ackn, acks)
            assert(ackn >= 0 and ackn < 255, "invalid ack number")
            assert(type(acks) == "table", "missing ack list")
            seg.hlen = RUDP_HEADER_LEN + #acks
            seg.ackn = ackn
            seg.acks = {}
            for i = 1, #acks do
               seg.acks[i] = acks[i]
            end
         end,
-- RST = none,
-- FIN = none,
   ACK = function(seg, ackn)
            assert(ackn >= 0 and ackn < 255, "invalid ack number")
            seg.ackn = ackn
         end,
   DAT = function(seg, ackn, data)
            assert(ackn >= 0 and ackn < 255, "invalid ack number")
            assert(type(data) == "string", "missing data")
            seg.ackn = ackn
            seg.hlen = RUDP_HEADER_LEN
            seg.data = data
         end,
}

function segments.new(segtype, seqn, ...)
   local seg = {
      nretx = 0,
      type = segtype,
      flags = seg_flag[segtype],
      hlen = header_len[segtype],
      seqn = seqn,
      ackn = 0xFF, -- no ack
   }
   assert(seg.flags, "invalid segment type")
   local fn = init_seg[segtype]
   if fn then fn(seg, ...) end
   return segments.make_seg(seg)
end

-- Java classdesc for serialization
local java_util_UUID = {
   __name = "java.util.UUID",
   __serialVersionUID = string.char(0xBC, 0x99, 0x03, 0xF7, 0x98, 0x6D, 0x85, 0x2F),
   [1] = { "long", "leastSigBits" },
   [2] = { "long", "mostSigBits" },
}

local seg_to_buf = {
   SYN = function(buf, seg)
            buf[5] = SHL(seg.version, 4)
            buf[6] = seg.max_ostand
            buf[7] = seg.opt_flags
            buf[8] = 0 -- spare
            buf[9],  buf[10] = SHR(seg.max_seg_size,    8), AND(seg.max_seg_size,    0xFF)
            buf[11], buf[12] = SHR(seg.retx_to*1000,    8), AND(seg.retx_to*1000, 0xFF)
            buf[13], buf[14] = SHR(seg.cack_to*1000,    8), AND(seg.cack_to*1000,    0xFF)
            buf[15], buf[16] = SHR(seg.nilseg_to*1000,  8), AND(seg.nilseg_to*1000,  0xFF)
            buf[17] = seg.max_retx
            buf[18] = seg.max_cack
            buf[19] = seg.max_outseq
            buf[20] = seg.max_autorst
            buf[21], buf[22] = 0, 0 -- checksum
         end,
-- NUL = none,
   UID = function(buf, seg)
            buf[5], buf[6] = 0, 0 -- checksum
            
            -- this serializes the UUID as a Java object like the Java implementation expects
            local msb = seg.uuid:sub(1,8)
            local lsb = seg.uuid:sub(9,16)
            local bytes = javaserialize.serialize({mostSigBits = msb, leastSigBits = lsb}, java_util_UUID)
            
            -- but in reality the Java implementation should be revised and it should be like this:
            -- local bytes = { string.byte(seg.uuid, 1, 16) }
            
            for i = 1, #bytes do
               buf[6+i] = string.byte(bytes:sub(i,i))
            end
         end,
   EAK = function(buf, seg)
            for i = 1, #seg.acks do
               buf[4+i] = seg.acks[i]
            end
            buf[#buf+1] = 0
            buf[#buf+1] = 0 -- checksum
         end,
-- RST = none,
-- FIN = none,
-- ACK = none,
   DAT = function(buf, seg)
            buf[5], buf[6] = 0, 0 -- checksum
            local len = #seg.data
            local bytes = { string.byte(seg.data, 1, len) }
            for i = 1, len do
               buf[6+i] = bytes[i]
            end
         end,
}

function segments.tobytes(seg)
   local buf = {}
   buf[1] = seg.flags
   buf[2] = seg.hlen
   buf[3] = seg.seqn
   buf[4] = seg.ackn
   local fn = seg_to_buf[seg.type]
   if fn then
      fn(buf, seg)
   else
      buf[5], buf[6] = 0, 0 -- checksum
   end
   return string.char(table.unpack(buf))
end

local buf_to_seg = {
   SYN = function(seg, buf)
            seg.version = SHR(buf[5], 4)
            if seg.version ~= RUDP_VERSION then return nil, "invalid RUDP version" end
            seg.max_ostand = buf[6]
            seg.opt_flags = buf[7]
            -- spare = buf[8]
            seg.max_seg_size = OR(SHL(buf[9],  8), buf[10])
            seg.retx_to =      OR(SHL(buf[11], 8), buf[12]) / 1000
            seg.cack_to =      OR(SHL(buf[13], 8), buf[14]) / 1000
            seg.nilseg_to =    OR(SHL(buf[15], 8), buf[16]) / 1000
            seg.max_retx =     buf[17]
            seg.max_cack =     buf[18]
            seg.max_outseq =   buf[19]
            seg.max_autorst =  buf[20]
            seg.checksum =     OR(SHL(buf[21], 8), buf[22])
            return seg
         end,
-- NUL = none,
   UID = function(seg, buf)
            seg.checksum = OR(SHL(buf[5], 8), buf[6])
            
            -- FIXME this extracts by forceps the UUID from the serialized Java class 
            seg.uuid = string.char(table.unpack(buf, 79, 86)) .. string.char(table.unpack(buf, 71, 78))

            -- this would be a little nicer:
            -- javaserialize.deserialize(string.char(table.unpack(buf, 7, #buf)))
            
            -- but in reality the Java implementation should be revised and it should be like this:
            -- seg.uuid = string.char(table.unpack(buf, 7, 22))
            return seg
         end,
   EAK = function(seg, buf)
            seg.acks = {}
            for i = 1, seg.hlen - RUDP_HEADER_LEN do
               seg.acks[i] = buf[4+i]
            end
            seg.checksum = OR(SHL(buf[seg.hlen-1], 8), buf[seg.hlen])
            return seg
         end,
-- RST = none,
-- FIN = none,
-- ACK = none,
   DAT = function(seg, buf)
            seg.checksum = OR(SHL(buf[5], 8), buf[6])
            seg.data = string.char(table.unpack(buf, 7, #buf))
            return seg
         end,
}

function segments.parse(data)
   local len = #data
   if len < RUDP_HEADER_LEN then
      return nil, "invalid segment"
   end
   local buf = { data:byte(1, len) }
   local flags = buf[1]
   local seg = {
      nretx = 0,
      flags = flags,
      hlen = buf[2],
      seqn = buf[3],
      ackn = buf[4],
   }
   local segtype
   if FLAG(flags, seg_flag.SYN) then segtype = "SYN"
   elseif FLAG(flags, seg_flag.NUL) then
      if len == RUDP_HEADER_LEN then segtype = "NUL"
      else segtype = "UID"
      end
   elseif FLAG(flags, seg_flag.EAK) then segtype = "EAK"
   elseif FLAG(flags, seg_flag.RST) then segtype = "RST"
   elseif FLAG(flags, seg_flag.FIN) then segtype = "FIN"
   elseif FLAG(flags, seg_flag.ACK) then
      -- always process ACKs or data segments last
      if len == RUDP_HEADER_LEN then segtype = "ACK"
      else segtype = "DAT"
      end
   end
   if not segtype then
      return nil, "invalid segment type"
   end
   local hlen = header_len[segtype]
   if segtype ~= "DAT" then
      if len ~= seg.hlen then
         return nil, "invalid reported length in segment: reported "..seg.hlen..", got "..len
      end
      if hlen and hlen ~= len then
         return nil, "invalid segment length for "..segtype..": expected "..hlen..", got "..len
      end
   end
   seg.type = segtype
   local fn = buf_to_seg[segtype]
   if fn then
      return segments.make_seg(fn(seg, buf))
   else
      seg.checksum = OR(SHL(buf[5], 8), buf[6])
   end
   return segments.make_seg(seg)
end

local seg_to_string = {
   SYN = function(out, seg)
            out[#out+1] = "v:"..seg.version
            out[#out+1] = "max_ostand:"..seg.max_ostand
            out[#out+1] = "opt_flg:"..seg.opt_flags
            out[#out+1] = "max_seg_size:"..seg.max_seg_size
            out[#out+1] = "retx_to:"..seg.retx_to
            out[#out+1] = "cack_to:"..seg.cack_to
            out[#out+1] = "nilseg_to:"..seg.nilseg_to
            out[#out+1] = "max_retx:"..seg.max_retx
            out[#out+1] = "max_cack:"..seg.max_cack
            out[#out+1] = "max_outseq:"..seg.max_outseq
            out[#out+1] = "amx_autorst:"..seg.max_autorst
         end,
-- NUL = none,
   UID = function(out, seg)
            local bytes = { string.byte(seg.uuid, 1, 16) }
            local outuuid = { "uuid:" }
            for i = 1, 16 do
               outuuid[i] = string.format("%02x", bytes[i])
            end
            out[#out+1] = table.concat(outuuid)
         end,
   EAK = function(out, seg)
            out[#out+1] = "acks:"
            for i = 1, #seg.acks do
               out[#out+1] = tostring(seg.acks[i])
            end
         end,
-- RST = none,
-- FIN = none,
-- ACK = none,
   DAT = function(out, seg)
            local len = #seg.data
            local bytes = { string.byte(seg.data, 1, len) }
            local outdata = { "data:" }
            for i = 1, len do
               outdata[i] = string.format("%02x", bytes[i])
            end
            out[#out+1] = table.concat(outdata)
         end,
}

function segments.ashex(seg)
   if type(seg) == "table" then
      seg = seg:tobytes()
   end
   assert(type(seg) == "string")
   local out = {}
   for i = 1, #seg do
      out[i] = string.format("%02x", string.byte(seg:sub(i,i)))
   end
   return table.concat(out)
end

function segments.fromhex(hex)
   local t = {}
   for i = 1, #hex, 2 do
      t[#t+1] = tonumber(hex:sub(i, i+1), 16)
   end
   return segments.parse(string.char(table.unpack(t)))
end

function segments.tostring(seg)
   local out = {seg:ashex(), " [", seg.type, "seq:", seg.seqn}
   out[#out+1] = "flag:"..seg.flags
   out[#out+1] = "hlen:"..seg.hlen
   out[#out+1] = "ackn:"..seg.ackn
   local fn = seg_to_string[seg.type]
   if fn then fn(out, seg) end
   if seg.nretx > 0 then
      out[#out+1] = " RETX:"..seg.nretx
   end
   out[#out+1] = "]"
   return table.concat(out, " ")
end

function segments.set_ack(seg, ackn)
   seg.flags = OR(seg.flags, seg_flag.ACK)
   seg.ackn = ackn
end

function segments.make_seg(...)
   local seg = select(1, ...)
   if seg then
      seg.set_ack = segments.set_ack
      seg.tobytes = segments.tobytes
      seg.ashex = segments.ashex
      seg.tostring = segments.tostring
      return seg
   end
   return ...
end

return segments
