import os
from dotenv import load_dotenv
load_dotenv()
= os.getenv("OPENAI_API_KEY") OPENAI_API_KEY
1 Introduction
SQL databases are frequently used to hold enterprise data. Natural language interaction with SQL databases is made feasible by LLMs such as OpenAI’s ChatGPT and GPT Models. LangChain provides SQL Chains and Agents for building and running SQL queries based on natural language prompts. These SQL Chains and Agents are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite).
They enable use cases like:
- Creating queries that will be executed in response to natural language questions
- Developing chatbots that can answer queries based on database data
- Developing custom dashboards based on information that a user want to analyse
In this article we will see different ways we can use langchain and LLM’s to ask questions about data in an SQL database.
2 Overview
LangChain provides tools to interact with SQL Databases:
Build SQL queries
based on natural language user questionsQuery a SQL database
using chains for query creation and executionInteract with a SQL database
using agents for robust and flexible querying
3 Import Libs & Setup
First, get required packages and set environment variables:
For our project we are also going to use Langsmith for logging runs and visulising runs, I wrote an article introducing Langsmith previously. Let’s set up and configure Langsmith.
from uuid import uuid4
= uuid4().hex[0:8]
unique_id "LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"Langchain SQL Demo - {unique_id}"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") # Update to your API key
os.environ[
# Used by the agent in this post
"OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") os.environ[
We will also use an SQLite connection with the Chinook database. The Chinook database is an sample data database that represents a digital media store, including tables for artists, albums, media tracks, invoices and customers.
There are 11 tables in the chinook sample database.
- employees table stores employees data such as employee id, last name, first name, etc. It also has a field named ReportsTo to specify who reports to whom.
- customers table stores customers data.
- invoices & invoice_items tables these two tables store invoice data. The invoices table stores invoice header data and the invoice_items table stores the invoice line items data.
- artists table stores artists data. It is a simple table that contains only the artist id and name.
- albums table stores data about a list of tracks. Each album belongs to one artist. However, one artist may have multiple albums.
- media_types table stores media types such as MPEG audio and AAC audio files.
- genres table stores music types such as rock, jazz, metal, etc.
- tracks table stores the data of songs. Each track belongs to one album.
- playlists & playlist_track tables playlists table store data about playlists. Each playlist contains a list of tracks. Each track may belong to multiple playlists. The relationship between the playlists table and tracks table is many-to-many. The playlist_track table is used to reflect this relationship.
Follow installation steps to create Chinook.db
in the same directory as this notebook:
- Save this file to the directory as
Chinook_Sqlite.sql
- Run
sqlite3 Chinook.db
- Run
.read Chinook_Sqlite.sql
- Test
SELECT * FROM Artist LIMIT 10;
Now, Chinhook.db
is in our directory.
Let’s create a SQLDatabaseChain
to create and execute SQL queries.
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain
= SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
db = OpenAI(temperature=0, verbose=True)
llm = SQLDatabaseChain.from_llm(llm, db, verbose=True) db_chain
/Users/pranathfernando/opt/anaconda3/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.24) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.
warnings.warn(
"How many customers are there?") db_chain.run(
> Entering new SQLDatabaseChain chain...
How many customers are there?
SQLQuery:SELECT COUNT(*) FROM Customer;
SQLResult: [(59,)]
Answer:There are 59 customers.
> Finished chain.
/Users/pranathfernando/opt/anaconda3/lib/python3.9/site-packages/langchain/utilities/sql_database.py:357: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.
sample_rows_result = connection.execute(command) # type: ignore
'There are 59 customers.'
Note that this both creates and executes the query. In the following sections, we will cover the 3 different use cases mentioned in the overview.
4 Case 1: Text-to-SQL query
from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
Let’s create the chain that will build the SQL Query:
= create_sql_query_chain(ChatOpenAI(temperature=0), db)
chain = chain.invoke({"question":"How many customers are there"})
response print(response)
SELECT COUNT(*) FROM Customer
After building the SQL query based on a user question, we can execute the query:
db.run(response)
'[(59,)]'
As we can see, the SQL Query Builder chain only created the query, and we handled the query execution separately.
4.1 Go deeper
Looking under the hood
We can look at the LangSmith trace to unpack this:
This is the full text of the prompt created from that query:
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
CREATE TABLE "Album" (
"AlbumId" INTEGER NOT NULL,
"Title" NVARCHAR(160) NOT NULL,
"ArtistId" INTEGER NOT NULL,
PRIMARY KEY ("AlbumId"),
FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)
/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Employee" (
"EmployeeId" INTEGER NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"FirstName" NVARCHAR(20) NOT NULL,
"Title" NVARCHAR(30),
"ReportsTo" INTEGER,
"BirthDate" DATETIME,
"HireDate" DATETIME,
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60),
PRIMARY KEY ("EmployeeId"),
FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Employee table:
EmployeeId LastName FirstName Title ReportsTo BirthDate HireDate Address City State Country PostalCode Phone Fax Email
1 Adams Andrew General Manager None 1962-02-18 00:00:00 2002-08-14 00:00:00 11120 Jasper Ave NW Edmonton AB Canada T5K 2N1 +1 (780) 428-9482 +1 (780) 428-3457 andrew@chinookcorp.com
2 Edwards Nancy Sales Manager 1 1958-12-08 00:00:00 2002-05-01 00:00:00 825 8 Ave SW Calgary AB Canada T2P 2T3 +1 (403) 262-3443 +1 (403) 262-3322 nancy@chinookcorp.com
3 Peacock Jane Sales Support Agent 2 1973-08-29 00:00:00 2002-04-01 00:00:00 1111 6 Ave SW Calgary AB Canada T2P 5M5 +1 (403) 262-3443 +1 (403) 262-6712 jane@chinookcorp.com
*/
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
CREATE TABLE "InvoiceLine" (
"InvoiceLineId" INTEGER NOT NULL,
"InvoiceId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
"Quantity" INTEGER NOT NULL,
PRIMARY KEY ("InvoiceLineId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)
/*
3 rows from InvoiceLine table:
InvoiceLineId InvoiceId TrackId UnitPrice Quantity
1 1 2 0.99 1
2 1 4 0.99 1
3 2 6 0.99 1
*/
CREATE TABLE "MediaType" (
"MediaTypeId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("MediaTypeId")
)
/*
3 rows from MediaType table:
MediaTypeId Name
1 MPEG audio file
2 Protected AAC audio file
3 Protected MPEG-4 video file
*/
CREATE TABLE "Playlist" (
"PlaylistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("PlaylistId")
)
/*
3 rows from Playlist table:
PlaylistId Name
1 Music
2 Movies
3 TV Shows
*/
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
3 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99
*/
Question: How many customers are there
SQLQuery:
Some papers have reported good performance when prompting with:
- A
CREATE TABLE
description for each table, which include column names, their types, etc - Followed by three example rows in a
SELECT
statement
create_sql_query_chain
adopts this the best practice (see more in this blog).
Improvements
The query builder can be improved in a variety of ways, including (but not limited to):
- Tailoring the database description to your particular use case
- Using a vector database to provide dynamic examples that are relevant to the individual user question - Hardcoding a few instances of questions and their related SQL query in the prompt
All of these examples involve changing the prompt for the chain. For example, we could include the following instances in our prompt:
from langchain.prompts import PromptTemplate
= """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
TEMPLATE Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}.
Some examples of SQL queries that corrsespond to questions are:
{few_shot_examples}
Question: {input}"""
= PromptTemplate(
CUSTOM_PROMPT =["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
input_variables )
5 Case 2: Text-to-SQL query and execution
We can use SQLDatabaseChain
from langchain_experimental
to create and run SQL queries.
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain
= OpenAI(temperature=0, verbose=True)
llm = SQLDatabaseChain.from_llm(llm, db, verbose=True) db_chain
"How many customers are there?") db_chain.run(
> Entering new SQLDatabaseChain chain...
How many customers are there?
SQLQuery:SELECT COUNT(*) FROM Customer;
SQLResult: [(59,)]
Answer:There are 59 customers.
> Finished chain.
'There are 59 customers.'
As we can see, we get the same result as the previous case.
Here, the chain also handles the query execution and provides a final answer based on the user question and the query result.
Be careful while using this approach as it is susceptible to SQL Injection
:
- The chain is executing queries that are created by an LLM, and weren’t validated
- e.g. records may be created, modified or deleted unintentionally_
This is why we see the SQLDatabaseChain
is inside langchain_experimental
.
5.1 Go deeper
Looking under the hood
We can use the LangSmith trace to see what is happening under the hood:
- As discussed above, first we create the query:
text: ' SELECT COUNT(*) FROM "Customer";'
- Then, it executes the query and passes the results to an LLM for synthesis.
Improvements
The performance of the SQLDatabaseChain
can be enhanced in several ways:
- Adding sample rows
- Specifying custom table information
- Using Query Checker self-correct invalid SQL using parameter
use_query_checker=True
- Customizing the LLM Prompt include specific instructions or relevant information, using parameter
prompt=CUSTOM_PROMPT
- Get intermediate steps access the SQL statement as well as the final result using parameter
return_intermediate_steps=True
- Limit the number of rows a query will return using parameter
top_k=5
You might find SQLDatabaseSequentialChain useful for cases in which the number of tables in the database is large.
This Sequential Chain
handles the process of:
- Determining which tables to use based on the user question
- Calling the normal SQL database chain using only relevant tables
Adding Sample Rows
Providing sample data can help the LLM construct correct queries when the data format is not obvious.
For example, we can tell LLM that artists are saved with their full names by providing two rows from the Track table.
= SQLDatabase.from_uri(
db "sqlite:///docs/Chinook.db",
=['Track'], # we include only one table to save tokens in the prompt :)
include_tables=2) sample_rows_in_table_info
The sample rows are added to the prompt after each corresponding table’s column information.
We can use db.table_info
and check which sample rows are included:
print(db.table_info)
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
2 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
*/
6 Case 3: SQL agents
LangChain has a SQL Agent that is more flexible than the ‘SQLDatabaseChain’ in communicating with SQL Databases.
The following are the primary benefits of utilising the SQL Agent:
- It can answer questions based on the schema as well as the content of the databases (for example, describing a specific table).
- It can recover from problems by running a created query, capturing the traceback, and correctly rebuilding it.
In this article the author desribed reasons why you might want to consider using an agent for SQL queries rather than just a chain:
‘…Let us first understand what is an agent and why it might be preferred over a simple SQLChain. An agent is a component that has access to a suite of tools, including a Large Language Model (LLM). Its distinguishing characteristic lies in its ability to make informed decisions based on user input, utilizing the appropriate tools until it achieves a satisfactory answer. For example in the context of text-to-SQL, the LangChain SQLAgent will not give up if there is an error in executing the generated SQL. Instead, it will attempt to recover by interpreting the error in a subsequent LLM call and rectify the issue. Therefore, in theory, SQLAgent should outperform SQLChain in productivity and accuracy’
And this is what that author found from their experiments:
‘…During our tests, we ran multiple questions on both SQLChain and SQLAgent using GPT-3.5 and compared their respective results. Our findings revealed that SQLAgent outperformed SQLChain by answering a greater number of questions…For accuracy, however, our findings also indicate a higher incidence of incorrect responses from SQLAgent. Besides the general shortcomings of using LLM to query database, we hypothesize that SQLAgent will occasionally make its best attempt to answer a question even when concrete results cannot be obtained from the SQL query.’
The ‘create_sql_agent’ method is used to initialise the agent.
This agent includes the ‘SQLDatabaseToolkit,’ which includes tools for:
- Create and run queries
- Verify query syntax
- Get table descriptions
- … and much more
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
# from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
= SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
db = OpenAI(temperature=0, verbose=True)
llm
= create_sql_agent(
agent_executor =OpenAI(temperature=0),
llm=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),
toolkit=True,
verbose=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
agent_type )
6.1 Agent task example #1 - Running queries
agent_executor.run("List the total sales per country. Which country's customers spent the most?"
)
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input:
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer
Observation:
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
Thought: I should query the total sales per country.
Action: sql_db_query
Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10
Observation: [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]
Thought: I now know the final answer
Final Answer: The country with the highest total sales is the USA, with a total of $523.06.
> Finished chain.
'The country with the highest total sales is the USA, with a total of $523.06.'
Looking at the LangSmith trace, we can see:
- The agent is using a ReAct style prompt
- First, it will look at the tables:
Action: sql_db_list_tables
using toolsql_db_list_tables
- Given the tables as an observation, it
thinks
and then determinates the nextaction
:
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer
- It then formulates the query using the schema from tool
sql_db_schema
Thought: I should query the total sales per country.
Action: sql_db_query
Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10
- It finally executes the generated query using tool
sql_db_query
6.2 Agent task example #2 - Describing a Table
"Describe the playlisttrack table") agent_executor.run(
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input:
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the PlaylistTrack table
Action: sql_db_schema
Action Input: PlaylistTrack
Observation:
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
Thought: I now know the final answer
Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.
> Finished chain.
'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.'
7 Extending the SQL Toolkit with Domain Specific Knowledge Tools
In a recent Langchain blog article on 5/9/23 they highlighted how you can use few shot examples to bring in domain specific knowlege as an alternative approach.
Although the Langchain SQL Toolkit includes all of the tools needed to begin working on a database, several additional tools may be beneficial for increasing the agent’s capabilities. This is especially important when attempting to integrate domain-specific information in the solution to increase overall performance.
Here are a few examples:
- Including dynamic demonstrations of a few shots
- Identifying misspellings of proper nouns for use as column filters
We can develop distinct tools to address these unique use cases and include them as an addition to the regular SQL Toolkit. Let’s look at how to incorporate these two bespoke tools.
7.1 Including dynamic few-shot examples
To integrate dynamic few-shot examples, we require a custom Retriever Tool that searches the vector database for examples that are semantically related to the user’s query.
Let’s begin by making a dictionary out of several examples:
= {'List all artists.': 'SELECT * FROM artists;',
few_shots "Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
"List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
}
We can then create a retriever using the list of questions, assigning the target SQL query as metadata:
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
= OpenAIEmbeddings()
embeddings
= [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
few_shot_docs = FAISS.from_documents(few_shot_docs, embeddings)
vector_db = vector_db.as_retriever() retriever
Now we can create our own custom tool and append it as a new tool in the create_sql_agent function:
from langchain.agents.agent_toolkits import create_retriever_tool
= """
tool_description This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""
= create_retriever_tool(
retriever_tool
retriever,='sql_get_similar_examples',
name=tool_description
description
)= [retriever_tool] custom_tool_list
We can now create the agent by modifying the normal SQL Agent suffix to reflect our use case. Although including it in the tool description is the simplest method to manage this, it is frequently insufficient, and we must express it in the agent prompt using the suffix argument in the constructor.
from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI
= SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
db = ChatOpenAI(model_name='gpt-4',temperature=0)
llm
= SQLDatabaseToolkit(db=db, llm=llm)
toolkit
= """
custom_suffix I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""
= create_sql_agent(llm=llm,
agent =toolkit,
toolkit=True,
verbose=AgentType.OPENAI_FUNCTIONS,
agent_type=custom_tool_list,
extra_tools=custom_suffix
suffix )
"How many employees do we have?") agent.run(
> Entering new AgentExecutor chain...
Invoking: `sql_get_similar_examples` with `How many employees do we have?`
[Document(page_content='How many employees are there', metadata={'sql_query': 'SELECT COUNT(*) FROM "employee"'}), Document(page_content='Find the total number of invoices.', metadata={'sql_query': 'SELECT COUNT(*) FROM invoices;'}), Document(page_content='Who are the top 5 customers by total purchase?', metadata={'sql_query': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;'}), Document(page_content='List all customers from Canada.', metadata={'sql_query': "SELECT * FROM customers WHERE Country = 'Canada';"})]
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM employee`
responded: {content}
SELECT COUNT(*) FROM employee
Invoking: `sql_db_query` with `SELECT COUNT(*) FROM employee`
[(8,)]We have 8 employees.
> Finished chain.
'We have 8 employees.'
7.2 Identifying and fixing proper noun misspellings
To accurately filter data from columns that contain proper nouns such as addresses, song titles, or artists, we must first double-check the spelling.
We may accomplish this by establishing a vector store with all of the various proper nouns in the database. The agent can then query that vector storage each time a proper noun is included in a question to find the right spelling for that word. Before constructing the target query, the agent can ensure that it understands which entity the user is referring to.
Let’s take a similar technique to the few shots, but without the metadata: embedding the proper nouns and then querying to find the one that is most similar to the misspelt user question.
First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:
import ast
import re
def run_query_save_results(db, query):
= db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r'\b\d+\b', '', string).strip() for string in res]
res return res
= run_query_save_results(db, "SELECT Name FROM Artist")
artists = run_query_save_results(db, "SELECT Title FROM Album") albums
Now we can proceed with creating the custom retreiver tool and the final agent:
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
= (artists + albums)
texts
= OpenAIEmbeddings()
embeddings = FAISS.from_texts(texts, embeddings)
vector_db = vector_db.as_retriever()
retriever
= create_retriever_tool(
retriever_tool
retriever,='name_search',
name='use to learn how a piece of data is actually written, can be from names, surnames addresses etc'
description
)
= [retriever_tool] custom_tool_list
from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI
# db = SQLDatabase.from_uri("sqlite:///Chinook.db")
= ChatOpenAI(model_name='gpt-4', temperature=0)
llm
= SQLDatabaseToolkit(db=db, llm=llm)
toolkit
= """
custom_suffix If a user asks for me to filter based on proper nouns, I should first check the spelling using the name_search tool.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""
= create_sql_agent(llm=llm,
agent =toolkit,
toolkit=True,
verbose=AgentType.OPENAI_FUNCTIONS,
agent_type=custom_tool_list,
extra_tools=custom_suffix
suffix )
"How many albums does alis in pains have?") agent.run(
> Entering new AgentExecutor chain...
Invoking: `name_search` with `alis in pains`
[Document(page_content='House of Pain', metadata={}), Document(page_content='Alice In Chains', metadata={}), Document(page_content='Aisha Duo', metadata={}), Document(page_content='House Of Pain', metadata={})]
Invoking: `sql_db_list_tables` with ``
responded: {content}
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `Album, Artist`
responded: {content}
CREATE TABLE "Album" (
"AlbumId" INTEGER NOT NULL,
"Title" NVARCHAR(160) NOT NULL,
"ArtistId" INTEGER NOT NULL,
PRIMARY KEY ("AlbumId"),
FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)
/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`
responded: {content}
SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')
Invoking: `sql_db_query` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`
[(1,)]Alice In Chains has 1 album in the database.
> Finished chain.
'Alice In Chains has 1 album in the database.'
8 Further Reading
To learn more about the SQL Agent and how it works please refer to the SQL Agent Toolkit and LangChain Uses cases - SQL documentation.
You can also check Agents for other document types: - Pandas Agent - CSV Agent
9 Acknowledgements
I’d like to express my thanks to the wonderful Langsmith Documentation and acknowledge the use of some images and other materials from the documentation in this article.