
local segments = require("lac.mrudp.segments")

local function fail(fn, ...)
   local ok, err = pcall(fn, ...)
   assert(not ok, "got success when expecting failure")
   print(err)
end

local function getnil(fn, ...)
   local nocrash, ok, err = pcall(fn, ...)
   assert(nocrash, ok)
   assert(not ok, "got success when expecting failure")
   print(err)
end

local function test_serialization(seg)
   print(segments.tostring(seg))
   local bytes = segments.tobytes(seg)
   print(assert(segments.ashex(bytes)))
   local seg2, err = segments.parse(bytes)
   assert(seg2, err)
   print(segments.tostring(seg2))
   local bytes2 = segments.tobytes(seg2)
   print(assert(segments.ashex(bytes2)))
   assert(bytes2 == bytes, "serialized representation of copy doesn't match")
end

local tests = {
   function()
      print("creating SYN segment")
      local seg = segments.new("SYN", 1, 10, 128, 0.2, 0.2, 0.2, 60, 70, 80, 90)
      assert(seg)
      test_serialization(seg)
   end,
   function()
      print("creating NUL segment")
      local seg = segments.new("NUL", 2)
      assert(seg)
      test_serialization(seg)
   end,
   function()
      print("creating UID segment")
      local uuid = "1234567812345678"
      local seg = segments.new("UID", 3, uuid)
      assert(seg)
      assert(seg.uuid == uuid)
      test_serialization(seg)
   end,
   function()
      print("creating EAK segment")
      local acks = {1, 2, 3, 4, 5}
      local seg = segments.new("EAK", 4, 9, acks)
      assert(seg)
      for i, a in ipairs(acks) do
         assert(seg.acks[i] == acks[i])
      end
      test_serialization(seg)
   end,
   function()
      print("creating RST segment")
      local seg = segments.new("RST", 5)
      assert(seg)
      test_serialization(seg)
   end,
   function()
      print("creating FIN segment")
      local seg = segments.new("FIN", 6)
      assert(seg)
      test_serialization(seg)
   end,
   function()
      print("creating ACK segment")
      local ackn = 12
      local seg = segments.new("ACK", 7, ackn)
      assert(seg)
      assert(seg.ackn == ackn)
      test_serialization(seg)
   end,
   function()
      print("creating DAT segment")
      local data = "hello"
      local seg = segments.new("DAT", 8, 9, data)
      assert(seg)
      assert(seg.data == data)
      test_serialization(seg)
   end,
   function()
      print("attempting invalid segment")
      fail(segments.new, "INV")
   end,
   function()
      print("attempting invalid SYN segment")
      fail(segments.new, "SYN", 1, 10, 20, 0.2, 0.2, 0.2, 60, 70, 80, 90)
   end,
   function()
      print("attempting invalid SYN segment")
      fail(segments.new, "SYN", 1, 10, 20, 30, 200, 50, 60, 70, 80, 90)
   end,
   function()
      print("attempting invalid SYN segment")
      fail(segments.new, "SYN", 1, 10, 20, 30, 40, 50, 60, 70, 80, 90)
   end,
   function()
      print("attempting UID segment without UUID")
      fail(segments.new, "UID", 3)
   end,
   function()
      print("attempting UID segment with invalid UUID")
      fail(segments.new, "UID", 3, "hello")
   end,
   function()
      print("attempting EAK segment without ack list")
      fail(segments.new, "EAK", 4, 9)
   end,
   function()
      print("attempting invalid EAK without ackn")
      fail(segments.new, "EAK", 4)
   end,
   function()
      print("attempting invalid ACK without ackn")
      fail(segments.new, "ACK", 4)
   end,
   function()
      print("attempting invalid ACK with out-of-range ackn")
      fail(segments.new, "ACK", 4, 9000)
   end,
   function()
      print("attempting invalid EAK without ackn")
      fail(segments.new, "EAK", 4)
   end,
   function()
      print("attempting invalid DAT without data")
      fail(segments.new, "DAT", 4, 9)
   end,
   function()
      print("attempting to parse a short invalid segment")
      getnil(segments.parse, "oi")
   end,
   function()
      print("attempting to parse an invalid segment")
      getnil(segments.parse, string.char(0x04, 0x07, 0x07, 0x0c, 0x00, 0x00, 0x00))
   end,
   function()
      print("attempting to parse a segment of invalid length")
      getnil(segments.parse, string.char(0x40, 0x07, 0x07, 0x0c, 0x00, 0x00))
   end,
   function()
      print("attempting to parse a segment of invalid length")
      getnil(segments.parse, string.char(0x02, 0x07, 0x07, 0x0c, 0x00, 0x00, 0x00))
   end,
   function()
      print("attempting to read invalid version in SYN segment")
      local seg = segments.new("SYN", 1, 10, 128, 0.2, 0.2, 0.2, 60, 70, 80, 90)
      seg.version = 2
      local bytes = segments.tobytes(seg)
      getnil(segments.parse, bytes)
   end,
}

local passed, failed = 0, 0
for i = 1, #tests do
   local ok, err = pcall(tests[i])
   if ok then
      print("OK")
      passed = passed + 1
   else
      print("fail:", err)
      failed = failed + 1
   end
   print("----------------------------------------")
end
print()
print("Total: "..passed.." test"..(passed == 1 and "" or "s").." passed. "..failed.." test"..(failed == 1 and "" or "s").." failed.")

