Skip to content

Instantly share code, notes, and snippets.

@pitrou
Created November 17, 2023 15:05
Show Gist options
  • Save pitrou/698144c6277d913448398c996aaac3eb to your computer and use it in GitHub Desktop.
Save pitrou/698144c6277d913448398c996aaac3eb to your computer and use it in GitHub Desktop.
diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs
index b7bf8d6e12..1155c9d18b 100644
--- a/arrow-schema/src/error.rs
+++ b/arrow-schema/src/error.rs
@@ -122,6 +122,20 @@ impl Error for ArrowError {
}
}
+#[macro_export]
+macro_rules! ensure {
+ ($cond:expr, $kind:ident, $msg:literal $(,)?) => {
+ if !$cond {
+ return Err($crate::ArrowError::$kind(($msg).to_string()));
+ }
+ };
+ ($cond:expr, $kind:ident, $fmt:literal, $($fmt_arg:tt)*) => {
+ if !$cond {
+ return Err($crate::ArrowError::$kind(format!($fmt, $($fmt_arg)*)));
+ }
+ };
+}
+
#[cfg(test)]
mod test {
use super::*;
@@ -150,4 +164,16 @@ mod test {
assert!(matches!(source, ArrowError::DivideByZero));
}
+
+ #[test]
+ fn ensure() {
+ fn check_even(a: i32) -> Result<(), ArrowError> {
+ ensure!(a % 2 == 0, InvalidArgumentError, "{} is not even", a);
+ Ok(())
+ }
+ assert!(matches!(check_even(42), Ok(())));
+ let e = check_even(43).unwrap_err();
+ assert!(matches!(e, ArrowError::InvalidArgumentError(_)));
+ assert_eq!(e.to_string(), "Invalid argument error: 43 is not even");
+ }
}
diff --git a/arrow/src/error.rs b/arrow/src/error.rs
index f7acec0b34..1fbb940a3a 100644
--- a/arrow/src/error.rs
+++ b/arrow/src/error.rs
@@ -17,6 +17,6 @@
//! Defines `ArrowError` for representing failures in various Arrow operations.
-pub use arrow_schema::ArrowError;
+pub use arrow_schema::{ensure, ArrowError};
pub type Result<T> = std::result::Result<T, ArrowError>;
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index 3760673ab7..bac07010a2 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -114,7 +114,7 @@ use arrow_schema::UnionMode;
use crate::array::{layout, ArrayData};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::DataType;
-use crate::error::{ArrowError, Result};
+use crate::error::{ensure, ArrowError, Result};
use crate::util::bit_util;
// returns the number of bits that buffer `i` (in the C data interface) is expected to have.
@@ -290,7 +290,11 @@ impl<'a> ImportedArrowArray<'a> {
if let Some(d) = self.dictionary()? {
// For dictionary type there should only be a single child, so we don't need to worry if
// there are other children added above.
- assert!(child_data.is_empty());
+ ensure!(
+ child_data.is_empty(),
+ CDataInterface,
+ "unexpected child arrays for dictionary array"
+ );
child_data.push(d.consume()?);
}
@@ -315,7 +319,13 @@ impl<'a> ImportedArrowArray<'a> {
| DataType::LargeList(field)
| DataType::Map(field, _) => Ok([self.consume_child(0, field.data_type())?].to_vec()),
DataType::Struct(fields) => {
- assert!(fields.len() == self.array.num_children());
+ ensure!(
+ fields.len() == self.array.num_children(),
+ CDataInterface,
+ "mismatching field lengths for struct: got {}, expected {}",
+ self.array.num_children(),
+ fields.len(),
+ );
fields
.iter()
.enumerate()
@@ -323,7 +333,13 @@ impl<'a> ImportedArrowArray<'a> {
.collect::<Result<Vec<_>>>()
}
DataType::Union(union_fields, _) => {
- assert!(union_fields.len() == self.array.num_children());
+ ensure!(
+ union_fields.len() == self.array.num_children(),
+ CDataInterface,
+ "mismatching field lengths for union: got {}, expected {}",
+ self.array.num_children(),
+ union_fields.len(),
+ );
union_fields
.iter()
.enumerate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment