tiders_query/
lib.rs

1//! # tiders-query
2//!
3//! Arrow RecordBatch query engine with row filtering, column projection, and cross-table joins.
4//!
5//! This crate operates on `BTreeMap<String, RecordBatch>` where keys are table names
6//! (e.g. "blocks", "transactions", "logs"). It supports:
7//!
8//! - **Row filtering** via [`Contains`] (set membership) and [`StartsWith`] (prefix matching).
9//! - **Column projection** via [`select_fields`].
10//! - **Cross-table joins** via [`Include`], which filters rows in one table based on
11//!   matching column values in another.
12//!
13//! Filtering uses xxhash3 for fast hash-table lookups (for sets >= 128 elements) and
14//! rayon for parallel execution across tables and selections.
15
16use anyhow::{anyhow, Context, Result};
17use arrow::array::{
18    Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
19    Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
20    UInt32Array, UInt64Array, UInt8Array,
21};
22use arrow::buffer::BooleanBuffer;
23use arrow::compute;
24use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
25use arrow::record_batch::RecordBatch;
26use arrow::row::{RowConverter, SortField};
27use hashbrown::HashTable;
28use rayon::prelude::*;
29use std::collections::btree_map::Entry;
30use std::collections::BTreeMap;
31use std::sync::Arc;
32use xxhash_rust::xxh3::xxh3_64;
33
34type TableName = String;
35type FieldName = String;
36
37/// A query definition specifying which rows and columns to return from a set of tables.
38///
39/// `selection` defines per-table row filters and cross-table join rules.
40/// `fields` defines which columns to project for each table.
41#[derive(Clone)]
42pub struct Query {
43    /// Row selection rules keyed by table name. Multiple [`TableSelection`]s for the
44    /// same table are OR-combined (a row passes if it matches any selection).
45    pub selection: Arc<BTreeMap<TableName, Vec<TableSelection>>>,
46    /// Columns to include in the output, keyed by table name.
47    pub fields: BTreeMap<TableName, Vec<FieldName>>,
48}
49
50impl Query {
51    /// Adds all filter and include column names to `self.fields` so that they are
52    /// present in the projected output. This ensures columns referenced by filters
53    /// and joins are fetched from the data source.
54    pub fn add_request_and_include_fields(&mut self) -> Result<()> {
55        for (table_name, selections) in &*self.selection {
56            for selection in selections {
57                for col_name in selection.filters.keys() {
58                    let table_fields = self
59                        .fields
60                        .get_mut(table_name)
61                        .with_context(|| format!("get fields for table {table_name}"))?;
62                    table_fields.push(col_name.to_owned());
63                }
64
65                for include in &selection.include {
66                    let other_table_fields = self
67                        .fields
68                        .get_mut(&include.other_table_name)
69                        .with_context(|| {
70                            format!("get fields for other table {}", include.other_table_name)
71                        })?;
72                    other_table_fields.extend_from_slice(&include.other_table_field_names);
73                    let table_fields = self
74                        .fields
75                        .get_mut(table_name)
76                        .with_context(|| format!("get fields for table {table_name}"))?;
77                    table_fields.extend_from_slice(&include.field_names);
78                }
79            }
80        }
81
82        Ok(())
83    }
84}
85
86/// A single selection rule for one table: column filters (AND-combined) plus
87/// cross-table join includes.
88pub struct TableSelection {
89    /// Column-level filters. All filters within a selection are AND-combined:
90    /// a row must pass every filter to be included.
91    pub filters: BTreeMap<FieldName, Filter>,
92    /// Cross-table joins. Rows from `other_table_name` are included if their
93    /// join columns match filtered rows from this table.
94    pub include: Vec<Include>,
95}
96
97/// Defines a cross-table join: after filtering the source table, include rows from
98/// `other_table_name` whose `other_table_field_names` columns match the source
99/// table's `field_names` columns.
100pub struct Include {
101    /// The target table to join into.
102    pub other_table_name: TableName,
103    /// Column names in the source table used for the join key.
104    pub field_names: Vec<FieldName>,
105    /// Corresponding column names in the target table.
106    pub other_table_field_names: Vec<FieldName>,
107}
108
109/// A row-level filter applied to a single column.
110pub enum Filter {
111    /// Set membership check: row passes if the column value is in the set.
112    Contains(Contains),
113    /// Prefix match: row passes if the column value starts with any of the prefixes.
114    StartsWith(StartsWith),
115    /// Boolean equality: row passes if the column's boolean value matches.
116    Bool(bool),
117}
118
119impl Filter {
120    /// Creates a [`Contains`] filter from the given array of allowed values.
121    pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
122        Ok(Self::Contains(Contains::new(arr)?))
123    }
124
125    /// Creates a [`StartsWith`] filter from the given array of prefixes.
126    pub fn starts_with(arr: Arc<dyn Array>) -> Result<Self> {
127        Ok(Self::StartsWith(StartsWith::new(arr)?))
128    }
129
130    /// Creates a boolean equality filter.
131    pub fn bool(b: bool) -> Self {
132        Self::Bool(b)
133    }
134
135    fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
136        match self {
137            Self::Contains(ct) => ct.contains(arr),
138            Self::StartsWith(sw) => sw.starts_with(arr),
139            Self::Bool(b) => {
140                let arr = arr
141                    .as_any()
142                    .downcast_ref::<BooleanArray>()
143                    .context("cast array to boolean array")?;
144
145                let mut filter = if *b {
146                    arr.clone()
147                } else {
148                    compute::not(arr).context("negate array")?
149                };
150
151                if let Some(nulls) = filter.nulls() {
152                    if nulls.null_count() > 0 {
153                        let nulls = BooleanArray::from(nulls.inner().clone());
154                        filter = compute::and(&filter, &nulls)
155                            .context("apply null mask to boolean filter")?;
156                    }
157                }
158
159                Ok(filter)
160            }
161        }
162    }
163}
164
165/// Set membership filter backed by a hash table for fast lookups.
166///
167/// Supports integer (i8–i64, u8–u64), binary, and string column types.
168/// For sets with fewer than 128 elements, falls back to linear scan;
169/// for larger sets, uses an xxhash3-based hash table.
170pub struct Contains {
171    array: Arc<dyn Array>,
172    hash_table: Option<HashTable<usize>>,
173}
174
175impl Contains {
176    fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
177        assert!(!arr.is_nullable());
178
179        let mut ht = HashTable::with_capacity(arr.len());
180
181        for (i, v) in arr.values().iter().enumerate() {
182            ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
183                xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
184            });
185        }
186
187        ht
188    }
189
190    fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
191        arr: &GenericByteArray<T>,
192    ) -> HashTable<usize> {
193        assert!(!arr.is_nullable());
194
195        let mut ht = HashTable::with_capacity(arr.len());
196
197        for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
198            ht.insert_unique(xxh3_64(v), i, |i| {
199                xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
200            });
201        }
202
203        ht
204    }
205
206    fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
207        let ht = match *array.data_type() {
208            DataType::UInt8 => {
209                let array = array
210                    .as_any()
211                    .downcast_ref::<UInt8Array>()
212                    .context("downcast to UInt8Array failed")?;
213                Self::ht_from_primitive(array)
214            }
215            DataType::UInt16 => {
216                let array = array
217                    .as_any()
218                    .downcast_ref::<UInt16Array>()
219                    .context("downcast to UInt16Array failed")?;
220                Self::ht_from_primitive(array)
221            }
222            DataType::UInt32 => {
223                let array = array
224                    .as_any()
225                    .downcast_ref::<UInt32Array>()
226                    .context("downcast to UInt32Array failed")?;
227                Self::ht_from_primitive(array)
228            }
229            DataType::UInt64 => {
230                let array = array
231                    .as_any()
232                    .downcast_ref::<UInt64Array>()
233                    .context("downcast to UInt64Array failed")?;
234                Self::ht_from_primitive(array)
235            }
236            DataType::Int8 => {
237                let array = array
238                    .as_any()
239                    .downcast_ref::<Int8Array>()
240                    .context("downcast to Int8Array failed")?;
241                Self::ht_from_primitive(array)
242            }
243            DataType::Int16 => {
244                let array = array
245                    .as_any()
246                    .downcast_ref::<Int16Array>()
247                    .context("downcast to Int16Array failed")?;
248                Self::ht_from_primitive(array)
249            }
250            DataType::Int32 => {
251                let array = array
252                    .as_any()
253                    .downcast_ref::<Int32Array>()
254                    .context("downcast to Int32Array failed")?;
255                Self::ht_from_primitive(array)
256            }
257            DataType::Int64 => {
258                let array = array
259                    .as_any()
260                    .downcast_ref::<Int64Array>()
261                    .context("downcast to Int64Array failed")?;
262                Self::ht_from_primitive(array)
263            }
264            DataType::Binary => {
265                let array = array
266                    .as_any()
267                    .downcast_ref::<BinaryArray>()
268                    .context("downcast to BinaryArray failed")?;
269                Self::ht_from_bytes(array)
270            }
271            DataType::Utf8 => {
272                let array = array
273                    .as_any()
274                    .downcast_ref::<StringArray>()
275                    .context("downcast to StringArray failed")?;
276                Self::ht_from_bytes(array)
277            }
278            _ => {
279                return Err(anyhow!("unsupported data type: {}", array.data_type()));
280            }
281        };
282
283        Ok(ht)
284    }
285
286    /// Creates a new containment filter from a non-nullable array of allowed values.
287    ///
288    /// Uses a hash table for sets of 128+ elements, linear scan otherwise.
289    pub fn new(array: Arc<dyn Array>) -> Result<Self> {
290        if array.is_nullable() {
291            return Err(anyhow!(
292                "cannot construct contains filter with a nullable array"
293            ));
294        }
295
296        let hash_table = if array.len() >= 128 {
297            Some(Self::ht_from_array(&array).context("construct hash table")?)
298        } else {
299            None
300        };
301
302        Ok(Self { array, hash_table })
303    }
304
305    fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
306        if arr.data_type() != self.array.data_type() {
307            return Err(anyhow!(
308                "filter array is of type {} but array to be filtered is of type {}",
309                self.array.data_type(),
310                arr.data_type(),
311            ));
312        }
313        anyhow::ensure!(
314            !self.array.is_nullable(),
315            "filter array must not be nullable"
316        );
317
318        let filter = match *arr.data_type() {
319            DataType::UInt8 => {
320                let self_arr = self
321                    .array
322                    .as_any()
323                    .downcast_ref::<UInt8Array>()
324                    .context("downcast to UInt8Array failed")?;
325                let other_arr = arr
326                    .as_any()
327                    .downcast_ref()
328                    .context("downcast other to UInt8Array failed")?;
329                self.contains_primitive(self_arr, other_arr)
330            }
331            DataType::UInt16 => {
332                let self_arr = self
333                    .array
334                    .as_any()
335                    .downcast_ref::<UInt16Array>()
336                    .context("downcast to UInt16Array failed")?;
337                let other_arr = arr
338                    .as_any()
339                    .downcast_ref()
340                    .context("downcast other to UInt16Array failed")?;
341                self.contains_primitive(self_arr, other_arr)
342            }
343            DataType::UInt32 => {
344                let self_arr = self
345                    .array
346                    .as_any()
347                    .downcast_ref::<UInt32Array>()
348                    .context("downcast to UInt32Array failed")?;
349                let other_arr = arr
350                    .as_any()
351                    .downcast_ref()
352                    .context("downcast other to UInt32Array failed")?;
353                self.contains_primitive(self_arr, other_arr)
354            }
355            DataType::UInt64 => {
356                let self_arr = self
357                    .array
358                    .as_any()
359                    .downcast_ref::<UInt64Array>()
360                    .context("downcast to UInt64Array failed")?;
361                let other_arr = arr
362                    .as_any()
363                    .downcast_ref()
364                    .context("downcast other to UInt64Array failed")?;
365                self.contains_primitive(self_arr, other_arr)
366            }
367            DataType::Int8 => {
368                let self_arr = self
369                    .array
370                    .as_any()
371                    .downcast_ref::<Int8Array>()
372                    .context("downcast to Int8Array failed")?;
373                let other_arr = arr
374                    .as_any()
375                    .downcast_ref()
376                    .context("downcast other to Int8Array failed")?;
377                self.contains_primitive(self_arr, other_arr)
378            }
379            DataType::Int16 => {
380                let self_arr = self
381                    .array
382                    .as_any()
383                    .downcast_ref::<Int16Array>()
384                    .context("downcast to Int16Array failed")?;
385                let other_arr = arr
386                    .as_any()
387                    .downcast_ref()
388                    .context("downcast other to Int16Array failed")?;
389                self.contains_primitive(self_arr, other_arr)
390            }
391            DataType::Int32 => {
392                let self_arr = self
393                    .array
394                    .as_any()
395                    .downcast_ref::<Int32Array>()
396                    .context("downcast to Int32Array failed")?;
397                let other_arr = arr
398                    .as_any()
399                    .downcast_ref()
400                    .context("downcast other to Int32Array failed")?;
401                self.contains_primitive(self_arr, other_arr)
402            }
403            DataType::Int64 => {
404                let self_arr = self
405                    .array
406                    .as_any()
407                    .downcast_ref::<Int64Array>()
408                    .context("downcast to Int64Array failed")?;
409                let other_arr = arr
410                    .as_any()
411                    .downcast_ref()
412                    .context("downcast other to Int64Array failed")?;
413                self.contains_primitive(self_arr, other_arr)
414            }
415            DataType::Binary => {
416                let self_arr = self
417                    .array
418                    .as_any()
419                    .downcast_ref::<BinaryArray>()
420                    .context("downcast to BinaryArray failed")?;
421                let other_arr = arr
422                    .as_any()
423                    .downcast_ref()
424                    .context("downcast other to BinaryArray failed")?;
425                self.contains_bytes(self_arr, other_arr)
426            }
427            DataType::Utf8 => {
428                let self_arr = self
429                    .array
430                    .as_any()
431                    .downcast_ref::<StringArray>()
432                    .context("downcast to StringArray failed")?;
433                let other_arr = arr
434                    .as_any()
435                    .downcast_ref()
436                    .context("downcast other to StringArray failed")?;
437                self.contains_bytes(self_arr, other_arr)
438            }
439            _ => {
440                return Err(anyhow!("unsupported data type: {}", arr.data_type()));
441            }
442        };
443
444        let mut filter = filter;
445
446        if let Some(nulls) = arr.nulls() {
447            if nulls.null_count() > 0 {
448                let nulls = BooleanArray::from(nulls.inner().clone());
449                filter =
450                    compute::and(&filter, &nulls).context("apply null mask to contains filter")?;
451            }
452        }
453
454        Ok(filter)
455    }
456
457    fn contains_primitive<T: ArrowPrimitiveType>(
458        &self,
459        self_arr: &PrimitiveArray<T>,
460        other_arr: &PrimitiveArray<T>,
461    ) -> BooleanArray {
462        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
463
464        if let Some(ht) = self.hash_table.as_ref() {
465            let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
466
467            for v in other_arr.values() {
468                let c = ht
469                    .find(hash_one(v), |idx| unsafe {
470                        self_arr.values().get_unchecked(*idx) == v
471                    })
472                    .is_some();
473                filter.append_value(c);
474            }
475        } else {
476            for v in other_arr.values() {
477                filter.append_value(self_arr.values().iter().any(|x| x == v));
478            }
479        }
480
481        filter.finish()
482    }
483
484    fn contains_bytes<T: ByteArrayType<Offset = i32>>(
485        &self,
486        self_arr: &GenericByteArray<T>,
487        other_arr: &GenericByteArray<T>,
488    ) -> BooleanArray {
489        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
490
491        if let Some(ht) = self.hash_table.as_ref() {
492            for v in iter_byte_array_without_validity(other_arr) {
493                let c = ht
494                    .find(xxh3_64(v), |idx| unsafe {
495                        byte_array_get_unchecked(self_arr, *idx) == v
496                    })
497                    .is_some();
498                filter.append_value(c);
499            }
500        } else {
501            for v in iter_byte_array_without_validity(other_arr) {
502                filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
503            }
504        }
505
506        filter.finish()
507    }
508}
509
510/// Prefix matching filter for binary and string columns.
511///
512/// A row passes if its column value starts with any of the prefix values
513/// in the filter array.
514pub struct StartsWith {
515    array: Arc<dyn Array>,
516}
517
518impl StartsWith {
519    /// Creates a new prefix filter from a non-nullable array of prefix values.
520    pub fn new(array: Arc<dyn Array>) -> Result<Self> {
521        if array.is_nullable() {
522            return Err(anyhow!(
523                "cannot construct starts_with filter with a nullable array"
524            ));
525        }
526
527        Ok(Self { array })
528    }
529
530    fn starts_with(&self, arr: &dyn Array) -> Result<BooleanArray> {
531        if arr.data_type() != self.array.data_type() {
532            return Err(anyhow!(
533                "filter array is of type {} but array to be filtered is of type {}",
534                self.array.data_type(),
535                arr.data_type(),
536            ));
537        }
538        anyhow::ensure!(
539            !self.array.is_nullable(),
540            "filter array must not be nullable"
541        );
542
543        let mut filter = match *arr.data_type() {
544            DataType::Binary => {
545                let self_arr = self
546                    .array
547                    .as_any()
548                    .downcast_ref::<BinaryArray>()
549                    .context("downcast to BinaryArray failed")?;
550                let other_arr = arr
551                    .as_any()
552                    .downcast_ref()
553                    .context("downcast other to BinaryArray failed")?;
554                Self::starts_with_bytes(self_arr, other_arr)
555            }
556            DataType::Utf8 => {
557                let self_arr = self
558                    .array
559                    .as_any()
560                    .downcast_ref::<StringArray>()
561                    .context("downcast to StringArray failed")?;
562                let other_arr = arr
563                    .as_any()
564                    .downcast_ref()
565                    .context("downcast other to StringArray failed")?;
566                Self::starts_with_bytes(self_arr, other_arr)
567            }
568            _ => {
569                return Err(anyhow!("unsupported data type: {}", arr.data_type()));
570            }
571        };
572
573        if let Some(nulls) = arr.nulls() {
574            if nulls.null_count() > 0 {
575                let nulls = BooleanArray::from(nulls.inner().clone());
576                filter = compute::and(&filter, &nulls)
577                    .context("apply null mask to starts_with filter")?;
578            }
579        }
580
581        Ok(filter)
582    }
583
584    fn starts_with_bytes<T: ByteArrayType<Offset = i32>>(
585        self_arr: &GenericByteArray<T>,
586        other_arr: &GenericByteArray<T>,
587    ) -> BooleanArray {
588        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
589
590        for v in iter_byte_array_without_validity(other_arr) {
591            let mut found = false;
592            for prefix in iter_byte_array_without_validity(self_arr) {
593                if v.starts_with(prefix) {
594                    found = true;
595                    break;
596                }
597            }
598            filter.append_value(found);
599        }
600
601        filter.finish()
602    }
603}
604
605/// Unchecked byte access into a [`GenericByteArray`] — adapted from arrow-rs internals.
606#[expect(clippy::unwrap_used, reason = "i32 offsets always fit in isize/usize")]
607unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
608    arr: &GenericByteArray<T>,
609    i: usize,
610) -> &[u8] {
611    let end = *arr.value_offsets().get_unchecked(i + 1);
612    let start = *arr.value_offsets().get_unchecked(i);
613
614    std::slice::from_raw_parts(
615        arr.value_data()
616            .as_ptr()
617            .offset(isize::try_from(start).unwrap()),
618        usize::try_from(end - start).unwrap(),
619    )
620}
621
622fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
623    arr: &GenericByteArray<T>,
624) -> impl Iterator<Item = &[u8]> {
625    (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
626}
627
628/// Executes a query against a set of named tables, returning filtered and projected results.
629///
630/// Applies all selection filters in parallel (via rayon), OR-combines filters for the
631/// same table, then projects the requested fields. Tables not referenced by any
632/// selection are excluded from the output.
633pub fn run_query(
634    data: &BTreeMap<TableName, RecordBatch>,
635    query: &Query,
636) -> Result<BTreeMap<TableName, RecordBatch>> {
637    let filters = query
638        .selection
639        .par_iter()
640        .map(|(table_name, selections)| {
641            selections
642                .par_iter()
643                .enumerate()
644                .map(|(i, selection)| {
645                    run_table_selection(data, table_name, selection).with_context(|| {
646                        format!("run table selection no:{i} for table {table_name}")
647                    })
648                })
649                .collect::<Result<Vec<_>>>()
650        })
651        .collect::<Result<Vec<_>>>()?;
652
653    let data = select_fields(data, &query.fields).context("select fields")?;
654
655    data.par_iter()
656        .filter_map(|(table_name, table_data)| {
657            let mut combined_filter: Option<BooleanArray> = None;
658
659            for f in &filters {
660                for f in f {
661                    let Some(filter) = f.get(table_name) else {
662                        continue;
663                    };
664
665                    match combined_filter.as_ref() {
666                        Some(e) => {
667                            let f = compute::or(e, filter)
668                                .with_context(|| format!("combine filters for {table_name}"));
669                            let f = match f {
670                                Ok(v) => v,
671                                Err(err) => return Some(Err(err)),
672                            };
673                            combined_filter = Some(f);
674                        }
675                        None => {
676                            combined_filter = Some(filter.clone());
677                        }
678                    }
679                }
680            }
681
682            let combined_filter = combined_filter?;
683
684            let table_data = compute::filter_record_batch(table_data, &combined_filter)
685                .context("filter record batch");
686            let table_data = match table_data {
687                Ok(v) => v,
688                Err(err) => return Some(Err(err)),
689            };
690
691            Some(Ok((table_name.to_owned(), table_data)))
692        })
693        .collect()
694}
695
696/// Projects each table to include only the specified columns.
697pub fn select_fields(
698    data: &BTreeMap<TableName, RecordBatch>,
699    fields: &BTreeMap<TableName, Vec<FieldName>>,
700) -> Result<BTreeMap<TableName, RecordBatch>> {
701    let mut out = BTreeMap::new();
702
703    for (table_name, field_names) in fields {
704        let table_data = data
705            .get(table_name)
706            .with_context(|| format!("get data for table {table_name}"))?;
707
708        let indices = table_data
709            .schema_ref()
710            .fields()
711            .iter()
712            .enumerate()
713            .filter(|(_, field)| field_names.contains(field.name()))
714            .map(|(i, _)| i)
715            .collect::<Vec<usize>>();
716
717        let table_data = table_data
718            .project(&indices)
719            .with_context(|| format!("project table {table_name}"))?;
720        out.insert(table_name.to_owned(), table_data);
721    }
722
723    Ok(out)
724}
725
726fn run_table_selection(
727    data: &BTreeMap<TableName, RecordBatch>,
728    table_name: &str,
729    selection: &TableSelection,
730) -> Result<BTreeMap<TableName, BooleanArray>> {
731    let mut out = BTreeMap::new();
732
733    let table_data = data.get(table_name).context("get table data")?;
734    let mut combined_filter = None;
735    for (field_name, filter) in &selection.filters {
736        let col = table_data
737            .column_by_name(field_name)
738            .with_context(|| format!("get field {field_name}"))?;
739
740        let f = filter
741            .check(&col)
742            .with_context(|| format!("check filter for column {field_name}"))?;
743
744        match combined_filter {
745            Some(cf) => {
746                combined_filter = Some(
747                    compute::and(&cf, &f)
748                        .with_context(|| format!("combine filter for column {field_name}"))?,
749                );
750            }
751            None => {
752                combined_filter = Some(f);
753            }
754        }
755    }
756
757    let combined_filter = match combined_filter {
758        Some(cf) => cf,
759        None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
760    };
761
762    out.insert(table_name.to_owned(), combined_filter.clone());
763
764    let mut filtered_cache = BTreeMap::new();
765
766    for (i, inc) in selection.include.iter().enumerate() {
767        if inc.other_table_field_names.len() != inc.field_names.len() {
768            return Err(anyhow!(
769                "field names are different for self table and other table while processing include no: {}. {} {}",
770                i,
771                inc.field_names.len(),
772                inc.other_table_field_names.len(),
773            ));
774        }
775
776        let other_table_data = data.get(&inc.other_table_name).with_context(|| {
777            format!(
778                "get data for table {} as other table data",
779                inc.other_table_name
780            )
781        })?;
782
783        let self_arr = columns_to_binary_array(table_data, &inc.field_names)
784            .context("get row format binary arr for self")?;
785
786        let contains = match filtered_cache.entry(inc.field_names.clone()) {
787            Entry::Vacant(entry) => {
788                let self_arr = compute::filter(&self_arr, &combined_filter)
789                    .context("apply combined filter to self arr")?;
790                let contains =
791                    Contains::new(Arc::new(self_arr)).context("create contains filter")?;
792                let contains = Arc::new(contains);
793                entry.insert(Arc::clone(&contains));
794                contains
795            }
796            Entry::Occupied(entry) => Arc::clone(entry.get()),
797        };
798
799        let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
800            .with_context(|| {
801                format!(
802                    "get row format binary arr for other table {}",
803                    inc.other_table_name
804                )
805            })?;
806
807        let f = contains
808            .contains(&other_arr)
809            .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
810
811        match out.entry(inc.other_table_name.clone()) {
812            Entry::Vacant(entry) => {
813                entry.insert(f);
814            }
815            Entry::Occupied(mut entry) => {
816                let new = compute::or(entry.get(), &f).with_context(|| {
817                    format!("or include filters for table {}", inc.other_table_name)
818                })?;
819                entry.insert(new);
820            }
821        }
822    }
823
824    Ok(out)
825}
826
827fn columns_to_binary_array(
828    table_data: &RecordBatch,
829    column_names: &[String],
830) -> Result<BinaryArray> {
831    let fields = column_names
832        .iter()
833        .map(|field_name| {
834            let f = table_data
835                .schema_ref()
836                .field_with_name(field_name)
837                .with_context(|| format!("get field {field_name} from schema"))?;
838            Ok(SortField::new(f.data_type().clone()))
839        })
840        .collect::<Result<Vec<_>>>()?;
841    let conv = RowConverter::new(fields).context("create row converter")?;
842
843    let columns = column_names
844        .iter()
845        .map(|field_name| {
846            let c = table_data
847                .column_by_name(field_name)
848                .with_context(|| format!("get data for column {field_name}"))?;
849            let c = Arc::clone(c);
850            Ok(c)
851        })
852        .collect::<Result<Vec<_>>>()?;
853
854    let rows = conv
855        .convert_columns(&columns)
856        .context("convert columns to row format")?;
857    let out = rows
858        .try_into_binary()
859        .context("convert row format to binary array")?;
860
861    Ok(out)
862}
863
864#[cfg(test)]
865mod tests {
866    use arrow::{
867        array::AsArray,
868        datatypes::{Field, Schema},
869    };
870
871    use super::*;
872
873    #[test]
874    fn basic_test_tiders_query() {
875        let team_a = RecordBatch::try_new(
876            Arc::new(Schema::new(vec![
877                Arc::new(Field::new("name", DataType::Utf8, true)),
878                Arc::new(Field::new("age", DataType::UInt64, true)),
879                Arc::new(Field::new("height", DataType::UInt64, true)),
880            ])),
881            vec![
882                Arc::new(StringArray::from_iter_values(
883                    vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
884                )),
885                Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
886                Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
887            ],
888        )
889        .unwrap();
890        let team_b = RecordBatch::try_new(
891            Arc::new(Schema::new(vec![
892                Arc::new(Field::new("name2", DataType::Utf8, true)),
893                Arc::new(Field::new("age2", DataType::UInt64, true)),
894                Arc::new(Field::new("height2", DataType::UInt64, true)),
895            ])),
896            vec![
897                Arc::new(StringArray::from_iter_values(vec![
898                    "yusuf", "abuzer", "asd",
899                ])),
900                Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
901                Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
902            ],
903        )
904        .unwrap();
905
906        let query = Query {
907            fields: [
908                ("team_a".to_owned(), vec!["name".to_owned()]),
909                ("team_b".to_owned(), vec!["name2".to_owned()]),
910            ]
911            .into_iter()
912            .collect(),
913            selection: Arc::new(
914                [(
915                    "team_a".to_owned(),
916                    vec![TableSelection {
917                        filters: [(
918                            "name".to_owned(),
919                            Filter::Contains(
920                                Contains::new(Arc::new(StringArray::from_iter_values(
921                                    vec!["kamil", "mahmut"].into_iter(),
922                                )))
923                                .unwrap(),
924                            ),
925                        )]
926                        .into_iter()
927                        .collect(),
928                        include: vec![
929                            Include {
930                                field_names: vec!["age".to_owned(), "height".to_owned()],
931                                other_table_field_names: vec![
932                                    "age2".to_owned(),
933                                    "height2".to_owned(),
934                                ],
935                                other_table_name: "team_b".to_owned(),
936                            },
937                            Include {
938                                field_names: vec!["height".to_owned()],
939                                other_table_field_names: vec!["height".to_owned()],
940                                other_table_name: "team_a".to_owned(),
941                            },
942                        ],
943                    }],
944                )]
945                .into_iter()
946                .collect(),
947            ),
948        };
949
950        let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
951            .into_iter()
952            .collect::<BTreeMap<_, _>>();
953
954        let res = run_query(&data, &query).unwrap();
955
956        let team_a = res.get("team_a").unwrap();
957        let team_b = res.get("team_b").unwrap();
958
959        assert_eq!(res.len(), 2);
960
961        let name = team_a.column_by_name("name").unwrap();
962        let name2 = team_b.column_by_name("name2").unwrap();
963
964        assert_eq!(team_a.num_columns(), 1);
965        assert_eq!(team_b.num_columns(), 1);
966
967        assert_eq!(
968            name.as_string(),
969            &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
970        );
971        assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
972    }
973
974    #[test]
975    fn test_starts_with_filter() {
976        let data = RecordBatch::try_new(
977            Arc::new(Schema::new(vec![
978                Arc::new(Field::new("name", DataType::Utf8, true)),
979                Arc::new(Field::new("binary", DataType::Binary, true)),
980            ])),
981            vec![
982                Arc::new(StringArray::from_iter_values(
983                    vec!["hello", "world", "helloworld", "goodbye", "hell"].into_iter(),
984                )),
985                Arc::new(BinaryArray::from_iter_values(
986                    vec![b"hello", b"world", b"hepto", b"grace", b"heheh"].into_iter(),
987                )),
988            ],
989        )
990        .unwrap();
991
992        let query = Query {
993            fields: [(
994                "data".to_owned(),
995                vec!["name".to_owned(), "binary".to_owned()],
996            )]
997            .into_iter()
998            .collect(),
999            selection: Arc::new(
1000                [(
1001                    "data".to_owned(),
1002                    vec![TableSelection {
1003                        filters: [
1004                            (
1005                                "name".to_owned(),
1006                                Filter::StartsWith(
1007                                    StartsWith::new(Arc::new(StringArray::from_iter_values(
1008                                        vec!["he"].into_iter(),
1009                                    )))
1010                                    .unwrap(),
1011                                ),
1012                            ),
1013                            (
1014                                "binary".to_owned(),
1015                                Filter::StartsWith(
1016                                    StartsWith::new(Arc::new(BinaryArray::from_iter_values(
1017                                        vec![b"he"].into_iter(),
1018                                    )))
1019                                    .unwrap(),
1020                                ),
1021                            ),
1022                        ]
1023                        .into_iter()
1024                        .collect(),
1025                        include: vec![],
1026                    }],
1027                )]
1028                .into_iter()
1029                .collect(),
1030            ),
1031        };
1032
1033        let data = [("data".to_owned(), data)]
1034            .into_iter()
1035            .collect::<BTreeMap<_, _>>();
1036
1037        let res = run_query(&data, &query).unwrap();
1038        let filtered = res.get("data").unwrap();
1039
1040        let name = filtered.column_by_name("name").unwrap();
1041        let binary = filtered.column_by_name("binary").unwrap();
1042        assert_eq!(
1043            name.as_string(),
1044            &StringArray::from_iter_values(["hello", "helloworld", "hell"])
1045        );
1046        assert_eq!(
1047            binary.as_binary::<i32>(),
1048            &BinaryArray::from_iter_values([b"hello", b"hepto", b"heheh"].into_iter())
1049        );
1050    }
1051}