import uuid

import streamlit as st
from langchain_core.messages import HumanMessage

from app.core.rag.graph import init_rag_app
from app.utils import get_logger

rag_app = init_rag_app()
logger = get_logger(__name__)

st.set_page_config(page_title="Bhasika AI", page_icon="🤖", layout="wide")

if "thread_id" not in st.session_state:
    st.session_state.thread_id = "test_user_conversation"
config = {"configurable": {"thread_id": st.session_state.thread_id}}
snapshot = rag_app.get_state(config)

if "messages" not in st.session_state:
    messages = snapshot.values.get("messages", [])
    st.session_state.messages = []
    for message in messages:
        formatted = {"role": message.type, "content": message.content}
        st.session_state.messages.append(formatted)

if "profile" not in st.session_state:
    # For displaying only
    st.session_state.profile = snapshot.values.get("profile", {})

with st.sidebar:
    st.title("Settings")
    st.info(f"Thread ID: {st.session_state.thread_id}")
    st.write(st.session_state.profile)
    if st.button("Clear Chat"):
        st.session_state.messages = []
        st.session_state.thread_id = str(uuid.uuid4())
        st.rerun()

st.title("Bhasika AI")

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        if "metadata" in message:
            with st.expander("Process Details"):
                st.json(message["metadata"])

if prompt := st.chat_input("Ask me something..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        status_placeholder = st.empty()
        response_placeholder = st.empty()

        full_response = ""

        inputs = {
            "user_input": prompt,
            "messages": [HumanMessage(prompt)],
            "retrieved_docs": [],  # reset for next tun
        }

        try:
            for chunk in rag_app.stream(inputs, config=config, stream_mode="updates"):
                for node_name, state_update in chunk.items():
                    status_placeholder.status(f"Running: **{node_name}**...")

                    if node_name == "generate" and "response" in state_update:
                        full_response = state_update["response"]
                        response_placeholder.markdown(full_response)
                    if node_name == "extract" and "profile" in state_update:
                        st.session_state.profile = state_update["profile"]

            status_placeholder.empty()

            st.session_state.messages.append(
                {
                    "role": "assistant",
                    "content": full_response,
                }
            )

        except Exception as e:
            logger.exception("Error processing workflow")
            st.error(f"Error processing workflow: {str(e)}")
