diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..61ed5d4 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "rust-lang.rust-analyzer" + ] +} \ No newline at end of file diff --git a/src/database/mod.rs b/src/database/mod.rs index b5446a6..92b751e 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -230,43 +230,42 @@ impl Database { fn collect_relations(&mut self, raw_relations: Vec) { let mut edges: HashMap<(String, String), Vec> = HashMap::new(); - + // For every relation, map it across all polymorphic inheritance permutations for relation in raw_relations { - if let Some(source_type_def) = self.types.get(&relation.source_type) { - if let Some(dest_type_def) = self.types.get(&relation.destination_type) { - + if let Some(_source_type_def) = self.types.get(&relation.source_type) { + if let Some(_dest_type_def) = self.types.get(&relation.destination_type) { let mut src_descendants = Vec::new(); let mut dest_descendants = Vec::new(); for (t_name, t_def) in &self.types { - if t_def.hierarchy.contains(&relation.source_type) { - src_descendants.push(t_name.clone()); - } - if t_def.hierarchy.contains(&relation.destination_type) { - dest_descendants.push(t_name.clone()); - } + if t_def.hierarchy.contains(&relation.source_type) { + src_descendants.push(t_name.clone()); + } + if t_def.hierarchy.contains(&relation.destination_type) { + dest_descendants.push(t_name.clone()); + } } for p_type in &src_descendants { for c_type in &dest_descendants { // Ignore entity <-> entity generic fallbacks, they aren't useful edges if p_type == "entity" && c_type == "entity" { - continue; + continue; } - + // Forward edge edges .entry((p_type.clone(), c_type.clone())) .or_default() .push(relation.clone()); - + // Reverse edge (only if types are different to avoid duplicating self-referential edges like activity parent_id) if p_type != c_type { - edges - .entry((c_type.clone(), p_type.clone())) - .or_default() - .push(relation.clone()); + edges + .entry((c_type.clone(), p_type.clone())) + .or_default() + .push(relation.clone()); } } } @@ -277,17 +276,20 @@ impl Database { } pub fn get_relation( - &self, - parent_type: &str, - child_type: &str, + &self, + parent_type: &str, + child_type: &str, prop_name: &str, - relative_keys: Option<&Vec> + relative_keys: Option<&Vec>, ) -> Option<&Relation> { - if let Some(relations) = self.relations.get(&(parent_type.to_string(), child_type.to_string())) { + if let Some(relations) = self + .relations + .get(&(parent_type.to_string(), child_type.to_string())) + { if relations.len() == 1 { return Some(&relations[0]); } - + // Reduce ambiguity with prefix for rel in relations { if let Some(prefix) = &rel.prefix { @@ -302,13 +304,13 @@ impl Database { let mut missing_prefix_rels = Vec::new(); for rel in relations { if let Some(prefix) = &rel.prefix { - if !keys.contains(prefix) { - missing_prefix_rels.push(rel); - } + if !keys.contains(prefix) { + missing_prefix_rels.push(rel); + } } } if missing_prefix_rels.len() == 1 { - return Some(missing_prefix_rels[0]); + return Some(missing_prefix_rels[0]); } } } @@ -424,14 +426,14 @@ impl Database { if let (Some(pt), Some(prop)) = (&parent_type, &property_name) { let expected_col = format!("{}_id", prop); let mut found = false; - + if let Some(rel) = db.get_relation(pt, &entity_type, prop, None) { if rel.source_columns.contains(&expected_col) { relation_col = Some(expected_col.clone()); found = true; } } - + if !found { relation_col = Some(expected_col); } diff --git a/src/merger/mod.rs b/src/merger/mod.rs index a6fd731..1f25b4a 100644 --- a/src/merger/mod.rs +++ b/src/merger/mod.rs @@ -23,7 +23,7 @@ impl Merger { pub fn merge(&self, data: Value) -> crate::drop::Drop { let mut val_resolved = Value::Null; let mut notifications_queue = Vec::new(); - + let result = self.merge_internal(data, &mut notifications_queue); match result { @@ -44,7 +44,7 @@ impl Merger { } }; - // Execute the globally collected, pre-ordered notifications last! + // Execute the globally collected, pre-ordered notifications last! for notify_sql in notifications_queue { if let Err(e) = self.db.execute(¬ify_sql, None) { return crate::drop::Drop::with_errors(vec![crate::drop::Error { @@ -88,7 +88,11 @@ impl Merger { crate::drop::Drop::success_with_val(stripped_val) } - pub(crate) fn merge_internal(&self, data: Value, notifications: &mut Vec) -> Result { + pub(crate) fn merge_internal( + &self, + data: Value, + notifications: &mut Vec, + ) -> Result { match data { Value::Array(items) => self.merge_array(items, notifications), Value::Object(map) => self.merge_object(map, notifications), @@ -96,7 +100,11 @@ impl Merger { } } - fn merge_array(&self, items: Vec, notifications: &mut Vec) -> Result { + fn merge_array( + &self, + items: Vec, + notifications: &mut Vec, + ) -> Result { let mut resolved_items = Vec::new(); for item in items { let resolved = self.merge_internal(item, notifications)?; @@ -105,7 +113,11 @@ impl Merger { Ok(Value::Array(resolved_items)) } - fn merge_object(&self, obj: serde_json::Map, notifications: &mut Vec) -> Result { + fn merge_object( + &self, + obj: serde_json::Map, + notifications: &mut Vec, + ) -> Result { let queue_start = notifications.len(); let type_name = match obj.get("type").and_then(|v| v.as_str()) { @@ -173,7 +185,12 @@ impl Merger { let relative_keys: Vec = relative.keys().cloned().collect(); // Call central Database O(1) graph logic - let relative_relation = self.db.get_relation(&type_def.name, relative_type_name, &relation_name, Some(&relative_keys)); + let relative_relation = self.db.get_relation( + &type_def.name, + relative_type_name, + &relation_name, + Some(&relative_keys), + ); if let Some(relation) = relative_relation { let parent_is_source = type_def.hierarchy.contains(&relation.source_type); @@ -271,7 +288,12 @@ impl Merger { let relative_keys: Vec = first_relative.keys().cloned().collect(); // Call central Database O(1) graph logic - let relative_relation = self.db.get_relation(&type_def.name, relative_type_name, &relation_name, Some(&relative_keys)); + let relative_relation = self.db.get_relation( + &type_def.name, + relative_type_name, + &relation_name, + Some(&relative_keys), + ); if let Some(relation) = relative_relation { let mut relative_responses = Vec::new(); @@ -290,10 +312,11 @@ impl Merger { &entity_fields, ); - let merged_relative = match self.merge_internal(Value::Object(relative_item), notifications)? { - Value::Object(m) => m, - _ => continue, - }; + let merged_relative = + match self.merge_internal(Value::Object(relative_item), notifications)? { + Value::Object(m) => m, + _ => continue, + }; relative_responses.push(Value::Object(merged_relative)); } diff --git a/src/queryer/compiler.rs b/src/queryer/compiler.rs index 7ab7ad4..d57dad5 100644 --- a/src/queryer/compiler.rs +++ b/src/queryer/compiler.rs @@ -47,7 +47,19 @@ impl SqlCompiler { // We expect the top level to typically be an Object or Array let is_stem_query = stem_path.is_some(); - let (sql, _) = self.walk_schema(target_schema, "t1", None, None, filter_keys, is_stem_query, 0, String::new())?; + let mut alias_counter: usize = 0; + let (sql, _) = self.walk_schema( + target_schema, + "t1", + None, + None, + None, + filter_keys, + is_stem_query, + 0, + String::new(), + &mut alias_counter, + )?; Ok(sql) } @@ -57,12 +69,14 @@ impl SqlCompiler { &self, schema: &crate::database::schema::Schema, parent_alias: &str, + parent_table_aliases: Option<&std::collections::HashMap>, parent_type_def: Option<&crate::database::r#type::Type>, prop_name_context: Option<&str>, filter_keys: &[String], is_stem_query: bool, depth: usize, current_path: String, + alias_counter: &mut usize, ) -> Result<(String, String), String> { // Determine the base schema type (could be an array, object, or literal) match &schema.obj.type_ { @@ -81,6 +95,7 @@ impl SqlCompiler { items, type_def, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, true, @@ -88,18 +103,21 @@ impl SqlCompiler { is_stem_query, depth, next_path, + alias_counter, ); } } let (item_sql, _) = self.walk_schema( items, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, filter_keys, is_stem_query, depth + 1, next_path, + alias_counter, )?; return Ok(( format!("(SELECT jsonb_agg({}) FROM TODO)", item_sql), @@ -128,6 +146,7 @@ impl SqlCompiler { schema, type_def, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, false, @@ -135,6 +154,7 @@ impl SqlCompiler { is_stem_query, depth, current_path, + alias_counter, ); } @@ -145,12 +165,14 @@ impl SqlCompiler { return self.walk_schema( target_schema, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, filter_keys, is_stem_query, depth, current_path, + alias_counter, ); } return Err(format!("Unresolved $ref: {}", ref_id)); @@ -174,12 +196,14 @@ impl SqlCompiler { return self.compile_one_of( &family_schemas, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, filter_keys, is_stem_query, depth, current_path, + alias_counter, ); } @@ -188,12 +212,14 @@ impl SqlCompiler { return self.compile_one_of( one_of, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, filter_keys, is_stem_query, depth, current_path, + alias_counter, ); } @@ -202,11 +228,13 @@ impl SqlCompiler { return self.compile_inline_object( props, parent_alias, + parent_table_aliases, parent_type_def, filter_keys, is_stem_query, depth, current_path, + alias_counter, ); } @@ -249,6 +277,7 @@ impl SqlCompiler { schema: &crate::database::schema::Schema, type_def: &crate::database::r#type::Type, parent_alias: &str, + parent_table_aliases: Option<&std::collections::HashMap>, parent_type_def: Option<&crate::database::r#type::Type>, prop_name: Option<&str>, is_array: bool, @@ -256,11 +285,10 @@ impl SqlCompiler { is_stem_query: bool, depth: usize, current_path: String, + alias_counter: &mut usize, ) -> Result<(String, String), String> { - let local_ctx = format!("{}_{}", parent_alias, prop_name.unwrap_or("obj")); - // 1. Build FROM clauses and table aliases - let (table_aliases, from_clauses) = self.build_hierarchy_from_clauses(type_def, &local_ctx); + let (table_aliases, from_clauses) = self.build_hierarchy_from_clauses(type_def, alias_counter); // 2. Map properties and build jsonb_build_object args let mut select_args = self.map_properties_to_aliases( @@ -272,6 +300,7 @@ impl SqlCompiler { is_stem_query, depth, ¤t_path, + alias_counter, )?; // 2.5 Inject polymorphism directly into the query object @@ -281,10 +310,10 @@ impl SqlCompiler { let mut sorted_targets: Vec = base_type.variations.iter().cloned().collect(); // Ensure the base type is included if not listed in variations by default if !sorted_targets.contains(family_target) { - sorted_targets.push(family_target.clone()); + sorted_targets.push(family_target.clone()); } sorted_targets.sort(); - + for target in sorted_targets { let mut ref_schema = crate::database::schema::Schema::default(); ref_schema.obj.r#ref = Some(target); @@ -297,14 +326,42 @@ impl SqlCompiler { family_schemas.push(std::sync::Arc::new(ref_schema)); } - let base_alias = table_aliases.get(&type_def.name).cloned().unwrap_or_else(|| parent_alias.to_string()); + let base_alias = table_aliases + .get(&type_def.name) + .cloned() + .unwrap_or_else(|| parent_alias.to_string()); select_args.push(format!("'id', {}.id", base_alias)); - let (case_sql, _) = self.compile_one_of(&family_schemas, &base_alias, parent_type_def, None, filter_keys, is_stem_query, depth, current_path.clone())?; + let (case_sql, _) = self.compile_one_of( + &family_schemas, + &base_alias, + Some(&table_aliases), + parent_type_def, + None, + filter_keys, + is_stem_query, + depth, + current_path.clone(), + alias_counter, + )?; select_args.push(format!("'type', {}", case_sql)); } else if let Some(one_of) = &schema.obj.one_of { - let base_alias = table_aliases.get(&type_def.name).cloned().unwrap_or_else(|| parent_alias.to_string()); + let base_alias = table_aliases + .get(&type_def.name) + .cloned() + .unwrap_or_else(|| parent_alias.to_string()); select_args.push(format!("'id', {}.id", base_alias)); - let (case_sql, _) = self.compile_one_of(one_of, &base_alias, parent_type_def, None, filter_keys, is_stem_query, depth, current_path.clone())?; + let (case_sql, _) = self.compile_one_of( + one_of, + &base_alias, + Some(&table_aliases), + parent_type_def, + None, + filter_keys, + is_stem_query, + depth, + current_path.clone(), + alias_counter, + )?; select_args.push(format!("'type', {}", case_sql)); } @@ -320,6 +377,7 @@ impl SqlCompiler { type_def, &table_aliases, parent_alias, + parent_table_aliases, parent_type_def, prop_name, filter_keys, @@ -352,19 +410,20 @@ impl SqlCompiler { fn build_hierarchy_from_clauses( &self, type_def: &crate::database::r#type::Type, - local_ctx: &str, + alias_counter: &mut usize, ) -> (std::collections::HashMap, Vec) { let mut table_aliases = std::collections::HashMap::new(); let mut from_clauses = Vec::new(); for (i, table_name) in type_def.hierarchy.iter().enumerate() { - let alias = format!("{}_t{}", local_ctx, i + 1); + *alias_counter += 1; + let alias = format!("{}_{}", table_name, alias_counter); table_aliases.insert(table_name.clone(), alias.clone()); if i == 0 { from_clauses.push(format!("agreego.{} {}", table_name, alias)); } else { - let prev_alias = format!("{}_t{}", local_ctx, i); + let prev_alias = format!("{}_{}", type_def.hierarchy[i - 1], *alias_counter - 1); from_clauses.push(format!( "JOIN agreego.{} {} ON {}.id = {}.id", table_name, alias, alias, prev_alias @@ -384,6 +443,7 @@ impl SqlCompiler { is_stem_query: bool, depth: usize, current_path: &str, + alias_counter: &mut usize, ) -> Result, String> { let mut select_args = Vec::new(); let grouped_fields = type_def.grouped_fields.as_ref().and_then(|v| v.as_object()); @@ -410,16 +470,20 @@ impl SqlCompiler { } let is_object_or_array = match &prop_schema.obj.type_ { - Some(crate::database::schema::SchemaTypeOrArray::Single(s)) => s == "object" || s == "array", - Some(crate::database::schema::SchemaTypeOrArray::Multiple(v)) => v.contains(&"object".to_string()) || v.contains(&"array".to_string()), - _ => false + Some(crate::database::schema::SchemaTypeOrArray::Single(s)) => { + s == "object" || s == "array" + } + Some(crate::database::schema::SchemaTypeOrArray::Multiple(v)) => { + v.contains(&"object".to_string()) || v.contains(&"array".to_string()) + } + _ => false, }; - let is_primitive = prop_schema.obj.r#ref.is_none() - && prop_schema.obj.items.is_none() - && prop_schema.obj.properties.is_none() - && prop_schema.obj.one_of.is_none() - && !is_object_or_array; + let is_primitive = prop_schema.obj.r#ref.is_none() + && prop_schema.obj.items.is_none() + && prop_schema.obj.properties.is_none() + && prop_schema.obj.one_of.is_none() + && !is_object_or_array; if is_primitive { if let Some(ft) = type_def.field_types.as_ref().and_then(|v| v.as_object()) { @@ -438,12 +502,14 @@ impl SqlCompiler { let (val_sql, val_type) = self.walk_schema( prop_schema, &owner_alias, + Some(table_aliases), Some(type_def), // Pass current type_def as parent_type_def for child properties Some(prop_key), filter_keys, is_stem_query, depth + 1, next_path, + alias_counter, )?; if val_type != "abort" { @@ -459,6 +525,7 @@ impl SqlCompiler { type_def: &crate::database::r#type::Type, table_aliases: &std::collections::HashMap, parent_alias: &str, + parent_table_aliases: Option<&std::collections::HashMap>, parent_type_def: Option<&crate::database::r#type::Type>, prop_name: Option<&str>, filter_keys: &[String], @@ -503,129 +570,151 @@ impl SqlCompiler { let mut filter_alias = base_alias.clone(); if let Some(gf) = type_def.grouped_fields.as_ref().and_then(|v| v.as_object()) { - for (t_name, fields_val) in gf { - if let Some(fields_arr) = fields_val.as_array() { - if fields_arr.iter().any(|v| v.as_str() == Some(field_name)) { - filter_alias = table_aliases - .get(t_name) - .cloned() - .unwrap_or_else(|| base_alias.clone()); - break; + for (t_name, fields_val) in gf { + if let Some(fields_arr) = fields_val.as_array() { + if fields_arr.iter().any(|v| v.as_str() == Some(field_name)) { + filter_alias = table_aliases + .get(t_name) + .cloned() + .unwrap_or_else(|| base_alias.clone()); + break; + } + } + } + } + + let mut is_ilike = false; + let mut cast = ""; + + if let Some(field_types) = type_def.field_types.as_ref().and_then(|v| v.as_object()) { + if let Some(pg_type_val) = field_types.get(field_name) { + if let Some(pg_type) = pg_type_val.as_str() { + if pg_type == "uuid" { + cast = "::uuid"; + } else if pg_type == "boolean" || pg_type == "bool" { + cast = "::boolean"; + } else if pg_type.contains("timestamp") || pg_type == "timestamptz" || pg_type == "date" + { + cast = "::timestamptz"; + } else if pg_type == "numeric" + || pg_type.contains("int") + || pg_type == "real" + || pg_type == "double precision" + { + cast = "::numeric"; + } else if pg_type == "text" || pg_type.contains("char") { + let mut is_enum = false; + if let Some(props) = &schema.obj.properties { + if let Some(ps) = props.get(field_name) { + is_enum = ps.obj.enum_.is_some(); + } + } + if !is_enum { + is_ilike = true; } } } } + } - let mut is_ilike = false; - let mut cast = ""; + let param_index = i + 1; + let p_val = format!("${}#>>'{{}}'", param_index); - if let Some(field_types) = type_def.field_types.as_ref().and_then(|v| v.as_object()) { - if let Some(pg_type_val) = field_types.get(field_name) { - if let Some(pg_type) = pg_type_val.as_str() { - if pg_type == "uuid" { - cast = "::uuid"; - } else if pg_type == "boolean" || pg_type == "bool" { - cast = "::boolean"; - } else if pg_type.contains("timestamp") - || pg_type == "timestamptz" - || pg_type == "date" - { - cast = "::timestamptz"; - } else if pg_type == "numeric" - || pg_type.contains("int") - || pg_type == "real" - || pg_type == "double precision" - { - cast = "::numeric"; - } else if pg_type == "text" || pg_type.contains("char") { - let mut is_enum = false; - if let Some(props) = &schema.obj.properties { - if let Some(ps) = props.get(field_name) { - is_enum = ps.obj.enum_.is_some(); - } - } - if !is_enum { - is_ilike = true; - } - } + if op == "$in" || op == "$nin" { + let sql_op = if op == "$in" { "IN" } else { "NOT IN" }; + let subquery = format!( + "(SELECT value{} FROM jsonb_array_elements_text(({})::jsonb))", + cast, p_val + ); + where_clauses.push(format!( + "{}.{} {} {}", + filter_alias, field_name, sql_op, subquery + )); + } else { + let sql_op = match op { + "$eq" => { + if is_ilike { + "ILIKE" + } else { + "=" } } - } + "$ne" => { + if is_ilike { + "NOT ILIKE" + } else { + "!=" + } + } + "$gt" => ">", + "$gte" => ">=", + "$lt" => "<", + "$lte" => "<=", + _ => { + if is_ilike { + "ILIKE" + } else { + "=" + } + } + }; - let param_index = i + 1; - let p_val = format!("${}#>>'{{}}'", param_index); - - if op == "$in" || op == "$nin" { - let sql_op = if op == "$in" { "IN" } else { "NOT IN" }; - let subquery = format!( - "(SELECT value{} FROM jsonb_array_elements_text(({})::jsonb))", - cast, p_val - ); - where_clauses.push(format!( - "{}.{} {} {}", - filter_alias, field_name, sql_op, subquery - )); + let param_sql = if is_ilike && (op == "$eq" || op == "$ne") { + p_val } else { - let sql_op = match op { - "$eq" => { - if is_ilike { - "ILIKE" - } else { - "=" - } - } - "$ne" => { - if is_ilike { - "NOT ILIKE" - } else { - "!=" - } - } - "$gt" => ">", - "$gte" => ">=", - "$lt" => "<", - "$lte" => "<=", - _ => { - if is_ilike { - "ILIKE" - } else { - "=" - } - } - }; + format!("({}){}", p_val, cast) + }; - let param_sql = if is_ilike && (op == "$eq" || op == "$ne") { - p_val - } else { - format!("({}){}", p_val, cast) - }; - - where_clauses.push(format!( - "{}.{} {} {}", - filter_alias, field_name, sql_op, param_sql - )); + where_clauses.push(format!( + "{}.{} {} {}", + filter_alias, field_name, sql_op, param_sql + )); } } if let Some(prop) = prop_name { // Find what type the parent alias is actually mapping to let mut relation_alias = parent_alias.to_string(); - + let mut relation_resolved = false; if let Some(parent_type) = parent_type_def { - if let Some(relation) = self.db.get_relation(&parent_type.name, &type_def.name, prop, None) { - + if let Some(relation) = self + .db + .get_relation(&parent_type.name, &type_def.name, prop, None) + { let source_col = &relation.source_columns[0]; let dest_col = &relation.destination_columns[0]; + let mut possible_relation_alias = None; + if let Some(pta) = parent_table_aliases { + if let Some(a) = pta.get(&relation.source_type) { + possible_relation_alias = Some(a.clone()); + } else if let Some(a) = pta.get(&relation.destination_type) { + possible_relation_alias = Some(a.clone()); + } + } + if let Some(pa) = possible_relation_alias { + relation_alias = pa; + } + // Determine directionality based on the Relation metadata - if relation.source_type == parent_type.name || parent_type.hierarchy.contains(&relation.source_type) { + if relation.source_type == parent_type.name + || parent_type.hierarchy.contains(&relation.source_type) + { // Parent is the source - where_clauses.push(format!("{}.{} = {}.{}", parent_alias, source_col, base_alias, dest_col)); + where_clauses.push(format!( + "{}.{} = {}.{}", + relation_alias, source_col, base_alias, dest_col + )); relation_resolved = true; - } else if relation.destination_type == parent_type.name || parent_type.hierarchy.contains(&relation.destination_type) { + } else if relation.destination_type == parent_type.name + || parent_type.hierarchy.contains(&relation.destination_type) + { // Parent is the destination - where_clauses.push(format!("{}.{} = {}.{}", base_alias, source_col, parent_alias, dest_col)); + where_clauses.push(format!( + "{}.{} = {}.{}", + base_alias, source_col, relation_alias, dest_col + )); relation_resolved = true; } } @@ -634,10 +723,15 @@ impl SqlCompiler { if !relation_resolved { // Fallback heuristics for unmapped polymorphism or abstract models if prop == "target" || prop == "source" { - if parent_alias.ends_with("_t1") { - relation_alias = parent_alias.replace("_t1", "_t2"); + if let Some(pta) = parent_table_aliases { + if let Some(a) = pta.get("relationship") { + relation_alias = a.clone(); + } } - where_clauses.push(format!("{}.id = {}.{}_id", base_alias, relation_alias, prop)); + where_clauses.push(format!( + "{}.id = {}.{}_id", + base_alias, relation_alias, prop + )); } else { where_clauses.push(format!("{}.parent_id = {}.id", base_alias, relation_alias)); } @@ -651,11 +745,13 @@ impl SqlCompiler { &self, props: &std::collections::BTreeMap>, parent_alias: &str, + parent_table_aliases: Option<&std::collections::HashMap>, parent_type_def: Option<&crate::database::r#type::Type>, filter_keys: &[String], is_stem_query: bool, depth: usize, current_path: String, + alias_counter: &mut usize, ) -> Result<(String, String), String> { let mut build_args = Vec::new(); for (k, v) in props { @@ -664,16 +760,18 @@ impl SqlCompiler { } else { format!("{}.{}", current_path, k) }; - + let (child_sql, val_type) = self.walk_schema( v, parent_alias, + parent_table_aliases, parent_type_def, Some(k), filter_keys, is_stem_query, depth + 1, next_path, + alias_counter, )?; if val_type == "abort" { continue; @@ -688,12 +786,14 @@ impl SqlCompiler { &self, schemas: &[Arc], parent_alias: &str, + parent_table_aliases: Option<&std::collections::HashMap>, parent_type_def: Option<&crate::database::r#type::Type>, prop_name_context: Option<&str>, filter_keys: &[String], is_stem_query: bool, depth: usize, current_path: String, + alias_counter: &mut usize, ) -> Result<(String, String), String> { let mut case_statements = Vec::new(); let type_col = if let Some(prop) = prop_name_context { @@ -706,17 +806,19 @@ impl SqlCompiler { if let Some(ref_id) = &option_schema.obj.r#ref { // Find the physical type this ref maps to let base_type_name = ref_id.split('.').next_back().unwrap_or("").to_string(); - + // Generate the nested SQL for this specific target type let (val_sql, _) = self.walk_schema( option_schema, parent_alias, + parent_table_aliases, parent_type_def, prop_name_context, filter_keys, is_stem_query, depth, current_path.clone(), + alias_counter, )?; case_statements.push(format!( @@ -730,10 +832,7 @@ impl SqlCompiler { return Ok(("NULL".to_string(), "string".to_string())); } - let sql = format!( - "CASE {} ELSE NULL END", - case_statements.join(" ") - ); + let sql = format!("CASE {} ELSE NULL END", case_statements.join(" ")); Ok((sql, "object".to_string())) } diff --git a/src/tests/sql_validator.rs b/src/tests/sql_validator.rs index f5b04e9..c778fe7 100644 --- a/src/tests/sql_validator.rs +++ b/src/tests/sql_validator.rs @@ -1,156 +1,194 @@ -use sqlparser::ast::{ - Expr, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem, SetExpr, Statement, - TableFactor, TableWithJoins, Ident, -}; +use sqlparser::ast::{Expr, Query, SelectItem, Statement, TableFactor}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use std::collections::HashSet; pub fn validate_semantic_sql(sql: &str) -> Result<(), String> { - let dialect = PostgreSqlDialect {}; - let statements = match Parser::parse_sql(&dialect, sql) { - Ok(s) => s, - Err(e) => return Err(format!("SQL Syntax Error: {}\nSQL: {}", e, sql)), - }; + let dialect = PostgreSqlDialect {}; + let statements = match Parser::parse_sql(&dialect, sql) { + Ok(s) => s, + Err(e) => return Err(format!("SQL Syntax Error: {}\nSQL: {}", e, sql)), + }; - for statement in statements { - validate_statement(&statement, sql)?; - } + for statement in statements { + validate_statement(&statement, sql)?; + } - Ok(()) + Ok(()) } fn validate_statement(stmt: &Statement, original_sql: &str) -> Result<(), String> { - match stmt { - Statement::Query(query) => validate_query(query, original_sql)?, - Statement::Insert(insert) => { - if let Some(query) = &insert.source { - validate_query(query, original_sql)? - } - } - Statement::Update(update) => { - if let Some(expr) = &update.selection { - validate_expr(expr, &HashSet::new(), original_sql)?; - } - } - Statement::Delete(delete) => { - if let Some(expr) = &delete.selection { - validate_expr(expr, &HashSet::new(), original_sql)?; - } - } - _ => {} + match stmt { + Statement::Query(query) => validate_query(query, &HashSet::new(), original_sql)?, + Statement::Insert(insert) => { + if let Some(query) = &insert.source { + validate_query(query, &HashSet::new(), original_sql)? + } } - Ok(()) + Statement::Update(update) => { + if let Some(expr) = &update.selection { + validate_expr(expr, &HashSet::new(), original_sql)?; + } + } + Statement::Delete(delete) => { + if let Some(expr) = &delete.selection { + validate_expr(expr, &HashSet::new(), original_sql)?; + } + } + _ => {} + } + Ok(()) } -fn validate_query(query: &Query, original_sql: &str) -> Result<(), String> { - if let SetExpr::Select(select) = &*query.body { - validate_select(select, original_sql)?; - } - Ok(()) +fn validate_query( + query: &Query, + available_aliases: &HashSet, + original_sql: &str, +) -> Result<(), String> { + if let sqlparser::ast::SetExpr::Select(select) = &*query.body { + validate_select(&select, available_aliases, original_sql)?; + } + Ok(()) } -fn validate_select(select: &Select, original_sql: &str) -> Result<(), String> { - let mut available_aliases = HashSet::new(); +fn validate_select( + select: &sqlparser::ast::Select, + parent_aliases: &HashSet, + original_sql: &str, +) -> Result<(), String> { + let mut available_aliases = parent_aliases.clone(); - // 1. Collect all declared table aliases in the FROM clause and JOINs - for table_with_joins in &select.from { - collect_aliases_from_table_factor(&table_with_joins.relation, &mut available_aliases); - for join in &table_with_joins.joins { - collect_aliases_from_table_factor(&join.relation, &mut available_aliases); - } + // 1. Collect all declared table aliases in the FROM clause and JOINs + for table_with_joins in &select.from { + collect_aliases_from_table_factor(&table_with_joins.relation, &mut available_aliases); + for join in &table_with_joins.joins { + collect_aliases_from_table_factor(&join.relation, &mut available_aliases); } + } - // 2. Validate all SELECT projection fields - for projection in &select.projection { - if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = projection { - validate_expr(expr, &available_aliases, original_sql)?; - } + // 2. Validate all SELECT projection fields + for projection in &select.projection { + if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = projection { + validate_expr(expr, &available_aliases, original_sql)?; } + } - // 3. Validate ON conditions in joins - for table_with_joins in &select.from { - for join in &table_with_joins.joins { - if let JoinOperator::Inner(JoinConstraint::On(expr)) - | JoinOperator::LeftOuter(JoinConstraint::On(expr)) - | JoinOperator::RightOuter(JoinConstraint::On(expr)) - | JoinOperator::FullOuter(JoinConstraint::On(expr)) - | JoinOperator::Join(JoinConstraint::On(expr)) = &join.join_operator - { - validate_expr(expr, &available_aliases, original_sql)?; - } - } + // 3. Validate ON conditions in joins + for table_with_joins in &select.from { + for join in &table_with_joins.joins { + if let sqlparser::ast::JoinOperator::Inner(sqlparser::ast::JoinConstraint::On(expr)) + | sqlparser::ast::JoinOperator::LeftOuter(sqlparser::ast::JoinConstraint::On(expr)) + | sqlparser::ast::JoinOperator::RightOuter(sqlparser::ast::JoinConstraint::On(expr)) + | sqlparser::ast::JoinOperator::FullOuter(sqlparser::ast::JoinConstraint::On(expr)) + | sqlparser::ast::JoinOperator::Join(sqlparser::ast::JoinConstraint::On(expr)) = + &join.join_operator + { + validate_expr(expr, &available_aliases, original_sql)?; + } } + } - // 4. Validate WHERE conditions - if let Some(selection) = &select.selection { - validate_expr(selection, &available_aliases, original_sql)?; - } + // 4. Validate WHERE conditions + if let Some(selection) = &select.selection { + validate_expr(selection, &available_aliases, original_sql)?; + } - Ok(()) + Ok(()) } fn collect_aliases_from_table_factor(tf: &TableFactor, aliases: &mut HashSet) { - match tf { - TableFactor::Table { name, alias, .. } => { - if let Some(table_alias) = alias { - aliases.insert(table_alias.name.value.clone()); - } else if let Some(last) = name.0.last() { - match last { - sqlparser::ast::ObjectNamePart::Identifier(i) => { - aliases.insert(i.value.clone()); - } - _ => {} - } - } + match tf { + TableFactor::Table { name, alias, .. } => { + if let Some(table_alias) = alias { + aliases.insert(table_alias.name.value.clone()); + } else if let Some(last) = name.0.last() { + match last { + sqlparser::ast::ObjectNamePart::Identifier(i) => { + aliases.insert(i.value.clone()); + } + _ => {} } - TableFactor::Derived { alias: Some(table_alias), .. } => { - aliases.insert(table_alias.name.value.clone()); - } - _ => {} + } } + TableFactor::Derived { + subquery, + alias: Some(table_alias), + .. + } => { + aliases.insert(table_alias.name.value.clone()); + // A derived table is technically a nested scope which is opaque outside, but for pure semantic checks + // its internal contents should be validated purely within its own scope (not leaking external aliases in, usually) + // but Postgres allows lateral correlation. We will validate its interior with an empty scope. + let _ = validate_query(subquery, &HashSet::new(), ""); + } + _ => {} + } } -fn validate_expr(expr: &Expr, available_aliases: &HashSet, sql: &str) -> Result<(), String> { - match expr { - Expr::CompoundIdentifier(idents) => { - if idents.len() == 2 { - let alias = &idents[0].value; - if !available_aliases.is_empty() && !available_aliases.contains(alias) { - return Err(format!( - "Semantic Error: Orchestrated query referenced table alias '{}' but it was not declared in the query's FROM/JOIN clauses.\nAvailable aliases: {:?}\nSQL: {}", - alias, available_aliases, sql - )); - } - } else if idents.len() > 2 { - let alias = &idents[1].value; // In form schema.table.column, 'table' is idents[1] - if !available_aliases.is_empty() && !available_aliases.contains(alias) { - return Err(format!( - "Semantic Error: Orchestrated query referenced table '{}' but it was not mapped.\nAvailable aliases: {:?}\nSQL: {}", - alias, available_aliases, sql - )); - } - } +fn validate_expr( + expr: &Expr, + available_aliases: &HashSet, + sql: &str, +) -> Result<(), String> { + match expr { + Expr::CompoundIdentifier(idents) => { + if idents.len() == 2 { + let alias = &idents[0].value; + if !available_aliases.is_empty() && !available_aliases.contains(alias) { + return Err(format!( + "Semantic Error: Orchestrated query referenced table alias '{}' but it was not declared in the query's FROM/JOIN clauses.\nAvailable aliases: {:?}\nSQL: {}", + alias, available_aliases, sql + )); } - Expr::BinaryOp { left, right, .. } => { - validate_expr(left, available_aliases, sql)?; - validate_expr(right, available_aliases, sql)?; + } else if idents.len() > 2 { + let alias = &idents[1].value; // In form schema.table.column, 'table' is idents[1] + if !available_aliases.is_empty() && !available_aliases.contains(alias) { + return Err(format!( + "Semantic Error: Orchestrated query referenced table '{}' but it was not mapped.\nAvailable aliases: {:?}\nSQL: {}", + alias, available_aliases, sql + )); } - Expr::IsFalse(e) | Expr::IsNotFalse(e) | Expr::IsTrue(e) | Expr::IsNotTrue(e) - | Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::InList { expr: e, .. } - | Expr::Nested(e) | Expr::UnaryOp { expr: e, .. } | Expr::Cast { expr: e, .. } - | Expr::Like { expr: e, .. } | Expr::ILike { expr: e, .. } | Expr::AnyOp { left: e, .. } - | Expr::AllOp { left: e, .. } => { - validate_expr(e, available_aliases, sql)?; - } - Expr::Function(func) => { - if let sqlparser::ast::FunctionArguments::List(args) = &func.args { - if let Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(e))) = args.args.get(0) { - validate_expr(e, available_aliases, sql)?; - } - } - } - _ => {} + } } - Ok(()) + Expr::Subquery(subquery) => validate_query(subquery, available_aliases, sql)?, + Expr::Exists { subquery, .. } => validate_query(subquery, available_aliases, sql)?, + Expr::InSubquery { + expr: e, subquery, .. + } => { + validate_expr(e, available_aliases, sql)?; + validate_query(subquery, available_aliases, sql)?; + } + Expr::BinaryOp { left, right, .. } => { + validate_expr(left, available_aliases, sql)?; + validate_expr(right, available_aliases, sql)?; + } + Expr::IsFalse(e) + | Expr::IsNotFalse(e) + | Expr::IsTrue(e) + | Expr::IsNotTrue(e) + | Expr::IsNull(e) + | Expr::IsNotNull(e) + | Expr::InList { expr: e, .. } + | Expr::Nested(e) + | Expr::UnaryOp { expr: e, .. } + | Expr::Cast { expr: e, .. } + | Expr::Like { expr: e, .. } + | Expr::ILike { expr: e, .. } + | Expr::AnyOp { left: e, .. } + | Expr::AllOp { left: e, .. } => { + validate_expr(e, available_aliases, sql)?; + } + Expr::Function(func) => { + if let sqlparser::ast::FunctionArguments::List(args) = &func.args { + if let Some(sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr( + e, + ))) = args.args.get(0) + { + validate_expr(e, available_aliases, sql)?; + } + } + } + _ => {} + } + Ok(()) } diff --git a/src/validator/mod.rs b/src/validator/mod.rs index 49dff47..c7a50a2 100644 --- a/src/validator/mod.rs +++ b/src/validator/mod.rs @@ -68,11 +68,11 @@ impl Validator { code: e.code, message: e.message, details: crate::drop::ErrorDetails { - path: e.path, - cause: None, - context: None, - schema: None, - }, + path: e.path, + cause: None, + context: None, + schema: None, + }, }) .collect(); crate::drop::Drop::with_errors(errors)