import produce from "immer";

import { ChatMessage } from "../types";

type ThreadState = {
  messages: ChatMessage[];
};

type ThreadAction =
  | { type: "reset_temp_thread"; threadId?: never }
  | { type: "set_messages"; threadId: string; payload: ChatMessage[] }
  | { type: "add_messages"; threadId?: string; payload: ChatMessage[] }
  | {
      type: "update_message";
      threadId: string;
      messageId: string;
      payload: Partial<ChatMessage>;
    }
  | {
      type: "append_message_chunk";
      threadId: string;
      messageId: string;
      payload: string;
    }
  | { type: "remove_messages"; threadId?: string; payload: string[] };

const defaultThreadState: ThreadState = {
  messages: [],
};

function localMessagesReducer(
  stateArg: ThreadState | undefined,
  action: ThreadAction & { threadId: string },
) {
  const state = stateArg ?? defaultThreadState;
  switch (action.type) {
    case "set_messages": {
      return {
        ...state,
        messages: action.payload,
      };
    }
    case "add_messages": {
      return {
        ...state,
        messages: state.messages.concat(action.payload),
      };
    }
    case "update_message": {
      return produce(state, (draft) => {
        for (let i = draft.messages.length - 1; i >= 0; i--) {
          if (draft.messages[i].id === action.messageId) {
            draft.messages[i] = { ...draft.messages[i], ...action.payload };
            break;
          }
        }
      });
    }
    case "append_message_chunk": {
      return produce(state, (draft) => {
        for (let i = draft.messages.length - 1; i >= 0; i--) {
          if (draft.messages[i].id === action.messageId) {
            draft.messages[i].content += action.payload;
            break;
          }
        }
      });
    }
    case "remove_messages": {
      return {
        ...state,
        messages: state.messages.filter((m) => !action.payload.includes(m.id)),
      };
    }
    default: {
      return state;
    }
  }
}

const TEMP_THREAD_ID = "temp_thread_id";

function threadsReducer(
  state: Record<string, ThreadState | undefined>,
  action: ThreadAction,
): typeof state {
  if (action.type === "reset_temp_thread") {
    return {
      ...state,
      [TEMP_THREAD_ID]: defaultThreadState,
    };
  }
  const threadId = action?.threadId ?? TEMP_THREAD_ID;
  return {
    ...state,
    [threadId]: localMessagesReducer(state[threadId], {
      ...action,
      threadId,
    } as ThreadAction & { threadId: string }),
  };
}

export { threadsReducer, TEMP_THREAD_ID };
