
local javaserialize = {}

local null = {}

javaserialize.null = null

local function lsb(n)
   return n % 256
end

local function shl(n)
   return math.floor(n / 256)
end

local function tobyte(n)
   assert (n < 256)
   return string.char(lsb(n))
end

local function toshort(n)
   assert (n < 65536)
   return string.char(shl(n), lsb(n))
end

local function toint(s0)
   local s1 = shl(s0)
   local s2 = shl(s1)
   local s3 = shl(s2)
   return string.char(lsb(s3), lsb(s2), lsb(s1), lsb(s0))
end
javaserialize.toint = toint

local L1 = string.char(0,0,0,0,0,0,0,1)

local function tolong(n)
   if type(n) == "string" and #n == 8 then return n end
   local s = n
   local out = {0, 0, 0, 0, 0, 0, 0, 0}
   for i = 8, 1, -1 do
      out[i] = lsb(s)
      s = shl(s)
      if s == 0 then break end
   end
   return string.char(table.unpack(out))
end

local function utf(s)
   -- TODO Java Modified UTF-8 encoding
   -- see http://download.oracle.com/javase/6/docs/api/java/io/DataInput.html#modified-utf-8
   -- and http://stackoverflow.com/questions/7921016/what-does-it-mean-to-say-java-modified-utf-8-encoding
   return toshort(#s) .. s
end

local function long_utf(s)
   -- TODO Java Modified UTF-8 encoding
   -- see http://download.oracle.com/javase/6/docs/api/java/io/DataInput.html#modified-utf-8
   -- and http://stackoverflow.com/questions/7921016/what-does-it-mean-to-say-java-modified-utf-8-encoding
   return toint(#s) .. s
end

local primitive = {
   byte = "B",
   char = "C",
   double = "D",
   float = "F",
   int = "I",
   long = "J",
   short = "S",
   boolean = "Z",
}

local typical = {
   String = "java.lang.String",
   Vector = "java.util.Vector",
}

local to_field_descriptor 
to_field_descriptor = function(typename)
   local arraytype = typename:match("(.*)%[%]$")
   if arraytype then
      local code, klass = to_field_descriptor(arraytype)
      return "[", "["..code..(klass or "")
   end
   local prim = primitive[typename]
   if prim then
      return prim
   end
   if typical[typename] then
      typename = typical[typename]
   end
   if not typename:match(".") then
      typename = "java.lang."..typename
   end
   return "L", "L"..typename:gsub("%.", "/")..";"
end

local function to_class_descriptor(typename)
   local code, fielddesc = to_field_descriptor(typename)
   if code == "L" then
      return fielddesc:sub(2,-2):gsub("/", ".")
   end
   return fielddesc
end

local function serializable(obj)   return (not obj.__class.__flags) or obj.__class.__flags.serializable end
local function externalizable(obj) return obj.__class.__flags and obj.__class.__flags.externalizable end
local function write_method(obj)   return obj.__class.__flags and obj.__class.__flags.write_method end
local function block_data(obj)     return obj.__class.__flags and obj.__class.__flags.block_data end

function javaserialize.writer()

   local STREAM_MAGIC = string.char(0xac, 0xed)
   local STREAM_VERSION = string.char(0x00, 0x05)
   local TC_NULL = string.char(0x70)
   local TC_REFERENCE = string.char(0x71)
   local TC_CLASSDESC = string.char(0x72)
   local TC_OBJECT = string.char(0x73)
   local TC_STRING = string.char(0x74)
   local TC_ARRAY = string.char(0x75)
   local TC_CLASS = string.char(0x76)
   local TC_BLOCKDATA = string.char(0x77)
   local TC_ENDBLOCKDATA = string.char(0x78)
   local TC_RESET = string.char(0x79)
   local TC_BLOCKDATALONG = string.char(0x7A)
   local TC_EXCEPTION = string.char(0x7B)
   local TC_LONGSTRING = string.char(0x7C)
   local TC_PROXYCLASSDESC = string.char(0x7D)
   local TC_ENUM = string.char(0x7E)

   local SC_WRITE_METHOD = 0x01
   local SC_BLOCK_DATA = 0x08
   local SC_SERIALIZABLE = 0x02
   local SC_EXTERNALIZABLE = 0x04
   local SC_ENUM = 0x10
   
   local writer = {

      -- fieldlist: the list of fields to be processed

      fieldlist = nil,
      begin_scope = function(w)
         local new_list = { parent = w.fieldlist }
         w.fieldlist = new_list
      end,
      end_scope = function(w)
         w.fieldlist = w.fieldlist.parent
      end,
      add_field = function(w, fieldtype, fieldname)
         table.insert(w.fieldlist, { type = fieldtype, name = fieldname })
      end,
      get_field = function(w, fieldtype, fieldname)
         return table.remove(w.fieldlist, 1)
      end,

      -- handles: the object serialization protocol mechanism for back references
   
      handles = {},
      baseWireHandle = 0x007e0000,
      curr_handle = 0x007e0000,
      
      new_handle = function(w, obj)
         w.handles[obj] = w.curr_handle
         w.curr_handle = w.curr_handle + 1
      end,
      
      -- output: the table contaning strings to be concatenated as the end result
      
      out = {},
      put = function(w, data)
         w.out[#(w.out)+1] = data
      end,
      
      -- rules: based on grammar from
      -- http://docs.oracle.com/javase/6/docs/platform/serialization/spec/protocol.html#10258

      stream = function(w, ...)
         w:magic()
         w:version()
         w:contents(...)
      end,

      contents = function(w, ...)
         for i = 1, select("#", ...) do 
            w:content(select(i, ...))
         end
      end,
      
      content = function(w, obj)
         if type(obj) == "table" then
            local kind = string.lower(obj.__kind or "object")
            if kind == "blockdata" then
               w:blockdata(obj)
               return
            end
         end
         -- TODO blockdata
         w:object(obj)
      end,
      
      -- @return when given a ClassDesc, returns number of classes
      -- composing this ClassDesc; otherwise returns nil.
      object = function(w, obj)
         if w.handles[obj] then
            w:prevObject(obj)
            return 1
         elseif type(obj) == "string" then
            w:newString(obj)
         elseif type(obj) == javaserialize.null then
            w:nullReference() 
         elseif type(obj) == "table" then
            local kind = string.lower(obj.__kind or "object")
            if kind == "class" then
               w:newClass(obj)
            elseif kind == "classdesc" then
               return w:newClassDesc(obj)
            elseif kind == "enum" then
               w:newEnum(obj)
            elseif kind == "exception" then
               w:exception(obj)
            elseif obj[1] then
               w:newArray(obj)
            else
               w:newObject(obj)
            end
         else
            assert(false, "Unrecognized object")
         end
      end,
      
      newClass = function(w, obj)
         assert(false, "Not yet implemented") -- TODO
      end,
      
      -- @return number of classes composing this ClassDesc
      classDesc = function(w, obj)
         local klass = obj and obj.__class
         if w.handles[klass] then
            w:prevObject(klass)
            return 1
         elseif klass == javaserialize.null or not klass then
            w:nullReference()
            return 0
         else
            return w:newClassDesc(klass)
         end
      end,
      
      -- @return number of classes composing this ClassDesc
      superClassDesc = function(w, super)
         return w:classDesc(super)
      end,
      
      -- @return number of classes composing this ClassDesc
      newClassDesc = function(w, cd)
         -- TODO TC_PROXYCLASSDESC
         w:put(TC_CLASSDESC)         
         w:new_handle(cd)
         w:put(utf(to_class_descriptor(cd.__name)))
         w:serialVersionUID(cd.__serialVersionUID)
         return w:classDescInfo(cd), cd.__name
      end,
      
      -- @return number of classes composing this ClassDesc
      classDescInfo = function(w, cd)
         local n = 1
         w:classDescFlags(cd.__flags)
         w:fields(cd)
         n = n + w:classAnnotation(cd.__annotations)
         n = n + w:superClassDesc(cd.__super)
         return n
      end,
      
      className = function(w, name)
         w:put(utf(name))
      end,
      
      serialVersionUID = function(w, uid)
         w:put(uid and tolong(uid) or L1)
      end,
      
      classDescFlags = function(w, flags)
         local n = 0
         if flags then
            if flags.write_method then n = n + SC_WRITE_METHOD end
            if flags.block_data then n = n + SC_BLOCK_DATA end
            if flags.serializable then n = n + SC_SERIALIZABLE end
            if flags.externalizable then n = n + SC_EXTERNALIZABLE end
            if flags.enum then n = n + SC_ENUM end
         else
            n = SC_SERIALIZABLE
         end
         w:put(tobyte(n))
      end,
      
      -- proxyClassDescInfo -- TODO
      
      -- proxyInterfaceName -- TODO
      
      fields = function(w, cd)
         w:put(toshort(#cd))
         for i, field in ipairs(cd) do
            w:fieldDesc(field)
         end
      end,
      
      fieldDesc = function(w, field)
         local fieldtype, fieldname = field[1], field[2]
         local typename
         if type(fieldtype) == "string" then
            typename = fieldtype
         elseif type(fieldtype) == "table" then
            typename = fieldtype.__name
         end
         local code, klass = to_field_descriptor(typename)
         w:put(code)
         w:fieldName(fieldname)
         if klass then
            w:className1(klass)
         end
         w:add_field(fieldtype, fieldname)
      end,
      
      -- primitiveDesc -- handled by fieldDesc
      
      -- objectDesc -- handled by fieldDesc
      
      fieldName = function(w, name)
         w:put(utf(name))
      end,
      
      className1 = function(w, name)
         if w.handles[name] then
            w:prevObject(name)
         else
            w:newString(name)
         end
      end,
      
      -- @return number of classes composing this ClassDesc
      classAnnotation = function(w, anns)
         local n = 0
         if anns then
            for i, content in ipairs(anns) do
               local r = w:content(content)
               if type(r) == "number" then
                  n = n + r
               end
            end
         end
         w:endBlockData()
         return n
      end,
      
      -- prim_typecode -- handled by to_field_descriptor
      
      -- obj_typecode -- handled by to_field_descriptor
      
      newArray = function(w, arr)
         w:put(TC_ARRAY)
         local n, name = w:classDesc(arr)
         w:put(toint(#arr))
         for _, value in ipairs(arr) do
            -- FIXME multidimensional arrays?
            local fieldtype = name:gsub("%[%]$", "")
            w:value(fieldtype, value)
         end
      end,
      
      newObject = function(w, obj)
         w:put(TC_OBJECT)
         w:begin_scope()
            local n = w:classDesc(obj)
            w:new_handle(obj)
            w:classdata(obj)
         w:end_scope()
      end,
      
      classdata = function(w, obj)
         if serializable(obj) then
            if write_method(obj) then
               w:wrclass(obj)
               w:objectAnnotation(obj.__annotations, obj.__write)
            else
               w:nowrclass(obj)
            end
         end
         if externalizable(obj) then
            if block_data(obj) then
               w:objectAnnotation(obj.__annotations)
            else
               w:externalContents(obj)
            end
         end
      end,
      
      nowrclass = function(w, obj)
         w:values(obj)
      end,
      
      wrclass = function(w, obj)
         w:nowrclass(obj)
      end,
      
      objectAnnotation = function(w, anns, write)
         if anns then
            for i, content in ipairs(anns) do
               w:content(content)
            end
         end
         if write then
            write(w)
         end
         w:endBlockData()
      end,
      
      blockdata = function(w, obj)
         if #obj[1] < 256 then
            w:blockdatashort(obj)
         else
            w:blockdatalong(obj)
         end
      end,
      
      blockdatashort = function(w, obj)
         w:put(TC_BLOCKDATA)
         w:put(tobyte(#obj[1]))
         w:put(obj[1])
      end,
      
      blockdatalong = function(w, obj)
         w:put(TC_BLOCKDATALONG)
         w:put(toint(#obj[1]))
         w:put(obj[1])
      end,
      
      endBlockData = function(w)
         w:put(TC_ENDBLOCKDATA)
      end,
      
      externalContent = function(w, obj)
         assert(false, "Not yet implemented") -- TODO
      end,
      
      externalContents = function(w, obj)
         assert(false, "Not yet implemented") -- TODO
      end,
      
      newString = function(w, str)
         w:new_handle(str)
         local out = {}
         if #str < 65536 then
            w:put(TC_STRING)
            w:put(utf(str))
         else
            w:put(TC_LONGSTRING)
            w:put(long_utf(str))
         end
      end,
      
      newEnum = function(w, obj)
         assert(false, "Not yet implemented") -- TODO
      end,
      
      -- enumConstantName
      
      prevObject = function(w, obj)
         w:put(TC_REFERENCE)
         w:put(toint(w.handles[obj]))
      end,
      
      nullReference = function(w)
         w:put(TC_NULL)
      end,
      
      exception = function(w, obj)
         assert(false, "Not yet implemented") -- TODO
      end,
      
      magic = function(w)
         w:put(STREAM_MAGIC)
      end,
      
      version = function(w)
         w:put(STREAM_VERSION)
      end,
      
      values = function(w, obj)
         while true do
            local field = w:get_field()
            if not field then break end
            
            local value = obj[field.name]
            w:value(field.type, value)
         end
      end,

      value = function(w, fieldtype, value)
         if fieldtype == "int" then
            w:put(toint(value))
         elseif fieldtype == "byte" then
            w:put(tobyte(value))
         elseif fieldtype == "short" then
            w:put(toshort(value))
         elseif fieldtype == "long" then
            w:put(tolong(value))
         elseif type(value) == "string" then
            w:newString(value)
         elseif value == javaserialize.null then
            w:nullReference()
         elseif type(fieldtype) == "table" then
            value.__class = fieldtype
            w:newObject(value)
         else
            assert(false, "Not yet implemented")
         end
      end

   }

   return writer
end

function javaserialize.serialize(obj, klass)
   local writer = javaserialize.writer()
   if type(obj) == "table" and klass then
      obj.__class = klass
   end
   writer:stream(obj)
   return table.concat(writer.out)
end

function hex_print(buf)
   for byte=1, #buf, 16 do
      local chunk = buf:sub(byte, byte+15)
      io.write(("%08X  "):format(byte-1))
      chunk:gsub(".", function (c) io.write(("%02X "):format(c:byte())) end)
      io.write((" "):rep(3*(16-#chunk)))
      io.write(" ",chunk:gsub("[^%w%p ]","."),"\n") 
   end
end

function test()

   local ImageIcon = {
      __name = "javax.swing.ImageIcon",
      __serialVersionUID = string.char(0xF2, 0xA6, 0x35, 0x6E, 0xDE, 0x0C, 0x0E, 0x32),
      __flags = { serializable = true, write_method = true },
      __annotations = {},
      [1] = { "int", "height" },
      [2] = { "int", "width" },
      [3] = { "javax.swing.ImageIcon$AccessibleImageIcon", "accessibleContext" },
      [4] = { "String", "description" },
      [5] = { "java.awt.image.ImageObserver", "imageObserver" },
   }

   local CustomData = {
      __name = "lac.cnclib.helloworld.CustomData",
      __serialVersionUID = string.char(0x54, 0x8F, 0x7D, 0x57, 0xF7, 0x2B, 0x88, 0xF6),
      __flags = { serializable = true },
      __annotations = {},
      [1] = { "String", "caption" },
      [2] = { ImageIcon, "icon" },
   }

   local IntArray = {
      __name = "int[]",
      __serialVersionUID = string.char(0x4D, 0xBA, 0x60, 0x26, 0x76, 0xEA, 0xB2, 0xA5),
      __flags = { serializable = true },
      __annotations = {},
   }

   local obj = {
      caption = "LAC",
      icon = {
         height = 79,
         width = 171,
         accessibleContext = javaserialize.null,
         description = "file:/Users/hisham/work/mr-udp/ContextNet/ClientLib/build/lac/cnclib/helloworld/logo.gif",
         imageObserver = javaserialize.null,
         __annotations = {
            { __kind = "blockdata", toint(79) .. toint(171) }
         },
         __write = function(w)
            local array = {
               __class = IntArray,
            }
            for i = 1, 79*171 do
               array[i] = 0xffffff00
            end
            w:newArray(array)
         end,
      }
   }

   hex_print( javaserialize.serialize("Hello World Title :-)") )
   print()
   hex_print( javaserialize.serialize(obj, CustomData) )

end

return javaserialize


--[[
stream
contents
content
object
newClass
classDesc
superClassDesc
newClassDesc
classDescInfo
className
serialVersionUID
classDescFlags
proxyClassDescInfo
proxyInterfaceName
fields
fieldDesc
primitiveDesc
objectDesc
fieldName
className1
classAnnotation
prim_typecode
obj_typecode
newArray
newObject
classdata
nowrclass
wrclass
objectAnnotation
blockdata
blockdatashort
blockdatalong
endBlockData
externalContent
externalContents
newString
newEnum
enumConstantName
prevObject
nullReference
exception
magic
version
values
newHandle
reset
]]
