Files
jspg/format_sql.py

112 lines
4.2 KiB
Python

import json
import re
with open("fixtures/merger.json", "r") as f:
data = json.load(f)
test_case = next(t for t in data[0]["tests"] if t["description"] == "Test organization_id syntactic sugar permutations")
def format_sql(sql_str):
if sql_str.startswith("INSERT INTO"):
parts = sql_str.split(" VALUES ")
insert_part = parts[0]
values_part = parts[1]
insert_match = re.match(r"(INSERT INTO [a-zA-Z0-9_.\"]+) \((.*)\)", insert_part)
table = insert_match.group(1)
cols_str = insert_match.group(2)
cols = [c.strip() for c in cols_str.split(",")]
values_str = values_part[1:-1]
# We need to split values_str carefully, as JSON strings contain commas!
# Since it's single quotes around values, we can split by ", " but that's risky.
# Let's do a simple parse:
vals = []
current_val = []
in_quote = False
i = 0
while i < len(values_str):
c = values_str[i]
if c == "'":
# handle double quotes inside? Postgres uses '' for escaping ' inside '.
# Here we don't have that complexity.
in_quote = not in_quote
current_val.append(c)
elif c == ',' and not in_quote:
vals.append("".join(current_val).strip())
current_val = []
else:
current_val.append(c)
i += 1
vals.append("".join(current_val).strip())
lines = [f"{table} ("]
for i, col in enumerate(cols):
lines.append(f" {col}" + ("," if i < len(cols) - 1 else ""))
lines.append(")")
lines.append("VALUES (")
for i, val in enumerate(vals):
if val.startswith("'{") and val.endswith("}'"):
# Format JSON
lines.append(" '{")
json_str = val[2:-2]
# Split json keys by ",
json_pairs = json_str.split(',"')
for j, pair in enumerate(json_pairs):
if j > 0:
pair = '"' + pair
lines.append(f" {pair}" + ("," if j < len(json_pairs) - 1 else ""))
lines.append(" }'" + ("," if i < len(vals) - 1 else ""))
else:
# Replace '{{uuid}}' with '00000000-0000-0000-0000-000000000000' for created_by etc if it was replaced as '{{uuid}}'
if val == "'{{uuid}}'" and cols[i] in ['"created_by"', '"modified_by"', 'modified_by']:
val = "'00000000-0000-0000-0000-000000000000'"
lines.append(f" {val}" + ("," if i < len(vals) - 1 else ""))
lines.append(")")
return lines
elif sql_str.startswith("SELECT pg_notify"):
# Format notify string
match = re.match(r"SELECT pg_notify\('entity', '(.*)'\)", sql_str)
payload = match.group(1)
# We know payload looks like {"complete":{...},"new":{...}}
lines = ["SELECT pg_notify('entity', '{"]
# split complete and new
complete_str = payload[payload.find('"complete":{')+12:payload.find('},"new":{')]
new_str = payload[payload.find('"new":{')+7:-2]
lines.append(" \"complete\":{")
complete_pairs = complete_str.split(',"')
for j, pair in enumerate(complete_pairs):
if j > 0:
pair = '"' + pair
lines.append(f" {pair}" + ("," if j < len(complete_pairs) - 1 else ""))
lines.append(" },")
lines.append(" \"new\":{")
new_pairs = new_str.split(',"')
for j, pair in enumerate(new_pairs):
if j > 0:
pair = '"' + pair
lines.append(f" {pair}" + ("," if j < len(new_pairs) - 1 else ""))
lines.append(" }")
lines.append(" }')")
return lines
return [sql_str]
new_sql = []
for sql_group in test_case["expect"]["sql"]:
sql_str = "".join(sql_group)
formatted = format_sql(sql_str)
new_sql.append(formatted)
test_case["expect"]["sql"] = new_sql
with open("fixtures/merger.json", "w") as f:
json.dump(data, f, indent=2)