0% found this document useful (0 votes)
20 views

Turn_CSV_data_into_Text2SQL_agent

Uploaded by

Avinash Reddy
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
20 views

Turn_CSV_data_into_Text2SQL_agent

Uploaded by

Avinash Reddy
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 9

14_text2sql_agent

October 8, 2024

1 Text2SQL Agent to Interact with CSV Data


1.1 System Architecture
Think about it as an agent with a set of tools such as search_cache() and generate_SQL_query(),
and run_sql_query().

1.2 Data Ingestion Pipeline


1. Read CSV

1
2. Create Database schema
3. Create a table
4. Load table with CSV data

[1]: import pandas as pd


import sqlite3

def csv_to_sqlite(csv_file, db_name, table_name):


# Read the CSV file into a pandas DataFrame
df = pd.read_csv(csv_file)

# Connect to the SQLite database (it will create the database file if it␣
↪doesn't exist)

conn = sqlite3.connect(db_name)
cursor = conn.cursor()

# Infer the schema based on the DataFrame columns and data types
def create_table_from_df(df, table_name):
# Get column names and types
col_types = []
for col in df.columns:
dtype = df[col].dtype
if dtype == 'int64':
col_type = 'INTEGER'
elif dtype == 'float64':
col_type = 'REAL'
else:
col_type = 'TEXT'
col_types.append(f'"{col}" {col_type}')

# Create the table schema


col_definitions = ", ".join(col_types)
create_table_query = f'CREATE TABLE IF NOT EXISTS {table_name}␣
↪({col_definitions});'

# print(create_table_query)

# Execute the table creation query


cursor.execute(create_table_query)
print(f"Table '{table_name}' created with schema: {col_definitions}")

# Create table schema


create_table_from_df(df, table_name)

# Insert CSV data into the SQLite table


df.to_sql(table_name, conn, if_exists='replace', index=False)

# Commit and close the connection

