datafusion_flight_sql_server/
service.rs

1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow::{
4    array::{ArrayRef, RecordBatch, StringArray},
5    compute::concat_batches,
6    datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
7    error::ArrowError,
8    ipc::{
9        reader::StreamReader,
10        writer::{IpcWriteOptions, StreamWriter},
11    },
12};
13use arrow_flight::{
14    decode::{DecodedPayload, FlightDataDecoder},
15    sql::{
16        self,
17        server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
18        ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
19        ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
20        ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
21        ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
22        ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
23        CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
24        CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
25        CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
26        CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
27        CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
28        TicketStatementQuery,
29    },
30};
31use arrow_flight::{
32    encode::FlightDataEncoderBuilder,
33    error::FlightError,
34    flight_service_server::{FlightService, FlightServiceServer},
35    Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
36    IpcMessage, SchemaAsIpc, Ticket,
37};
38use datafusion::{
39    common::{arrow::datatypes::Schema, ParamValues},
40    dataframe::DataFrame,
41    datasource::TableType,
42    error::{DataFusionError, Result as DataFusionResult},
43    execution::context::{SQLOptions, SessionContext, SessionState},
44    logical_expr::LogicalPlan,
45    physical_plan::SendableRecordBatchStream,
46    scalar::ScalarValue,
47};
48use datafusion_substrait::{
49    logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51
52use futures::{Stream, StreamExt, TryStreamExt};
53use log::info;
54use once_cell::sync::Lazy;
55use prost::bytes::Bytes;
56use prost::Message;
57use tonic::transport::Server;
58use tonic::{Request, Response, Status, Streaming};
59
60use super::config::FlightSqlServiceConfig;
61use super::session::{SessionStateProvider, StaticSessionStateProvider};
62use super::state::{CommandTicket, QueryHandle};
63
64type Result<T, E = Status> = std::result::Result<T, E>;
65
66/// FlightSqlService is a basic stateless FlightSqlService implementation.
67pub struct FlightSqlService {
68    provider: Box<dyn SessionStateProvider>,
69    sql_options: Option<SQLOptions>,
70    config: FlightSqlServiceConfig,
71}
72
73impl FlightSqlService {
74    /// Creates a new FlightSqlService with a static SessionState.
75    pub fn new(state: SessionState) -> Self {
76        Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
77    }
78
79    /// Creates a new FlightSqlService with a SessionStateProvider.
80    pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
81        Self {
82            provider,
83            sql_options: None,
84            config: FlightSqlServiceConfig::default(),
85        }
86    }
87
88    /// Replaces the FlightSqlServiceConfig with the provided config.
89    pub fn with_config(self, config: FlightSqlServiceConfig) -> Self {
90        Self { config, ..self }
91    }
92
93    /// Replaces the sql_options with the provided options.
94    /// These options are used to verify all SQL queries.
95    /// When None the default [`SQLOptions`] are used.
96    pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
97        Self {
98            sql_options: Some(sql_options),
99            ..self
100        }
101    }
102
103    // Federate substrait plans instead of SQL
104    // pub fn with_substrait() -> Self {
105    // TODO: Substrait federation
106    // }
107
108    // Serves straightforward on the specified address.
109    pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
110        let addr = addr.parse()?;
111        info!("Listening on {addr:?}");
112
113        let svc = FlightServiceServer::new(self);
114
115        Ok(Server::builder().add_service(svc).serve(addr).await?)
116    }
117
118    async fn new_context<T>(
119        &self,
120        request: Request<T>,
121    ) -> Result<(Request<T>, FlightSqlSessionContext)> {
122        let (metadata, extensions, msg) = request.into_parts();
123        let inspect_request = Request::from_parts(metadata, extensions, ());
124
125        let state = self.provider.new_context(&inspect_request).await?;
126        let ctx = SessionContext::new_with_state(state);
127
128        let (metadata, extensions, _) = inspect_request.into_parts();
129        Ok((
130            Request::from_parts(metadata, extensions, msg),
131            FlightSqlSessionContext {
132                inner: ctx,
133                sql_options: self.sql_options,
134            },
135        ))
136    }
137}
138
139/// The schema for GetTableTypes
140static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
141    //TODO: Move this into arrow-flight itself, similar to the builder pattern for CommandGetCatalogs and CommandGetDbSchemas
142    Arc::new(Schema::new(vec![Field::new(
143        "table_type",
144        DataType::Utf8,
145        false,
146    )]))
147});
148
149struct FlightSqlSessionContext {
150    inner: SessionContext,
151    sql_options: Option<SQLOptions>,
152}
153
154impl FlightSqlSessionContext {
155    async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
156        let plan = self.inner.state().create_logical_plan(sql).await?;
157        let verifier = self.sql_options.unwrap_or_default();
158        verifier.verify_plan(&plan)?;
159        Ok(plan)
160    }
161
162    async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
163        let plan = self.sql_to_logical_plan(sql).await?;
164        self.execute_logical_plan(plan).await
165    }
166
167    async fn execute_logical_plan(
168        &self,
169        plan: LogicalPlan,
170    ) -> DataFusionResult<SendableRecordBatchStream> {
171        self.inner
172            .execute_logical_plan(plan)
173            .await?
174            .execute_stream()
175            .await
176    }
177}
178
179#[tonic::async_trait]
180impl ArrowFlightSqlService for FlightSqlService {
181    type FlightService = FlightSqlService;
182
183    async fn do_handshake(
184        &self,
185        _request: Request<Streaming<HandshakeRequest>>,
186    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
187        info!("do_handshake");
188        // Favor middleware over handshake
189        // https://github.com/apache/arrow/issues/23836
190        // https://github.com/apache/arrow/issues/25848
191        Err(Status::unimplemented("handshake is not supported"))
192    }
193
194    async fn do_get_fallback(
195        &self,
196        request: Request<Ticket>,
197        _message: Any,
198    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
199        let (request, ctx) = self.new_context(request).await?;
200
201        let ticket = CommandTicket::try_decode(request.into_inner().ticket)
202            .map_err(flight_error_to_status)?;
203
204        match ticket.command {
205            sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
206                // print!("Query: {query}\n");
207
208                let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
209                let arrow_schema = stream.schema();
210                let arrow_stream = stream.map(|i| {
211                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
212                    Ok(batch)
213                });
214
215                let flight_data_stream = FlightDataEncoderBuilder::new()
216                    .with_schema(arrow_schema)
217                    .build(arrow_stream)
218                    .map_err(flight_error_to_status)
219                    .boxed();
220
221                Ok(Response::new(flight_data_stream))
222            }
223            sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
224                prepared_statement_handle,
225            }) => {
226                let handle = QueryHandle::try_decode(prepared_statement_handle)?;
227
228                let mut plan = ctx
229                    .sql_to_logical_plan(handle.query())
230                    .await
231                    .map_err(df_error_to_status)?;
232
233                if let Some(param_values) =
234                    decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
235                {
236                    plan = plan
237                        .with_param_values(param_values)
238                        .map_err(df_error_to_status)?;
239                }
240
241                let stream = ctx
242                    .execute_logical_plan(plan)
243                    .await
244                    .map_err(df_error_to_status)?;
245                let arrow_schema = stream.schema();
246                let arrow_stream = stream.map(|i| {
247                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
248                    Ok(batch)
249                });
250
251                let flight_data_stream = FlightDataEncoderBuilder::new()
252                    .with_schema(arrow_schema)
253                    .build(arrow_stream)
254                    .map_err(flight_error_to_status)
255                    .boxed();
256
257                Ok(Response::new(flight_data_stream))
258            }
259            sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
260                plan,
261                ..
262            }) => {
263                let substrait_bytes = &plan
264                    .ok_or(Status::invalid_argument(
265                        "Expected substrait plan, found None",
266                    ))?
267                    .plan;
268
269                let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
270
271                let state = ctx.inner.state();
272                let df = DataFrame::new(state, plan);
273
274                let stream = df.execute_stream().await.map_err(df_error_to_status)?;
275                let arrow_schema = stream.schema();
276                let arrow_stream = stream.map(|i| {
277                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
278                    Ok(batch)
279                });
280
281                let flight_data_stream = FlightDataEncoderBuilder::new()
282                    .with_schema(arrow_schema)
283                    .build(arrow_stream)
284                    .map_err(flight_error_to_status)
285                    .boxed();
286
287                Ok(Response::new(flight_data_stream))
288            }
289            _ => {
290                return Err(Status::internal(format!(
291                    "statement handle not found: {:?}",
292                    ticket.command
293                )));
294            }
295        }
296    }
297
298    async fn get_flight_info_statement(
299        &self,
300        query: CommandStatementQuery,
301        request: Request<FlightDescriptor>,
302    ) -> Result<Response<FlightInfo>> {
303        let (request, ctx) = self.new_context(request).await?;
304
305        let sql = &query.query;
306        info!("get_flight_info_statement with query={sql}");
307
308        let flight_descriptor = request.into_inner();
309
310        let plan = ctx
311            .sql_to_logical_plan(sql)
312            .await
313            .map_err(df_error_to_status)?;
314
315        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
316
317        // Form the response ticket (that the client will pass back to DoGet)
318        let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
319            .try_encode()
320            .map_err(flight_error_to_status)?;
321
322        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
323
324        let flight_info = FlightInfo::new()
325            .with_endpoint(endpoint)
326            // return descriptor we were passed
327            .with_descriptor(flight_descriptor)
328            .try_with_schema(dataset_schema.as_ref())
329            .map_err(arrow_error_to_status)?;
330
331        Ok(Response::new(flight_info))
332    }
333
334    async fn get_flight_info_substrait_plan(
335        &self,
336        query: CommandStatementSubstraitPlan,
337        request: Request<FlightDescriptor>,
338    ) -> Result<Response<FlightInfo>> {
339        info!("get_flight_info_substrait_plan");
340        let (request, ctx) = self.new_context(request).await?;
341
342        let substrait_bytes = &query
343            .plan
344            .as_ref()
345            .ok_or(Status::invalid_argument(
346                "Expected substrait plan, found None",
347            ))?
348            .plan;
349
350        let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
351
352        let flight_descriptor = request.into_inner();
353
354        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
355
356        // Form the response ticket (that the client will pass back to DoGet)
357        let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
358            .try_encode()
359            .map_err(flight_error_to_status)?;
360
361        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
362
363        let flight_info = FlightInfo::new()
364            .with_endpoint(endpoint)
365            // return descriptor we were passed
366            .with_descriptor(flight_descriptor)
367            .try_with_schema(dataset_schema.as_ref())
368            .map_err(arrow_error_to_status)?;
369
370        Ok(Response::new(flight_info))
371    }
372
373    async fn get_flight_info_prepared_statement(
374        &self,
375        cmd: CommandPreparedStatementQuery,
376        request: Request<FlightDescriptor>,
377    ) -> Result<Response<FlightInfo>> {
378        let (request, ctx) = self.new_context(request).await?;
379
380        let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
381            .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
382
383        info!("get_flight_info_prepared_statement with handle={handle}");
384
385        let flight_descriptor = request.into_inner();
386
387        let sql = handle.query();
388        let plan = ctx
389            .sql_to_logical_plan(sql)
390            .await
391            .map_err(df_error_to_status)?;
392
393        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
394
395        // Form the response ticket (that the client will pass back to DoGet)
396        let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
397            .try_encode()
398            .map_err(flight_error_to_status)?;
399
400        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
401
402        let flight_info = FlightInfo::new()
403            .with_endpoint(endpoint)
404            // return descriptor we were passed
405            .with_descriptor(flight_descriptor)
406            .try_with_schema(dataset_schema.as_ref())
407            .map_err(arrow_error_to_status)?;
408
409        Ok(Response::new(flight_info))
410    }
411
412    async fn get_flight_info_catalogs(
413        &self,
414        query: CommandGetCatalogs,
415        request: Request<FlightDescriptor>,
416    ) -> Result<Response<FlightInfo>> {
417        info!("get_flight_info_catalogs");
418        let (request, _ctx) = self.new_context(request).await?;
419
420        let flight_descriptor = request.into_inner();
421        let ticket = Ticket {
422            ticket: query.as_any().encode_to_vec().into(),
423        };
424        let endpoint = FlightEndpoint::new().with_ticket(ticket);
425
426        let flight_info = FlightInfo::new()
427            .try_with_schema(&query.into_builder().schema())
428            .map_err(arrow_error_to_status)?
429            .with_endpoint(endpoint)
430            .with_descriptor(flight_descriptor);
431
432        Ok(Response::new(flight_info))
433    }
434
435    async fn get_flight_info_schemas(
436        &self,
437        query: CommandGetDbSchemas,
438        request: Request<FlightDescriptor>,
439    ) -> Result<Response<FlightInfo>> {
440        info!("get_flight_info_schemas");
441        let (request, _ctx) = self.new_context(request).await?;
442        let flight_descriptor = request.into_inner();
443        let ticket = Ticket {
444            ticket: query.as_any().encode_to_vec().into(),
445        };
446        let endpoint = FlightEndpoint::new().with_ticket(ticket);
447
448        let flight_info = FlightInfo::new()
449            .try_with_schema(&query.into_builder().schema())
450            .map_err(arrow_error_to_status)?
451            .with_endpoint(endpoint)
452            .with_descriptor(flight_descriptor);
453
454        Ok(Response::new(flight_info))
455    }
456
457    async fn get_flight_info_tables(
458        &self,
459        query: CommandGetTables,
460        request: Request<FlightDescriptor>,
461    ) -> Result<Response<FlightInfo>> {
462        info!("get_flight_info_tables");
463        let (request, _ctx) = self.new_context(request).await?;
464
465        let flight_descriptor = request.into_inner();
466        let ticket = Ticket {
467            ticket: query.as_any().encode_to_vec().into(),
468        };
469        let endpoint = FlightEndpoint::new().with_ticket(ticket);
470
471        let flight_info = FlightInfo::new()
472            .try_with_schema(&query.into_builder().schema())
473            .map_err(arrow_error_to_status)?
474            .with_endpoint(endpoint)
475            .with_descriptor(flight_descriptor);
476
477        Ok(Response::new(flight_info))
478    }
479
480    async fn get_flight_info_table_types(
481        &self,
482        query: CommandGetTableTypes,
483        request: Request<FlightDescriptor>,
484    ) -> Result<Response<FlightInfo>> {
485        info!("get_flight_info_table_types");
486        let (request, _ctx) = self.new_context(request).await?;
487
488        let flight_descriptor = request.into_inner();
489        let ticket = Ticket {
490            ticket: query.as_any().encode_to_vec().into(),
491        };
492        let endpoint = FlightEndpoint::new().with_ticket(ticket);
493
494        let flight_info = FlightInfo::new()
495            .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
496            .map_err(arrow_error_to_status)?
497            .with_endpoint(endpoint)
498            .with_descriptor(flight_descriptor);
499
500        Ok(Response::new(flight_info))
501    }
502
503    async fn get_flight_info_sql_info(
504        &self,
505        _query: CommandGetSqlInfo,
506        request: Request<FlightDescriptor>,
507    ) -> Result<Response<FlightInfo>> {
508        info!("get_flight_info_sql_info");
509        let (_, _) = self.new_context(request).await?;
510
511        Err(Status::unimplemented("Implement CommandGetSqlInfo"))
512    }
513
514    async fn get_flight_info_primary_keys(
515        &self,
516        _query: CommandGetPrimaryKeys,
517        request: Request<FlightDescriptor>,
518    ) -> Result<Response<FlightInfo>> {
519        info!("get_flight_info_primary_keys");
520        let (_, _) = self.new_context(request).await?;
521
522        Err(Status::unimplemented(
523            "Implement get_flight_info_primary_keys",
524        ))
525    }
526
527    async fn get_flight_info_exported_keys(
528        &self,
529        _query: CommandGetExportedKeys,
530        request: Request<FlightDescriptor>,
531    ) -> Result<Response<FlightInfo>> {
532        info!("get_flight_info_exported_keys");
533        let (_, _) = self.new_context(request).await?;
534
535        Err(Status::unimplemented(
536            "Implement get_flight_info_exported_keys",
537        ))
538    }
539
540    async fn get_flight_info_imported_keys(
541        &self,
542        _query: CommandGetImportedKeys,
543        request: Request<FlightDescriptor>,
544    ) -> Result<Response<FlightInfo>> {
545        info!("get_flight_info_imported_keys");
546        let (_, _) = self.new_context(request).await?;
547
548        Err(Status::unimplemented(
549            "Implement get_flight_info_imported_keys",
550        ))
551    }
552
553    async fn get_flight_info_cross_reference(
554        &self,
555        _query: CommandGetCrossReference,
556        request: Request<FlightDescriptor>,
557    ) -> Result<Response<FlightInfo>> {
558        info!("get_flight_info_cross_reference");
559        let (_, _) = self.new_context(request).await?;
560
561        Err(Status::unimplemented(
562            "Implement get_flight_info_cross_reference",
563        ))
564    }
565
566    async fn get_flight_info_xdbc_type_info(
567        &self,
568        _query: CommandGetXdbcTypeInfo,
569        request: Request<FlightDescriptor>,
570    ) -> Result<Response<FlightInfo>> {
571        info!("get_flight_info_xdbc_type_info");
572        let (_, _) = self.new_context(request).await?;
573
574        Err(Status::unimplemented(
575            "Implement get_flight_info_xdbc_type_info",
576        ))
577    }
578
579    async fn do_get_statement(
580        &self,
581        _ticket: TicketStatementQuery,
582        request: Request<Ticket>,
583    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
584        info!("do_get_statement");
585        let (_, _) = self.new_context(request).await?;
586
587        Err(Status::unimplemented("Implement do_get_statement"))
588    }
589
590    async fn do_get_prepared_statement(
591        &self,
592        _query: CommandPreparedStatementQuery,
593        request: Request<Ticket>,
594    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
595        info!("do_get_prepared_statement");
596        let (_, _) = self.new_context(request).await?;
597
598        Err(Status::unimplemented("Implement do_get_prepared_statement"))
599    }
600
601    async fn do_get_catalogs(
602        &self,
603        query: CommandGetCatalogs,
604        request: Request<Ticket>,
605    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
606        info!("do_get_catalogs");
607        let (_request, ctx) = self.new_context(request).await?;
608        let catalog_names = ctx.inner.catalog_names();
609
610        let mut builder = query.into_builder();
611        for catalog_name in &catalog_names {
612            builder.append(catalog_name);
613        }
614        let schema = builder.schema();
615        let batch = builder.build();
616        let stream = FlightDataEncoderBuilder::new()
617            .with_schema(schema)
618            .build(futures::stream::once(async { batch }))
619            .map_err(Status::from);
620        Ok(Response::new(Box::pin(stream)))
621    }
622
623    async fn do_get_schemas(
624        &self,
625        query: CommandGetDbSchemas,
626        request: Request<Ticket>,
627    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
628        info!("do_get_schemas");
629        let (_request, ctx) = self.new_context(request).await?;
630        let catalog_name = query.catalog.clone();
631        // Append all schemas to builder, the builder handles applying the filters.
632        let mut builder = query.into_builder();
633        if let Some(catalog_name) = &catalog_name {
634            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
635                for schema_name in &catalog.schema_names() {
636                    builder.append(catalog_name, schema_name);
637                }
638            }
639        };
640
641        let schema = builder.schema();
642        let batch = builder.build();
643        let stream = FlightDataEncoderBuilder::new()
644            .with_schema(schema)
645            .build(futures::stream::once(async { batch }))
646            .map_err(Status::from);
647        Ok(Response::new(Box::pin(stream)))
648    }
649
650    async fn do_get_tables(
651        &self,
652        query: CommandGetTables,
653        request: Request<Ticket>,
654    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
655        info!("do_get_tables");
656        let (_request, ctx) = self.new_context(request).await?;
657        let catalog_name = query.catalog.clone();
658        let mut builder = query.into_builder();
659        // Append all schemas/tables to builder, the builder handles applying the filters.
660        if let Some(catalog_name) = &catalog_name {
661            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
662                for schema_name in &catalog.schema_names() {
663                    if let Some(schema) = catalog.schema(schema_name) {
664                        for table_name in &schema.table_names() {
665                            if let Some(table) =
666                                schema.table(table_name).await.map_err(df_error_to_status)?
667                            {
668                                builder
669                                    .append(
670                                        catalog_name,
671                                        schema_name,
672                                        table_name,
673                                        table.table_type().to_string(),
674                                        &table.schema(),
675                                    )
676                                    .map_err(flight_error_to_status)?;
677                            }
678                        }
679                    }
680                }
681            }
682        };
683
684        let schema = builder.schema();
685        let batch = builder.build();
686        let stream = FlightDataEncoderBuilder::new()
687            .with_schema(schema)
688            .build(futures::stream::once(async { batch }))
689            .map_err(Status::from);
690        Ok(Response::new(Box::pin(stream)))
691    }
692
693    async fn do_get_table_types(
694        &self,
695        _query: CommandGetTableTypes,
696        request: Request<Ticket>,
697    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
698        info!("do_get_table_types");
699        let (_, _) = self.new_context(request).await?;
700
701        // Report all variants of table types that datafusion uses.
702        let table_types: ArrayRef = Arc::new(StringArray::from(
703            vec![TableType::Base, TableType::View, TableType::Temporary]
704                .into_iter()
705                .map(|tt| tt.to_string())
706                .collect::<Vec<String>>(),
707        ));
708
709        let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
710
711        let stream = FlightDataEncoderBuilder::new()
712            .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
713            .build(futures::stream::once(async { Ok(batch) }))
714            .map_err(Status::from);
715        Ok(Response::new(Box::pin(stream)))
716    }
717
718    async fn do_get_sql_info(
719        &self,
720        _query: CommandGetSqlInfo,
721        request: Request<Ticket>,
722    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
723        info!("do_get_sql_info");
724        let (_, _) = self.new_context(request).await?;
725
726        Err(Status::unimplemented("Implement do_get_sql_info"))
727    }
728
729    async fn do_get_primary_keys(
730        &self,
731        _query: CommandGetPrimaryKeys,
732        request: Request<Ticket>,
733    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
734        info!("do_get_primary_keys");
735        let (_, _) = self.new_context(request).await?;
736
737        Err(Status::unimplemented("Implement do_get_primary_keys"))
738    }
739
740    async fn do_get_exported_keys(
741        &self,
742        _query: CommandGetExportedKeys,
743        request: Request<Ticket>,
744    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
745        info!("do_get_exported_keys");
746        let (_, _) = self.new_context(request).await?;
747
748        Err(Status::unimplemented("Implement do_get_exported_keys"))
749    }
750
751    async fn do_get_imported_keys(
752        &self,
753        _query: CommandGetImportedKeys,
754        request: Request<Ticket>,
755    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
756        info!("do_get_imported_keys");
757        let (_, _) = self.new_context(request).await?;
758
759        Err(Status::unimplemented("Implement do_get_imported_keys"))
760    }
761
762    async fn do_get_cross_reference(
763        &self,
764        _query: CommandGetCrossReference,
765        request: Request<Ticket>,
766    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
767        info!("do_get_cross_reference");
768        let (_, _) = self.new_context(request).await?;
769
770        Err(Status::unimplemented("Implement do_get_cross_reference"))
771    }
772
773    async fn do_get_xdbc_type_info(
774        &self,
775        _query: CommandGetXdbcTypeInfo,
776        request: Request<Ticket>,
777    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
778        info!("do_get_xdbc_type_info");
779        let (_, _) = self.new_context(request).await?;
780
781        Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
782    }
783
784    async fn do_put_statement_update(
785        &self,
786        _ticket: CommandStatementUpdate,
787        request: Request<PeekableFlightDataStream>,
788    ) -> Result<i64, Status> {
789        info!("do_put_statement_update");
790        let (_, _) = self.new_context(request).await?;
791
792        Err(Status::unimplemented("Implement do_put_statement_update"))
793    }
794
795    async fn do_put_prepared_statement_query(
796        &self,
797        query: CommandPreparedStatementQuery,
798        request: Request<PeekableFlightDataStream>,
799    ) -> Result<DoPutPreparedStatementResult, Status> {
800        info!("do_put_prepared_statement_query");
801        let (request, _) = self.new_context(request).await?;
802
803        let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
804
805        info!(
806            "do_action_create_prepared_statement query={:?}",
807            handle.query()
808        );
809        // Collect request flight data as parameters
810        // Decode and encode as a single ipc stream
811        let mut decoder =
812            FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
813        let schema = decode_schema(&mut decoder).await?;
814        let mut parameters = Vec::new();
815        let mut encoder =
816            StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
817        let mut total_rows = 0;
818        while let Some(msg) = decoder.try_next().await? {
819            match msg.payload {
820                DecodedPayload::None => {}
821                DecodedPayload::Schema(_) => {
822                    return Err(Status::invalid_argument(
823                        "parameter flight data must contain a single schema",
824                    ));
825                }
826                DecodedPayload::RecordBatch(record_batch) => {
827                    total_rows += record_batch.num_rows();
828                    encoder
829                        .write(&record_batch)
830                        .map_err(arrow_error_to_status)?;
831                }
832            }
833        }
834        if total_rows > 1 {
835            return Err(Status::invalid_argument(
836                "parameters should contain a single row",
837            ));
838        }
839
840        handle.set_parameters(Some(parameters.into()));
841
842        let res = DoPutPreparedStatementResult {
843            prepared_statement_handle: Some(Bytes::from(handle)),
844        };
845
846        Ok(res)
847    }
848
849    async fn do_put_prepared_statement_update(
850        &self,
851        _handle: CommandPreparedStatementUpdate,
852        request: Request<PeekableFlightDataStream>,
853    ) -> Result<i64, Status> {
854        info!("do_put_prepared_statement_update");
855        let (_, _) = self.new_context(request).await?;
856
857        // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function
858        // and we are required to return some row count here
859        Ok(-1)
860    }
861
862    async fn do_put_substrait_plan(
863        &self,
864        _query: CommandStatementSubstraitPlan,
865        request: Request<PeekableFlightDataStream>,
866    ) -> Result<i64, Status> {
867        info!("do_put_prepared_statement_update");
868        let (_, _) = self.new_context(request).await?;
869
870        Err(Status::unimplemented(
871            "Implement do_put_prepared_statement_update",
872        ))
873    }
874
875    async fn do_action_create_prepared_statement(
876        &self,
877        query: ActionCreatePreparedStatementRequest,
878        request: Request<Action>,
879    ) -> Result<ActionCreatePreparedStatementResult, Status> {
880        let (_, ctx) = self.new_context(request).await?;
881
882        let sql = query.query.clone();
883        info!(
884            "do_action_create_prepared_statement query={:?}",
885            query.query
886        );
887
888        let plan = ctx
889            .sql_to_logical_plan(sql.as_str())
890            .await
891            .map_err(df_error_to_status)?;
892
893        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
894        let parameter_schema = parameter_schema_for_plan(&plan).map_err(|e| e.as_ref().clone())?;
895
896        let dataset_schema =
897            encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
898        let parameter_schema =
899            encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
900
901        let handle = QueryHandle::new(sql, None);
902
903        let res = ActionCreatePreparedStatementResult {
904            prepared_statement_handle: Bytes::from(handle),
905            dataset_schema,
906            parameter_schema,
907        };
908
909        Ok(res)
910    }
911
912    async fn do_action_close_prepared_statement(
913        &self,
914        query: ActionClosePreparedStatementRequest,
915        request: Request<Action>,
916    ) -> Result<(), Status> {
917        let (_, _) = self.new_context(request).await?;
918
919        let handle = query.prepared_statement_handle.as_ref();
920        if let Ok(handle) = std::str::from_utf8(handle) {
921            info!("do_action_close_prepared_statement with handle {handle:?}",);
922
923            // NOP since stateless
924        }
925        Ok(())
926    }
927
928    async fn do_action_create_prepared_substrait_plan(
929        &self,
930        _query: ActionCreatePreparedSubstraitPlanRequest,
931        request: Request<Action>,
932    ) -> Result<ActionCreatePreparedStatementResult, Status> {
933        info!("do_action_create_prepared_substrait_plan");
934        let (_, _) = self.new_context(request).await?;
935
936        Err(Status::unimplemented(
937            "Implement do_action_create_prepared_substrait_plan",
938        ))
939    }
940
941    async fn do_action_begin_transaction(
942        &self,
943        _query: ActionBeginTransactionRequest,
944        request: Request<Action>,
945    ) -> Result<ActionBeginTransactionResult, Status> {
946        let (_, _) = self.new_context(request).await?;
947
948        info!("do_action_begin_transaction");
949        Err(Status::unimplemented(
950            "Implement do_action_begin_transaction",
951        ))
952    }
953
954    async fn do_action_end_transaction(
955        &self,
956        _query: ActionEndTransactionRequest,
957        request: Request<Action>,
958    ) -> Result<(), Status> {
959        info!("do_action_end_transaction");
960        let (_, _) = self.new_context(request).await?;
961
962        Err(Status::unimplemented("Implement do_action_end_transaction"))
963    }
964
965    async fn do_action_begin_savepoint(
966        &self,
967        _query: ActionBeginSavepointRequest,
968        request: Request<Action>,
969    ) -> Result<ActionBeginSavepointResult, Status> {
970        info!("do_action_begin_savepoint");
971        let (_, _) = self.new_context(request).await?;
972
973        Err(Status::unimplemented("Implement do_action_begin_savepoint"))
974    }
975
976    async fn do_action_end_savepoint(
977        &self,
978        _query: ActionEndSavepointRequest,
979        request: Request<Action>,
980    ) -> Result<(), Status> {
981        info!("do_action_end_savepoint");
982        let (_, _) = self.new_context(request).await?;
983
984        Err(Status::unimplemented("Implement do_action_end_savepoint"))
985    }
986
987    async fn do_action_cancel_query(
988        &self,
989        _query: ActionCancelQueryRequest,
990        request: Request<Action>,
991    ) -> Result<ActionCancelQueryResult, Status> {
992        info!("do_action_cancel_query");
993        let (_, _) = self.new_context(request).await?;
994
995        Err(Status::unimplemented("Implement do_action_cancel_query"))
996    }
997
998    async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
999}
1000
1001/// Takes a substrait plan serialized as [Bytes] and deserializes this to
1002/// a Datafusion [LogicalPlan]
1003async fn parse_substrait_bytes(
1004    ctx: &FlightSqlSessionContext,
1005    substrait: &Bytes,
1006) -> Result<LogicalPlan> {
1007    let substrait_plan = deserialize_bytes(substrait.to_vec())
1008        .await
1009        .map_err(df_error_to_status)?;
1010
1011    from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1012        .await
1013        .map_err(df_error_to_status)
1014}
1015
1016/// Encodes the schema IPC encoded (schema_bytes)
1017fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1018    let options = IpcWriteOptions::default();
1019
1020    // encode the schema into the correct form
1021    let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1022
1023    let IpcMessage(schema) = message?;
1024
1025    Ok(schema)
1026}
1027
1028/// Return the schema for the specified logical plan
1029fn get_schema_for_plan(logical_plan: &LogicalPlan, with_metadata: bool) -> SchemaRef {
1030    let schema: SchemaRef = if with_metadata {
1031        // Get the DFSchema which contains table qualifiers
1032        let df_schema = logical_plan.schema();
1033
1034        // Convert to Arrow Schema and add table name metadata to fields
1035        let fields_with_metadata: Vec<_> = df_schema
1036            .iter()
1037            .map(|(qualifier, field)| {
1038                // If there's a table qualifier, add it as metadata
1039                if let Some(table_ref) = qualifier {
1040                    let mut metadata = field.metadata().clone();
1041                    metadata.insert("table_name".to_string(), table_ref.to_string());
1042                    field.as_ref().clone().with_metadata(metadata)
1043                } else {
1044                    field.as_ref().clone()
1045                }
1046            })
1047            .collect();
1048
1049        Arc::new(Schema::new_with_metadata(
1050            fields_with_metadata,
1051            df_schema.as_ref().metadata().clone(),
1052        ))
1053    } else {
1054        Arc::new(Schema::from(logical_plan.schema().as_ref()))
1055    };
1056
1057    // Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
1058    // This is necessary as the schema can change based on dictionary hydration behavior.
1059    let flight_data_stream = FlightDataEncoderBuilder::new()
1060        // Inform the builder of the input stream schema
1061        .with_schema(schema)
1062        .build(futures::stream::iter([]));
1063
1064    // Retrieve the schema of the encoded data
1065    flight_data_stream
1066        .known_schema()
1067        .expect("flight data schema should be known when explicitly provided via `with_schema`")
1068}
1069
1070fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Box<Status>> {
1071    let parameters = plan
1072        .get_parameter_types()
1073        .map_err(df_error_to_status)?
1074        .into_iter()
1075        .map(|(name, dt)| {
1076            dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1077                Status::internal(format!(
1078                    "unable to determine type of query parameter {name}"
1079                ))
1080            })
1081        })
1082        // Collect into BTreeMap so we get a consistent order of the parameters
1083        .collect::<Result<BTreeMap<_, _>, Status>>()?;
1084
1085    let mut builder = SchemaBuilder::new();
1086    parameters
1087        .into_iter()
1088        .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1089    Ok(builder.finish().into())
1090}
1091
1092fn arrow_error_to_status(err: ArrowError) -> Status {
1093    Status::internal(format!("{err:?}"))
1094}
1095
1096fn flight_error_to_status(err: FlightError) -> Status {
1097    Status::internal(format!("{err:?}"))
1098}
1099
1100fn df_error_to_status(err: DataFusionError) -> Status {
1101    Status::internal(format!("{err:?}"))
1102}
1103
1104fn status_to_flight_error(status: Status) -> FlightError {
1105    FlightError::Tonic(Box::new(status))
1106}
1107
1108async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1109    while let Some(msg) = decoder.try_next().await? {
1110        match msg.payload {
1111            DecodedPayload::None => {}
1112            DecodedPayload::Schema(schema) => {
1113                return Ok(schema);
1114            }
1115            DecodedPayload::RecordBatch(_) => {
1116                return Err(Status::invalid_argument(
1117                    "parameter flight data must have a known schema",
1118                ));
1119            }
1120        }
1121    }
1122
1123    Err(Status::invalid_argument(
1124        "parameter flight data must have a schema",
1125    ))
1126}
1127
1128// Decode parameter ipc stream as ParamValues
1129fn decode_param_values(
1130    parameters: Option<&[u8]>,
1131) -> Result<Option<ParamValues>, arrow::error::ArrowError> {
1132    parameters
1133        .map(|parameters| {
1134            let decoder = StreamReader::try_new(parameters, None)?;
1135            let schema = decoder.schema();
1136            let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1137            let batch = concat_batches(&schema, batches.iter())?;
1138            Ok(record_to_param_values(&batch)?)
1139        })
1140        .transpose()
1141}
1142
1143// Converts a record batch with a single row into ParamValues
1144fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1145    let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1146
1147    let mut is_list = true;
1148    for col_index in 0..batch.num_columns() {
1149        let array = batch.column(col_index);
1150        let scalar = ScalarValue::try_from_array(array, 0)?;
1151        let name = batch
1152            .schema_ref()
1153            .field(col_index)
1154            .name()
1155            .trim_start_matches('$')
1156            .to_string();
1157        let index = name.parse().ok();
1158        is_list &= index.is_some();
1159        param_values.push((name, index, scalar));
1160    }
1161    if is_list {
1162        let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1163            .into_iter()
1164            .map(|(_name, index, value)| (index, value))
1165            .collect();
1166        values.sort_by_key(|(index, _value)| *index);
1167        Ok(values
1168            .into_iter()
1169            .map(|(_index, value)| value)
1170            .collect::<Vec<ScalarValue>>()
1171            .into())
1172    } else {
1173        Ok(param_values
1174            .into_iter()
1175            .map(|(name, _index, value)| (name, value))
1176            .collect::<Vec<(String, ScalarValue)>>()
1177            .into())
1178    }
1179}