1#![allow(clippy::manual_div_ceil)]
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use anyhow::{Context, Result};
17use arrow::{
18 array::{
19 builder, make_array, Array, BinaryArray, Decimal128Array, Decimal256Array,
20 FixedSizeListArray, GenericBinaryArray, GenericStringArray, Int32Array, LargeBinaryArray,
21 OffsetSizeTrait, RecordBatch, StructArray,
22 },
23 buffer::NullBuffer,
24 compute::CastOptions,
25 datatypes::{DataType, Field, Schema},
26};
27
28pub fn cast<S: AsRef<str>>(
33 map: &[(S, DataType)],
34 data: &RecordBatch,
35 allow_cast_fail: bool,
36) -> Result<RecordBatch> {
37 let schema = cast_schema(map, data.schema_ref()).context("cast schema")?;
38
39 let mut arrays = Vec::with_capacity(data.num_columns());
40
41 let cast_opt = CastOptions {
42 safe: allow_cast_fail,
43 ..Default::default()
44 };
45
46 for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
47 let cast_target = map.iter().find(|x| x.0.as_ref() == field.name());
48
49 let col = match cast_target {
50 Some(tgt) => {
51 if matches!(
53 col.data_type(),
54 DataType::Decimal256(..) | DataType::Decimal128(..)
55 ) && tgt.1.is_floating()
56 {
57 let string_col =
58 arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
59 .with_context(|| {
60 format!(
61 "Failed when casting column '{}' to string as intermediate step",
62 field.name()
63 )
64 })?;
65 Arc::new(
66 arrow::compute::cast_with_options(&string_col, &tgt.1, &cast_opt)
67 .with_context(|| {
68 format!(
69 "Failed when casting column '{}' to {:?}",
70 field.name(),
71 tgt.1
72 )
73 })?,
74 )
75 } else {
76 Arc::new(
77 arrow::compute::cast_with_options(col, &tgt.1, &cast_opt).with_context(
78 || {
79 format!(
80 "Failed when casting column '{}' from {:?} to {:?}",
81 field.name(),
82 col.data_type(),
83 tgt.1
84 )
85 },
86 )?,
87 )
88 }
89 }
90 None => col.clone(),
91 };
92
93 arrays.push(col);
94 }
95
96 let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
97
98 Ok(batch)
99}
100
101pub fn cast_schema<S: AsRef<str>>(map: &[(S, DataType)], schema: &Schema) -> Result<Schema> {
103 let mut fields = schema.fields().to_vec();
104
105 for f in &mut fields {
106 let cast_target = map.iter().find(|x| x.0.as_ref() == f.name());
107
108 if let Some(tgt) = cast_target {
109 *f = Arc::new(Field::new(f.name(), tgt.1.clone(), f.is_nullable()));
110 }
111 }
112
113 Ok(Schema::new(fields))
114}
115
116pub fn cast_by_type(
121 data: &RecordBatch,
122 from_type: &DataType,
123 to_type: &DataType,
124 allow_cast_fail: bool,
125) -> Result<RecordBatch> {
126 let schema =
127 cast_schema_by_type(data.schema_ref(), from_type, to_type).context("cast schema")?;
128
129 let mut arrays = Vec::with_capacity(data.num_columns());
130
131 let cast_opt = CastOptions {
132 safe: allow_cast_fail,
133 ..Default::default()
134 };
135
136 for (col, field) in data.columns().iter().zip(data.schema_ref().fields().iter()) {
137 let col = if col.data_type() == from_type {
138 if matches!(
140 col.data_type(),
141 DataType::Decimal256(..) | DataType::Decimal128(..)
142 ) && to_type.is_floating()
143 {
144 let string_col = arrow::compute::cast_with_options(col, &DataType::Utf8, &cast_opt)
145 .with_context(|| {
146 format!(
147 "Failed when casting_by_type column '{}' to string as intermediate step",
148 field.name()
149 )
150 })?;
151 Arc::new(
152 arrow::compute::cast_with_options(&string_col, to_type, &cast_opt)
153 .with_context(|| {
154 format!(
155 "Failed when casting_by_type column '{}' to {:?}",
156 field.name(),
157 to_type
158 )
159 })?,
160 )
161 } else {
162 Arc::new(
163 arrow::compute::cast_with_options(col, to_type, &cast_opt).with_context(
164 || {
165 format!(
166 "Failed when casting_by_type column '{}' to {:?}",
167 field.name(),
168 to_type
169 )
170 },
171 )?,
172 )
173 }
174 } else {
175 col.clone()
176 };
177
178 arrays.push(col);
179 }
180
181 let batch = RecordBatch::try_new(Arc::new(schema), arrays).context("construct record batch")?;
182
183 Ok(batch)
184}
185
186pub fn cast_schema_by_type(
188 schema: &Schema,
189 from_type: &DataType,
190 to_type: &DataType,
191) -> Result<Schema> {
192 let mut fields = schema.fields().to_vec();
193
194 for f in &mut fields {
195 if f.data_type() == from_type {
196 *f = Arc::new(Field::new(f.name(), to_type.clone(), f.is_nullable()));
197 }
198 }
199
200 Ok(Schema::new(fields))
201}
202
203#[expect(
204 clippy::unwrap_used,
205 reason = "downcast is guaranteed by prior data type check"
206)]
207pub fn base58_encode(data: &RecordBatch) -> Result<RecordBatch> {
209 let schema = schema_binary_to_string(data.schema_ref());
210 let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
211
212 for col in data.columns() {
213 if col.data_type() == &DataType::Binary {
214 columns.push(Arc::new(base58_encode_column(
215 col.as_any().downcast_ref::<BinaryArray>().unwrap(),
216 )));
217 } else if col.data_type() == &DataType::LargeBinary {
218 columns.push(Arc::new(base58_encode_column(
219 col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
220 )));
221 } else {
222 columns.push(col.clone());
223 }
224 }
225
226 RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
227}
228
229pub fn base58_encode_column<I: OffsetSizeTrait>(
231 col: &GenericBinaryArray<I>,
232) -> GenericStringArray<I> {
233 let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
234 col.len(),
235 (col.value_data().len() + 2) * 2,
236 );
237
238 for v in col {
239 match v {
240 Some(v) => {
241 let v = bs58::encode(v)
242 .with_alphabet(bs58::Alphabet::BITCOIN)
243 .into_string();
244 arr.append_value(v);
245 }
246 None => arr.append_null(),
247 }
248 }
249
250 arr.finish()
251}
252
253#[expect(
254 clippy::unwrap_used,
255 reason = "downcast is guaranteed by prior data type check"
256)]
257pub fn hex_encode<const PREFIXED: bool>(data: &RecordBatch) -> Result<RecordBatch> {
261 let schema = schema_binary_to_string(data.schema_ref());
262 let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
263
264 for col in data.columns() {
265 if col.data_type() == &DataType::Binary {
266 columns.push(Arc::new(hex_encode_column::<PREFIXED, i32>(
267 col.as_any().downcast_ref::<BinaryArray>().unwrap(),
268 )));
269 } else if col.data_type() == &DataType::LargeBinary {
270 columns.push(Arc::new(hex_encode_column::<PREFIXED, i64>(
271 col.as_any().downcast_ref::<LargeBinaryArray>().unwrap(),
272 )));
273 } else {
274 columns.push(col.clone());
275 }
276 }
277
278 RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
279}
280
281pub fn hex_encode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
285 col: &GenericBinaryArray<I>,
286) -> GenericStringArray<I> {
287 let mut arr = builder::GenericStringBuilder::<I>::with_capacity(
288 col.len(),
289 (col.value_data().len() + 2) * 2,
290 );
291
292 for v in col {
293 match v {
294 Some(v) => {
295 let v = if PREFIXED {
298 format!("0x{}", faster_hex::hex_string(v))
299 } else {
300 faster_hex::hex_string(v)
301 };
302
303 arr.append_value(v);
304 }
305 None => arr.append_null(),
306 }
307 }
308
309 arr.finish()
310}
311
312pub fn schema_binary_to_string(schema: &Schema) -> Schema {
316 let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
317
318 for f in schema.fields() {
319 if f.data_type() == &DataType::Binary {
320 fields.push(Arc::new(Field::new(
321 f.name().clone(),
322 DataType::Utf8,
323 f.is_nullable(),
324 )));
325 } else if f.data_type() == &DataType::LargeBinary {
326 fields.push(Arc::new(Field::new(
327 f.name().clone(),
328 DataType::LargeUtf8,
329 f.is_nullable(),
330 )));
331 } else {
332 fields.push(f.clone());
333 }
334 }
335
336 Schema::new(fields)
337}
338
339pub fn schema_decimal256_to_binary(schema: &Schema) -> Schema {
343 let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
344
345 for f in schema.fields() {
346 if f.data_type() == &DataType::Decimal256(76, 0) {
347 fields.push(Arc::new(Field::new(
348 f.name().clone(),
349 DataType::Binary,
350 f.is_nullable(),
351 )));
352 } else {
353 fields.push(f.clone());
354 }
355 }
356
357 Schema::new(fields)
358}
359
360pub fn base58_decode_column<I: OffsetSizeTrait>(
362 col: &GenericStringArray<I>,
363) -> Result<GenericBinaryArray<I>> {
364 let mut arr =
365 builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
366
367 for v in col {
368 match v {
369 Some(v) => {
371 let v = bs58::decode(v)
372 .with_alphabet(bs58::Alphabet::BITCOIN)
373 .into_vec()
374 .context("bs58 decode")?;
375 arr.append_value(v);
376 }
377 None => arr.append_null(),
378 }
379 }
380
381 Ok(arr.finish())
382}
383
384pub fn hex_decode_column<const PREFIXED: bool, I: OffsetSizeTrait>(
388 col: &GenericStringArray<I>,
389) -> Result<GenericBinaryArray<I>> {
390 let mut arr =
391 builder::GenericBinaryBuilder::<I>::with_capacity(col.len(), col.value_data().len() / 2);
392
393 for v in col {
394 match v {
395 Some(v) => {
397 let v = v.as_bytes();
398 let v = if PREFIXED {
399 v.get(2..).context("index into prefix hex encoded value")?
400 } else {
401 v
402 };
403
404 let len = v.len();
405 let mut dst = vec![0; (len + 1) / 2];
406
407 faster_hex::hex_decode(v, &mut dst).context("hex decode")?;
408
409 arr.append_value(dst);
410 }
411 None => arr.append_null(),
412 }
413 }
414
415 Ok(arr.finish())
416}
417
418pub fn u256_column_from_binary<I: OffsetSizeTrait>(
420 col: &GenericBinaryArray<I>,
421) -> Result<Decimal256Array> {
422 let mut arr = builder::Decimal256Builder::with_capacity(col.len());
423
424 for v in col {
425 match v {
426 Some(v) => {
427 let num = ruint::aliases::U256::try_from_be_slice(v).context("parse ruint u256")?;
428 let num = alloy_primitives::I256::try_from(num)
429 .with_context(|| format!("u256 to i256. val was {num}"))?;
430
431 let val = arrow::datatypes::i256::from_be_bytes(num.to_be_bytes::<32>());
432 arr.append_value(val);
433 }
434 None => arr.append_null(),
435 }
436 }
437
438 Ok(arr
439 .with_precision_and_scale(76, 0)
440 .context("set precision and scale for Decimal256")?
441 .finish())
442}
443
444pub fn u256_column_to_binary(col: &Decimal256Array) -> Result<BinaryArray> {
446 let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
447
448 for v in col {
449 match v {
450 Some(v) => {
451 let num = alloy_primitives::I256::from_be_bytes::<32>(v.to_be_bytes());
452 let num = ruint::aliases::U256::try_from(num).context("convert i256 to u256")?;
453 arr.append_value(num.to_be_bytes_trimmed_vec());
454 }
455 None => {
456 arr.append_null();
457 }
458 }
459 }
460
461 Ok(arr.finish())
462}
463
464#[expect(
466 clippy::unwrap_used,
467 reason = "downcast is guaranteed by prior data type check"
468)]
469pub fn u256_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
470 let schema = schema_decimal256_to_binary(data.schema_ref());
471 let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
472
473 for (i, col) in data.columns().iter().enumerate() {
474 if col.data_type() == &DataType::Decimal256(76, 0) {
475 let col = col.as_any().downcast_ref::<Decimal256Array>().unwrap();
476 let x = u256_column_to_binary(col)
477 .with_context(|| format!("col {} to binary", data.schema().fields()[i].name()))?;
478 columns.push(Arc::new(x));
479 } else {
480 columns.push(col.clone());
481 }
482 }
483
484 RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
485}
486
487pub fn decimal256_to_be32(col: &Decimal256Array) -> BinaryArray {
494 let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 32);
495
496 for v in col {
497 match v {
498 Some(v) => arr.append_value(v.to_be_bytes()),
499 None => arr.append_null(),
500 }
501 }
502
503 arr.finish()
504}
505
506pub fn decimal128_to_be16(col: &Decimal128Array) -> BinaryArray {
512 let mut arr = builder::BinaryBuilder::with_capacity(col.len(), col.len() * 16);
513
514 for v in col {
515 match v {
516 Some(v) => arr.append_value(v.to_be_bytes()),
517 None => arr.append_null(),
518 }
519 }
520
521 arr.finish()
522}
523
524#[expect(
532 clippy::unwrap_used,
533 reason = "downcast is guaranteed by prior data type check"
534)]
535pub fn large_ints_to_binary(data: &RecordBatch) -> Result<RecordBatch> {
536 let schema = schema_large_int_to_binary(data.schema_ref());
537 let mut columns = Vec::<Arc<dyn Array>>::with_capacity(data.columns().len());
538
539 for col in data.columns() {
540 let new: Arc<dyn Array> = match col.data_type() {
541 DataType::Decimal256(_, 0) => Arc::new(decimal256_to_be32(
542 col.as_any().downcast_ref::<Decimal256Array>().unwrap(),
543 )),
544 DataType::Decimal128(_, 0) => Arc::new(decimal128_to_be16(
545 col.as_any().downcast_ref::<Decimal128Array>().unwrap(),
546 )),
547 _ => col.clone(),
548 };
549 columns.push(new);
550 }
551
552 RecordBatch::try_new(Arc::new(schema), columns).context("construct arrow batch")
553}
554
555pub fn schema_large_int_to_binary(schema: &Schema) -> Schema {
558 let mut fields = Vec::<Arc<Field>>::with_capacity(schema.fields().len());
559
560 for f in schema.fields() {
561 match f.data_type() {
562 DataType::Decimal256(_, 0) | DataType::Decimal128(_, 0) => {
563 fields.push(Arc::new(Field::new(
564 f.name().clone(),
565 DataType::Binary,
566 f.is_nullable(),
567 )));
568 }
569 _ => fields.push(f.clone()),
570 }
571 }
572
573 Schema::new(fields)
574}
575
576pub fn flatten_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
588 let mut out_fields: Vec<Arc<Field>> = Vec::new();
589 let mut out_arrays: Vec<Arc<dyn Array>> = Vec::new();
590
591 for (field, col) in batch.schema().fields().iter().zip(batch.columns()) {
592 expand_column(field.name(), col, &mut out_fields, &mut out_arrays)?;
593 }
594
595 resolve_name_collisions(&mut out_fields);
596
597 RecordBatch::try_new(Arc::new(Schema::new(out_fields)), out_arrays)
598 .context("construct flattened batch")
599}
600
601pub fn flatten_schema(schema: &Schema) -> Schema {
603 let mut out_fields: Vec<Arc<Field>> = Vec::new();
604
605 for field in schema.fields() {
606 expand_field(field.name(), field.data_type(), &mut out_fields);
607 }
608
609 resolve_name_collisions(&mut out_fields);
610
611 Schema::new(out_fields)
612}
613
614fn expand_column(
615 name: &str,
616 col: &Arc<dyn Array>,
617 out_fields: &mut Vec<Arc<Field>>,
618 out_arrays: &mut Vec<Arc<dyn Array>>,
619) -> Result<()> {
620 match col.data_type() {
621 DataType::Struct(inner_fields) => {
622 let struct_arr = col
623 .as_any()
624 .downcast_ref::<StructArray>()
625 .context("downcast to StructArray")?;
626 for (i, inner_field) in inner_fields.iter().enumerate() {
627 let child = struct_arr.column(i).clone();
628 let child = propagate_nulls(col.nulls(), child);
629 expand_column(
630 &format!("{}.{}", name, inner_field.name()),
631 &child,
632 out_fields,
633 out_arrays,
634 )?;
635 }
636 }
637 DataType::FixedSizeList(inner_field, n)
638 if matches!(inner_field.data_type(), DataType::Struct(_)) =>
639 {
640 let n = usize::try_from(*n).context("FixedSizeList size must be non-negative")?;
641 let list_arr = col
642 .as_any()
643 .downcast_ref::<FixedSizeListArray>()
644 .context("downcast to FixedSizeListArray")?;
645 let values = list_arr.values();
646 let num_rows = list_arr.len();
647
648 for i in 0..n {
649 let indices: Int32Array = (0..num_rows)
650 .map(|r| i32::try_from(r * n + i).context("index overflows i32"))
651 .collect::<Result<Vec<_>>>()?
652 .into();
653 let element = arrow::compute::take(values.as_ref(), &indices, None)
654 .context("take element from FixedSizeList")?;
655 let element = propagate_nulls(list_arr.nulls(), element);
656 expand_column(&format!("{name}.{i}"), &element, out_fields, out_arrays)?;
657 }
658 }
659 DataType::List(inner_field) if matches!(inner_field.data_type(), DataType::Struct(_)) => {
660 let str_col = arrow::compute::cast_with_options(
661 col.as_ref(),
662 &DataType::Utf8,
663 &CastOptions {
664 safe: true,
665 ..Default::default()
666 },
667 )
668 .context("cast List<Struct> to Utf8")?;
669 out_fields.push(Arc::new(Field::new(name, DataType::Utf8, true)));
670 out_arrays.push(str_col);
671 }
672 _ => {
673 out_fields.push(Arc::new(Field::new(name, col.data_type().clone(), true)));
674 out_arrays.push(col.clone());
675 }
676 }
677 Ok(())
678}
679
680fn expand_field(name: &str, dtype: &DataType, out: &mut Vec<Arc<Field>>) {
681 match dtype {
682 DataType::Struct(inner_fields) => {
683 for f in inner_fields {
684 expand_field(&format!("{}.{}", name, f.name()), f.data_type(), out);
685 }
686 }
687 DataType::FixedSizeList(inner_field, n)
688 if matches!(inner_field.data_type(), DataType::Struct(_)) =>
689 {
690 for i in 0..usize::try_from(*n).unwrap_or(0) {
691 expand_field(&format!("{name}.{i}"), inner_field.data_type(), out);
692 }
693 }
694 DataType::List(inner_field) if matches!(inner_field.data_type(), DataType::Struct(_)) => {
695 out.push(Arc::new(Field::new(name, DataType::Utf8, true)));
696 }
697 _ => {
698 out.push(Arc::new(Field::new(name, dtype.clone(), true)));
699 }
700 }
701}
702
703fn propagate_nulls(parent_nulls: Option<&NullBuffer>, col: Arc<dyn Array>) -> Arc<dyn Array> {
706 let Some(parent_nulls) = parent_nulls else {
707 return col;
708 };
709 let merged = NullBuffer::union(Some(parent_nulls), col.nulls());
710 let null_count = merged.as_ref().map_or(0, NullBuffer::null_count);
711 let data = col.into_data();
712 let new_data = unsafe {
714 data.into_builder()
715 .null_bit_buffer(merged.map(|nb| nb.into_inner().into_inner()))
716 .null_count(null_count)
717 .build_unchecked()
718 };
719 make_array(new_data)
720}
721
722fn resolve_name_collisions(fields: &mut [Arc<Field>]) {
725 let mut counts: HashMap<String, usize> = HashMap::new();
726 for f in fields.iter() {
727 *counts.entry(f.name().clone()).or_insert(0) += 1;
728 }
729 let mut seen: HashMap<String, usize> = HashMap::new();
730 for field in fields.iter_mut() {
731 let name = field.name().clone();
732 if *counts.get(&name).unwrap_or(&0) > 1 {
733 let idx = seen.entry(name.clone()).or_insert(0);
734 if *idx > 0 {
735 let new_name = format!("{name}_{idx}");
736 *field = Arc::new(Field::new(
737 new_name,
738 field.data_type().clone(),
739 field.is_nullable(),
740 ));
741 }
742 *idx += 1;
743 }
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use arrow::datatypes::DataType;
751 use std::fs::File;
752
753 #[test]
754 #[ignore]
755 fn test_cast() {
756 use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
757
758 let builder =
759 ParquetRecordBatchReaderBuilder::try_new(File::open("data.parquet").unwrap()).unwrap();
760 let mut reader = builder.build().unwrap();
761 let table = reader.next().unwrap().unwrap();
762
763 let type_mappings = vec![
764 ("amount0In", DataType::Decimal128(15, 0)),
765 ("amount1In", DataType::Float32),
766 ("amount0Out", DataType::Float64),
767 ("amount1Out", DataType::Decimal128(38, 0)),
768 ("timestamp", DataType::Int64),
769 ];
770
771 let result = cast(&type_mappings, &table, true).unwrap();
772
773 let mut file = File::create("result.parquet").unwrap();
775 let mut writer =
776 parquet::arrow::ArrowWriter::try_new(&mut file, result.schema(), None).unwrap();
777 writer.write(&result).unwrap();
778 writer.close().unwrap();
779 }
780}