Seq2SQL: 用強化學習與你的資料庫對話

Introduction

「關聯式資料庫」(Relational Database) 在現今已經是儲存 structured data 中不可或缺的工具之一,但是使用者要搜尋資料庫的話,必須得先學會怎麼使用複雜的 SQL (Structured Query Language) 語法。有鑑於此,有沒有更加直覺的方式來降低使用上的門檻,讓一般使用者也能夠輕鬆駕馭這些系統呢?

本篇要介紹來自 Salesforce 的最新研究 "Seq2SQL: Generating Structured Queries from Natural Language using Reinforcement Learning" 就提出了兩點貢獻:
1. Seq2SQL:利用 RL (Reinforcement Learning) 將自然語言轉換成 SQL 指令,讓使用者透過這種自然語言介面 (Natural Language Interface) 就能輕鬆與資料庫做互動。
2. WikiSQL:開源了從 Wikipedia 擷取的 87,673 筆資料,包含 natural language questions, SQL queries, 與 SQL tables。

Model

一般來說不管像是機器翻譯 (input source language; output target language)、文章摘要 (input document; output summary),或是本篇所介紹的語意解析 (input natural language; output SQL query) 等這一類 "sequence-to-sequence learning" 的問題會使用目前很紅的深度學習模型-"seq2seq model" 作為 baseline。也就是說,你有一個 neural network encoder 將 input sequence 用 token-by-token 的方式在每一個 time step 下一個一個將 token 讀進來之後作 encoding 得到一個向量 (也就是 encoder 在最後一個 time step 的 hidden state vector),以及另一個 neural network decoder 從 encoder 的 hidden state vector 一樣也是用 token-by-token 的方式來 decode 出 output sequence。

Augmented Pointer Network

不過 seq2seq 模型在一般的應用中,它們的 output vocabulary size 通常很龐大。例如,如果今天的問題 setting 是機器翻譯,則 model 的 output vocabulary size 可能就是 target langauage 語料庫中所有的 token 種類。但在本篇的問題 setting 中,output 的 sequence 通常只會包含 "SELECT", "COUNT" 等有限的 SQL 保留字以及來自 input sequence 中的關鍵字,因此我們可以利用所謂的 "Pointer Network" 中的 "Copy Mechanism",透過讓 decoder 在某個 time step 下,學習該在何時直接 copy 出現在 input sequence 中的關鍵字或是直接從 output vocabulary(SQL 保留字)來當作 output token,藉此達到大量減少 output vocabulary size 的效果。而在這邊的 encoder 與 decoder 分別採用的是 2-layer bi-LSTM 以及 2-layer LSTM。

註:這種技巧其實也有用在前陣子 Salesforce 利用強化學習做文章摘要 的研究中喔!

Seq2SQL

雖然透過 pointer network 可以解決 vocabulary size 的問題,但是其實可以利用 SQL 語法規則來幫助我們設計更好的 model,而不是單純直接用 seq2seq 來直接產生 SQL query,因為這樣有可能也會產生語法上 invalid 的 SQL query。

如上圖所示,Seq2SQL 包含了三個 components,分別是:
1. Aggregation classifier:對 input question 作語意頗析,例如,圖中的例子為 "How many...",此語意對應到 SQL 語法則為 "aggregation functions" 中的 "COUNT" operator。當然也需要額外定義 "NULL" operator 來 handle 不是問關於統計資訊的 question。
2. SELECT column pointer:一樣根據 question 來決定要從哪個 column name 撈資料出來。例如,圖中的例子為 "How many engine types...",此語意表示我們需要從 column = 'Engine' 的資料 select 資料出來。
3. WHERE clause pointer decoder:對應到 input question 中的條件語意產生 SQL query,而這邊會採用 RL 中的 policy gradient 去訓練。

註:Paper 中有詳述每個 component 的計算方式,在此不多贅述。

Generating Equivalent Queries using Policy Gradient

或許對於熟知 deep learning 或是 seq2seq model 的朋友們會問道:「既然現在有 seq2seq 了,為什麼還需要 reinforcement learning 呢?」,這是一個好問題!由於現今大部分的 seq2seq model 基本上在訓練的時候,會有一個 ground truth sequence 來監督 model 在每個 time step 產生的 token 需要與 ground truth 一致,若是不一致則會有 cross-entropy loss 的 penality(也就是所謂的 "Teacher Forcing")。但在 SQL 語法中,有可能有好幾種不同的 SQL queries 其實是 equivalent 的!

拿以下這兩個功能一樣的 SQL queries 來說:
1. SELECT * FROM mytable WHERE column1 = 'A' AND column2 = 'B'
2. SELECT * FROM mytable WHERE column2 = 'B' AND column1 = 'A'

註:從資料表 mytable 中撈出欄位 column1 等於 'A' 值且欄位 column2 等於 'B' 值的資料。

雖然這兩個 SQL queries 長得不一樣,但由於 WHERE 背後接的條件順序不同並不會導致這兩個 SQL queries 功能不一樣,因此對我們來說,model 不管 output 哪兩個 SQL query 對我們來說都沒差,但如果今天在訓練 model 的 ground truth 是第一個 query 而非第二個 query,那麼當 model 產生第二個 query 時,雖然是正確答案,但也會遭受到 cross-entropy loss 的 penality。

要解決這項問題,他們使用 RL 中的 policy gradient 的訓練方式來獎勵 model (WHERE clause pointer decoder) 生成不一樣但卻 equivalent 的 SQL query(若執行結果與 ground truth query 一致則視為 equivalent);反之,若生成的是 invalid 的 query 或是與 ground truth query 執行結果不一致的話,則被 penalized,如下圖所示:


$y$ 為接在 WHERE 背後的 generated query;$q(y)$ 為整個完整的 generated query;$q_g$ 為 ground truth query。

Experiment

以下為 model 在 WikiSQL 上的 performance 的比較。其中 $Acc_{if}$ 的意思是 "Logical form accuracy",指 model 生成的 SQL query 與 ground truth query 一致;$Acc_{ex}$ 的意思是 "Execution accuracy",指 model 生成的 SQL query 在資料庫執行出來的結果與 ground truth query 在資料庫執行出來的結果一致。"Logical form accuracy" 可衡量為 real performace 的 lower bound,而 "Execution accuracy" 則衡量 real performance 的 upper bound。

DEMO

Prediction 過程

Prediction 結果

Links

有興趣的讀者可在下方連結查看更多細節:
- Blog:在 Footnotes 的部分 reference 了許多背景知識,例如:natural language interfaces, semantic parsing, attentional seq2seq, pointer network 與 policy gradient methods。
- Paper:包含更詳細的實作細節、實驗結果以及實驗分析討論等。
- Dataset (WikiSQL):87,673 筆包含 natural language questions, SQL queries, 與 SQL tables。

Share the joy
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  

近期文章

近期迴響

彙整

分類

其它

Howard Lo Written by:

Be First to Comment

發表迴響

你的電子郵件位址並不會被公開。 必要欄位標記為 *