use sqlparser::ast::{ BinaryOperator, Expr, Function, FunctionArg, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem, SetExpr, Statement, TableWithJoins, Value }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; pub struct SqlFormatter { pub lines: Vec, pub indent: usize, } impl SqlFormatter { pub fn new() -> Self { Self { lines: Vec::new(), indent: 0, } } pub fn format(sql: &str) -> Vec { let dialect = PostgreSqlDialect {}; let ast = match Parser::parse_sql(&dialect, sql) { Ok(ast) => ast, Err(e) => { println!("DEBUG PARSE SQL ERROR: {:?}", e); return vec![sql.to_string()]; } }; if ast.is_empty() { return vec![sql.to_string()]; } let mut formatter = SqlFormatter::new(); formatter.format_statement(&ast[0]); formatter.lines } fn push_str(&mut self, s: &str) { if self.lines.is_empty() { self.lines.push(format!("{}{}", " ".repeat(self.indent), s.replace("JSONB", "jsonb"))); } else { let last = self.lines.last_mut().unwrap(); last.push_str(&s.replace("JSONB", "jsonb")); } } fn push_line(&mut self, s: &str) { self.lines.push(format!("{}{}", " ".repeat(self.indent), s.replace("JSONB", "jsonb"))); } fn format_statement(&mut self, stmt: &Statement) { match stmt { Statement::Query(query) => { self.push_line("("); self.format_query(query); self.push_str(")"); } Statement::Update(_update) => { let sql = stmt.to_string(); self.format_update_fallback(&sql); } _ => { let sql = stmt.to_string(); if sql.starts_with("INSERT") { self.format_insert_fallback(&sql); } else { self.push_line(&sql); } } } } fn format_insert_fallback(&mut self, sql: &str) { let s = sql.to_string(); if let Some(values_idx) = s.find(" VALUES (") { let prefix = &s[..values_idx]; let suffix = &s[values_idx + 9..]; if let Some(paren_idx) = prefix.find(" (") { self.push_line(&format!("{} (", &prefix[..paren_idx])); self.indent += 2; let cols = &prefix[paren_idx + 2..prefix.len() - 1]; let cols_split: Vec<&str> = cols.split(", ").collect(); for (i, col) in cols_split.iter().enumerate() { let comma = if i < cols_split.len() - 1 { "," } else { "" }; let c = col.replace("\"", ""); self.push_line(&format!("\"{}\"{}", c, comma)); } self.indent -= 2; self.push_line(")"); } else { self.push_line(prefix); } self.push_line("VALUES ("); self.indent += 2; let vals = if suffix.ends_with(")") { &suffix[..suffix.len() - 1] } else { suffix }; let mut val_tokens = Vec::new(); let mut curr = String::new(); let mut in_str = false; for c in vals.chars() { if c == '\'' { in_str = !in_str; curr.push(c); } else if c == ',' && !in_str { val_tokens.push(curr.trim().to_string()); curr = String::new(); } else { curr.push(c); } } if !curr.trim().is_empty() { val_tokens.push(curr.trim().to_string()); } for (i, val) in val_tokens.iter().enumerate() { let comma = if i < val_tokens.len() - 1 { "," } else { "" }; if val.starts_with("'{") && val.ends_with("}'") { let inner = &val[1..val.len() - 1]; // Unescape single quotes from SQL strings let unescaped = inner.replace("''", "'"); if let Ok(json) = serde_json::from_str::(&unescaped) { if let Ok(pretty) = serde_json::to_string_pretty(&json) { let lines: Vec<&str> = pretty.split('\n').collect(); self.push_line("'{"); self.indent += 2; for (j, line) in lines.iter().skip(1).enumerate() { if j == lines.len() - 2 { self.indent -= 2; // re-escape single quotes for SQL self.push_line(&format!("{}'{}", line.replace("'", "''"), comma)); } else { self.push_line(&line.replace("'", "''")); } } continue; } } } self.push_line(&format!("{}{}", val, comma)); } self.indent -= 2; self.push_line(")"); } else { self.push_line(&s); } } fn format_update_fallback(&mut self, sql: &str) { let s = sql.to_string(); if let Some(set_idx) = s.find(" SET ") { self.push_line(&format!("{} SET", &s[..set_idx])); self.indent += 2; let after_set = &s[set_idx + 5..]; let where_idx = after_set.find(" WHERE "); let assigns = if let Some(w) = where_idx { &after_set[..w] } else { after_set }; let assigns_split: Vec<&str> = assigns.split(", ").collect(); for (i, assign) in assigns_split.iter().enumerate() { let comma = if i < assigns_split.len() - 1 { "," } else { "" }; self.push_line(&format!("{}{}", assign.replace("\"", ""), comma)); } self.indent -= 2; if let Some(w) = where_idx { self.push_line("WHERE"); self.indent += 2; self.push_line(&after_set[w + 7..]); self.indent -= 2; } } else { self.push_line(&s); } } fn format_query(&mut self, query: &Query) { match &*query.body { SetExpr::Select(select) => self.format_select(select), SetExpr::Query(inner_query) => { self.push_str("("); self.format_query(inner_query); self.push_str(")"); } _ => self.push_str(&query.to_string()), } } fn format_select(&mut self, select: &Select) { self.push_str("SELECT "); for (i, p) in select.projection.iter().enumerate() { let comma = if i < select.projection.len() - 1 { ", " } else { "" }; self.format_select_item(p); self.push_str(comma); } if !select.from.is_empty() { self.push_line("FROM "); for (i, table) in select.from.iter().enumerate() { let comma = if i < select.from.len() - 1 { ", " } else { "" }; self.format_table_with_joins(table); self.push_str(comma); } if let Some(selection) = &select.selection { self.push_line("WHERE"); self.indent += 2; self.push_line(""); // new line for where clauses self.format_expr(selection); self.indent -= 2; } } } fn format_select_item(&mut self, item: &SelectItem) { match item { SelectItem::UnnamedExpr(expr) => self.format_expr(expr), SelectItem::ExprWithAlias { expr, alias } => { self.format_expr(expr); self.push_str(&format!(" AS {}", alias)); } _ => self.push_str(&item.to_string()), } } fn format_table_with_joins(&mut self, table: &TableWithJoins) { self.push_str(&table.relation.to_string()); for join in &table.joins { self.push_line(""); self.format_join(join); } } fn format_join(&mut self, join: &Join) { let op = match &join.join_operator { JoinOperator::Inner(_) => "JOIN", JoinOperator::LeftOuter(_) => "LEFT JOIN", _ => "JOIN", }; self.push_str(&format!("{} {} ON ", op, join.relation)); match &join.join_operator { JoinOperator::Inner(JoinConstraint::On(expr)) => self.format_expr(expr), JoinOperator::LeftOuter(JoinConstraint::On(expr)) => self.format_expr(expr), JoinOperator::Join(JoinConstraint::On(expr)) => self.format_expr(expr), _ => { println!("FALLBACK JOIN OP: {:?}", join.join_operator); } } } fn format_expr(&mut self, expr: &Expr) { match expr { Expr::Function(func) => self.format_function(func), Expr::BinaryOp { left, op, right } => { if *op == BinaryOperator::And || *op == BinaryOperator::Or { self.format_expr(left); self.push_line(&format!("{} ", op)); self.format_expr(right); } else { self.format_expr(left); self.push_str(&format!(" {} ", op)); self.format_expr(right); } } Expr::Nested(inner) => { self.push_str("("); self.format_expr(inner); self.push_str(")"); } Expr::IsNull(inner) => { self.format_expr(inner); self.push_str(" IS NULL"); } Expr::IsNotNull(inner) => { self.format_expr(inner); self.push_str(" IS NOT NULL"); } Expr::Subquery(query) => { self.push_str("("); self.indent += 2; self.push_line(""); self.format_query(query); self.indent -= 2; self.push_line(")"); } Expr::Case { operand, conditions, else_result, .. } => { self.push_str("CASE"); if let Some(op) = operand { self.push_str(" "); self.format_expr(op); } self.indent += 2; for when in conditions { self.push_line("WHEN "); self.format_expr(&when.condition); self.push_str(" THEN "); self.format_expr(&when.result); } if let Some(els) = else_result { self.push_line("ELSE "); self.format_expr(els); } self.indent -= 2; self.push_line("END"); } Expr::UnaryOp { op, expr: inner } => { self.push_str(&format!("{} ", op)); self.format_expr(inner); } Expr::Value(sqlparser::ast::ValueWithSpan { value: Value::SingleQuotedString(s), .. }) | Expr::Value(sqlparser::ast::ValueWithSpan { value: Value::EscapedStringLiteral(s), .. }) => { if s.starts_with('{') && s.ends_with('}') { if let Ok(json) = serde_json::from_str::(s) { if let Ok(pretty) = serde_json::to_string_pretty(&json) { let lines: Vec<&str> = pretty.split('\n').collect(); self.push_str("'{"); self.indent += 2; for (j, line) in lines.iter().skip(1).enumerate() { if j == lines.len() - 2 { self.indent -= 2; self.push_line(&format!("{}'", line.replace("'", "''"))); } else { self.push_line(&line.replace("'", "''")); } } return; } } } self.push_str(&expr.to_string()); } _ => { self.push_str(&expr.to_string()); } } } fn format_function(&mut self, func: &Function) { let name = func.name.to_string(); self.push_str(&format!("{}(", name)); if let sqlparser::ast::FunctionArguments::List(list) = &func.args { if name == "jsonb_build_object" { self.indent += 2; self.push_line(""); let mut i = 0; while i < list.args.len() { let arg_key = &list.args[i]; let arg_val = if i + 1 < list.args.len() { Some(&list.args[i+1]) } else { None }; self.format_function_arg(arg_key); self.push_str(", "); if let Some(val) = arg_val { self.format_function_arg(val); } if i + 2 < list.args.len() { self.push_str(","); self.push_line(""); } i += 2; } self.indent -= 2; self.push_line(")"); } else { for (i, arg) in list.args.iter().enumerate() { let comma = if i < list.args.len() - 1 { ", " } else { "" }; self.format_function_arg(arg); self.push_str(comma); } self.push_str(")"); } } else { self.push_str(")"); } } fn format_function_arg(&mut self, arg: &FunctionArg) { match arg { FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(expr)) => self.format_expr(expr), _ => { println!("FALLBACK ARG: {:?}", arg); self.push_str(&arg.to_string()); } } } }