import { useMemo, useReducer } from "react";
import { v4 as uuid } from "uuid";

import { ChatRole, useChatMessagesQuery } from "@/apollo/types";

import { TEMP_THREAD_ID, threadsReducer } from "./store/threadsReducer";
import { useChatStream } from "./stream-completion";
import { ChatMessage, CreateChatMessageInput } from "./types";

export type SubmitMessageResult =
  | {
      data: {
        threadId: string;
        messageId: string;
        responseMessageId: string;
      };
      error: undefined;
    }
  | {
      data: undefined;
      error: Error;
    };

export function useChat(
  threadId: string | undefined,
  options: {
    onThreadCreated?: (threadId: string) => void;
  } = {},
) {
  const [state, dispatch] = useReducer(threadsReducer, {});

  const localMessages = state[threadId ?? TEMP_THREAD_ID]?.messages;

  const {
    data,
    refetch: refetchMessages,
    loading: isMessagesLoading,
  } = useChatMessagesQuery({
    fetchPolicy: "cache-and-network",
    variables: {
      threadId: threadId ?? "",
    },
    skip: !threadId,
    onCompleted(data) {
      if (localMessages?.length === 0 || data.chatMessages.length === 0) {
        return;
      }
      const idSet = new Set(data.chatMessages.map((m) => m.id));
      const filteredLocalMessages = (localMessages ?? []).filter(
        (m) => !idSet.has(m.id),
      );
      if (threadId) {
        dispatch({
          type: "set_messages",
          threadId,
          payload: filteredLocalMessages,
        });
      }
    },
  });

  const serverMessages = data?.chatMessages;

  const messages = useMemo(() => {
    return (threadId ? (serverMessages ?? []) : [])
      .map(
        (m) =>
          ({
            id: m.id,
            role: m.role as ChatRole.User | ChatRole.Assistant,
            content: m.content,
          }) as ChatMessage,
      )
      .concat(localMessages ?? []);
  }, [threadId, serverMessages, localMessages]);

  const [postChatMessage, { isLoading: isStreamingResponse }] = useChatStream({
    onMessageCreated(context) {
      if (threadId === undefined && context.threadId) {
        options.onThreadCreated?.(context.threadId);
        dispatch({ type: "reset_temp_thread" });
      }
    },
  });

  const submitMessage = async (options: {
    threadId?: string;
    input: string;
    modelContext?: CreateChatMessageInput["modelContext"];
    tables?: string[];
  }): Promise<SubmitMessageResult | undefined> => {
    if (isStreamingResponse) return;

    const newMessage: ChatMessage = {
      id: uuid(),
      role: ChatRole.User,
      content: options.input,
    };
    const responseMessage: ChatMessage = {
      id: uuid(),
      role: ChatRole.Assistant,
      content: "",
    };

    dispatch({
      type: "add_messages",
      threadId: threadId,
      payload: [newMessage, responseMessage],
    });

    const input: CreateChatMessageInput = {
      threadId: options.threadId,
      content: newMessage.content,
      modelContext: options.modelContext,
      tables: options.tables,
    };
    try {
      const data = await postChatMessage({
        input,
        onThreadCreated(response) {
          // Create a new state in the store for the new thread and add the initial messages
          dispatch({
            type: "set_messages",
            threadId: response.threadId,
            payload: [
              { ...newMessage, id: response.messageId },
              { ...responseMessage, id: response.responseMessageId },
            ],
          });
        },
        onMessageCreated(response) {
          // Update the initial messages with new ids received from the server
          dispatch({
            type: "update_message",
            threadId: response.threadId,
            messageId: newMessage.id,
            payload: { id: response.messageId },
          });
          dispatch({
            type: "update_message",
            threadId: response.threadId,
            messageId: responseMessage.id,
            payload: { id: response.responseMessageId },
          });
        },
        onChunk(chunk, { threadId, responseMessageId }) {
          dispatch({
            type: "append_message_chunk",
            threadId,
            messageId: responseMessageId,
            payload: chunk,
          });
        },
      });
      return {
        data: data,
        error: undefined,
      };
    } catch (e) {
      refetchMessages();
      dispatch({
        type: "remove_messages",
        threadId,
        payload: [newMessage.id, responseMessage.id],
      });
      return {
        data: undefined,
        error: e instanceof Error ? e : new Error(String(e)),
      };
    }
  };

  return {
    threadId,
    messages,
    isMessagesLoading,
    isStreamingResponse,
    submitMessage,
  };
}
