datafusion_flight_sql_server/
service.rs

1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow::{
4    array::{ArrayRef, RecordBatch, StringArray},
5    compute::concat_batches,
6    datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
7    error::ArrowError,
8    ipc::{
9        reader::StreamReader,
10        writer::{IpcWriteOptions, StreamWriter},
11    },
12};
13use arrow_flight::{
14    decode::{DecodedPayload, FlightDataDecoder},
15    sql::{
16        self,
17        server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
18        ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
19        ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
20        ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
21        ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
22        ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
23        CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
24        CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
25        CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
26        CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
27        CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
28        TicketStatementQuery,
29    },
30};
31use arrow_flight::{
32    encode::FlightDataEncoderBuilder,
33    error::FlightError,
34    flight_service_server::{FlightService, FlightServiceServer},
35    Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
36    IpcMessage, SchemaAsIpc, Ticket,
37};
38use datafusion::{
39    common::{arrow::datatypes::Schema, ParamValues},
40    dataframe::DataFrame,
41    datasource::TableType,
42    error::{DataFusionError, Result as DataFusionResult},
43    execution::context::{SQLOptions, SessionContext, SessionState},
44    logical_expr::LogicalPlan,
45    physical_plan::SendableRecordBatchStream,
46    scalar::ScalarValue,
47};
48use datafusion_substrait::{
49    logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51use futures::{Stream, StreamExt, TryStreamExt};
52use log::info;
53use once_cell::sync::Lazy;
54use prost::bytes::Bytes;
55use prost::Message;
56use tonic::transport::Server;
57use tonic::{Request, Response, Status, Streaming};
58
59use super::session::{SessionStateProvider, StaticSessionStateProvider};
60use super::state::{CommandTicket, QueryHandle};
61
62type Result<T, E = Status> = std::result::Result<T, E>;
63
64/// FlightSqlService is a basic stateless FlightSqlService implementation.
65pub struct FlightSqlService {
66    provider: Box<dyn SessionStateProvider>,
67    sql_options: Option<SQLOptions>,
68}
69
70impl FlightSqlService {
71    /// Creates a new FlightSqlService with a static SessionState.
72    pub fn new(state: SessionState) -> Self {
73        Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
74    }
75
76    /// Creates a new FlightSqlService with a SessionStateProvider.
77    pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
78        Self {
79            provider,
80            sql_options: None,
81        }
82    }
83
84    /// Replaces the sql_options with the provided options.
85    /// These options are used to verify all SQL queries.
86    /// When None the default [`SQLOptions`] are used.
87    pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
88        Self {
89            sql_options: Some(sql_options),
90            ..self
91        }
92    }
93
94    // Federate substrait plans instead of SQL
95    // pub fn with_substrait() -> Self {
96    // TODO: Substrait federation
97    // }
98
99    // Serves straightforward on the specified address.
100    pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
101        let addr = addr.parse()?;
102        info!("Listening on {addr:?}");
103
104        let svc = FlightServiceServer::new(self);
105
106        Ok(Server::builder().add_service(svc).serve(addr).await?)
107    }
108
109    async fn new_context<T>(
110        &self,
111        request: Request<T>,
112    ) -> Result<(Request<T>, FlightSqlSessionContext)> {
113        let (metadata, extensions, msg) = request.into_parts();
114        let inspect_request = Request::from_parts(metadata, extensions, ());
115
116        let state = self.provider.new_context(&inspect_request).await?;
117        let ctx = SessionContext::new_with_state(state);
118
119        let (metadata, extensions, _) = inspect_request.into_parts();
120        Ok((
121            Request::from_parts(metadata, extensions, msg),
122            FlightSqlSessionContext {
123                inner: ctx,
124                sql_options: self.sql_options,
125            },
126        ))
127    }
128}
129
130/// The schema for GetTableTypes
131static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
132    //TODO: Move this into arrow-flight itself, similar to the builder pattern for CommandGetCatalogs and CommandGetDbSchemas
133    Arc::new(Schema::new(vec![Field::new(
134        "table_type",
135        DataType::Utf8,
136        false,
137    )]))
138});
139
140struct FlightSqlSessionContext {
141    inner: SessionContext,
142    sql_options: Option<SQLOptions>,
143}
144
145impl FlightSqlSessionContext {
146    async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
147        let plan = self.inner.state().create_logical_plan(sql).await?;
148        let verifier = self.sql_options.unwrap_or_default();
149        verifier.verify_plan(&plan)?;
150        Ok(plan)
151    }
152
153    async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
154        let plan = self.sql_to_logical_plan(sql).await?;
155        self.execute_logical_plan(plan).await
156    }
157
158    async fn execute_logical_plan(
159        &self,
160        plan: LogicalPlan,
161    ) -> DataFusionResult<SendableRecordBatchStream> {
162        self.inner
163            .execute_logical_plan(plan)
164            .await?
165            .execute_stream()
166            .await
167    }
168}
169
170#[tonic::async_trait]
171impl ArrowFlightSqlService for FlightSqlService {
172    type FlightService = FlightSqlService;
173
174    async fn do_handshake(
175        &self,
176        _request: Request<Streaming<HandshakeRequest>>,
177    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
178        info!("do_handshake");
179        // Favor middleware over handshake
180        // https://github.com/apache/arrow/issues/23836
181        // https://github.com/apache/arrow/issues/25848
182        Err(Status::unimplemented("handshake is not supported"))
183    }
184
185    async fn do_get_fallback(
186        &self,
187        request: Request<Ticket>,
188        _message: Any,
189    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
190        let (request, ctx) = self.new_context(request).await?;
191
192        let ticket = CommandTicket::try_decode(request.into_inner().ticket)
193            .map_err(flight_error_to_status)?;
194
195        match ticket.command {
196            sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
197                // print!("Query: {query}\n");
198
199                let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
200                let arrow_schema = stream.schema();
201                let arrow_stream = stream.map(|i| {
202                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
203                    Ok(batch)
204                });
205
206                let flight_data_stream = FlightDataEncoderBuilder::new()
207                    .with_schema(arrow_schema)
208                    .build(arrow_stream)
209                    .map_err(flight_error_to_status)
210                    .boxed();
211
212                Ok(Response::new(flight_data_stream))
213            }
214            sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
215                prepared_statement_handle,
216            }) => {
217                let handle = QueryHandle::try_decode(prepared_statement_handle)?;
218
219                let mut plan = ctx
220                    .sql_to_logical_plan(handle.query())
221                    .await
222                    .map_err(df_error_to_status)?;
223
224                if let Some(param_values) =
225                    decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
226                {
227                    plan = plan
228                        .with_param_values(param_values)
229                        .map_err(df_error_to_status)?;
230                }
231
232                let stream = ctx
233                    .execute_logical_plan(plan)
234                    .await
235                    .map_err(df_error_to_status)?;
236                let arrow_schema = stream.schema();
237                let arrow_stream = stream.map(|i| {
238                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
239                    Ok(batch)
240                });
241
242                let flight_data_stream = FlightDataEncoderBuilder::new()
243                    .with_schema(arrow_schema)
244                    .build(arrow_stream)
245                    .map_err(flight_error_to_status)
246                    .boxed();
247
248                Ok(Response::new(flight_data_stream))
249            }
250            sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
251                plan,
252                ..
253            }) => {
254                let substrait_bytes = &plan
255                    .ok_or(Status::invalid_argument(
256                        "Expected substrait plan, found None",
257                    ))?
258                    .plan;
259
260                let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
261
262                let state = ctx.inner.state();
263                let df = DataFrame::new(state, plan);
264
265                let stream = df.execute_stream().await.map_err(df_error_to_status)?;
266                let arrow_schema = stream.schema();
267                let arrow_stream = stream.map(|i| {
268                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
269                    Ok(batch)
270                });
271
272                let flight_data_stream = FlightDataEncoderBuilder::new()
273                    .with_schema(arrow_schema)
274                    .build(arrow_stream)
275                    .map_err(flight_error_to_status)
276                    .boxed();
277
278                Ok(Response::new(flight_data_stream))
279            }
280            _ => {
281                return Err(Status::internal(format!(
282                    "statement handle not found: {:?}",
283                    ticket.command
284                )));
285            }
286        }
287    }
288
289    async fn get_flight_info_statement(
290        &self,
291        query: CommandStatementQuery,
292        request: Request<FlightDescriptor>,
293    ) -> Result<Response<FlightInfo>> {
294        let (request, ctx) = self.new_context(request).await?;
295
296        let sql = &query.query;
297        info!("get_flight_info_statement with query={sql}");
298
299        let flight_descriptor = request.into_inner();
300
301        let plan = ctx
302            .sql_to_logical_plan(sql)
303            .await
304            .map_err(df_error_to_status)?;
305
306        let dataset_schema = get_schema_for_plan(&plan);
307
308        // Form the response ticket (that the client will pass back to DoGet)
309        let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
310            .try_encode()
311            .map_err(flight_error_to_status)?;
312
313        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
314
315        let flight_info = FlightInfo::new()
316            .with_endpoint(endpoint)
317            // return descriptor we were passed
318            .with_descriptor(flight_descriptor)
319            .try_with_schema(dataset_schema.as_ref())
320            .map_err(arrow_error_to_status)?;
321
322        Ok(Response::new(flight_info))
323    }
324
325    async fn get_flight_info_substrait_plan(
326        &self,
327        query: CommandStatementSubstraitPlan,
328        request: Request<FlightDescriptor>,
329    ) -> Result<Response<FlightInfo>> {
330        info!("get_flight_info_substrait_plan");
331        let (request, ctx) = self.new_context(request).await?;
332
333        let substrait_bytes = &query
334            .plan
335            .as_ref()
336            .ok_or(Status::invalid_argument(
337                "Expected substrait plan, found None",
338            ))?
339            .plan;
340
341        let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
342
343        let flight_descriptor = request.into_inner();
344
345        let dataset_schema = get_schema_for_plan(&plan);
346
347        // Form the response ticket (that the client will pass back to DoGet)
348        let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
349            .try_encode()
350            .map_err(flight_error_to_status)?;
351
352        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
353
354        let flight_info = FlightInfo::new()
355            .with_endpoint(endpoint)
356            // return descriptor we were passed
357            .with_descriptor(flight_descriptor)
358            .try_with_schema(dataset_schema.as_ref())
359            .map_err(arrow_error_to_status)?;
360
361        Ok(Response::new(flight_info))
362    }
363
364    async fn get_flight_info_prepared_statement(
365        &self,
366        cmd: CommandPreparedStatementQuery,
367        request: Request<FlightDescriptor>,
368    ) -> Result<Response<FlightInfo>> {
369        let (request, ctx) = self.new_context(request).await?;
370
371        let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
372            .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
373
374        info!("get_flight_info_prepared_statement with handle={handle}");
375
376        let flight_descriptor = request.into_inner();
377
378        let sql = handle.query();
379        let plan = ctx
380            .sql_to_logical_plan(sql)
381            .await
382            .map_err(df_error_to_status)?;
383
384        let dataset_schema = get_schema_for_plan(&plan);
385
386        // Form the response ticket (that the client will pass back to DoGet)
387        let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
388            .try_encode()
389            .map_err(flight_error_to_status)?;
390
391        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
392
393        let flight_info = FlightInfo::new()
394            .with_endpoint(endpoint)
395            // return descriptor we were passed
396            .with_descriptor(flight_descriptor)
397            .try_with_schema(dataset_schema.as_ref())
398            .map_err(arrow_error_to_status)?;
399
400        Ok(Response::new(flight_info))
401    }
402
403    async fn get_flight_info_catalogs(
404        &self,
405        query: CommandGetCatalogs,
406        request: Request<FlightDescriptor>,
407    ) -> Result<Response<FlightInfo>> {
408        info!("get_flight_info_catalogs");
409        let (request, _ctx) = self.new_context(request).await?;
410
411        let flight_descriptor = request.into_inner();
412        let ticket = Ticket {
413            ticket: query.as_any().encode_to_vec().into(),
414        };
415        let endpoint = FlightEndpoint::new().with_ticket(ticket);
416
417        let flight_info = FlightInfo::new()
418            .try_with_schema(&query.into_builder().schema())
419            .map_err(arrow_error_to_status)?
420            .with_endpoint(endpoint)
421            .with_descriptor(flight_descriptor);
422
423        Ok(Response::new(flight_info))
424    }
425
426    async fn get_flight_info_schemas(
427        &self,
428        query: CommandGetDbSchemas,
429        request: Request<FlightDescriptor>,
430    ) -> Result<Response<FlightInfo>> {
431        info!("get_flight_info_schemas");
432        let (request, _ctx) = self.new_context(request).await?;
433        let flight_descriptor = request.into_inner();
434        let ticket = Ticket {
435            ticket: query.as_any().encode_to_vec().into(),
436        };
437        let endpoint = FlightEndpoint::new().with_ticket(ticket);
438
439        let flight_info = FlightInfo::new()
440            .try_with_schema(&query.into_builder().schema())
441            .map_err(arrow_error_to_status)?
442            .with_endpoint(endpoint)
443            .with_descriptor(flight_descriptor);
444
445        Ok(Response::new(flight_info))
446    }
447
448    async fn get_flight_info_tables(
449        &self,
450        query: CommandGetTables,
451        request: Request<FlightDescriptor>,
452    ) -> Result<Response<FlightInfo>> {
453        info!("get_flight_info_tables");
454        let (request, _ctx) = self.new_context(request).await?;
455
456        let flight_descriptor = request.into_inner();
457        let ticket = Ticket {
458            ticket: query.as_any().encode_to_vec().into(),
459        };
460        let endpoint = FlightEndpoint::new().with_ticket(ticket);
461
462        let flight_info = FlightInfo::new()
463            .try_with_schema(&query.into_builder().schema())
464            .map_err(arrow_error_to_status)?
465            .with_endpoint(endpoint)
466            .with_descriptor(flight_descriptor);
467
468        Ok(Response::new(flight_info))
469    }
470
471    async fn get_flight_info_table_types(
472        &self,
473        query: CommandGetTableTypes,
474        request: Request<FlightDescriptor>,
475    ) -> Result<Response<FlightInfo>> {
476        info!("get_flight_info_table_types");
477        let (request, _ctx) = self.new_context(request).await?;
478
479        let flight_descriptor = request.into_inner();
480        let ticket = Ticket {
481            ticket: query.as_any().encode_to_vec().into(),
482        };
483        let endpoint = FlightEndpoint::new().with_ticket(ticket);
484
485        let flight_info = FlightInfo::new()
486            .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
487            .map_err(arrow_error_to_status)?
488            .with_endpoint(endpoint)
489            .with_descriptor(flight_descriptor);
490
491        Ok(Response::new(flight_info))
492    }
493
494    async fn get_flight_info_sql_info(
495        &self,
496        _query: CommandGetSqlInfo,
497        request: Request<FlightDescriptor>,
498    ) -> Result<Response<FlightInfo>> {
499        info!("get_flight_info_sql_info");
500        let (_, _) = self.new_context(request).await?;
501
502        Err(Status::unimplemented("Implement CommandGetSqlInfo"))
503    }
504
505    async fn get_flight_info_primary_keys(
506        &self,
507        _query: CommandGetPrimaryKeys,
508        request: Request<FlightDescriptor>,
509    ) -> Result<Response<FlightInfo>> {
510        info!("get_flight_info_primary_keys");
511        let (_, _) = self.new_context(request).await?;
512
513        Err(Status::unimplemented(
514            "Implement get_flight_info_primary_keys",
515        ))
516    }
517
518    async fn get_flight_info_exported_keys(
519        &self,
520        _query: CommandGetExportedKeys,
521        request: Request<FlightDescriptor>,
522    ) -> Result<Response<FlightInfo>> {
523        info!("get_flight_info_exported_keys");
524        let (_, _) = self.new_context(request).await?;
525
526        Err(Status::unimplemented(
527            "Implement get_flight_info_exported_keys",
528        ))
529    }
530
531    async fn get_flight_info_imported_keys(
532        &self,
533        _query: CommandGetImportedKeys,
534        request: Request<FlightDescriptor>,
535    ) -> Result<Response<FlightInfo>> {
536        info!("get_flight_info_imported_keys");
537        let (_, _) = self.new_context(request).await?;
538
539        Err(Status::unimplemented(
540            "Implement get_flight_info_imported_keys",
541        ))
542    }
543
544    async fn get_flight_info_cross_reference(
545        &self,
546        _query: CommandGetCrossReference,
547        request: Request<FlightDescriptor>,
548    ) -> Result<Response<FlightInfo>> {
549        info!("get_flight_info_cross_reference");
550        let (_, _) = self.new_context(request).await?;
551
552        Err(Status::unimplemented(
553            "Implement get_flight_info_cross_reference",
554        ))
555    }
556
557    async fn get_flight_info_xdbc_type_info(
558        &self,
559        _query: CommandGetXdbcTypeInfo,
560        request: Request<FlightDescriptor>,
561    ) -> Result<Response<FlightInfo>> {
562        info!("get_flight_info_xdbc_type_info");
563        let (_, _) = self.new_context(request).await?;
564
565        Err(Status::unimplemented(
566            "Implement get_flight_info_xdbc_type_info",
567        ))
568    }
569
570    async fn do_get_statement(
571        &self,
572        _ticket: TicketStatementQuery,
573        request: Request<Ticket>,
574    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
575        info!("do_get_statement");
576        let (_, _) = self.new_context(request).await?;
577
578        Err(Status::unimplemented("Implement do_get_statement"))
579    }
580
581    async fn do_get_prepared_statement(
582        &self,
583        _query: CommandPreparedStatementQuery,
584        request: Request<Ticket>,
585    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
586        info!("do_get_prepared_statement");
587        let (_, _) = self.new_context(request).await?;
588
589        Err(Status::unimplemented("Implement do_get_prepared_statement"))
590    }
591
592    async fn do_get_catalogs(
593        &self,
594        query: CommandGetCatalogs,
595        request: Request<Ticket>,
596    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
597        info!("do_get_catalogs");
598        let (_request, ctx) = self.new_context(request).await?;
599        let catalog_names = ctx.inner.catalog_names();
600
601        let mut builder = query.into_builder();
602        for catalog_name in &catalog_names {
603            builder.append(catalog_name);
604        }
605        let schema = builder.schema();
606        let batch = builder.build();
607        let stream = FlightDataEncoderBuilder::new()
608            .with_schema(schema)
609            .build(futures::stream::once(async { batch }))
610            .map_err(Status::from);
611        Ok(Response::new(Box::pin(stream)))
612    }
613
614    async fn do_get_schemas(
615        &self,
616        query: CommandGetDbSchemas,
617        request: Request<Ticket>,
618    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
619        info!("do_get_schemas");
620        let (_request, ctx) = self.new_context(request).await?;
621        let catalog_name = query.catalog.clone();
622        // Append all schemas to builder, the builder handles applying the filters.
623        let mut builder = query.into_builder();
624        if let Some(catalog_name) = &catalog_name {
625            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
626                for schema_name in &catalog.schema_names() {
627                    builder.append(catalog_name, schema_name);
628                }
629            }
630        };
631
632        let schema = builder.schema();
633        let batch = builder.build();
634        let stream = FlightDataEncoderBuilder::new()
635            .with_schema(schema)
636            .build(futures::stream::once(async { batch }))
637            .map_err(Status::from);
638        Ok(Response::new(Box::pin(stream)))
639    }
640
641    async fn do_get_tables(
642        &self,
643        query: CommandGetTables,
644        request: Request<Ticket>,
645    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
646        info!("do_get_tables");
647        let (_request, ctx) = self.new_context(request).await?;
648        let catalog_name = query.catalog.clone();
649        let mut builder = query.into_builder();
650        // Append all schemas/tables to builder, the builder handles applying the filters.
651        if let Some(catalog_name) = &catalog_name {
652            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
653                for schema_name in &catalog.schema_names() {
654                    if let Some(schema) = catalog.schema(schema_name) {
655                        for table_name in &schema.table_names() {
656                            if let Some(table) =
657                                schema.table(table_name).await.map_err(df_error_to_status)?
658                            {
659                                builder
660                                    .append(
661                                        catalog_name,
662                                        schema_name,
663                                        table_name,
664                                        table.table_type().to_string(),
665                                        &table.schema(),
666                                    )
667                                    .map_err(flight_error_to_status)?;
668                            }
669                        }
670                    }
671                }
672            }
673        };
674
675        let schema = builder.schema();
676        let batch = builder.build();
677        let stream = FlightDataEncoderBuilder::new()
678            .with_schema(schema)
679            .build(futures::stream::once(async { batch }))
680            .map_err(Status::from);
681        Ok(Response::new(Box::pin(stream)))
682    }
683
684    async fn do_get_table_types(
685        &self,
686        _query: CommandGetTableTypes,
687        request: Request<Ticket>,
688    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
689        info!("do_get_table_types");
690        let (_, _) = self.new_context(request).await?;
691
692        // Report all variants of table types that datafusion uses.
693        let table_types: ArrayRef = Arc::new(StringArray::from(
694            vec![TableType::Base, TableType::View, TableType::Temporary]
695                .into_iter()
696                .map(|tt| tt.to_string())
697                .collect::<Vec<String>>(),
698        ));
699
700        let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
701
702        let stream = FlightDataEncoderBuilder::new()
703            .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
704            .build(futures::stream::once(async { Ok(batch) }))
705            .map_err(Status::from);
706        Ok(Response::new(Box::pin(stream)))
707    }
708
709    async fn do_get_sql_info(
710        &self,
711        _query: CommandGetSqlInfo,
712        request: Request<Ticket>,
713    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
714        info!("do_get_sql_info");
715        let (_, _) = self.new_context(request).await?;
716
717        Err(Status::unimplemented("Implement do_get_sql_info"))
718    }
719
720    async fn do_get_primary_keys(
721        &self,
722        _query: CommandGetPrimaryKeys,
723        request: Request<Ticket>,
724    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
725        info!("do_get_primary_keys");
726        let (_, _) = self.new_context(request).await?;
727
728        Err(Status::unimplemented("Implement do_get_primary_keys"))
729    }
730
731    async fn do_get_exported_keys(
732        &self,
733        _query: CommandGetExportedKeys,
734        request: Request<Ticket>,
735    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
736        info!("do_get_exported_keys");
737        let (_, _) = self.new_context(request).await?;
738
739        Err(Status::unimplemented("Implement do_get_exported_keys"))
740    }
741
742    async fn do_get_imported_keys(
743        &self,
744        _query: CommandGetImportedKeys,
745        request: Request<Ticket>,
746    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
747        info!("do_get_imported_keys");
748        let (_, _) = self.new_context(request).await?;
749
750        Err(Status::unimplemented("Implement do_get_imported_keys"))
751    }
752
753    async fn do_get_cross_reference(
754        &self,
755        _query: CommandGetCrossReference,
756        request: Request<Ticket>,
757    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
758        info!("do_get_cross_reference");
759        let (_, _) = self.new_context(request).await?;
760
761        Err(Status::unimplemented("Implement do_get_cross_reference"))
762    }
763
764    async fn do_get_xdbc_type_info(
765        &self,
766        _query: CommandGetXdbcTypeInfo,
767        request: Request<Ticket>,
768    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
769        info!("do_get_xdbc_type_info");
770        let (_, _) = self.new_context(request).await?;
771
772        Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
773    }
774
775    async fn do_put_statement_update(
776        &self,
777        _ticket: CommandStatementUpdate,
778        request: Request<PeekableFlightDataStream>,
779    ) -> Result<i64, Status> {
780        info!("do_put_statement_update");
781        let (_, _) = self.new_context(request).await?;
782
783        Err(Status::unimplemented("Implement do_put_statement_update"))
784    }
785
786    async fn do_put_prepared_statement_query(
787        &self,
788        query: CommandPreparedStatementQuery,
789        request: Request<PeekableFlightDataStream>,
790    ) -> Result<DoPutPreparedStatementResult, Status> {
791        info!("do_put_prepared_statement_query");
792        let (request, _) = self.new_context(request).await?;
793
794        let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
795
796        info!(
797            "do_action_create_prepared_statement query={:?}",
798            handle.query()
799        );
800        // Collect request flight data as parameters
801        // Decode and encode as a single ipc stream
802        let mut decoder =
803            FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
804        let schema = decode_schema(&mut decoder).await?;
805        let mut parameters = Vec::new();
806        let mut encoder =
807            StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
808        let mut total_rows = 0;
809        while let Some(msg) = decoder.try_next().await? {
810            match msg.payload {
811                DecodedPayload::None => {}
812                DecodedPayload::Schema(_) => {
813                    return Err(Status::invalid_argument(
814                        "parameter flight data must contain a single schema",
815                    ));
816                }
817                DecodedPayload::RecordBatch(record_batch) => {
818                    total_rows += record_batch.num_rows();
819                    encoder
820                        .write(&record_batch)
821                        .map_err(arrow_error_to_status)?;
822                }
823            }
824        }
825        if total_rows > 1 {
826            return Err(Status::invalid_argument(
827                "parameters should contain a single row",
828            ));
829        }
830
831        handle.set_parameters(Some(parameters.into()));
832
833        let res = DoPutPreparedStatementResult {
834            prepared_statement_handle: Some(Bytes::from(handle)),
835        };
836
837        Ok(res)
838    }
839
840    async fn do_put_prepared_statement_update(
841        &self,
842        _handle: CommandPreparedStatementUpdate,
843        request: Request<PeekableFlightDataStream>,
844    ) -> Result<i64, Status> {
845        info!("do_put_prepared_statement_update");
846        let (_, _) = self.new_context(request).await?;
847
848        // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function
849        // and we are required to return some row count here
850        Ok(-1)
851    }
852
853    async fn do_put_substrait_plan(
854        &self,
855        _query: CommandStatementSubstraitPlan,
856        request: Request<PeekableFlightDataStream>,
857    ) -> Result<i64, Status> {
858        info!("do_put_prepared_statement_update");
859        let (_, _) = self.new_context(request).await?;
860
861        Err(Status::unimplemented(
862            "Implement do_put_prepared_statement_update",
863        ))
864    }
865
866    async fn do_action_create_prepared_statement(
867        &self,
868        query: ActionCreatePreparedStatementRequest,
869        request: Request<Action>,
870    ) -> Result<ActionCreatePreparedStatementResult, Status> {
871        let (_, ctx) = self.new_context(request).await?;
872
873        let sql = query.query.clone();
874        info!(
875            "do_action_create_prepared_statement query={:?}",
876            query.query
877        );
878
879        let plan = ctx
880            .sql_to_logical_plan(sql.as_str())
881            .await
882            .map_err(df_error_to_status)?;
883
884        let dataset_schema = get_schema_for_plan(&plan);
885        let parameter_schema = parameter_schema_for_plan(&plan)?;
886
887        let dataset_schema =
888            encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
889        let parameter_schema =
890            encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
891
892        let handle = QueryHandle::new(sql, None);
893
894        let res = ActionCreatePreparedStatementResult {
895            prepared_statement_handle: Bytes::from(handle),
896            dataset_schema,
897            parameter_schema,
898        };
899
900        Ok(res)
901    }
902
903    async fn do_action_close_prepared_statement(
904        &self,
905        query: ActionClosePreparedStatementRequest,
906        request: Request<Action>,
907    ) -> Result<(), Status> {
908        let (_, _) = self.new_context(request).await?;
909
910        let handle = query.prepared_statement_handle.as_ref();
911        if let Ok(handle) = std::str::from_utf8(handle) {
912            info!(
913                "do_action_close_prepared_statement with handle {:?}",
914                handle
915            );
916
917            // NOP since stateless
918        }
919        Ok(())
920    }
921
922    async fn do_action_create_prepared_substrait_plan(
923        &self,
924        _query: ActionCreatePreparedSubstraitPlanRequest,
925        request: Request<Action>,
926    ) -> Result<ActionCreatePreparedStatementResult, Status> {
927        info!("do_action_create_prepared_substrait_plan");
928        let (_, _) = self.new_context(request).await?;
929
930        Err(Status::unimplemented(
931            "Implement do_action_create_prepared_substrait_plan",
932        ))
933    }
934
935    async fn do_action_begin_transaction(
936        &self,
937        _query: ActionBeginTransactionRequest,
938        request: Request<Action>,
939    ) -> Result<ActionBeginTransactionResult, Status> {
940        let (_, _) = self.new_context(request).await?;
941
942        info!("do_action_begin_transaction");
943        Err(Status::unimplemented(
944            "Implement do_action_begin_transaction",
945        ))
946    }
947
948    async fn do_action_end_transaction(
949        &self,
950        _query: ActionEndTransactionRequest,
951        request: Request<Action>,
952    ) -> Result<(), Status> {
953        info!("do_action_end_transaction");
954        let (_, _) = self.new_context(request).await?;
955
956        Err(Status::unimplemented("Implement do_action_end_transaction"))
957    }
958
959    async fn do_action_begin_savepoint(
960        &self,
961        _query: ActionBeginSavepointRequest,
962        request: Request<Action>,
963    ) -> Result<ActionBeginSavepointResult, Status> {
964        info!("do_action_begin_savepoint");
965        let (_, _) = self.new_context(request).await?;
966
967        Err(Status::unimplemented("Implement do_action_begin_savepoint"))
968    }
969
970    async fn do_action_end_savepoint(
971        &self,
972        _query: ActionEndSavepointRequest,
973        request: Request<Action>,
974    ) -> Result<(), Status> {
975        info!("do_action_end_savepoint");
976        let (_, _) = self.new_context(request).await?;
977
978        Err(Status::unimplemented("Implement do_action_end_savepoint"))
979    }
980
981    async fn do_action_cancel_query(
982        &self,
983        _query: ActionCancelQueryRequest,
984        request: Request<Action>,
985    ) -> Result<ActionCancelQueryResult, Status> {
986        info!("do_action_cancel_query");
987        let (_, _) = self.new_context(request).await?;
988
989        Err(Status::unimplemented("Implement do_action_cancel_query"))
990    }
991
992    async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
993}
994
995/// Takes a substrait plan serialized as [Bytes] and deserializes this to
996/// a Datafusion [LogicalPlan]
997async fn parse_substrait_bytes(
998    ctx: &FlightSqlSessionContext,
999    substrait: &Bytes,
1000) -> Result<LogicalPlan> {
1001    let substrait_plan = deserialize_bytes(substrait.to_vec())
1002        .await
1003        .map_err(df_error_to_status)?;
1004
1005    from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1006        .await
1007        .map_err(df_error_to_status)
1008}
1009
1010/// Encodes the schema IPC encoded (schema_bytes)
1011fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1012    let options = IpcWriteOptions::default();
1013
1014    // encode the schema into the correct form
1015    let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1016
1017    let IpcMessage(schema) = message?;
1018
1019    Ok(schema)
1020}
1021
1022/// Return the schema for the specified logical plan
1023fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef {
1024    // gather real schema, but only
1025    let schema = Schema::from(logical_plan.schema().as_ref()).into();
1026
1027    // Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
1028    // This is necessary as the schema can change based on dictionary hydration behavior.
1029    let flight_data_stream = FlightDataEncoderBuilder::new()
1030        // Inform the builder of the input stream schema
1031        .with_schema(schema)
1032        .build(futures::stream::iter([]));
1033
1034    // Retrieve the schema of the encoded data
1035    flight_data_stream
1036        .known_schema()
1037        .expect("flight data schema should be known when explicitly provided via `with_schema`")
1038}
1039
1040fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Status> {
1041    let parameters = plan
1042        .get_parameter_types()
1043        .map_err(df_error_to_status)?
1044        .into_iter()
1045        .map(|(name, dt)| {
1046            dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1047                Status::internal(format!(
1048                    "unable to determine type of query parameter {name}"
1049                ))
1050            })
1051        })
1052        // Collect into BTreeMap so we get a consistent order of the parameters
1053        .collect::<Result<BTreeMap<_, _>, Status>>()?;
1054
1055    let mut builder = SchemaBuilder::new();
1056    parameters
1057        .into_iter()
1058        .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1059    Ok(builder.finish().into())
1060}
1061
1062fn arrow_error_to_status(err: ArrowError) -> Status {
1063    Status::internal(format!("{err:?}"))
1064}
1065
1066fn flight_error_to_status(err: FlightError) -> Status {
1067    Status::internal(format!("{err:?}"))
1068}
1069
1070fn df_error_to_status(err: DataFusionError) -> Status {
1071    Status::internal(format!("{err:?}"))
1072}
1073
1074fn status_to_flight_error(status: Status) -> FlightError {
1075    FlightError::Tonic(status)
1076}
1077
1078async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1079    while let Some(msg) = decoder.try_next().await? {
1080        match msg.payload {
1081            DecodedPayload::None => {}
1082            DecodedPayload::Schema(schema) => {
1083                return Ok(schema);
1084            }
1085            DecodedPayload::RecordBatch(_) => {
1086                return Err(Status::invalid_argument(
1087                    "parameter flight data must have a known schema",
1088                ));
1089            }
1090        }
1091    }
1092
1093    Err(Status::invalid_argument(
1094        "parameter flight data must have a schema",
1095    ))
1096}
1097
1098// Decode parameter ipc stream as ParamValues
1099fn decode_param_values(
1100    parameters: Option<&[u8]>,
1101) -> Result<Option<ParamValues>, arrow::error::ArrowError> {
1102    parameters
1103        .map(|parameters| {
1104            let decoder = StreamReader::try_new(parameters, None)?;
1105            let schema = decoder.schema();
1106            let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1107            let batch = concat_batches(&schema, batches.iter())?;
1108            Ok(record_to_param_values(&batch)?)
1109        })
1110        .transpose()
1111}
1112
1113// Converts a record batch with a single row into ParamValues
1114fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1115    let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1116
1117    let mut is_list = true;
1118    for col_index in 0..batch.num_columns() {
1119        let array = batch.column(col_index);
1120        let scalar = ScalarValue::try_from_array(array, 0)?;
1121        let name = batch
1122            .schema_ref()
1123            .field(col_index)
1124            .name()
1125            .trim_start_matches('$')
1126            .to_string();
1127        let index = name.parse().ok();
1128        is_list &= index.is_some();
1129        param_values.push((name, index, scalar));
1130    }
1131    if is_list {
1132        let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1133            .into_iter()
1134            .map(|(_name, index, value)| (index, value))
1135            .collect();
1136        values.sort_by_key(|(index, _value)| *index);
1137        Ok(values
1138            .into_iter()
1139            .map(|(_index, value)| value)
1140            .collect::<Vec<ScalarValue>>()
1141            .into())
1142    } else {
1143        Ok(param_values
1144            .into_iter()
1145            .map(|(name, _index, value)| (name, value))
1146            .collect::<Vec<(String, ScalarValue)>>()
1147            .into())
1148    }
1149}