datafusion_flight_sql_server/
service.rs

1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow_flight::{
4    decode::{DecodedPayload, FlightDataDecoder},
5    sql::{
6        self,
7        server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
8        ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
9        ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
10        ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
11        ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
12        ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
13        CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
14        CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
15        CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
16        CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
17        CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
18        TicketStatementQuery,
19    },
20};
21use arrow_flight::{
22    encode::FlightDataEncoderBuilder,
23    error::FlightError,
24    flight_service_server::{FlightService, FlightServiceServer},
25    Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
26    IpcMessage, SchemaAsIpc, Ticket,
27};
28use datafusion::arrow::{
29    array::{ArrayRef, RecordBatch, StringArray},
30    compute::concat_batches,
31    datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
32    error::ArrowError,
33    ipc::{
34        reader::StreamReader,
35        writer::{IpcWriteOptions, StreamWriter},
36    },
37};
38use datafusion::{
39    common::{arrow::datatypes::Schema, ParamValues},
40    dataframe::DataFrame,
41    datasource::TableType,
42    error::{DataFusionError, Result as DataFusionResult},
43    execution::context::{SQLOptions, SessionContext, SessionState},
44    logical_expr::LogicalPlan,
45    physical_plan::SendableRecordBatchStream,
46    scalar::ScalarValue,
47};
48use datafusion_substrait::{
49    logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51
52use futures::{Stream, StreamExt, TryStreamExt};
53use log::info;
54use once_cell::sync::Lazy;
55use prost::bytes::Bytes;
56use prost::Message;
57use tonic::transport::Server;
58use tonic::{Request, Response, Status, Streaming};
59
60use super::config::FlightSqlServiceConfig;
61use super::session::{SessionStateProvider, StaticSessionStateProvider};
62use super::state::{CommandTicket, QueryHandle};
63
64type Result<T, E = Status> = std::result::Result<T, E>;
65
66/// FlightSqlService is a basic stateless FlightSqlService implementation.
67pub struct FlightSqlService {
68    provider: Box<dyn SessionStateProvider>,
69    sql_options: Option<SQLOptions>,
70    config: FlightSqlServiceConfig,
71}
72
73impl FlightSqlService {
74    /// Creates a new FlightSqlService with a static SessionState.
75    pub fn new(state: SessionState) -> Self {
76        Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
77    }
78
79    /// Creates a new FlightSqlService with a SessionStateProvider.
80    pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
81        Self {
82            provider,
83            sql_options: None,
84            config: FlightSqlServiceConfig::default(),
85        }
86    }
87
88    /// Replaces the FlightSqlServiceConfig with the provided config.
89    pub fn with_config(self, config: FlightSqlServiceConfig) -> Self {
90        Self { config, ..self }
91    }
92
93    /// Replaces the sql_options with the provided options.
94    /// These options are used to verify all SQL queries.
95    /// When None the default [`SQLOptions`] are used.
96    pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
97        Self {
98            sql_options: Some(sql_options),
99            ..self
100        }
101    }
102
103    // Federate substrait plans instead of SQL
104    // pub fn with_substrait() -> Self {
105    // TODO: Substrait federation
106    // }
107
108    // Serves straightforward on the specified address.
109    pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
110        let addr = addr.parse()?;
111        info!("Listening on {addr:?}");
112
113        let svc = FlightServiceServer::new(self);
114
115        Ok(Server::builder().add_service(svc).serve(addr).await?)
116    }
117
118    pub async fn serve_with_listener(
119        self,
120        listener: std::net::TcpListener,
121    ) -> Result<(), Box<dyn std::error::Error>> {
122        info!("Listening on {}", listener.local_addr()?);
123
124        let svc = FlightServiceServer::new(self);
125        let listener = tokio::net::TcpListener::from_std(listener)?;
126
127        Ok(Server::builder()
128            .add_service(svc)
129            .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
130            .await?)
131    }
132
133    async fn new_context<T>(
134        &self,
135        request: Request<T>,
136    ) -> Result<(Request<T>, FlightSqlSessionContext)> {
137        let (metadata, extensions, msg) = request.into_parts();
138        let inspect_request = Request::from_parts(metadata, extensions, ());
139
140        let state = self.provider.new_context(&inspect_request).await?;
141        let ctx = SessionContext::new_with_state(state);
142
143        let (metadata, extensions, _) = inspect_request.into_parts();
144        Ok((
145            Request::from_parts(metadata, extensions, msg),
146            FlightSqlSessionContext {
147                inner: ctx,
148                sql_options: self.sql_options,
149            },
150        ))
151    }
152}
153
154/// The schema for GetTableTypes
155static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
156    //TODO: Move this into arrow-flight itself, similar to the builder pattern for CommandGetCatalogs and CommandGetDbSchemas
157    Arc::new(Schema::new(vec![Field::new(
158        "table_type",
159        DataType::Utf8,
160        false,
161    )]))
162});
163
164struct FlightSqlSessionContext {
165    inner: SessionContext,
166    sql_options: Option<SQLOptions>,
167}
168
169impl FlightSqlSessionContext {
170    async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
171        let plan = self.inner.state().create_logical_plan(sql).await?;
172        let verifier = self.sql_options.unwrap_or_default();
173        verifier.verify_plan(&plan)?;
174        Ok(plan)
175    }
176
177    async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
178        let plan = self.sql_to_logical_plan(sql).await?;
179        self.execute_logical_plan(plan).await
180    }
181
182    async fn execute_logical_plan(
183        &self,
184        plan: LogicalPlan,
185    ) -> DataFusionResult<SendableRecordBatchStream> {
186        self.inner
187            .execute_logical_plan(plan)
188            .await?
189            .execute_stream()
190            .await
191    }
192}
193
194#[tonic::async_trait]
195impl ArrowFlightSqlService for FlightSqlService {
196    type FlightService = FlightSqlService;
197
198    async fn do_handshake(
199        &self,
200        _request: Request<Streaming<HandshakeRequest>>,
201    ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
202        info!("do_handshake");
203        // Favor middleware over handshake
204        // https://github.com/apache/arrow/issues/23836
205        // https://github.com/apache/arrow/issues/25848
206        Err(Status::unimplemented("handshake is not supported"))
207    }
208
209    async fn do_get_fallback(
210        &self,
211        request: Request<Ticket>,
212        _message: Any,
213    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
214        let (request, ctx) = self.new_context(request).await?;
215
216        let ticket = CommandTicket::try_decode(request.into_inner().ticket)
217            .map_err(flight_error_to_status)?;
218
219        match ticket.command {
220            sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
221                // print!("Query: {query}\n");
222
223                let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
224                let arrow_schema = stream.schema();
225                let arrow_stream = stream.map(|i| {
226                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
227                    Ok(batch)
228                });
229
230                let flight_data_stream = FlightDataEncoderBuilder::new()
231                    .with_schema(arrow_schema)
232                    .build(arrow_stream)
233                    .map_err(flight_error_to_status)
234                    .boxed();
235
236                Ok(Response::new(flight_data_stream))
237            }
238            sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
239                prepared_statement_handle,
240            }) => {
241                let handle = QueryHandle::try_decode(prepared_statement_handle)?;
242
243                let mut plan = ctx
244                    .sql_to_logical_plan(handle.query())
245                    .await
246                    .map_err(df_error_to_status)?;
247
248                if let Some(param_values) =
249                    decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
250                {
251                    plan = plan
252                        .with_param_values(param_values)
253                        .map_err(df_error_to_status)?;
254                }
255
256                let stream = ctx
257                    .execute_logical_plan(plan)
258                    .await
259                    .map_err(df_error_to_status)?;
260                let arrow_schema = stream.schema();
261                let arrow_stream = stream.map(|i| {
262                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
263                    Ok(batch)
264                });
265
266                let flight_data_stream = FlightDataEncoderBuilder::new()
267                    .with_schema(arrow_schema)
268                    .build(arrow_stream)
269                    .map_err(flight_error_to_status)
270                    .boxed();
271
272                Ok(Response::new(flight_data_stream))
273            }
274            sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
275                plan,
276                ..
277            }) => {
278                let substrait_bytes = &plan
279                    .ok_or(Status::invalid_argument(
280                        "Expected substrait plan, found None",
281                    ))?
282                    .plan;
283
284                let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
285
286                let state = ctx.inner.state();
287                let df = DataFrame::new(state, plan);
288
289                let stream = df.execute_stream().await.map_err(df_error_to_status)?;
290                let arrow_schema = stream.schema();
291                let arrow_stream = stream.map(|i| {
292                    let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
293                    Ok(batch)
294                });
295
296                let flight_data_stream = FlightDataEncoderBuilder::new()
297                    .with_schema(arrow_schema)
298                    .build(arrow_stream)
299                    .map_err(flight_error_to_status)
300                    .boxed();
301
302                Ok(Response::new(flight_data_stream))
303            }
304            _ => {
305                return Err(Status::internal(format!(
306                    "statement handle not found: {:?}",
307                    ticket.command
308                )));
309            }
310        }
311    }
312
313    async fn get_flight_info_statement(
314        &self,
315        query: CommandStatementQuery,
316        request: Request<FlightDescriptor>,
317    ) -> Result<Response<FlightInfo>> {
318        let (request, ctx) = self.new_context(request).await?;
319
320        let sql = &query.query;
321        info!("get_flight_info_statement with query={sql}");
322
323        let flight_descriptor = request.into_inner();
324
325        let plan = ctx
326            .sql_to_logical_plan(sql)
327            .await
328            .map_err(df_error_to_status)?;
329
330        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
331
332        // Form the response ticket (that the client will pass back to DoGet)
333        let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
334            .try_encode()
335            .map_err(flight_error_to_status)?;
336
337        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
338
339        let flight_info = FlightInfo::new()
340            .with_endpoint(endpoint)
341            // return descriptor we were passed
342            .with_descriptor(flight_descriptor)
343            .try_with_schema(dataset_schema.as_ref())
344            .map_err(arrow_error_to_status)?;
345
346        Ok(Response::new(flight_info))
347    }
348
349    async fn get_flight_info_substrait_plan(
350        &self,
351        query: CommandStatementSubstraitPlan,
352        request: Request<FlightDescriptor>,
353    ) -> Result<Response<FlightInfo>> {
354        info!("get_flight_info_substrait_plan");
355        let (request, ctx) = self.new_context(request).await?;
356
357        let substrait_bytes = &query
358            .plan
359            .as_ref()
360            .ok_or(Status::invalid_argument(
361                "Expected substrait plan, found None",
362            ))?
363            .plan;
364
365        let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
366
367        let flight_descriptor = request.into_inner();
368
369        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
370
371        // Form the response ticket (that the client will pass back to DoGet)
372        let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
373            .try_encode()
374            .map_err(flight_error_to_status)?;
375
376        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
377
378        let flight_info = FlightInfo::new()
379            .with_endpoint(endpoint)
380            // return descriptor we were passed
381            .with_descriptor(flight_descriptor)
382            .try_with_schema(dataset_schema.as_ref())
383            .map_err(arrow_error_to_status)?;
384
385        Ok(Response::new(flight_info))
386    }
387
388    async fn get_flight_info_prepared_statement(
389        &self,
390        cmd: CommandPreparedStatementQuery,
391        request: Request<FlightDescriptor>,
392    ) -> Result<Response<FlightInfo>> {
393        let (request, ctx) = self.new_context(request).await?;
394
395        let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
396            .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
397
398        info!("get_flight_info_prepared_statement with handle={handle}");
399
400        let flight_descriptor = request.into_inner();
401
402        let sql = handle.query();
403        let plan = ctx
404            .sql_to_logical_plan(sql)
405            .await
406            .map_err(df_error_to_status)?;
407
408        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
409
410        // Form the response ticket (that the client will pass back to DoGet)
411        let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
412            .try_encode()
413            .map_err(flight_error_to_status)?;
414
415        let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
416
417        let flight_info = FlightInfo::new()
418            .with_endpoint(endpoint)
419            // return descriptor we were passed
420            .with_descriptor(flight_descriptor)
421            .try_with_schema(dataset_schema.as_ref())
422            .map_err(arrow_error_to_status)?;
423
424        Ok(Response::new(flight_info))
425    }
426
427    async fn get_flight_info_catalogs(
428        &self,
429        query: CommandGetCatalogs,
430        request: Request<FlightDescriptor>,
431    ) -> Result<Response<FlightInfo>> {
432        info!("get_flight_info_catalogs");
433        let (request, _ctx) = self.new_context(request).await?;
434
435        let flight_descriptor = request.into_inner();
436        let ticket = Ticket {
437            ticket: query.as_any().encode_to_vec().into(),
438        };
439        let endpoint = FlightEndpoint::new().with_ticket(ticket);
440
441        let flight_info = FlightInfo::new()
442            .try_with_schema(&query.into_builder().schema())
443            .map_err(arrow_error_to_status)?
444            .with_endpoint(endpoint)
445            .with_descriptor(flight_descriptor);
446
447        Ok(Response::new(flight_info))
448    }
449
450    async fn get_flight_info_schemas(
451        &self,
452        query: CommandGetDbSchemas,
453        request: Request<FlightDescriptor>,
454    ) -> Result<Response<FlightInfo>> {
455        info!("get_flight_info_schemas");
456        let (request, _ctx) = self.new_context(request).await?;
457        let flight_descriptor = request.into_inner();
458        let ticket = Ticket {
459            ticket: query.as_any().encode_to_vec().into(),
460        };
461        let endpoint = FlightEndpoint::new().with_ticket(ticket);
462
463        let flight_info = FlightInfo::new()
464            .try_with_schema(&query.into_builder().schema())
465            .map_err(arrow_error_to_status)?
466            .with_endpoint(endpoint)
467            .with_descriptor(flight_descriptor);
468
469        Ok(Response::new(flight_info))
470    }
471
472    async fn get_flight_info_tables(
473        &self,
474        query: CommandGetTables,
475        request: Request<FlightDescriptor>,
476    ) -> Result<Response<FlightInfo>> {
477        info!("get_flight_info_tables");
478        let (request, _ctx) = self.new_context(request).await?;
479
480        let flight_descriptor = request.into_inner();
481        let ticket = Ticket {
482            ticket: query.as_any().encode_to_vec().into(),
483        };
484        let endpoint = FlightEndpoint::new().with_ticket(ticket);
485
486        let flight_info = FlightInfo::new()
487            .try_with_schema(&query.into_builder().schema())
488            .map_err(arrow_error_to_status)?
489            .with_endpoint(endpoint)
490            .with_descriptor(flight_descriptor);
491
492        Ok(Response::new(flight_info))
493    }
494
495    async fn get_flight_info_table_types(
496        &self,
497        query: CommandGetTableTypes,
498        request: Request<FlightDescriptor>,
499    ) -> Result<Response<FlightInfo>> {
500        info!("get_flight_info_table_types");
501        let (request, _ctx) = self.new_context(request).await?;
502
503        let flight_descriptor = request.into_inner();
504        let ticket = Ticket {
505            ticket: query.as_any().encode_to_vec().into(),
506        };
507        let endpoint = FlightEndpoint::new().with_ticket(ticket);
508
509        let flight_info = FlightInfo::new()
510            .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
511            .map_err(arrow_error_to_status)?
512            .with_endpoint(endpoint)
513            .with_descriptor(flight_descriptor);
514
515        Ok(Response::new(flight_info))
516    }
517
518    async fn get_flight_info_sql_info(
519        &self,
520        _query: CommandGetSqlInfo,
521        request: Request<FlightDescriptor>,
522    ) -> Result<Response<FlightInfo>> {
523        info!("get_flight_info_sql_info");
524        let (_, _) = self.new_context(request).await?;
525
526        Err(Status::unimplemented("Implement CommandGetSqlInfo"))
527    }
528
529    async fn get_flight_info_primary_keys(
530        &self,
531        _query: CommandGetPrimaryKeys,
532        request: Request<FlightDescriptor>,
533    ) -> Result<Response<FlightInfo>> {
534        info!("get_flight_info_primary_keys");
535        let (_, _) = self.new_context(request).await?;
536
537        Err(Status::unimplemented(
538            "Implement get_flight_info_primary_keys",
539        ))
540    }
541
542    async fn get_flight_info_exported_keys(
543        &self,
544        _query: CommandGetExportedKeys,
545        request: Request<FlightDescriptor>,
546    ) -> Result<Response<FlightInfo>> {
547        info!("get_flight_info_exported_keys");
548        let (_, _) = self.new_context(request).await?;
549
550        Err(Status::unimplemented(
551            "Implement get_flight_info_exported_keys",
552        ))
553    }
554
555    async fn get_flight_info_imported_keys(
556        &self,
557        _query: CommandGetImportedKeys,
558        request: Request<FlightDescriptor>,
559    ) -> Result<Response<FlightInfo>> {
560        info!("get_flight_info_imported_keys");
561        let (_, _) = self.new_context(request).await?;
562
563        Err(Status::unimplemented(
564            "Implement get_flight_info_imported_keys",
565        ))
566    }
567
568    async fn get_flight_info_cross_reference(
569        &self,
570        _query: CommandGetCrossReference,
571        request: Request<FlightDescriptor>,
572    ) -> Result<Response<FlightInfo>> {
573        info!("get_flight_info_cross_reference");
574        let (_, _) = self.new_context(request).await?;
575
576        Err(Status::unimplemented(
577            "Implement get_flight_info_cross_reference",
578        ))
579    }
580
581    async fn get_flight_info_xdbc_type_info(
582        &self,
583        _query: CommandGetXdbcTypeInfo,
584        request: Request<FlightDescriptor>,
585    ) -> Result<Response<FlightInfo>> {
586        info!("get_flight_info_xdbc_type_info");
587        let (_, _) = self.new_context(request).await?;
588
589        Err(Status::unimplemented(
590            "Implement get_flight_info_xdbc_type_info",
591        ))
592    }
593
594    async fn do_get_statement(
595        &self,
596        _ticket: TicketStatementQuery,
597        request: Request<Ticket>,
598    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
599        info!("do_get_statement");
600        let (_, _) = self.new_context(request).await?;
601
602        Err(Status::unimplemented("Implement do_get_statement"))
603    }
604
605    async fn do_get_prepared_statement(
606        &self,
607        _query: CommandPreparedStatementQuery,
608        request: Request<Ticket>,
609    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
610        info!("do_get_prepared_statement");
611        let (_, _) = self.new_context(request).await?;
612
613        Err(Status::unimplemented("Implement do_get_prepared_statement"))
614    }
615
616    async fn do_get_catalogs(
617        &self,
618        query: CommandGetCatalogs,
619        request: Request<Ticket>,
620    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
621        info!("do_get_catalogs");
622        let (_request, ctx) = self.new_context(request).await?;
623        let catalog_names = ctx.inner.catalog_names();
624
625        let mut builder = query.into_builder();
626        for catalog_name in &catalog_names {
627            builder.append(catalog_name);
628        }
629        let schema = builder.schema();
630        let batch = builder.build();
631        let stream = FlightDataEncoderBuilder::new()
632            .with_schema(schema)
633            .build(futures::stream::once(async { batch }))
634            .map_err(Status::from);
635        Ok(Response::new(Box::pin(stream)))
636    }
637
638    async fn do_get_schemas(
639        &self,
640        query: CommandGetDbSchemas,
641        request: Request<Ticket>,
642    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
643        info!("do_get_schemas");
644        let (_request, ctx) = self.new_context(request).await?;
645        let catalog_name = query.catalog.clone();
646        // Append all schemas to builder, the builder handles applying the filters.
647        let mut builder = query.into_builder();
648        if let Some(catalog_name) = &catalog_name {
649            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
650                for schema_name in &catalog.schema_names() {
651                    builder.append(catalog_name, schema_name);
652                }
653            }
654        };
655
656        let schema = builder.schema();
657        let batch = builder.build();
658        let stream = FlightDataEncoderBuilder::new()
659            .with_schema(schema)
660            .build(futures::stream::once(async { batch }))
661            .map_err(Status::from);
662        Ok(Response::new(Box::pin(stream)))
663    }
664
665    async fn do_get_tables(
666        &self,
667        query: CommandGetTables,
668        request: Request<Ticket>,
669    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
670        info!("do_get_tables");
671        let (_request, ctx) = self.new_context(request).await?;
672        let catalog_name = query.catalog.clone();
673        let mut builder = query.into_builder();
674        // Append all schemas/tables to builder, the builder handles applying the filters.
675        if let Some(catalog_name) = &catalog_name {
676            if let Some(catalog) = ctx.inner.catalog(catalog_name) {
677                for schema_name in &catalog.schema_names() {
678                    if let Some(schema) = catalog.schema(schema_name) {
679                        for table_name in &schema.table_names() {
680                            if let Some(table) =
681                                schema.table(table_name).await.map_err(df_error_to_status)?
682                            {
683                                builder
684                                    .append(
685                                        catalog_name,
686                                        schema_name,
687                                        table_name,
688                                        table.table_type().to_string(),
689                                        &table.schema(),
690                                    )
691                                    .map_err(flight_error_to_status)?;
692                            }
693                        }
694                    }
695                }
696            }
697        };
698
699        let schema = builder.schema();
700        let batch = builder.build();
701        let stream = FlightDataEncoderBuilder::new()
702            .with_schema(schema)
703            .build(futures::stream::once(async { batch }))
704            .map_err(Status::from);
705        Ok(Response::new(Box::pin(stream)))
706    }
707
708    async fn do_get_table_types(
709        &self,
710        _query: CommandGetTableTypes,
711        request: Request<Ticket>,
712    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
713        info!("do_get_table_types");
714        let (_, _) = self.new_context(request).await?;
715
716        // Report all variants of table types that datafusion uses.
717        let table_types: ArrayRef = Arc::new(StringArray::from(
718            vec![TableType::Base, TableType::View, TableType::Temporary]
719                .into_iter()
720                .map(|tt| tt.to_string())
721                .collect::<Vec<String>>(),
722        ));
723
724        let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
725
726        let stream = FlightDataEncoderBuilder::new()
727            .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
728            .build(futures::stream::once(async { Ok(batch) }))
729            .map_err(Status::from);
730        Ok(Response::new(Box::pin(stream)))
731    }
732
733    async fn do_get_sql_info(
734        &self,
735        _query: CommandGetSqlInfo,
736        request: Request<Ticket>,
737    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
738        info!("do_get_sql_info");
739        let (_, _) = self.new_context(request).await?;
740
741        Err(Status::unimplemented("Implement do_get_sql_info"))
742    }
743
744    async fn do_get_primary_keys(
745        &self,
746        _query: CommandGetPrimaryKeys,
747        request: Request<Ticket>,
748    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
749        info!("do_get_primary_keys");
750        let (_, _) = self.new_context(request).await?;
751
752        Err(Status::unimplemented("Implement do_get_primary_keys"))
753    }
754
755    async fn do_get_exported_keys(
756        &self,
757        _query: CommandGetExportedKeys,
758        request: Request<Ticket>,
759    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
760        info!("do_get_exported_keys");
761        let (_, _) = self.new_context(request).await?;
762
763        Err(Status::unimplemented("Implement do_get_exported_keys"))
764    }
765
766    async fn do_get_imported_keys(
767        &self,
768        _query: CommandGetImportedKeys,
769        request: Request<Ticket>,
770    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
771        info!("do_get_imported_keys");
772        let (_, _) = self.new_context(request).await?;
773
774        Err(Status::unimplemented("Implement do_get_imported_keys"))
775    }
776
777    async fn do_get_cross_reference(
778        &self,
779        _query: CommandGetCrossReference,
780        request: Request<Ticket>,
781    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
782        info!("do_get_cross_reference");
783        let (_, _) = self.new_context(request).await?;
784
785        Err(Status::unimplemented("Implement do_get_cross_reference"))
786    }
787
788    async fn do_get_xdbc_type_info(
789        &self,
790        _query: CommandGetXdbcTypeInfo,
791        request: Request<Ticket>,
792    ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
793        info!("do_get_xdbc_type_info");
794        let (_, _) = self.new_context(request).await?;
795
796        Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
797    }
798
799    async fn do_put_statement_update(
800        &self,
801        _ticket: CommandStatementUpdate,
802        request: Request<PeekableFlightDataStream>,
803    ) -> Result<i64, Status> {
804        info!("do_put_statement_update");
805        let (_, _) = self.new_context(request).await?;
806
807        Err(Status::unimplemented("Implement do_put_statement_update"))
808    }
809
810    async fn do_put_prepared_statement_query(
811        &self,
812        query: CommandPreparedStatementQuery,
813        request: Request<PeekableFlightDataStream>,
814    ) -> Result<DoPutPreparedStatementResult, Status> {
815        info!("do_put_prepared_statement_query");
816        let (request, _) = self.new_context(request).await?;
817
818        let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
819
820        info!(
821            "do_action_create_prepared_statement query={:?}",
822            handle.query()
823        );
824        // Collect request flight data as parameters
825        // Decode and encode as a single ipc stream
826        let mut decoder =
827            FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
828        let schema = decode_schema(&mut decoder).await?;
829        let mut parameters = Vec::new();
830        let mut encoder =
831            StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
832        let mut total_rows = 0;
833        while let Some(msg) = decoder.try_next().await? {
834            match msg.payload {
835                DecodedPayload::None => {}
836                DecodedPayload::Schema(_) => {
837                    return Err(Status::invalid_argument(
838                        "parameter flight data must contain a single schema",
839                    ));
840                }
841                DecodedPayload::RecordBatch(record_batch) => {
842                    total_rows += record_batch.num_rows();
843                    encoder
844                        .write(&record_batch)
845                        .map_err(arrow_error_to_status)?;
846                }
847            }
848        }
849        if total_rows > 1 {
850            return Err(Status::invalid_argument(
851                "parameters should contain a single row",
852            ));
853        }
854
855        handle.set_parameters(Some(parameters.into()));
856
857        let res = DoPutPreparedStatementResult {
858            prepared_statement_handle: Some(Bytes::from(handle)),
859        };
860
861        Ok(res)
862    }
863
864    async fn do_put_prepared_statement_update(
865        &self,
866        _handle: CommandPreparedStatementUpdate,
867        request: Request<PeekableFlightDataStream>,
868    ) -> Result<i64, Status> {
869        info!("do_put_prepared_statement_update");
870        let (_, _) = self.new_context(request).await?;
871
872        // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function
873        // and we are required to return some row count here
874        Ok(-1)
875    }
876
877    async fn do_put_substrait_plan(
878        &self,
879        _query: CommandStatementSubstraitPlan,
880        request: Request<PeekableFlightDataStream>,
881    ) -> Result<i64, Status> {
882        info!("do_put_prepared_statement_update");
883        let (_, _) = self.new_context(request).await?;
884
885        Err(Status::unimplemented(
886            "Implement do_put_prepared_statement_update",
887        ))
888    }
889
890    async fn do_action_create_prepared_statement(
891        &self,
892        query: ActionCreatePreparedStatementRequest,
893        request: Request<Action>,
894    ) -> Result<ActionCreatePreparedStatementResult, Status> {
895        let (_, ctx) = self.new_context(request).await?;
896
897        let sql = query.query.clone();
898        info!(
899            "do_action_create_prepared_statement query={:?}",
900            query.query
901        );
902
903        let plan = ctx
904            .sql_to_logical_plan(sql.as_str())
905            .await
906            .map_err(df_error_to_status)?;
907
908        let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
909        let parameter_schema = parameter_schema_for_plan(&plan).map_err(|e| e.as_ref().clone())?;
910
911        let dataset_schema =
912            encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
913        let parameter_schema =
914            encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
915
916        let handle = QueryHandle::new(sql, None);
917
918        let res = ActionCreatePreparedStatementResult {
919            prepared_statement_handle: Bytes::from(handle),
920            dataset_schema,
921            parameter_schema,
922        };
923
924        Ok(res)
925    }
926
927    async fn do_action_close_prepared_statement(
928        &self,
929        query: ActionClosePreparedStatementRequest,
930        request: Request<Action>,
931    ) -> Result<(), Status> {
932        let (_, _) = self.new_context(request).await?;
933
934        let handle = query.prepared_statement_handle.as_ref();
935        if let Ok(handle) = std::str::from_utf8(handle) {
936            info!("do_action_close_prepared_statement with handle {handle:?}",);
937
938            // NOP since stateless
939        }
940        Ok(())
941    }
942
943    async fn do_action_create_prepared_substrait_plan(
944        &self,
945        _query: ActionCreatePreparedSubstraitPlanRequest,
946        request: Request<Action>,
947    ) -> Result<ActionCreatePreparedStatementResult, Status> {
948        info!("do_action_create_prepared_substrait_plan");
949        let (_, _) = self.new_context(request).await?;
950
951        Err(Status::unimplemented(
952            "Implement do_action_create_prepared_substrait_plan",
953        ))
954    }
955
956    async fn do_action_begin_transaction(
957        &self,
958        _query: ActionBeginTransactionRequest,
959        request: Request<Action>,
960    ) -> Result<ActionBeginTransactionResult, Status> {
961        let (_, _) = self.new_context(request).await?;
962
963        info!("do_action_begin_transaction");
964        Err(Status::unimplemented(
965            "Implement do_action_begin_transaction",
966        ))
967    }
968
969    async fn do_action_end_transaction(
970        &self,
971        _query: ActionEndTransactionRequest,
972        request: Request<Action>,
973    ) -> Result<(), Status> {
974        info!("do_action_end_transaction");
975        let (_, _) = self.new_context(request).await?;
976
977        Err(Status::unimplemented("Implement do_action_end_transaction"))
978    }
979
980    async fn do_action_begin_savepoint(
981        &self,
982        _query: ActionBeginSavepointRequest,
983        request: Request<Action>,
984    ) -> Result<ActionBeginSavepointResult, Status> {
985        info!("do_action_begin_savepoint");
986        let (_, _) = self.new_context(request).await?;
987
988        Err(Status::unimplemented("Implement do_action_begin_savepoint"))
989    }
990
991    async fn do_action_end_savepoint(
992        &self,
993        _query: ActionEndSavepointRequest,
994        request: Request<Action>,
995    ) -> Result<(), Status> {
996        info!("do_action_end_savepoint");
997        let (_, _) = self.new_context(request).await?;
998
999        Err(Status::unimplemented("Implement do_action_end_savepoint"))
1000    }
1001
1002    async fn do_action_cancel_query(
1003        &self,
1004        _query: ActionCancelQueryRequest,
1005        request: Request<Action>,
1006    ) -> Result<ActionCancelQueryResult, Status> {
1007        info!("do_action_cancel_query");
1008        let (_, _) = self.new_context(request).await?;
1009
1010        Err(Status::unimplemented("Implement do_action_cancel_query"))
1011    }
1012
1013    async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
1014}
1015
1016/// Takes a substrait plan serialized as [Bytes] and deserializes this to
1017/// a Datafusion [LogicalPlan]
1018async fn parse_substrait_bytes(
1019    ctx: &FlightSqlSessionContext,
1020    substrait: &Bytes,
1021) -> Result<LogicalPlan> {
1022    let substrait_plan = deserialize_bytes(substrait.to_vec())
1023        .await
1024        .map_err(df_error_to_status)?;
1025
1026    from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1027        .await
1028        .map_err(df_error_to_status)
1029}
1030
1031/// Encodes the schema IPC encoded (schema_bytes)
1032fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1033    let options = IpcWriteOptions::default();
1034
1035    // encode the schema into the correct form
1036    let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1037
1038    let IpcMessage(schema) = message?;
1039
1040    Ok(schema)
1041}
1042
1043/// Return the schema for the specified logical plan
1044fn get_schema_for_plan(logical_plan: &LogicalPlan, with_metadata: bool) -> SchemaRef {
1045    let schema: SchemaRef = if with_metadata {
1046        // Get the DFSchema which contains table qualifiers
1047        let df_schema = logical_plan.schema();
1048
1049        // Convert to Arrow Schema and add table name metadata to fields
1050        let fields_with_metadata: Vec<_> = df_schema
1051            .iter()
1052            .map(|(qualifier, field)| {
1053                // If there's a table qualifier, add it as metadata
1054                if let Some(table_ref) = qualifier {
1055                    let mut metadata = field.metadata().clone();
1056                    metadata.insert("table_name".to_string(), table_ref.to_string());
1057                    field.as_ref().clone().with_metadata(metadata)
1058                } else {
1059                    field.as_ref().clone()
1060                }
1061            })
1062            .collect();
1063
1064        Arc::new(Schema::new_with_metadata(
1065            fields_with_metadata,
1066            df_schema.as_ref().metadata().clone(),
1067        ))
1068    } else {
1069        Arc::new(logical_plan.schema().as_arrow().clone())
1070    };
1071
1072    // Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
1073    // This is necessary as the schema can change based on dictionary hydration behavior.
1074    let flight_data_stream = FlightDataEncoderBuilder::new()
1075        // Inform the builder of the input stream schema
1076        .with_schema(schema)
1077        .build(futures::stream::iter([]));
1078
1079    // Retrieve the schema of the encoded data
1080    flight_data_stream
1081        .known_schema()
1082        .expect("flight data schema should be known when explicitly provided via `with_schema`")
1083}
1084
1085fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Box<Status>> {
1086    let parameters = plan
1087        .get_parameter_types()
1088        .map_err(df_error_to_status)?
1089        .into_iter()
1090        .map(|(name, dt)| {
1091            dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1092                Status::internal(format!(
1093                    "unable to determine type of query parameter {name}"
1094                ))
1095            })
1096        })
1097        // Collect into BTreeMap so we get a consistent order of the parameters
1098        .collect::<Result<BTreeMap<_, _>, Status>>()?;
1099
1100    let mut builder = SchemaBuilder::new();
1101    parameters
1102        .into_iter()
1103        .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1104    Ok(builder.finish().into())
1105}
1106
1107fn arrow_error_to_status(err: ArrowError) -> Status {
1108    Status::internal(format!("{err:?}"))
1109}
1110
1111fn flight_error_to_status(err: FlightError) -> Status {
1112    Status::internal(format!("{err:?}"))
1113}
1114
1115fn df_error_to_status(err: DataFusionError) -> Status {
1116    Status::internal(format!("{err:?}"))
1117}
1118
1119fn status_to_flight_error(status: Status) -> FlightError {
1120    FlightError::Tonic(Box::new(status))
1121}
1122
1123async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1124    while let Some(msg) = decoder.try_next().await? {
1125        match msg.payload {
1126            DecodedPayload::None => {}
1127            DecodedPayload::Schema(schema) => {
1128                return Ok(schema);
1129            }
1130            DecodedPayload::RecordBatch(_) => {
1131                return Err(Status::invalid_argument(
1132                    "parameter flight data must have a known schema",
1133                ));
1134            }
1135        }
1136    }
1137
1138    Err(Status::invalid_argument(
1139        "parameter flight data must have a schema",
1140    ))
1141}
1142
1143// Decode parameter ipc stream as ParamValues
1144fn decode_param_values(parameters: Option<&[u8]>) -> Result<Option<ParamValues>, ArrowError> {
1145    parameters
1146        .map(|parameters| {
1147            let decoder = StreamReader::try_new(parameters, None)?;
1148            let schema = decoder.schema();
1149            let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1150            let batch = concat_batches(&schema, batches.iter())?;
1151            Ok(record_to_param_values(&batch)?)
1152        })
1153        .transpose()
1154}
1155
1156// Converts a record batch with a single row into ParamValues
1157fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1158    let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1159
1160    let mut is_list = true;
1161    for col_index in 0..batch.num_columns() {
1162        let array = batch.column(col_index);
1163        let scalar = ScalarValue::try_from_array(array, 0)?;
1164        let name = batch
1165            .schema_ref()
1166            .field(col_index)
1167            .name()
1168            .trim_start_matches('$')
1169            .to_string();
1170        let index = name.parse().ok();
1171        is_list &= index.is_some();
1172        param_values.push((name, index, scalar));
1173    }
1174    if is_list {
1175        let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1176            .into_iter()
1177            .map(|(_name, index, value)| (index, value))
1178            .collect();
1179        values.sort_by_key(|(index, _value)| *index);
1180        Ok(values
1181            .into_iter()
1182            .map(|(_index, value)| value)
1183            .collect::<Vec<ScalarValue>>()
1184            .into())
1185    } else {
1186        Ok(param_values
1187            .into_iter()
1188            .map(|(name, _index, value)| (name, value))
1189            .collect::<Vec<(String, ScalarValue)>>()
1190            .into())
1191    }
1192}