394 lines
14 KiB
Rust
394 lines
14 KiB
Rust
use sqlparser::ast::{
|
|
BinaryOperator, Expr, Function, FunctionArg, Join, JoinConstraint, JoinOperator,
|
|
Query, Select, SelectItem, SetExpr, Statement, TableWithJoins, Value
|
|
};
|
|
use sqlparser::dialect::PostgreSqlDialect;
|
|
use sqlparser::parser::Parser;
|
|
|
|
pub struct SqlFormatter {
|
|
pub lines: Vec<String>,
|
|
pub indent: usize,
|
|
}
|
|
|
|
impl SqlFormatter {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
lines: Vec::new(),
|
|
indent: 0,
|
|
}
|
|
}
|
|
|
|
pub fn format(sql: &str) -> Vec<String> {
|
|
let dialect = PostgreSqlDialect {};
|
|
let ast = match Parser::parse_sql(&dialect, sql) {
|
|
Ok(ast) => ast,
|
|
Err(e) => {
|
|
println!("DEBUG PARSE SQL ERROR: {:?}", e);
|
|
return vec![sql.to_string()];
|
|
}
|
|
};
|
|
|
|
if ast.is_empty() {
|
|
return vec![sql.to_string()];
|
|
}
|
|
|
|
let mut formatter = SqlFormatter::new();
|
|
formatter.format_statement(&ast[0]);
|
|
formatter.lines
|
|
}
|
|
|
|
fn push_str(&mut self, s: &str) {
|
|
if self.lines.is_empty() {
|
|
self.lines.push(format!("{}{}", " ".repeat(self.indent), s.replace("JSONB", "jsonb")));
|
|
} else {
|
|
let last = self.lines.last_mut().unwrap();
|
|
last.push_str(&s.replace("JSONB", "jsonb"));
|
|
}
|
|
}
|
|
|
|
fn push_line(&mut self, s: &str) {
|
|
self.lines.push(format!("{}{}", " ".repeat(self.indent), s.replace("JSONB", "jsonb")));
|
|
}
|
|
|
|
fn format_statement(&mut self, stmt: &Statement) {
|
|
match stmt {
|
|
Statement::Query(query) => {
|
|
self.push_line("(");
|
|
self.format_query(query);
|
|
self.push_str(")");
|
|
}
|
|
Statement::Update(_update) => {
|
|
let sql = stmt.to_string();
|
|
self.format_update_fallback(&sql);
|
|
}
|
|
_ => {
|
|
let sql = stmt.to_string();
|
|
if sql.starts_with("INSERT") {
|
|
self.format_insert_fallback(&sql);
|
|
} else {
|
|
self.push_line(&sql);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_insert_fallback(&mut self, sql: &str) {
|
|
let s = sql.to_string();
|
|
if let Some(values_idx) = s.find(" VALUES (") {
|
|
let prefix = &s[..values_idx];
|
|
let suffix = &s[values_idx + 9..];
|
|
|
|
if let Some(paren_idx) = prefix.find(" (") {
|
|
self.push_line(&format!("{} (", &prefix[..paren_idx]));
|
|
self.indent += 2;
|
|
let cols = &prefix[paren_idx + 2..prefix.len() - 1];
|
|
let cols_split: Vec<&str> = cols.split(", ").collect();
|
|
for (i, col) in cols_split.iter().enumerate() {
|
|
let comma = if i < cols_split.len() - 1 { "," } else { "" };
|
|
let c = col.replace("\"", "");
|
|
self.push_line(&format!("\"{}\"{}", c, comma));
|
|
}
|
|
self.indent -= 2;
|
|
self.push_line(")");
|
|
} else {
|
|
self.push_line(prefix);
|
|
}
|
|
|
|
self.push_line("VALUES (");
|
|
self.indent += 2;
|
|
|
|
let vals = if suffix.ends_with(")") { &suffix[..suffix.len() - 1] } else { suffix };
|
|
let mut val_tokens = Vec::new();
|
|
let mut curr = String::new();
|
|
let mut in_str = false;
|
|
for c in vals.chars() {
|
|
if c == '\'' {
|
|
in_str = !in_str;
|
|
curr.push(c);
|
|
} else if c == ',' && !in_str {
|
|
val_tokens.push(curr.trim().to_string());
|
|
curr = String::new();
|
|
} else {
|
|
curr.push(c);
|
|
}
|
|
}
|
|
if !curr.trim().is_empty() {
|
|
val_tokens.push(curr.trim().to_string());
|
|
}
|
|
|
|
for (i, val) in val_tokens.iter().enumerate() {
|
|
let comma = if i < val_tokens.len() - 1 { "," } else { "" };
|
|
|
|
if val.starts_with("'{") && val.ends_with("}'") {
|
|
let inner = &val[1..val.len() - 1];
|
|
// Unescape single quotes from SQL strings
|
|
let unescaped = inner.replace("''", "'");
|
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&unescaped) {
|
|
if let Ok(pretty) = serde_json::to_string_pretty(&json) {
|
|
let lines: Vec<&str> = pretty.split('\n').collect();
|
|
self.push_line("'{");
|
|
self.indent += 2;
|
|
for (j, line) in lines.iter().skip(1).enumerate() {
|
|
if j == lines.len() - 2 {
|
|
self.indent -= 2;
|
|
// re-escape single quotes for SQL
|
|
self.push_line(&format!("{}'{}", line.replace("'", "''"), comma));
|
|
} else {
|
|
self.push_line(&line.replace("'", "''"));
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
self.push_line(&format!("{}{}", val, comma));
|
|
}
|
|
self.indent -= 2;
|
|
self.push_line(")");
|
|
} else {
|
|
self.push_line(&s);
|
|
}
|
|
}
|
|
|
|
fn format_update_fallback(&mut self, sql: &str) {
|
|
let s = sql.to_string();
|
|
if let Some(set_idx) = s.find(" SET ") {
|
|
self.push_line(&format!("{} SET", &s[..set_idx]));
|
|
self.indent += 2;
|
|
|
|
let after_set = &s[set_idx + 5..];
|
|
let where_idx = after_set.find(" WHERE ");
|
|
let assigns = if let Some(w) = where_idx { &after_set[..w] } else { after_set };
|
|
let assigns_split: Vec<&str> = assigns.split(", ").collect();
|
|
for (i, assign) in assigns_split.iter().enumerate() {
|
|
let comma = if i < assigns_split.len() - 1 { "," } else { "" };
|
|
self.push_line(&format!("{}{}", assign.replace("\"", ""), comma));
|
|
}
|
|
self.indent -= 2;
|
|
|
|
if let Some(w) = where_idx {
|
|
self.push_line("WHERE");
|
|
self.indent += 2;
|
|
self.push_line(&after_set[w + 7..]);
|
|
self.indent -= 2;
|
|
}
|
|
} else {
|
|
self.push_line(&s);
|
|
}
|
|
}
|
|
|
|
fn format_query(&mut self, query: &Query) {
|
|
match &*query.body {
|
|
SetExpr::Select(select) => self.format_select(select),
|
|
SetExpr::Query(inner_query) => {
|
|
self.push_str("(");
|
|
self.format_query(inner_query);
|
|
self.push_str(")");
|
|
}
|
|
_ => self.push_str(&query.to_string()),
|
|
}
|
|
}
|
|
|
|
fn format_select(&mut self, select: &Select) {
|
|
self.push_str("SELECT ");
|
|
for (i, p) in select.projection.iter().enumerate() {
|
|
let comma = if i < select.projection.len() - 1 { ", " } else { "" };
|
|
self.format_select_item(p);
|
|
self.push_str(comma);
|
|
}
|
|
|
|
if !select.from.is_empty() {
|
|
self.push_line("FROM ");
|
|
for (i, table) in select.from.iter().enumerate() {
|
|
let comma = if i < select.from.len() - 1 { ", " } else { "" };
|
|
self.format_table_with_joins(table);
|
|
self.push_str(comma);
|
|
}
|
|
|
|
if let Some(selection) = &select.selection {
|
|
self.push_line("WHERE");
|
|
self.indent += 2;
|
|
self.push_line(""); // new line for where clauses
|
|
self.format_expr(selection);
|
|
self.indent -= 2;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_select_item(&mut self, item: &SelectItem) {
|
|
match item {
|
|
SelectItem::UnnamedExpr(expr) => self.format_expr(expr),
|
|
SelectItem::ExprWithAlias { expr, alias } => {
|
|
self.format_expr(expr);
|
|
self.push_str(&format!(" AS {}", alias));
|
|
}
|
|
_ => self.push_str(&item.to_string()),
|
|
}
|
|
}
|
|
|
|
fn format_table_with_joins(&mut self, table: &TableWithJoins) {
|
|
self.push_str(&table.relation.to_string());
|
|
for join in &table.joins {
|
|
self.push_line("");
|
|
self.format_join(join);
|
|
}
|
|
}
|
|
|
|
fn format_join(&mut self, join: &Join) {
|
|
let op = match &join.join_operator {
|
|
JoinOperator::Inner(_) => "JOIN",
|
|
JoinOperator::LeftOuter(_) => "LEFT JOIN",
|
|
_ => "JOIN",
|
|
};
|
|
self.push_str(&format!("{} {} ON ", op, join.relation));
|
|
|
|
match &join.join_operator {
|
|
JoinOperator::Inner(JoinConstraint::On(expr)) => self.format_expr(expr),
|
|
JoinOperator::LeftOuter(JoinConstraint::On(expr)) => self.format_expr(expr),
|
|
JoinOperator::Join(JoinConstraint::On(expr)) => self.format_expr(expr),
|
|
_ => {
|
|
println!("FALLBACK JOIN OP: {:?}", join.join_operator);
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_expr(&mut self, expr: &Expr) {
|
|
match expr {
|
|
Expr::Function(func) => self.format_function(func),
|
|
Expr::BinaryOp { left, op, right } => {
|
|
if *op == BinaryOperator::And || *op == BinaryOperator::Or {
|
|
self.format_expr(left);
|
|
self.push_line(&format!("{} ", op));
|
|
self.format_expr(right);
|
|
} else {
|
|
self.format_expr(left);
|
|
self.push_str(&format!(" {} ", op));
|
|
self.format_expr(right);
|
|
}
|
|
}
|
|
Expr::Nested(inner) => {
|
|
self.push_str("(");
|
|
self.format_expr(inner);
|
|
self.push_str(")");
|
|
}
|
|
Expr::IsNull(inner) => {
|
|
self.format_expr(inner);
|
|
self.push_str(" IS NULL");
|
|
}
|
|
Expr::IsNotNull(inner) => {
|
|
self.format_expr(inner);
|
|
self.push_str(" IS NOT NULL");
|
|
}
|
|
Expr::Subquery(query) => {
|
|
self.push_str("(");
|
|
self.indent += 2;
|
|
self.push_line("");
|
|
self.format_query(query);
|
|
self.indent -= 2;
|
|
self.push_line(")");
|
|
}
|
|
Expr::Case { operand, conditions, else_result, .. } => {
|
|
self.push_str("CASE");
|
|
if let Some(op) = operand {
|
|
self.push_str(" ");
|
|
self.format_expr(op);
|
|
}
|
|
self.indent += 2;
|
|
for when in conditions {
|
|
self.push_line("WHEN ");
|
|
self.format_expr(&when.condition);
|
|
self.push_str(" THEN ");
|
|
self.format_expr(&when.result);
|
|
}
|
|
if let Some(els) = else_result {
|
|
self.push_line("ELSE ");
|
|
self.format_expr(els);
|
|
}
|
|
self.indent -= 2;
|
|
self.push_line("END");
|
|
}
|
|
Expr::UnaryOp { op, expr: inner } => {
|
|
self.push_str(&format!("{} ", op));
|
|
self.format_expr(inner);
|
|
}
|
|
|
|
Expr::Value(sqlparser::ast::ValueWithSpan { value: Value::SingleQuotedString(s), .. }) | Expr::Value(sqlparser::ast::ValueWithSpan { value: Value::EscapedStringLiteral(s), .. }) => {
|
|
if s.starts_with('{') && s.ends_with('}') {
|
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(s) {
|
|
if let Ok(pretty) = serde_json::to_string_pretty(&json) {
|
|
let lines: Vec<&str> = pretty.split('\n').collect();
|
|
self.push_str("'{");
|
|
self.indent += 2;
|
|
for (j, line) in lines.iter().skip(1).enumerate() {
|
|
if j == lines.len() - 2 {
|
|
self.indent -= 2;
|
|
self.push_line(&format!("{}'", line.replace("'", "''")));
|
|
} else {
|
|
self.push_line(&line.replace("'", "''"));
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
self.push_str(&expr.to_string());
|
|
}
|
|
_ => {
|
|
self.push_str(&expr.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_function(&mut self, func: &Function) {
|
|
let name = func.name.to_string();
|
|
self.push_str(&format!("{}(", name));
|
|
|
|
if let sqlparser::ast::FunctionArguments::List(list) = &func.args {
|
|
if name == "jsonb_build_object" {
|
|
self.indent += 2;
|
|
self.push_line("");
|
|
let mut i = 0;
|
|
while i < list.args.len() {
|
|
let arg_key = &list.args[i];
|
|
let arg_val = if i + 1 < list.args.len() { Some(&list.args[i+1]) } else { None };
|
|
|
|
self.format_function_arg(arg_key);
|
|
self.push_str(", ");
|
|
if let Some(val) = arg_val {
|
|
self.format_function_arg(val);
|
|
}
|
|
|
|
if i + 2 < list.args.len() {
|
|
self.push_str(",");
|
|
self.push_line("");
|
|
}
|
|
i += 2;
|
|
}
|
|
self.indent -= 2;
|
|
self.push_line(")");
|
|
} else {
|
|
for (i, arg) in list.args.iter().enumerate() {
|
|
let comma = if i < list.args.len() - 1 { ", " } else { "" };
|
|
self.format_function_arg(arg);
|
|
self.push_str(comma);
|
|
}
|
|
self.push_str(")");
|
|
}
|
|
} else {
|
|
self.push_str(")");
|
|
}
|
|
}
|
|
|
|
fn format_function_arg(&mut self, arg: &FunctionArg) {
|
|
match arg {
|
|
FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(expr)) => self.format_expr(expr),
|
|
_ => {
|
|
println!("FALLBACK ARG: {:?}", arg);
|
|
self.push_str(&arg.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|