use crate::database::executors::DatabaseExecutor; use pgrx::prelude::*; use serde_json::Value; /// The production executor that wraps `pgrx::spi::Spi`. pub struct SpiExecutor; impl SpiExecutor { pub fn new() -> Self { Self {} } fn transact(&self, f: F) -> Result where F: FnOnce() -> Result, { unsafe { let oldcontext = pgrx::pg_sys::CurrentMemoryContext; let oldowner = pgrx::pg_sys::CurrentResourceOwner; pgrx::pg_sys::BeginInternalSubTransaction(std::ptr::null()); pgrx::pg_sys::MemoryContextSwitchTo(oldcontext); let runner = std::panic::AssertUnwindSafe(move || { let res = f(); pgrx::pg_sys::ReleaseCurrentSubTransaction(); pgrx::pg_sys::MemoryContextSwitchTo(oldcontext); pgrx::pg_sys::CurrentResourceOwner = oldowner; res }); pgrx::PgTryBuilder::new(runner) .catch_rust_panic(|cause| { pgrx::pg_sys::RollbackAndReleaseCurrentSubTransaction(); pgrx::pg_sys::MemoryContextSwitchTo(oldcontext); pgrx::pg_sys::CurrentResourceOwner = oldowner; // Rust panics are fatal bugs, not validation errors. Rethrow so they bubble up. cause.rethrow() }) .catch_others(|cause| { pgrx::pg_sys::RollbackAndReleaseCurrentSubTransaction(); pgrx::pg_sys::MemoryContextSwitchTo(oldcontext); pgrx::pg_sys::CurrentResourceOwner = oldowner; let error_msg = match &cause { pgrx::pg_sys::panic::CaughtError::PostgresError(e) | pgrx::pg_sys::panic::CaughtError::ErrorReport(e) => { let json_err = serde_json::json!({ "error": e.message(), "code": format!("{:?}", e.sql_error_code()), "detail": e.detail(), "hint": e.hint() }); json_err.to_string() } _ => format!("{:?}", cause), }; pgrx::warning!("JSPG Caught Native Postgres Error: {}", error_msg); Err(error_msg) }) .execute() } } } impl DatabaseExecutor for SpiExecutor { fn query(&self, sql: &str, args: Option<&[Value]>) -> Result { let mut json_args = Vec::new(); let mut args_with_oid: Vec = Vec::new(); if let Some(params) = args { for val in params { json_args.push(pgrx::JsonB(val.clone())); } for j_val in json_args.into_iter() { args_with_oid.push(pgrx::datum::DatumWithOid::from(j_val)); } } self.transact(|| { Spi::connect(|client| { pgrx::notice!("JSPG_SQL: {}", sql); match client.select(sql, Some(args_with_oid.len() as i64), &args_with_oid) { Ok(tup_table) => { let mut results = Vec::new(); for row in tup_table { if let Ok(Some(jsonb)) = row.get::(1) { results.push(jsonb.0); } } Ok(Value::Array(results)) } Err(e) => Err(format!("SPI Query Fetch Failure: {}", e)), } }) }) } fn execute(&self, sql: &str, args: Option<&[Value]>) -> Result<(), String> { let mut json_args = Vec::new(); let mut args_with_oid: Vec = Vec::new(); if let Some(params) = args { for val in params { json_args.push(pgrx::JsonB(val.clone())); } for j_val in json_args.into_iter() { args_with_oid.push(pgrx::datum::DatumWithOid::from(j_val)); } } self.transact(|| { Spi::connect_mut(|client| { pgrx::notice!("JSPG_SQL: {}", sql); match client.update(sql, Some(args_with_oid.len() as i64), &args_with_oid) { Ok(_) => Ok(()), Err(e) => Err(format!("SPI Execution Failure: {}", e)), } }) }) } fn auth_user_id(&self) -> Result { self.transact(|| { Spi::connect(|client| { let mut tup_table = client .select( "SELECT COALESCE(current_setting('auth.user_id', true), 'ffffffff-ffff-ffff-ffff-ffffffffffff')", None, &[], ) .map_err(|e| format!("SPI Select Error: {}", e))?; let row = tup_table .next() .ok_or("No user id setting returned from context".to_string())?; let user_id: Option = row.get(1).map_err(|e| e.to_string())?; user_id.ok_or("Missing user_id".to_string()) }) }) } fn timestamp(&self) -> Result { self.transact(|| { Spi::connect(|client| { let mut tup_table = client .select("SELECT clock_timestamp()::text", None, &[]) .map_err(|e| format!("SPI Select Error: {}", e))?; let row = tup_table .next() .ok_or("No clock timestamp returned".to_string())?; let timestamp: Option = row.get(1).map_err(|e| e.to_string())?; timestamp.ok_or("Missing timestamp".to_string()) }) }) } }