2
conn.commit()
conn.close()
print(f"Data loaded into '{table_name}' table in '{db_name}' SQLite␣
↪database.")

csv_file = "movies.csv"
db_name = "movies_db.db"
table_name = "movies"
csv_to_sqlite(csv_file, db_name, table_name)

Table 'movies' created with schema: "Movie" TEXT, "LeadStudio" TEXT,


"RottenTomatoes" REAL, "AudienceScore" REAL, "Story" TEXT, "Genre" TEXT,
"TheatersOpenWeek" REAL, "OpeningWeekend" REAL, "BOAvgOpenWeekend" REAL,
"DomesticGross" REAL, "ForeignGross" REAL, "WorldGross" REAL, "Budget" REAL,
"Profitability" REAL, "OpenProfit" REAL, "Year" INTEGER
Data loaded into 'movies' table in 'movies_db.db' SQLite database.

[2]: def run_sql_query(db_name, query):


"""
Executes a SQL query on a SQLite database and returns the results.

Args:
db_name (str): The name of the SQLite database file.
query (str): The SQL query to run.

Returns:
list: Query result as a list of tuples, or an empty list if no results␣
↪or error occurred.

"""
try:
# Connect to the SQLite database
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

# Execute the SQL query


cursor.execute(query)

# Fetch all results


results = cursor.fetchall()

# Close the connection


conn.close()

# Return results or an empty list if no results were found


return results if results else []

except sqlite3.Error as e:

3
print(f"An error occurred while executing the query: {e}")
return []

[3]: query = f"SELECT count(*) FROM {table_name};"


results = run_sql_query(db_name, query)

if results:
for row in results:
print(row)

(970,)

1.3 Ask Natural Language Questions


[24]: import openai
import faiss
import numpy as np
import os
from openai import OpenAI
from litellm import completion
from IPython.display import Markdown, display

[5]: OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")


client = OpenAI(api_key=OPENAI_API_KEY)

# Initialize the FAISS index


dimension = 1536 # Dimension size for OpenAI embeddings (may vary by model)
index = faiss.IndexFlatL2(dimension) # L2 distance index

# Cache will hold (user_question, sql_query, response)


cache = []

[6]: # Helper function to get embeddings from OpenAI or any embedding model
def get_embeddings(text):
"""
Converts a text string into a vector embedding using OpenAI embeddings.

Args:
text (str): The text string to convert.

Returns:
np.array: A vector representation of the text.
"""
response = client.embeddings.create(input=text,␣
↪model="text-embedding-3-small")

embedding = np.array(response.data[0].embedding)
return embedding

4
[31]: def search_cache(question_embedding, threshold=0.1):
"""
Searches the FAISS index for a similar question.

Args:
question_embedding (np.array): The embedding of the user's question.
threshold (float): The similarity threshold for considering a hit.

Returns:
tuple: (sql_query, response) if a hit is found, otherwise None.
"""
if index.ntotal > 0:
distances, indices = index.search(np.array([question_embedding]), k=1)
# print(distances)
# print(indices)
# Check if the closest distance is below the threshold
if distances[0][0] < threshold:
cache_index = indices[0][0]
return cache[cache_index][1], cache[cache_index][2]
return None

[16]: def get_table_schema(db_name, table_name):


"""
Retrieves the schema (columns and data types) for a given table in the␣
↪SQLite database.

Args:
db_name (str): The name of the SQLite database file.
table_name (str): The name of the table.

Returns:
list: A list of tuples with column name, data type, and other info.
"""
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

# Use PRAGMA to get the table schema


cursor.execute(f"PRAGMA table_info({table_name});")
schema = cursor.fetchall()

conn.close()
return schema

table_name = 'movies'
schema = get_table_schema(db_name, table_name)
print(f"Schema for {table_name}:")
for col in schema:

5
print(col)

Schema for movies:


(0, 'Movie', 'TEXT', 0, None, 0)
(1, 'LeadStudio', 'TEXT', 0, None, 0)
(2, 'RottenTomatoes', 'REAL', 0, None, 0)
(3, 'AudienceScore', 'REAL', 0, None, 0)
(4, 'Story', 'TEXT', 0, None, 0)
(5, 'Genre', 'TEXT', 0, None, 0)
(6, 'TheatersOpenWeek', 'REAL', 0, None, 0)
(7, 'OpeningWeekend', 'REAL', 0, None, 0)
(8, 'BOAvgOpenWeekend', 'REAL', 0, None, 0)
(9, 'DomesticGross', 'REAL', 0, None, 0)
(10, 'ForeignGross', 'REAL', 0, None, 0)
(11, 'WorldGross', 'REAL', 0, None, 0)
(12, 'Budget', 'REAL', 0, None, 0)
(13, 'Profitability', 'REAL', 0, None, 0)
(14, 'OpenProfit', 'REAL', 0, None, 0)
(15, 'Year', 'INTEGER', 0, None, 0)

[25]: def generate_llm_prompt(table_name, table_schema):


"""
Generates a prompt to provide context about a table's schema for LLM to␣
↪convert natural language to SQL.

Args:
table_name (str): The name of the table.
table_schema (list): A list of tuples where each tuple contains␣
↪information about the columns in the table.

Returns:
str: The generated prompt to be used by the LLM.
"""
prompt = f"""You are an expert in writing SQL queries for relational␣
↪databases.

You will be provided with a database schema and a natural


language question, and your task is to generate an accurate SQL query.

The database has a table named '{table_name}' with the following schema:
↪\n\n"""

prompt += "Columns:\n"

for col in table_schema:


column_name = col[1]
column_type = col[2]
prompt += f"- {column_name} ({column_type})\n"

6
prompt += "\nPlease generate a SQL query based on the following natural␣
↪language question. ONLY return the SQL query."

return prompt

table_name = "movies"
schema = get_table_schema(db_name, table_name)
# Generate the prompt
llm_prompt = generate_llm_prompt(table_name, schema)
print(llm_prompt)

You are an expert in writing SQL queries for relational databases.


You will be provided with a database schema and a natural
language question, and your task is to generate an accurate SQL query.

The database has a table named 'movies' with the following schema:

Columns:
- Movie (TEXT)
- LeadStudio (TEXT)
- RottenTomatoes (REAL)
- AudienceScore (REAL)
- Story (TEXT)
- Genre (TEXT)
- TheatersOpenWeek (REAL)
- OpeningWeekend (REAL)
- BOAvgOpenWeekend (REAL)
- DomesticGross (REAL)
- ForeignGross (REAL)
- WorldGross (REAL)
- Budget (REAL)
- Profitability (REAL)
- OpenProfit (REAL)
- Year (INTEGER)

Please generate a SQL query based on the following natural language question.
ONLY return the SQL query.

[26]: def handle_user_question(user_question):


"""
Handles the user's question by first searching the cache, and if there's no␣
↪hit, generating a SQL query and response.

Args:
user_question (str): The user's natural language question.

7
Returns:
list: The response to the user's question.
"""
# Convert the user's question to an embedding
question_embedding = get_embeddings(user_question)

# Step 1: Search cache for similar questions


cache_hit = search_cache(question_embedding)
if cache_hit:
sql_query, response = cache_hit
print(f"Cache hit! SQL Query: {sql_query}")
return response

# Step 2: No hit, go to LLM for SQL generation


print("Cache miss! Generating SQL from LLM...")
sql_query = generate_sql_query(user_question)

# Step 3: Run the SQL query on the database


response = run_sql_query(db_name, sql_query)

# Step 4: Store question, SQL, and response in cache


cache.append((user_question, sql_query, response))
index.add(np.array([question_embedding])) # Add question embedding to␣
↪FAISS index

return response

[27]: def generate_sql_query(question):


table_name = 'movies'
db_name = 'movies_db.db'
table_schema = get_table_schema(db_name, table_name)
llm_prompt = generate_llm_prompt(table_name, table_schema)
user_prompt = """Question: {question}"""
response = completion(
api_key=OPENAI_API_KEY,
model="gpt-4o-mini",
messages=[
{"content": llm_prompt.format(table_name=table_name),"role":␣
↪"system"},

{"content": user_prompt.format(question=question),"role": "user"}],


max_tokens=1000
)
answer = response.choices[0].message.content
display(Markdown(answer))
query = answer.replace("```sql", "").replace("```", "")
query = query.strip()
return query

8
[37]: # question = "total number of movies are made by Warner Bros company in year␣
↪2008?"

# question = "how many movies have RottenTomatoes scores lower than 85?"
question = "how many movies with action genre are in the database"
handle_user_question(question)

Cache miss! Generating SQL from LLM…


SELECT COUNT(*) AS ActionMovieCount
FROM movies
WHERE Genre = 'Action';

[37]: [(166,)]

[38]: cache

[38]: [('total number of movies are made by Warner Bros company in year 2008?',
"SELECT COUNT(*) \nFROM movies \nWHERE LeadStudio = 'Warner Bros' AND Year =
2008;",
[(21,)]),
('how many movies have RottenTomatoes scores greater than 85?',
'SELECT COUNT(*) \nFROM movies \nWHERE RottenTomatoes > 85;',
[(120,)]),
('how many movies have RottenTomatoes scores lower than 85?',
'SELECT COUNT(*) \nFROM movies \nWHERE RottenTomatoes < 85;',
[(782,)]),
('how many movies with action genre are in the database',
"SELECT COUNT(*) AS ActionMovieCount\nFROM movies\nWHERE Genre = 'Action';",
[(166,)])]

[ ]:

[ ]:

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy