Created
October 28, 2025 14:08
-
-
Save michaelgorkow/92fa7c0f38e8f4ed32c23aed1e0f97e6 to your computer and use it in GitHub Desktop.
snowflake-analyst-sis-app
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ##### CHANGE THIS TO YOUR LOCATION OF SEMANTIC VIEW ##### | |
| SEMANTIC_VIEW_DB = "HACKATHON" | |
| SEMANTIC_VIEW_SCHEMA = "CALL_CENTER" | |
| SEMANTIC_VIEW_NAME = "CALL_CENTER_DATA" | |
| ######################################################### | |
| SEMANTIC_VIEW = f"{SEMANTIC_VIEW_DB}.{SEMANTIC_VIEW_SCHEMA}.{SEMANTIC_VIEW_NAME}." | |
| """ | |
| Cortex Analyst App | |
| ==================== | |
| This app allows users to interact with their data using natural language. | |
| """ | |
| import json # To handle JSON data | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import _snowflake # For interacting with Snowflake-specific APIs | |
| import pandas as pd | |
| import streamlit as st # Streamlit library for building the web app | |
| from snowflake.snowpark.context import ( | |
| get_active_session, | |
| ) # To interact with Snowflake sessions | |
| from snowflake.snowpark.exceptions import SnowparkSQLException | |
| # List of available semantic model paths in the format: <DATABASE>.<SCHEMA>.<STAGE>/<FILE-NAME> | |
| API_ENDPOINT = "/api/v2/cortex/analyst/message" | |
| FEEDBACK_API_ENDPOINT = "/api/v2/cortex/analyst/feedback" | |
| API_TIMEOUT = 50000 # in milliseconds | |
| # Initialize a Snowpark session for executing queries | |
| session = get_active_session() | |
| def main(): | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| reset_session_state() | |
| show_header_and_sidebar() | |
| if len(st.session_state.messages) == 0: | |
| process_user_input("What questions can I ask?") | |
| display_conversation() | |
| handle_user_inputs() | |
| handle_error_notifications() | |
| display_warnings() | |
| def reset_session_state(): | |
| """Reset important session state elements.""" | |
| st.session_state.messages = [] # List to store conversation messages | |
| st.session_state.active_suggestion = None # Currently selected suggestion | |
| st.session_state.warnings = [] # List to store warnings | |
| st.session_state.form_submitted = ( | |
| {} | |
| ) # Dictionary to store feedback submission for each request | |
| def show_header_and_sidebar(): | |
| """Display the header and sidebar of the app.""" | |
| # Set the title and introductory text of the app | |
| st.title("Cortex Analyst") | |
| st.markdown( | |
| "Welcome to Cortex Analyst! Type your questions below to interact with your data. " | |
| ) | |
| # Sidebar with a reset button | |
| with st.sidebar: | |
| # Center this button | |
| _, btn_container, _ = st.columns([2, 6, 2]) | |
| if btn_container.button("Clear Chat History", use_container_width=True): | |
| reset_session_state() | |
| def handle_user_inputs(): | |
| """Handle user inputs from the chat interface.""" | |
| # Handle chat input | |
| user_input = st.chat_input("What is your question?") | |
| if user_input: | |
| process_user_input(user_input) | |
| # Handle suggested question click | |
| elif st.session_state.active_suggestion is not None: | |
| suggestion = st.session_state.active_suggestion | |
| st.session_state.active_suggestion = None | |
| process_user_input(suggestion) | |
| def handle_error_notifications(): | |
| if st.session_state.get("fire_API_error_notify"): | |
| st.toast("An API error has occured!", icon="π¨") | |
| st.session_state["fire_API_error_notify"] = False | |
| def process_user_input(prompt: str): | |
| """ | |
| Process user input and update the conversation history. | |
| Args: | |
| prompt (str): The user's input. | |
| """ | |
| # Clear previous warnings at the start of a new request | |
| st.session_state.warnings = [] | |
| # Create a new message, append to history and display imidiately | |
| new_user_message = { | |
| "role": "user", | |
| "content": [{"type": "text", "text": prompt}], | |
| } | |
| st.session_state.messages.append(new_user_message) | |
| with st.chat_message("user"): | |
| user_msg_index = len(st.session_state.messages) - 1 | |
| display_message(new_user_message["content"], user_msg_index) | |
| # Show progress indicator inside analyst chat message while waiting for response | |
| with st.chat_message("analyst"): | |
| with st.spinner("Waiting for Analyst's response..."): | |
| time.sleep(1) | |
| response, error_msg = get_analyst_response(st.session_state.messages) | |
| if error_msg is None: | |
| analyst_message = { | |
| "role": "analyst", | |
| "content": response["message"]["content"], | |
| "request_id": response["request_id"], | |
| } | |
| else: | |
| analyst_message = { | |
| "role": "analyst", | |
| "content": [{"type": "text", "text": error_msg}], | |
| "request_id": response["request_id"], | |
| } | |
| st.session_state["fire_API_error_notify"] = True | |
| if "warnings" in response: | |
| st.session_state.warnings = response["warnings"] | |
| st.session_state.messages.append(analyst_message) | |
| st.rerun() | |
| def display_warnings(): | |
| """ | |
| Display warnings to the user. | |
| """ | |
| warnings = st.session_state.warnings | |
| for warning in warnings: | |
| st.warning(warning["message"], icon="β οΈ") | |
| def get_analyst_response(messages: List[Dict]) -> Tuple[Dict, Optional[str]]: | |
| """ | |
| Send chat history to the Cortex Analyst API and return the response. | |
| Args: | |
| messages (List[Dict]): The conversation history. | |
| Returns: | |
| Optional[Dict]: The response from the Cortex Analyst API. | |
| """ | |
| # Prepare the request body with the user's prompt | |
| request_body = { | |
| "messages": messages, | |
| "semantic_view": SEMANTIC_VIEW, | |
| } | |
| # Send a POST request to the Cortex Analyst API endpoint | |
| # Adjusted to use positional arguments as per the API's requirement | |
| resp = _snowflake.send_snow_api_request( | |
| "POST", # method | |
| API_ENDPOINT, # path | |
| {}, # headers | |
| {}, # params | |
| request_body, # body | |
| None, # request_guid | |
| API_TIMEOUT, # timeout in milliseconds | |
| ) | |
| # Content is a string with serialized JSON object | |
| parsed_content = json.loads(resp["content"]) | |
| # Check if the response is successful | |
| if resp["status"] < 400: | |
| # Return the content of the response as a JSON object | |
| return parsed_content, None | |
| else: | |
| # Craft readable error message | |
| error_msg = f""" | |
| π¨ An Analyst API error has occurred π¨ | |
| * response code: `{resp['status']}` | |
| * request-id: `{parsed_content['request_id']}` | |
| * error code: `{parsed_content['error_code']}` | |
| Message: | |
| ``` | |
| {parsed_content['message']} | |
| ``` | |
| """ | |
| return parsed_content, error_msg | |
| def display_conversation(): | |
| """ | |
| Display the conversation history between the user and the assistant. | |
| """ | |
| for idx, message in enumerate(st.session_state.messages): | |
| role = message["role"] | |
| content = message["content"] | |
| with st.chat_message(role): | |
| if role == "analyst": | |
| display_message(content, idx, message["request_id"]) | |
| else: | |
| display_message(content, idx) | |
| def display_message( | |
| content: List[Dict[str, Union[str, Dict]]], | |
| message_index: int, | |
| request_id: Union[str, None] = None, | |
| ): | |
| """ | |
| Display a single message content. | |
| Args: | |
| content (List[Dict[str, str]]): The message content. | |
| message_index (int): The index of the message. | |
| """ | |
| for item in content: | |
| if item["type"] == "text": | |
| st.markdown(item["text"]) | |
| elif item["type"] == "suggestions": | |
| # Display suggestions as buttons | |
| for suggestion_index, suggestion in enumerate(item["suggestions"]): | |
| if st.button( | |
| suggestion, key=f"suggestion_{message_index}_{suggestion_index}" | |
| ): | |
| st.session_state.active_suggestion = suggestion | |
| elif item["type"] == "sql": | |
| # Display the SQL query and results | |
| display_sql_query( | |
| item["statement"], message_index, item["confidence"], request_id | |
| ) | |
| else: | |
| # Handle other content types if necessary | |
| pass | |
| @st.cache_data(show_spinner=False) | |
| def get_query_exec_result(query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]: | |
| """ | |
| Execute the SQL query and convert the results to a pandas DataFrame. | |
| Args: | |
| query (str): The SQL query. | |
| Returns: | |
| Tuple[Optional[pd.DataFrame], Optional[str]]: The query results and the error message. | |
| """ | |
| global session | |
| try: | |
| df = session.sql(query).to_pandas() | |
| return df, None | |
| except SnowparkSQLException as e: | |
| return None, str(e) | |
| def display_sql_confidence(confidence: dict): | |
| if confidence is None: | |
| return | |
| verified_query_used = confidence["verified_query_used"] | |
| with st.popover( | |
| "Verified Query Used", | |
| help="The verified query from [Verified Query Repository](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-analyst/verified-query-repository), used to generate the SQL", | |
| ): | |
| with st.container(): | |
| if verified_query_used is None: | |
| st.text( | |
| "There is no query from the Verified Query Repository used to generate this SQL answer" | |
| ) | |
| return | |
| st.text(f"Name: {verified_query_used['name']}") | |
| st.text(f"Question: {verified_query_used['question']}") | |
| st.text(f"Verified by: {verified_query_used['verified_by']}") | |
| st.text( | |
| f"Verified at: {datetime.fromtimestamp(verified_query_used['verified_at'])}" | |
| ) | |
| st.text("SQL query:") | |
| st.code(verified_query_used["sql"], language="sql", wrap_lines=True) | |
| def display_sql_query( | |
| sql: str, message_index: int, confidence: dict, request_id: Union[str, None] = None | |
| ): | |
| """ | |
| Executes the SQL query and displays the results in form of data frame and charts. | |
| Args: | |
| sql (str): The SQL query. | |
| message_index (int): The index of the message. | |
| confidence (dict): The confidence information of SQL query generation | |
| request_id (str): Request id from user request | |
| """ | |
| # Display the SQL query | |
| with st.expander("SQL Query", expanded=False): | |
| st.code(sql, language="sql") | |
| display_sql_confidence(confidence) | |
| # Display the results of the SQL query | |
| with st.expander("Results", expanded=True): | |
| with st.spinner("Running SQL..."): | |
| df, err_msg = get_query_exec_result(sql) | |
| if df is None: | |
| st.error(f"Could not execute generated SQL query. Error: {err_msg}") | |
| elif df.empty: | |
| st.write("Query returned no data") | |
| else: | |
| # Show query results in two tabs | |
| data_tab, chart_tab = st.tabs(["Data π", "Chart π"]) | |
| with data_tab: | |
| st.dataframe(df, use_container_width=True) | |
| with chart_tab: | |
| display_charts_tab(df, message_index) | |
| if request_id: | |
| display_feedback_section(request_id) | |
| def display_charts_tab(df: pd.DataFrame, message_index: int) -> None: | |
| """ | |
| Display the charts tab. | |
| Args: | |
| df (pd.DataFrame): The query results. | |
| message_index (int): The index of the message. | |
| """ | |
| # There should be at least 2 columns to draw charts | |
| if len(df.columns) >= 2: | |
| all_cols_set = set(df.columns) | |
| col1, col2 = st.columns(2) | |
| x_col = col1.selectbox( | |
| "X axis", all_cols_set, key=f"x_col_select_{message_index}" | |
| ) | |
| y_col = col2.selectbox( | |
| "Y axis", | |
| all_cols_set.difference({x_col}), | |
| key=f"y_col_select_{message_index}", | |
| ) | |
| chart_type = st.selectbox( | |
| "Select chart type", | |
| options=["Line Chart π", "Bar Chart π"], | |
| key=f"chart_type_{message_index}", | |
| ) | |
| if chart_type == "Line Chart π": | |
| st.line_chart(df.set_index(x_col)[y_col]) | |
| elif chart_type == "Bar Chart π": | |
| st.bar_chart(df.set_index(x_col)[y_col]) | |
| else: | |
| st.write("At least 2 columns are required") | |
| def display_feedback_section(request_id: str): | |
| with st.popover("π Query Feedback"): | |
| if request_id not in st.session_state.form_submitted: | |
| with st.form(f"feedback_form_{request_id}", clear_on_submit=True): | |
| positive = st.radio( | |
| "Rate the generated SQL", options=["π", "π"], horizontal=True | |
| ) | |
| positive = positive == "π" | |
| submit_disabled = ( | |
| request_id in st.session_state.form_submitted | |
| and st.session_state.form_submitted[request_id] | |
| ) | |
| feedback_message = st.text_input("Optional feedback message") | |
| submitted = st.form_submit_button("Submit", disabled=submit_disabled) | |
| if submitted: | |
| err_msg = submit_feedback(request_id, positive, feedback_message) | |
| st.session_state.form_submitted[request_id] = {"error": err_msg} | |
| st.session_state.popover_open = False | |
| st.rerun() | |
| elif ( | |
| request_id in st.session_state.form_submitted | |
| and st.session_state.form_submitted[request_id]["error"] is None | |
| ): | |
| st.success("Feedback submitted", icon="β ") | |
| else: | |
| st.error(st.session_state.form_submitted[request_id]["error"]) | |
| def submit_feedback( | |
| request_id: str, positive: bool, feedback_message: str | |
| ) -> Optional[str]: | |
| request_body = { | |
| "request_id": request_id, | |
| "positive": positive, | |
| "feedback_message": feedback_message, | |
| } | |
| resp = _snowflake.send_snow_api_request( | |
| "POST", # method | |
| FEEDBACK_API_ENDPOINT, # path | |
| {}, # headers | |
| {}, # params | |
| request_body, # body | |
| None, # request_guid | |
| API_TIMEOUT, # timeout in milliseconds | |
| ) | |
| if resp["status"] == 200: | |
| return None | |
| parsed_content = json.loads(resp["content"]) | |
| # Craft readable error message | |
| err_msg = f""" | |
| π¨ An Analyst API error has occurred π¨ | |
| * response code: `{resp['status']}` | |
| * request-id: `{parsed_content['request_id']}` | |
| * error code: `{parsed_content['error_code']}` | |
| Message: | |
| ``` | |
| {parsed_content['message']} | |
| ``` | |
| """ | |
| return err_msg | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment