tiders_cast/
lib.rs

1//! # tiders-cast
2//!
3//! Type casting and encoding/decoding utilities for Arrow columns and schemas.
4//!
5//! Provides batch-level and column-level operations for:
6//! - **Type casting**: Convert columns between Arrow data types by name or by source type.
7//! - **Hex encoding/decoding**: Binary columns to/from hex strings (with optional `0x` prefix).
8//! - **Base58 encoding/decoding**: Binary columns to/from Base58 strings (Bitcoin alphabet).
9//! - **U256 conversion**: Between `Decimal256(76,0)` and big-endian binary representations.
10
11#![allow(clippy::manual_div_ceil)]
12
13use std::sync::Arc;
14
15use anyhow::{Context, Result};
16use arrow::{
17    array::{
18        builder, Array, BinaryArray, Decimal256Array, GenericBinaryArray, GenericStringArray,
19        LargeBinaryArray, OffsetSizeTrait, RecordBatch,
20    },
21    compute::CastOptions,
22    datatypes::{DataType, Field, Schema},
23};
24
25/// Casts columns according to given (column name, target data type) pairs.
26///
27/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
28/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
29pub fn cast<S: AsRef<str>>(
30    map: &[(S, DataType)],
31    data: &RecordBatch,
32    allow_cast_fail: bool,
33) -> Result<RecordBatch> {
34    let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
35
36    let mut arrays = Vec::with_capacity(data.num_columns());
37
38    let cast_opt = CastOptions {
39        safe: allow_cast_fail,
40        ..Default::default()
41    };
42
43    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
44        let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
45
46        let col = match cast_target {
47            Some(tgt) => {
48                // allow precision loss for decimal types into floating point types
49                if matches!(
50                    col.data_type(),
51                    DataType::Decimal256(..) | DataType::Decimal128(..)
52                ) && tgt.1.is_floating()
53                {
54                    let string_col =
55                        arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
56                            .with_context(|| {
57                                format!(
58                            "Failed when casting column '{}' to string as intermediate step",
59                            field.name()
60                        )
61                            })?;
62                    Arc::new(
63                        arrow::compute::cast_with_options(&string_col, &tgt.1, &cast_opt)
64                            .with_context(|| {
65                                format!(
66                                    "Failed when casting column '{}' to {:?}",
67                                    field.name(),
68                                    tgt.1
69                                )
70                            })?,
71                    )
72                } else {
73                    Arc::new(
74                        arrow::compute::cast_with_options(col, &tgt.1, &cast_opt).with_context(
75                            || {
76                                format!(
77                                    "Failed when casting column '{}' from {:?} to {:?}",
78                                    field.name(),
79                                    col.data_type(),
80                                    tgt.1
81                                )
82                            },
83                        )?,
84                    )
85                }
86            }
87            None => col.clone(),
88        };
89
90        arrays.push(col);
91    }
92
93    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
94
95    Ok(batch)
96}
97
98/// Casts column types according to given (column name, target data type) pairs.
99pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
100    let mut fields = schema.fields().to_vec();
101
102    for f in &mut fields {
103        let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
104
105        if let Some(tgt) = cast_target {
106            *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
107        }
108    }
109
110    Ok(Schema::new(fields))
111}
112
113/// Casts all columns with from_type to to_type.
114///
115/// Returns error if casting a row fails and `allow_cast_fail` is set to `false`.
116/// Writes `null` to output if casting a row fails and `allow_cast_fail` is set to `true`.
117pub fn cast_by_type(
118    data: &RecordBatch,
119    from_type: &DataType,
120    to_type: &DataType,
121    allow_cast_fail: bool,
122) -> Result<RecordBatch> {
123    let schema =
124        cast_schema_by_type(data.schema_ref(), from_type, to_type).context("cast schema")?;
125
126    let mut arrays = Vec::with_capacity(data.num_columns());
127
128    let cast_opt = CastOptions {
129        safe: allow_cast_fail,
130        ..Default::default()
131    };
132
133    for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
134        let col = if col.data_type() == from_type {
135            // allow precision loss for decimal types into floating point types
136            if matches!(
137                col.data_type(),
138                DataType::Decimal256(..) | DataType::Decimal128(..)
139            ) && to_type.is_floating()
140            {
141                let string_col = arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
142                    .with_context(|| {
143                        format!(
144                            "Failed when casting_by_type column '{}' to string as intermediate step",
145                            field.name()
146                        )
147                    })?;
148                Arc::new(
149                    arrow::compute::cast_with_options(&string_col, to_type, &cast_opt)
150                        .with_context(|| {
151                            format!(
152                                "Failed when casting_by_type column '{}' to {:?}",
153                                field.name(),
154                                to_type
155                            )
156                        })?,
157                )
158            } else {
159                Arc::new(
160                    arrow::compute::cast_with_options(col, to_type, &cast_opt).with_context(
161                        || {
162                            format!(
163                                "Failed when casting_by_type column '{}' to {:?}",
164                                field.name(),
165                                to_type
166                            )
167                        },
168                    )?,
169                )
170            }
171        } else {
172            col.clone()
173        };
174
175        arrays.push(col);
176    }
177
178    let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
179
180    Ok(batch)
181}
182
183/// Casts columns with from_type to to_type
184pub fn cast_schema_by_type(
185    schema: &Schema,
186    from_type: &DataType,
187    to_type: &DataType,
188) -> Result<Schema> {
189    let mut fields = schema.fields().to_vec();
190
191    for f in &mut fields {
192        if f.data_type() == from_type {
193            *f = Arc::new(Field::new(f.name(), to_type.clone(), f.is_nullable()));
194        }
195    }
196
197    Ok(Schema::new(fields))
198}
199
200#[expect(
201    clippy::unwrap_used,
202    reason = "downcast is guaranteed by prior data type check"
203)]
204/// Encodes all Binary and LargeBinary columns in the batch to Base58 (Bitcoin alphabet) strings.
205pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
206    let schema = schema_binary_to_string(data.schema_ref());
207    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
208
209    for col in data.columns() {
210        if col.data_type() == &DataType::Binary {
211            columns.push(Arc::new(base58_encode_column(
212                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
213            )));
214        } else if col.data_type() == &DataType::LargeBinary {
215            columns.push(Arc::new(base58_encode_column(
216                col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
217            )));
218        } else {
219            columns.push(col.clone());
220        }
221    }
222
223    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
224}
225
226/// Encodes a single binary column to Base58 strings.
227pub fn base58_encode_column<I: OffsetSizeTrait>(
228    col: &GenericBinaryArray<I>,
229) -> GenericStringArray<I> {
230    let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
231        col.len(),
232        (col.value_data().len() + 2) * 2,
233    );
234
235    for v in col {
236        match v {
237            Some(v) => {
238                let v = bs58::encode(v)
239                    .with_alphabet(bs58::Alphabet::BITCOIN)
240                    .into_string();
241                arr.append_value(v);
242            }
243            None => arr.append_null(),
244        }
245    }
246
247    arr.finish()
248}
249
250#[expect(
251    clippy::unwrap_used,
252    reason = "downcast is guaranteed by prior data type check"
253)]
254/// Encodes all Binary and LargeBinary columns in the batch to hex strings.
255///
256/// When `PREFIXED` is `true`, output strings include the `0x` prefix.
257pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
258    let schema = schema_binary_to_string(data.schema_ref());
259    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
260
261    for col in data.columns() {
262        if col.data_type() == &DataType::Binary {
263            columns.push(Arc::new(hex_encode_column::<PREFIXED, i32>(
264                col.as_any().downcast_ref::<BinaryArray>().unwrap(),
265            )));
266        } else if col.data_type() == &DataType::LargeBinary {
267            columns.push(Arc::new(hex_encode_column::<PREFIXED, i64>(
268                col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
269            )));
270        } else {
271            columns.push(col.clone());
272        }
273    }
274
275    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
276}
277
278/// Encodes a single binary column to hex strings.
279///
280/// When `PREFIXED` is `true`, output strings include the `0x` prefix.
281pub fn hex_encode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
282    col: &GenericBinaryArray<I>,
283) -> GenericStringArray<I> {
284    let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
285        col.len(),
286        (col.value_data().len() + 2) * 2,
287    );
288
289    for v in col {
290        match v {
291            Some(v) => {
292                // TODO: avoid allocation here and use a scratch buffer to encode hex into or write to arrow buffer
293                // directly somehow.
294                let v = if PREFIXED {
295                    format!("0x{}", faster_hex::hex_string(v))
296                } else {
297                    faster_hex::hex_string(v)
298                };
299
300                arr.append_value(v);
301            }
302            None => arr.append_null(),
303        }
304    }
305
306    arr.finish()
307}
308
309/// Converts binary fields to string in the schema
310///
311/// Intended to be used with encode hex functions
312pub fn schema_binary_to_string(schema: &Schema) -> Schema {
313    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
314
315    for f in schema.fields() {
316        if f.data_type() == &DataType::Binary {
317            fields.push(Arc::new(Field::new(
318                f.name().clone(),
319                DataType::Utf8,
320                f.is_nullable(),
321            )));
322        } else if f.data_type() == &DataType::LargeBinary {
323            fields.push(Arc::new(Field::new(
324                f.name().clone(),
325                DataType::LargeUtf8,
326                f.is_nullable(),
327            )));
328        } else {
329            fields.push(f.clone());
330        }
331    }
332
333    Schema::new(fields)
334}
335
336/// Converts decimal256 fields to binary in the schema
337///
338/// Intended to be used with u256_to_binary function
339pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
340    let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
341
342    for f in schema.fields() {
343        if f.data_type() == &DataType::Decimal256(76, 0) {
344            fields.push(Arc::new(Field::new(
345                f.name().clone(),
346                DataType::Binary,
347                f.is_nullable(),
348            )));
349        } else {
350            fields.push(f.clone());
351        }
352    }
353
354    Schema::new(fields)
355}
356
357/// Decodes a Base58-encoded string column to binary.
358pub fn base58_decode_column<I: OffsetSizeTrait>(
359    col: &GenericStringArray<I>,
360) -> Result<GenericBinaryArray<I>> {
361    let mut arr =
362        builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
363
364    for v in col {
365        match v {
366            // TODO: this should be optimized by removing allocations if needed
367            Some(v) => {
368                let v = bs58::decode(v)
369                    .with_alphabet(bs58::Alphabet::BITCOIN)
370                    .into_vec()
371                    .context("bs58 decode")?;
372                arr.append_value(v);
373            }
374            None => arr.append_null(),
375        }
376    }
377
378    Ok(arr.finish())
379}
380
381/// Decodes a hex-encoded string column to binary.
382///
383/// When `PREFIXED` is `true`, expects and strips the `0x` prefix from each value.
384pub fn hex_decode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
385    col: &GenericStringArray<I>,
386) -> Result<GenericBinaryArray<I>> {
387    let mut arr =
388        builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
389
390    for v in col {
391        match v {
392            // TODO: this should be optimized by removing allocations if needed
393            Some(v) => {
394                let v = v.as_bytes();
395                let v = if PREFIXED {
396                    v.get(2..).context("index into prefix hex encoded value")?
397                } else {
398                    v
399                };
400
401                let len = v.len();
402                let mut dst = vec![0; (len + 1) / 2];
403
404                faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
405
406                arr.append_value(dst);
407            }
408            None => arr.append_null(),
409        }
410    }
411
412    Ok(arr.finish())
413}
414
415/// Converts a big-endian binary column (up to 32 bytes) to Decimal256(76,0).
416pub fn u256_column_from_binary<I: OffsetSizeTrait>(
417    col: &GenericBinaryArray<I>,
418) -> Result<Decimal256Array> {
419    let mut arr = builder::Decimal256Builder::with_capacity(col.len());
420
421    for v in col {
422        match v {
423            Some(v) => {
424                let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
425                let num = alloy_primitives::I256::try_from(num)
426                    .with_context(|| format!("u256 to i256. val was {num}"))?;
427
428                let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
429                arr.append_value(val);
430            }
431            None => arr.append_null(),
432        }
433    }
434
435    Ok(arr
436        .with_precision_and_scale(76, 0)
437        .context("set precision and scale for Decimal256")?
438        .finish())
439}
440
441/// Converts a Decimal256(76,0) column to trimmed big-endian binary.
442pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
443    let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
444
445    for v in col {
446        match v {
447            Some(v) => {
448                let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
449                let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
450                arr.append_value(num.to_be_bytes_trimmed_vec());
451            }
452            None => {
453                arr.append_null();
454            }
455        }
456    }
457
458    Ok(arr.finish())
459}
460
461/// Converts all Decimal256 (U256) columns in the batch to big endian binary values
462#[expect(
463    clippy::unwrap_used,
464    reason = "downcast is guaranteed by prior data type check"
465)]
466pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
467    let schema = schema_decimal256_to_binary(data.schema_ref());
468    let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
469
470    for (i, col) in data.columns().iter().enumerate() {
471        if col.data_type() == &DataType::Decimal256(76, 0) {
472            let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
473            let x = u256_column_to_binary(col)
474                .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
475            columns.push(Arc::new(x));
476        } else {
477            columns.push(col.clone());
478        }
479    }
480
481    RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use arrow::datatypes::DataType;
488    use std::fs::File;
489
490    #[test]
491    #[ignore]
492    fn test_cast() {
493        use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
494
495        let builder =
496            ParquetRecordBatchReaderBuilder::try_new(File::open("data.parquet").unwrap()).unwrap();
497        let mut reader = builder.build().unwrap();
498        let table = reader.next().unwrap().unwrap();
499
500        let type_mappings = vec![
501            ("amount0In", DataType::Decimal128(15, 0)),
502            ("amount1In", DataType::Float32),
503            ("amount0Out", DataType::Float64),
504            ("amount1Out", DataType::Decimal128(38, 0)),
505            ("timestamp", DataType::Int64),
506        ];
507
508        let result = cast(&type_mappings, &table, true).unwrap();
509
510        // Save the filtered instructions to a new parquet file
511        let mut file = File::create("result.parquet").unwrap();
512        let mut writer =
513            parquet::arrow::ArrowWriter::try_new(&mut file, result.schema(), None).unwrap();
514        writer.write(&result).unwrap();
515        writer.close().unwrap();
516    }
517}