
local dispatcher = {}

local Dispatcher = {}

local socket = require("socket")
local utils = require("lac.utils")
require("compat52")

--[[
local colors = require("ansicolors")
local myprint = print
local function print(...)
   local t = {}
   for i=1,select("#", ...) do
      t[i] = tostring(select(i, ...))
   end
   myprint(colors("%{bright green}"..table.concat(t, "\t")))
end
]]

local Entry = {

   is_ready = function(self)
      return self.state == "timeout" or self.state == "notifier" or (self.state == "wait" and #self.msgq > 0)
   end,

   set_notifier = function(self)
      assert(self.state == "idle")
      self.state = "notifier"
      self.timeout = nil
      self.coro = coroutine.running()
   end,

   set_timeout = function(self)
      assert(self.state == "wait")
      self.state = "timeout"
      self.timeout = nil
   end,
   
   enq_data = function(self, ...)
      -- TODO
      table.insert(self.msgq, {...})
   end,

   deq_data = function(self)
      -- TODO
      return table.remove(self.msgq, 1)
   end,

   replace_data = function(self, ...)
      self.msgq = { {...} }
   end,
   
   go_wait = function(self, timeout)
      assert(self.state == "idle")
      self.state = "wait"
      self.timeout = timeout and (socket.gettime() + timeout)
      self.coro = coroutine.running()
   end,

   go_idle = function(self)
      assert(self:is_ready())
      self.state = "idle"
      self.timeout = nil
      self.coro = nil
   end,

   go_dead = function(self)
      self.state = "dead"
      self.timeout = nil
      self.coro = nil
   end,

}

local function get_entry(self, obj)
   local entry = self.objects[obj]
   if entry then
      return entry
   end
   
   entry = {}
   entry.obj = obj
   entry.msgq = {}
   entry.state = "idle"
   entry.coro = nil
   setmetatable(entry, { __index = Entry })
   table.insert(self.pool, entry)
   self.objects[obj] = entry
   return entry
end

local function check_sockets(self, instant)
   local time = nil
   local recvt = {}
   local timeout
   local timeout_entry
   local now = socket.gettime()
   local any_ready = false
   local any_wait = false
   
   for _, entry in ipairs(self.pool) do

      if entry:is_ready() then
         any_ready = true
      end
      if entry.state == "wait" then
         any_wait = true
         local objtype = type(entry.obj)
         if objtype ~= "thread" and objtype ~= "function" then
            table.insert(recvt, entry.obj)
         end
         if entry.timeout and (time == nil or entry.timeout < time) then
            time = entry.timeout
            timeout_entry = entry
            if time - now <= 0 then
               entry:set_timeout()
               any_ready = true
            end
         end
      end
   end
   
   if any_wait and not any_ready then
      timeout = time and (time - now) or -1
   elseif any_ready then
      time = nil
      timeout = 0
   else -- not any_ready and not any_wait
      return nil, "terminate"
   end

   local skts_ready, _, err = socket.select(recvt, nil, timeout)
   
   for _, skt in ipairs(skts_ready) do
      local data, host, port = skt:receivefrom()
      self:get_entry(skt):enq_data(data, host, port)
   end
   return true
end

function Dispatcher.notify(self, obj, ...)
   local entry = self:get_entry(obj)
   if entry.state == "wait" then
      entry:replace_data(...)
   end
   local myself = self:get_entry(coroutine.running())
   myself:set_notifier()
   local state = coroutine.yield()
   assert(state == "notifier")
   return true
end

function Dispatcher.enqueue(self, obj, ...)
   local entry = self:get_entry(obj)
   entry:enq_data(...)
   local myself = self:get_entry(coroutine.running())
   myself:set_notifier()
   local state = coroutine.yield()
   assert(state == "notifier")
   return true
end

function Dispatcher.wait(self, obj, timeout)
   local entry = self:get_entry(obj)
   entry:go_wait(timeout)
   local state = coroutine.yield()
   assert(state == "wait" or state == "timeout")
   if state == "wait" then
      local data = entry:deq_data()
      return true, table.unpack(data)
   else
      return nil, "timeout"
   end
end

function Dispatcher.start(self, fn)
   local main_coro = coroutine.create(fn)
   self:set_thread_name("main", main_coro)
   local ok, err = coroutine.resume(main_coro, self)
   if not ok then
      error(err, 2)
   end
   local at = 1
   local limit = #self.pool
   local all_idle = true
   while true do
      if #self.pool == 0 then break end
      local entry = self.pool[at]
      local coro = entry.coro
      local state = entry.state
      
      --print("dispatcher", at, entry.obj, self.threads[coro], state, #entry.msgq)
      
      if coro and coroutine.status(coro) == "dead" then

         entry:go_dead()

      elseif entry:is_ready() then
         entry:go_idle()
         local ok, err = coroutine.resume(coro, state)
         if not ok then
            error(err, 2)
         end
      
      end

      if at == limit then
         local ok, err = check_sockets(self)
         if err == "terminate" then
            break
         end
         limit = #self.pool
         at = 1
      else
         at = at + 1
      end
      
   end
end

function Dispatcher.wait_until(self, obj, abstime)
   if not abstime or abstime < 0 then
      return self:wait(obj)
   end
   local now = socket.gettime()
   if now >= abstime then
      return nil, "timeout"
   end
   return self:wait(obj, abstime - now)
end

function Dispatcher.sleep(self, time)
   return self:wait({}, time)
end

function Dispatcher.time(self)
   return socket.gettime()
end

function Dispatcher.set_thread_name(self, name, coro)
   coro = coro or coroutine.running()
   self.threads[coro] = name
end

function dispatcher.new()
   local disp = {
      threads = {},
      objects = {},
      pool = {},
      get_entry = get_entry,
      set_thread_name = Dispatcher.set_thread_name,
      wait = Dispatcher.wait,
      wait_until = Dispatcher.wait_until,
      notify = Dispatcher.notify,
      enqueue = Dispatcher.enqueue,
      start = Dispatcher.start,
      sleep = Dispatcher.sleep,
      time = Dispatcher.time,
   }
   return disp
end

return dispatcher
