1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow::{
4 array::{ArrayRef, RecordBatch, StringArray},
5 compute::concat_batches,
6 datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
7 error::ArrowError,
8 ipc::{
9 reader::StreamReader,
10 writer::{IpcWriteOptions, StreamWriter},
11 },
12};
13use arrow_flight::{
14 decode::{DecodedPayload, FlightDataDecoder},
15 sql::{
16 self,
17 server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
18 ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
19 ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
20 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
21 ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
22 ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
23 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
24 CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
25 CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
26 CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
27 CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
28 TicketStatementQuery,
29 },
30};
31use arrow_flight::{
32 encode::FlightDataEncoderBuilder,
33 error::FlightError,
34 flight_service_server::{FlightService, FlightServiceServer},
35 Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
36 IpcMessage, SchemaAsIpc, Ticket,
37};
38use datafusion::{
39 common::{arrow::datatypes::Schema, ParamValues},
40 dataframe::DataFrame,
41 datasource::TableType,
42 error::{DataFusionError, Result as DataFusionResult},
43 execution::context::{SQLOptions, SessionContext, SessionState},
44 logical_expr::LogicalPlan,
45 physical_plan::SendableRecordBatchStream,
46 scalar::ScalarValue,
47};
48use datafusion_substrait::{
49 logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51use futures::{Stream, StreamExt, TryStreamExt};
52use log::info;
53use once_cell::sync::Lazy;
54use prost::bytes::Bytes;
55use prost::Message;
56use tonic::transport::Server;
57use tonic::{Request, Response, Status, Streaming};
58
59use super::session::{SessionStateProvider, StaticSessionStateProvider};
60use super::state::{CommandTicket, QueryHandle};
61
62type Result<T, E = Status> = std::result::Result<T, E>;
63
64pub struct FlightSqlService {
66 provider: Box<dyn SessionStateProvider>,
67 sql_options: Option<SQLOptions>,
68}
69
70impl FlightSqlService {
71 pub fn new(state: SessionState) -> Self {
73 Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
74 }
75
76 pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
78 Self {
79 provider,
80 sql_options: None,
81 }
82 }
83
84 pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
88 Self {
89 sql_options: Some(sql_options),
90 ..self
91 }
92 }
93
94 pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
101 let addr = addr.parse()?;
102 info!("Listening on {addr:?}");
103
104 let svc = FlightServiceServer::new(self);
105
106 Ok(Server::builder().add_service(svc).serve(addr).await?)
107 }
108
109 async fn new_context<T>(
110 &self,
111 request: Request<T>,
112 ) -> Result<(Request<T>, FlightSqlSessionContext)> {
113 let (metadata, extensions, msg) = request.into_parts();
114 let inspect_request = Request::from_parts(metadata, extensions, ());
115
116 let state = self.provider.new_context(&inspect_request).await?;
117 let ctx = SessionContext::new_with_state(state);
118
119 let (metadata, extensions, _) = inspect_request.into_parts();
120 Ok((
121 Request::from_parts(metadata, extensions, msg),
122 FlightSqlSessionContext {
123 inner: ctx,
124 sql_options: self.sql_options,
125 },
126 ))
127 }
128}
129
130static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
132 Arc::new(Schema::new(vec![Field::new(
134 "table_type",
135 DataType::Utf8,
136 false,
137 )]))
138});
139
140struct FlightSqlSessionContext {
141 inner: SessionContext,
142 sql_options: Option<SQLOptions>,
143}
144
145impl FlightSqlSessionContext {
146 async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
147 let plan = self.inner.state().create_logical_plan(sql).await?;
148 let verifier = self.sql_options.unwrap_or_default();
149 verifier.verify_plan(&plan)?;
150 Ok(plan)
151 }
152
153 async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
154 let plan = self.sql_to_logical_plan(sql).await?;
155 self.execute_logical_plan(plan).await
156 }
157
158 async fn execute_logical_plan(
159 &self,
160 plan: LogicalPlan,
161 ) -> DataFusionResult<SendableRecordBatchStream> {
162 self.inner
163 .execute_logical_plan(plan)
164 .await?
165 .execute_stream()
166 .await
167 }
168}
169
170#[tonic::async_trait]
171impl ArrowFlightSqlService for FlightSqlService {
172 type FlightService = FlightSqlService;
173
174 async fn do_handshake(
175 &self,
176 _request: Request<Streaming<HandshakeRequest>>,
177 ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
178 info!("do_handshake");
179 Err(Status::unimplemented("handshake is not supported"))
183 }
184
185 async fn do_get_fallback(
186 &self,
187 request: Request<Ticket>,
188 _message: Any,
189 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
190 let (request, ctx) = self.new_context(request).await?;
191
192 let ticket = CommandTicket::try_decode(request.into_inner().ticket)
193 .map_err(flight_error_to_status)?;
194
195 match ticket.command {
196 sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
197 let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
200 let arrow_schema = stream.schema();
201 let arrow_stream = stream.map(|i| {
202 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
203 Ok(batch)
204 });
205
206 let flight_data_stream = FlightDataEncoderBuilder::new()
207 .with_schema(arrow_schema)
208 .build(arrow_stream)
209 .map_err(flight_error_to_status)
210 .boxed();
211
212 Ok(Response::new(flight_data_stream))
213 }
214 sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
215 prepared_statement_handle,
216 }) => {
217 let handle = QueryHandle::try_decode(prepared_statement_handle)?;
218
219 let mut plan = ctx
220 .sql_to_logical_plan(handle.query())
221 .await
222 .map_err(df_error_to_status)?;
223
224 if let Some(param_values) =
225 decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
226 {
227 plan = plan
228 .with_param_values(param_values)
229 .map_err(df_error_to_status)?;
230 }
231
232 let stream = ctx
233 .execute_logical_plan(plan)
234 .await
235 .map_err(df_error_to_status)?;
236 let arrow_schema = stream.schema();
237 let arrow_stream = stream.map(|i| {
238 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
239 Ok(batch)
240 });
241
242 let flight_data_stream = FlightDataEncoderBuilder::new()
243 .with_schema(arrow_schema)
244 .build(arrow_stream)
245 .map_err(flight_error_to_status)
246 .boxed();
247
248 Ok(Response::new(flight_data_stream))
249 }
250 sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
251 plan,
252 ..
253 }) => {
254 let substrait_bytes = &plan
255 .ok_or(Status::invalid_argument(
256 "Expected substrait plan, found None",
257 ))?
258 .plan;
259
260 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
261
262 let state = ctx.inner.state();
263 let df = DataFrame::new(state, plan);
264
265 let stream = df.execute_stream().await.map_err(df_error_to_status)?;
266 let arrow_schema = stream.schema();
267 let arrow_stream = stream.map(|i| {
268 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
269 Ok(batch)
270 });
271
272 let flight_data_stream = FlightDataEncoderBuilder::new()
273 .with_schema(arrow_schema)
274 .build(arrow_stream)
275 .map_err(flight_error_to_status)
276 .boxed();
277
278 Ok(Response::new(flight_data_stream))
279 }
280 _ => {
281 return Err(Status::internal(format!(
282 "statement handle not found: {:?}",
283 ticket.command
284 )));
285 }
286 }
287 }
288
289 async fn get_flight_info_statement(
290 &self,
291 query: CommandStatementQuery,
292 request: Request<FlightDescriptor>,
293 ) -> Result<Response<FlightInfo>> {
294 let (request, ctx) = self.new_context(request).await?;
295
296 let sql = &query.query;
297 info!("get_flight_info_statement with query={sql}");
298
299 let flight_descriptor = request.into_inner();
300
301 let plan = ctx
302 .sql_to_logical_plan(sql)
303 .await
304 .map_err(df_error_to_status)?;
305
306 let dataset_schema = get_schema_for_plan(&plan);
307
308 let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
310 .try_encode()
311 .map_err(flight_error_to_status)?;
312
313 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
314
315 let flight_info = FlightInfo::new()
316 .with_endpoint(endpoint)
317 .with_descriptor(flight_descriptor)
319 .try_with_schema(dataset_schema.as_ref())
320 .map_err(arrow_error_to_status)?;
321
322 Ok(Response::new(flight_info))
323 }
324
325 async fn get_flight_info_substrait_plan(
326 &self,
327 query: CommandStatementSubstraitPlan,
328 request: Request<FlightDescriptor>,
329 ) -> Result<Response<FlightInfo>> {
330 info!("get_flight_info_substrait_plan");
331 let (request, ctx) = self.new_context(request).await?;
332
333 let substrait_bytes = &query
334 .plan
335 .as_ref()
336 .ok_or(Status::invalid_argument(
337 "Expected substrait plan, found None",
338 ))?
339 .plan;
340
341 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
342
343 let flight_descriptor = request.into_inner();
344
345 let dataset_schema = get_schema_for_plan(&plan);
346
347 let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
349 .try_encode()
350 .map_err(flight_error_to_status)?;
351
352 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
353
354 let flight_info = FlightInfo::new()
355 .with_endpoint(endpoint)
356 .with_descriptor(flight_descriptor)
358 .try_with_schema(dataset_schema.as_ref())
359 .map_err(arrow_error_to_status)?;
360
361 Ok(Response::new(flight_info))
362 }
363
364 async fn get_flight_info_prepared_statement(
365 &self,
366 cmd: CommandPreparedStatementQuery,
367 request: Request<FlightDescriptor>,
368 ) -> Result<Response<FlightInfo>> {
369 let (request, ctx) = self.new_context(request).await?;
370
371 let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
372 .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
373
374 info!("get_flight_info_prepared_statement with handle={handle}");
375
376 let flight_descriptor = request.into_inner();
377
378 let sql = handle.query();
379 let plan = ctx
380 .sql_to_logical_plan(sql)
381 .await
382 .map_err(df_error_to_status)?;
383
384 let dataset_schema = get_schema_for_plan(&plan);
385
386 let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
388 .try_encode()
389 .map_err(flight_error_to_status)?;
390
391 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
392
393 let flight_info = FlightInfo::new()
394 .with_endpoint(endpoint)
395 .with_descriptor(flight_descriptor)
397 .try_with_schema(dataset_schema.as_ref())
398 .map_err(arrow_error_to_status)?;
399
400 Ok(Response::new(flight_info))
401 }
402
403 async fn get_flight_info_catalogs(
404 &self,
405 query: CommandGetCatalogs,
406 request: Request<FlightDescriptor>,
407 ) -> Result<Response<FlightInfo>> {
408 info!("get_flight_info_catalogs");
409 let (request, _ctx) = self.new_context(request).await?;
410
411 let flight_descriptor = request.into_inner();
412 let ticket = Ticket {
413 ticket: query.as_any().encode_to_vec().into(),
414 };
415 let endpoint = FlightEndpoint::new().with_ticket(ticket);
416
417 let flight_info = FlightInfo::new()
418 .try_with_schema(&query.into_builder().schema())
419 .map_err(arrow_error_to_status)?
420 .with_endpoint(endpoint)
421 .with_descriptor(flight_descriptor);
422
423 Ok(Response::new(flight_info))
424 }
425
426 async fn get_flight_info_schemas(
427 &self,
428 query: CommandGetDbSchemas,
429 request: Request<FlightDescriptor>,
430 ) -> Result<Response<FlightInfo>> {
431 info!("get_flight_info_schemas");
432 let (request, _ctx) = self.new_context(request).await?;
433 let flight_descriptor = request.into_inner();
434 let ticket = Ticket {
435 ticket: query.as_any().encode_to_vec().into(),
436 };
437 let endpoint = FlightEndpoint::new().with_ticket(ticket);
438
439 let flight_info = FlightInfo::new()
440 .try_with_schema(&query.into_builder().schema())
441 .map_err(arrow_error_to_status)?
442 .with_endpoint(endpoint)
443 .with_descriptor(flight_descriptor);
444
445 Ok(Response::new(flight_info))
446 }
447
448 async fn get_flight_info_tables(
449 &self,
450 query: CommandGetTables,
451 request: Request<FlightDescriptor>,
452 ) -> Result<Response<FlightInfo>> {
453 info!("get_flight_info_tables");
454 let (request, _ctx) = self.new_context(request).await?;
455
456 let flight_descriptor = request.into_inner();
457 let ticket = Ticket {
458 ticket: query.as_any().encode_to_vec().into(),
459 };
460 let endpoint = FlightEndpoint::new().with_ticket(ticket);
461
462 let flight_info = FlightInfo::new()
463 .try_with_schema(&query.into_builder().schema())
464 .map_err(arrow_error_to_status)?
465 .with_endpoint(endpoint)
466 .with_descriptor(flight_descriptor);
467
468 Ok(Response::new(flight_info))
469 }
470
471 async fn get_flight_info_table_types(
472 &self,
473 query: CommandGetTableTypes,
474 request: Request<FlightDescriptor>,
475 ) -> Result<Response<FlightInfo>> {
476 info!("get_flight_info_table_types");
477 let (request, _ctx) = self.new_context(request).await?;
478
479 let flight_descriptor = request.into_inner();
480 let ticket = Ticket {
481 ticket: query.as_any().encode_to_vec().into(),
482 };
483 let endpoint = FlightEndpoint::new().with_ticket(ticket);
484
485 let flight_info = FlightInfo::new()
486 .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
487 .map_err(arrow_error_to_status)?
488 .with_endpoint(endpoint)
489 .with_descriptor(flight_descriptor);
490
491 Ok(Response::new(flight_info))
492 }
493
494 async fn get_flight_info_sql_info(
495 &self,
496 _query: CommandGetSqlInfo,
497 request: Request<FlightDescriptor>,
498 ) -> Result<Response<FlightInfo>> {
499 info!("get_flight_info_sql_info");
500 let (_, _) = self.new_context(request).await?;
501
502 Err(Status::unimplemented("Implement CommandGetSqlInfo"))
503 }
504
505 async fn get_flight_info_primary_keys(
506 &self,
507 _query: CommandGetPrimaryKeys,
508 request: Request<FlightDescriptor>,
509 ) -> Result<Response<FlightInfo>> {
510 info!("get_flight_info_primary_keys");
511 let (_, _) = self.new_context(request).await?;
512
513 Err(Status::unimplemented(
514 "Implement get_flight_info_primary_keys",
515 ))
516 }
517
518 async fn get_flight_info_exported_keys(
519 &self,
520 _query: CommandGetExportedKeys,
521 request: Request<FlightDescriptor>,
522 ) -> Result<Response<FlightInfo>> {
523 info!("get_flight_info_exported_keys");
524 let (_, _) = self.new_context(request).await?;
525
526 Err(Status::unimplemented(
527 "Implement get_flight_info_exported_keys",
528 ))
529 }
530
531 async fn get_flight_info_imported_keys(
532 &self,
533 _query: CommandGetImportedKeys,
534 request: Request<FlightDescriptor>,
535 ) -> Result<Response<FlightInfo>> {
536 info!("get_flight_info_imported_keys");
537 let (_, _) = self.new_context(request).await?;
538
539 Err(Status::unimplemented(
540 "Implement get_flight_info_imported_keys",
541 ))
542 }
543
544 async fn get_flight_info_cross_reference(
545 &self,
546 _query: CommandGetCrossReference,
547 request: Request<FlightDescriptor>,
548 ) -> Result<Response<FlightInfo>> {
549 info!("get_flight_info_cross_reference");
550 let (_, _) = self.new_context(request).await?;
551
552 Err(Status::unimplemented(
553 "Implement get_flight_info_cross_reference",
554 ))
555 }
556
557 async fn get_flight_info_xdbc_type_info(
558 &self,
559 _query: CommandGetXdbcTypeInfo,
560 request: Request<FlightDescriptor>,
561 ) -> Result<Response<FlightInfo>> {
562 info!("get_flight_info_xdbc_type_info");
563 let (_, _) = self.new_context(request).await?;
564
565 Err(Status::unimplemented(
566 "Implement get_flight_info_xdbc_type_info",
567 ))
568 }
569
570 async fn do_get_statement(
571 &self,
572 _ticket: TicketStatementQuery,
573 request: Request<Ticket>,
574 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
575 info!("do_get_statement");
576 let (_, _) = self.new_context(request).await?;
577
578 Err(Status::unimplemented("Implement do_get_statement"))
579 }
580
581 async fn do_get_prepared_statement(
582 &self,
583 _query: CommandPreparedStatementQuery,
584 request: Request<Ticket>,
585 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
586 info!("do_get_prepared_statement");
587 let (_, _) = self.new_context(request).await?;
588
589 Err(Status::unimplemented("Implement do_get_prepared_statement"))
590 }
591
592 async fn do_get_catalogs(
593 &self,
594 query: CommandGetCatalogs,
595 request: Request<Ticket>,
596 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
597 info!("do_get_catalogs");
598 let (_request, ctx) = self.new_context(request).await?;
599 let catalog_names = ctx.inner.catalog_names();
600
601 let mut builder = query.into_builder();
602 for catalog_name in &catalog_names {
603 builder.append(catalog_name);
604 }
605 let schema = builder.schema();
606 let batch = builder.build();
607 let stream = FlightDataEncoderBuilder::new()
608 .with_schema(schema)
609 .build(futures::stream::once(async { batch }))
610 .map_err(Status::from);
611 Ok(Response::new(Box::pin(stream)))
612 }
613
614 async fn do_get_schemas(
615 &self,
616 query: CommandGetDbSchemas,
617 request: Request<Ticket>,
618 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
619 info!("do_get_schemas");
620 let (_request, ctx) = self.new_context(request).await?;
621 let catalog_name = query.catalog.clone();
622 let mut builder = query.into_builder();
624 if let Some(catalog_name) = &catalog_name {
625 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
626 for schema_name in &catalog.schema_names() {
627 builder.append(catalog_name, schema_name);
628 }
629 }
630 };
631
632 let schema = builder.schema();
633 let batch = builder.build();
634 let stream = FlightDataEncoderBuilder::new()
635 .with_schema(schema)
636 .build(futures::stream::once(async { batch }))
637 .map_err(Status::from);
638 Ok(Response::new(Box::pin(stream)))
639 }
640
641 async fn do_get_tables(
642 &self,
643 query: CommandGetTables,
644 request: Request<Ticket>,
645 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
646 info!("do_get_tables");
647 let (_request, ctx) = self.new_context(request).await?;
648 let catalog_name = query.catalog.clone();
649 let mut builder = query.into_builder();
650 if let Some(catalog_name) = &catalog_name {
652 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
653 for schema_name in &catalog.schema_names() {
654 if let Some(schema) = catalog.schema(schema_name) {
655 for table_name in &schema.table_names() {
656 if let Some(table) =
657 schema.table(table_name).await.map_err(df_error_to_status)?
658 {
659 builder
660 .append(
661 catalog_name,
662 schema_name,
663 table_name,
664 table.table_type().to_string(),
665 &table.schema(),
666 )
667 .map_err(flight_error_to_status)?;
668 }
669 }
670 }
671 }
672 }
673 };
674
675 let schema = builder.schema();
676 let batch = builder.build();
677 let stream = FlightDataEncoderBuilder::new()
678 .with_schema(schema)
679 .build(futures::stream::once(async { batch }))
680 .map_err(Status::from);
681 Ok(Response::new(Box::pin(stream)))
682 }
683
684 async fn do_get_table_types(
685 &self,
686 _query: CommandGetTableTypes,
687 request: Request<Ticket>,
688 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
689 info!("do_get_table_types");
690 let (_, _) = self.new_context(request).await?;
691
692 let table_types: ArrayRef = Arc::new(StringArray::from(
694 vec![TableType::Base, TableType::View, TableType::Temporary]
695 .into_iter()
696 .map(|tt| tt.to_string())
697 .collect::<Vec<String>>(),
698 ));
699
700 let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
701
702 let stream = FlightDataEncoderBuilder::new()
703 .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
704 .build(futures::stream::once(async { Ok(batch) }))
705 .map_err(Status::from);
706 Ok(Response::new(Box::pin(stream)))
707 }
708
709 async fn do_get_sql_info(
710 &self,
711 _query: CommandGetSqlInfo,
712 request: Request<Ticket>,
713 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
714 info!("do_get_sql_info");
715 let (_, _) = self.new_context(request).await?;
716
717 Err(Status::unimplemented("Implement do_get_sql_info"))
718 }
719
720 async fn do_get_primary_keys(
721 &self,
722 _query: CommandGetPrimaryKeys,
723 request: Request<Ticket>,
724 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
725 info!("do_get_primary_keys");
726 let (_, _) = self.new_context(request).await?;
727
728 Err(Status::unimplemented("Implement do_get_primary_keys"))
729 }
730
731 async fn do_get_exported_keys(
732 &self,
733 _query: CommandGetExportedKeys,
734 request: Request<Ticket>,
735 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
736 info!("do_get_exported_keys");
737 let (_, _) = self.new_context(request).await?;
738
739 Err(Status::unimplemented("Implement do_get_exported_keys"))
740 }
741
742 async fn do_get_imported_keys(
743 &self,
744 _query: CommandGetImportedKeys,
745 request: Request<Ticket>,
746 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
747 info!("do_get_imported_keys");
748 let (_, _) = self.new_context(request).await?;
749
750 Err(Status::unimplemented("Implement do_get_imported_keys"))
751 }
752
753 async fn do_get_cross_reference(
754 &self,
755 _query: CommandGetCrossReference,
756 request: Request<Ticket>,
757 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
758 info!("do_get_cross_reference");
759 let (_, _) = self.new_context(request).await?;
760
761 Err(Status::unimplemented("Implement do_get_cross_reference"))
762 }
763
764 async fn do_get_xdbc_type_info(
765 &self,
766 _query: CommandGetXdbcTypeInfo,
767 request: Request<Ticket>,
768 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
769 info!("do_get_xdbc_type_info");
770 let (_, _) = self.new_context(request).await?;
771
772 Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
773 }
774
775 async fn do_put_statement_update(
776 &self,
777 _ticket: CommandStatementUpdate,
778 request: Request<PeekableFlightDataStream>,
779 ) -> Result<i64, Status> {
780 info!("do_put_statement_update");
781 let (_, _) = self.new_context(request).await?;
782
783 Err(Status::unimplemented("Implement do_put_statement_update"))
784 }
785
786 async fn do_put_prepared_statement_query(
787 &self,
788 query: CommandPreparedStatementQuery,
789 request: Request<PeekableFlightDataStream>,
790 ) -> Result<DoPutPreparedStatementResult, Status> {
791 info!("do_put_prepared_statement_query");
792 let (request, _) = self.new_context(request).await?;
793
794 let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
795
796 info!(
797 "do_action_create_prepared_statement query={:?}",
798 handle.query()
799 );
800 let mut decoder =
803 FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
804 let schema = decode_schema(&mut decoder).await?;
805 let mut parameters = Vec::new();
806 let mut encoder =
807 StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
808 let mut total_rows = 0;
809 while let Some(msg) = decoder.try_next().await? {
810 match msg.payload {
811 DecodedPayload::None => {}
812 DecodedPayload::Schema(_) => {
813 return Err(Status::invalid_argument(
814 "parameter flight data must contain a single schema",
815 ));
816 }
817 DecodedPayload::RecordBatch(record_batch) => {
818 total_rows += record_batch.num_rows();
819 encoder
820 .write(&record_batch)
821 .map_err(arrow_error_to_status)?;
822 }
823 }
824 }
825 if total_rows > 1 {
826 return Err(Status::invalid_argument(
827 "parameters should contain a single row",
828 ));
829 }
830
831 handle.set_parameters(Some(parameters.into()));
832
833 let res = DoPutPreparedStatementResult {
834 prepared_statement_handle: Some(Bytes::from(handle)),
835 };
836
837 Ok(res)
838 }
839
840 async fn do_put_prepared_statement_update(
841 &self,
842 _handle: CommandPreparedStatementUpdate,
843 request: Request<PeekableFlightDataStream>,
844 ) -> Result<i64, Status> {
845 info!("do_put_prepared_statement_update");
846 let (_, _) = self.new_context(request).await?;
847
848 Ok(-1)
851 }
852
853 async fn do_put_substrait_plan(
854 &self,
855 _query: CommandStatementSubstraitPlan,
856 request: Request<PeekableFlightDataStream>,
857 ) -> Result<i64, Status> {
858 info!("do_put_prepared_statement_update");
859 let (_, _) = self.new_context(request).await?;
860
861 Err(Status::unimplemented(
862 "Implement do_put_prepared_statement_update",
863 ))
864 }
865
866 async fn do_action_create_prepared_statement(
867 &self,
868 query: ActionCreatePreparedStatementRequest,
869 request: Request<Action>,
870 ) -> Result<ActionCreatePreparedStatementResult, Status> {
871 let (_, ctx) = self.new_context(request).await?;
872
873 let sql = query.query.clone();
874 info!(
875 "do_action_create_prepared_statement query={:?}",
876 query.query
877 );
878
879 let plan = ctx
880 .sql_to_logical_plan(sql.as_str())
881 .await
882 .map_err(df_error_to_status)?;
883
884 let dataset_schema = get_schema_for_plan(&plan);
885 let parameter_schema = parameter_schema_for_plan(&plan)?;
886
887 let dataset_schema =
888 encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
889 let parameter_schema =
890 encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
891
892 let handle = QueryHandle::new(sql, None);
893
894 let res = ActionCreatePreparedStatementResult {
895 prepared_statement_handle: Bytes::from(handle),
896 dataset_schema,
897 parameter_schema,
898 };
899
900 Ok(res)
901 }
902
903 async fn do_action_close_prepared_statement(
904 &self,
905 query: ActionClosePreparedStatementRequest,
906 request: Request<Action>,
907 ) -> Result<(), Status> {
908 let (_, _) = self.new_context(request).await?;
909
910 let handle = query.prepared_statement_handle.as_ref();
911 if let Ok(handle) = std::str::from_utf8(handle) {
912 info!(
913 "do_action_close_prepared_statement with handle {:?}",
914 handle
915 );
916
917 }
919 Ok(())
920 }
921
922 async fn do_action_create_prepared_substrait_plan(
923 &self,
924 _query: ActionCreatePreparedSubstraitPlanRequest,
925 request: Request<Action>,
926 ) -> Result<ActionCreatePreparedStatementResult, Status> {
927 info!("do_action_create_prepared_substrait_plan");
928 let (_, _) = self.new_context(request).await?;
929
930 Err(Status::unimplemented(
931 "Implement do_action_create_prepared_substrait_plan",
932 ))
933 }
934
935 async fn do_action_begin_transaction(
936 &self,
937 _query: ActionBeginTransactionRequest,
938 request: Request<Action>,
939 ) -> Result<ActionBeginTransactionResult, Status> {
940 let (_, _) = self.new_context(request).await?;
941
942 info!("do_action_begin_transaction");
943 Err(Status::unimplemented(
944 "Implement do_action_begin_transaction",
945 ))
946 }
947
948 async fn do_action_end_transaction(
949 &self,
950 _query: ActionEndTransactionRequest,
951 request: Request<Action>,
952 ) -> Result<(), Status> {
953 info!("do_action_end_transaction");
954 let (_, _) = self.new_context(request).await?;
955
956 Err(Status::unimplemented("Implement do_action_end_transaction"))
957 }
958
959 async fn do_action_begin_savepoint(
960 &self,
961 _query: ActionBeginSavepointRequest,
962 request: Request<Action>,
963 ) -> Result<ActionBeginSavepointResult, Status> {
964 info!("do_action_begin_savepoint");
965 let (_, _) = self.new_context(request).await?;
966
967 Err(Status::unimplemented("Implement do_action_begin_savepoint"))
968 }
969
970 async fn do_action_end_savepoint(
971 &self,
972 _query: ActionEndSavepointRequest,
973 request: Request<Action>,
974 ) -> Result<(), Status> {
975 info!("do_action_end_savepoint");
976 let (_, _) = self.new_context(request).await?;
977
978 Err(Status::unimplemented("Implement do_action_end_savepoint"))
979 }
980
981 async fn do_action_cancel_query(
982 &self,
983 _query: ActionCancelQueryRequest,
984 request: Request<Action>,
985 ) -> Result<ActionCancelQueryResult, Status> {
986 info!("do_action_cancel_query");
987 let (_, _) = self.new_context(request).await?;
988
989 Err(Status::unimplemented("Implement do_action_cancel_query"))
990 }
991
992 async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
993}
994
995async fn parse_substrait_bytes(
998 ctx: &FlightSqlSessionContext,
999 substrait: &Bytes,
1000) -> Result<LogicalPlan> {
1001 let substrait_plan = deserialize_bytes(substrait.to_vec())
1002 .await
1003 .map_err(df_error_to_status)?;
1004
1005 from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1006 .await
1007 .map_err(df_error_to_status)
1008}
1009
1010fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1012 let options = IpcWriteOptions::default();
1013
1014 let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1016
1017 let IpcMessage(schema) = message?;
1018
1019 Ok(schema)
1020}
1021
1022fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef {
1024 let schema = Schema::from(logical_plan.schema().as_ref()).into();
1026
1027 let flight_data_stream = FlightDataEncoderBuilder::new()
1030 .with_schema(schema)
1032 .build(futures::stream::iter([]));
1033
1034 flight_data_stream
1036 .known_schema()
1037 .expect("flight data schema should be known when explicitly provided via `with_schema`")
1038}
1039
1040fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Status> {
1041 let parameters = plan
1042 .get_parameter_types()
1043 .map_err(df_error_to_status)?
1044 .into_iter()
1045 .map(|(name, dt)| {
1046 dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1047 Status::internal(format!(
1048 "unable to determine type of query parameter {name}"
1049 ))
1050 })
1051 })
1052 .collect::<Result<BTreeMap<_, _>, Status>>()?;
1054
1055 let mut builder = SchemaBuilder::new();
1056 parameters
1057 .into_iter()
1058 .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1059 Ok(builder.finish().into())
1060}
1061
1062fn arrow_error_to_status(err: ArrowError) -> Status {
1063 Status::internal(format!("{err:?}"))
1064}
1065
1066fn flight_error_to_status(err: FlightError) -> Status {
1067 Status::internal(format!("{err:?}"))
1068}
1069
1070fn df_error_to_status(err: DataFusionError) -> Status {
1071 Status::internal(format!("{err:?}"))
1072}
1073
1074fn status_to_flight_error(status: Status) -> FlightError {
1075 FlightError::Tonic(status)
1076}
1077
1078async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1079 while let Some(msg) = decoder.try_next().await? {
1080 match msg.payload {
1081 DecodedPayload::None => {}
1082 DecodedPayload::Schema(schema) => {
1083 return Ok(schema);
1084 }
1085 DecodedPayload::RecordBatch(_) => {
1086 return Err(Status::invalid_argument(
1087 "parameter flight data must have a known schema",
1088 ));
1089 }
1090 }
1091 }
1092
1093 Err(Status::invalid_argument(
1094 "parameter flight data must have a schema",
1095 ))
1096}
1097
1098fn decode_param_values(
1100 parameters: Option<&[u8]>,
1101) -> Result<Option<ParamValues>, arrow::error::ArrowError> {
1102 parameters
1103 .map(|parameters| {
1104 let decoder = StreamReader::try_new(parameters, None)?;
1105 let schema = decoder.schema();
1106 let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1107 let batch = concat_batches(&schema, batches.iter())?;
1108 Ok(record_to_param_values(&batch)?)
1109 })
1110 .transpose()
1111}
1112
1113fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1115 let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1116
1117 let mut is_list = true;
1118 for col_index in 0..batch.num_columns() {
1119 let array = batch.column(col_index);
1120 let scalar = ScalarValue::try_from_array(array, 0)?;
1121 let name = batch
1122 .schema_ref()
1123 .field(col_index)
1124 .name()
1125 .trim_start_matches('$')
1126 .to_string();
1127 let index = name.parse().ok();
1128 is_list &= index.is_some();
1129 param_values.push((name, index, scalar));
1130 }
1131 if is_list {
1132 let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1133 .into_iter()
1134 .map(|(_name, index, value)| (index, value))
1135 .collect();
1136 values.sort_by_key(|(index, _value)| *index);
1137 Ok(values
1138 .into_iter()
1139 .map(|(_index, value)| value)
1140 .collect::<Vec<ScalarValue>>()
1141 .into())
1142 } else {
1143 Ok(param_values
1144 .into_iter()
1145 .map(|(name, _index, value)| (name, value))
1146 .collect::<Vec<(String, ScalarValue)>>()
1147 .into())
1148 }
1149}