1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow_flight::{
4 decode::{DecodedPayload, FlightDataDecoder},
5 sql::{
6 self,
7 server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
8 ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
9 ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
10 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
11 ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
12 ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
13 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
14 CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
15 CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
16 CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
17 CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
18 TicketStatementQuery,
19 },
20};
21use arrow_flight::{
22 encode::FlightDataEncoderBuilder,
23 error::FlightError,
24 flight_service_server::{FlightService, FlightServiceServer},
25 Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
26 IpcMessage, SchemaAsIpc, Ticket,
27};
28use datafusion::arrow::{
29 array::{ArrayRef, RecordBatch, StringArray},
30 compute::concat_batches,
31 datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
32 error::ArrowError,
33 ipc::{
34 reader::StreamReader,
35 writer::{IpcWriteOptions, StreamWriter},
36 },
37};
38use datafusion::{
39 common::{arrow::datatypes::Schema, ParamValues},
40 dataframe::DataFrame,
41 datasource::TableType,
42 error::{DataFusionError, Result as DataFusionResult},
43 execution::context::{SQLOptions, SessionContext, SessionState},
44 logical_expr::LogicalPlan,
45 physical_plan::SendableRecordBatchStream,
46 scalar::ScalarValue,
47};
48use datafusion_substrait::{
49 logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51
52use futures::{Stream, StreamExt, TryStreamExt};
53use log::info;
54use once_cell::sync::Lazy;
55use prost::bytes::Bytes;
56use prost::Message;
57use tonic::transport::Server;
58use tonic::{Request, Response, Status, Streaming};
59
60use super::config::FlightSqlServiceConfig;
61use super::session::{SessionStateProvider, StaticSessionStateProvider};
62use super::state::{CommandTicket, QueryHandle};
63
64type Result<T, E = Status> = std::result::Result<T, E>;
65
66pub struct FlightSqlService {
68 provider: Box<dyn SessionStateProvider>,
69 sql_options: Option<SQLOptions>,
70 config: FlightSqlServiceConfig,
71}
72
73impl FlightSqlService {
74 pub fn new(state: SessionState) -> Self {
76 Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
77 }
78
79 pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
81 Self {
82 provider,
83 sql_options: None,
84 config: FlightSqlServiceConfig::default(),
85 }
86 }
87
88 pub fn with_config(self, config: FlightSqlServiceConfig) -> Self {
90 Self { config, ..self }
91 }
92
93 pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
97 Self {
98 sql_options: Some(sql_options),
99 ..self
100 }
101 }
102
103 pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
110 let addr = addr.parse()?;
111 info!("Listening on {addr:?}");
112
113 let svc = FlightServiceServer::new(self);
114
115 Ok(Server::builder().add_service(svc).serve(addr).await?)
116 }
117
118 pub async fn serve_with_listener(
119 self,
120 listener: std::net::TcpListener,
121 ) -> Result<(), Box<dyn std::error::Error>> {
122 info!("Listening on {}", listener.local_addr()?);
123
124 let svc = FlightServiceServer::new(self);
125 let listener = tokio::net::TcpListener::from_std(listener)?;
126
127 Ok(Server::builder()
128 .add_service(svc)
129 .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
130 .await?)
131 }
132
133 async fn new_context<T>(
134 &self,
135 request: Request<T>,
136 ) -> Result<(Request<T>, FlightSqlSessionContext)> {
137 let (metadata, extensions, msg) = request.into_parts();
138 let inspect_request = Request::from_parts(metadata, extensions, ());
139
140 let state = self.provider.new_context(&inspect_request).await?;
141 let ctx = SessionContext::new_with_state(state);
142
143 let (metadata, extensions, _) = inspect_request.into_parts();
144 Ok((
145 Request::from_parts(metadata, extensions, msg),
146 FlightSqlSessionContext {
147 inner: ctx,
148 sql_options: self.sql_options,
149 },
150 ))
151 }
152}
153
154static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
156 Arc::new(Schema::new(vec![Field::new(
158 "table_type",
159 DataType::Utf8,
160 false,
161 )]))
162});
163
164struct FlightSqlSessionContext {
165 inner: SessionContext,
166 sql_options: Option<SQLOptions>,
167}
168
169impl FlightSqlSessionContext {
170 async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
171 let plan = self.inner.state().create_logical_plan(sql).await?;
172 let verifier = self.sql_options.unwrap_or_default();
173 verifier.verify_plan(&plan)?;
174 Ok(plan)
175 }
176
177 async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
178 let plan = self.sql_to_logical_plan(sql).await?;
179 self.execute_logical_plan(plan).await
180 }
181
182 async fn execute_logical_plan(
183 &self,
184 plan: LogicalPlan,
185 ) -> DataFusionResult<SendableRecordBatchStream> {
186 self.inner
187 .execute_logical_plan(plan)
188 .await?
189 .execute_stream()
190 .await
191 }
192}
193
194#[tonic::async_trait]
195impl ArrowFlightSqlService for FlightSqlService {
196 type FlightService = FlightSqlService;
197
198 async fn do_handshake(
199 &self,
200 _request: Request<Streaming<HandshakeRequest>>,
201 ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
202 info!("do_handshake");
203 Err(Status::unimplemented("handshake is not supported"))
207 }
208
209 async fn do_get_fallback(
210 &self,
211 request: Request<Ticket>,
212 _message: Any,
213 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
214 let (request, ctx) = self.new_context(request).await?;
215
216 let ticket = CommandTicket::try_decode(request.into_inner().ticket)
217 .map_err(flight_error_to_status)?;
218
219 match ticket.command {
220 sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
221 let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
224 let arrow_schema = stream.schema();
225 let arrow_stream = stream.map(|i| {
226 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
227 Ok(batch)
228 });
229
230 let flight_data_stream = FlightDataEncoderBuilder::new()
231 .with_schema(arrow_schema)
232 .build(arrow_stream)
233 .map_err(flight_error_to_status)
234 .boxed();
235
236 Ok(Response::new(flight_data_stream))
237 }
238 sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
239 prepared_statement_handle,
240 }) => {
241 let handle = QueryHandle::try_decode(prepared_statement_handle)?;
242
243 let mut plan = ctx
244 .sql_to_logical_plan(handle.query())
245 .await
246 .map_err(df_error_to_status)?;
247
248 if let Some(param_values) =
249 decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
250 {
251 plan = plan
252 .with_param_values(param_values)
253 .map_err(df_error_to_status)?;
254 }
255
256 let stream = ctx
257 .execute_logical_plan(plan)
258 .await
259 .map_err(df_error_to_status)?;
260 let arrow_schema = stream.schema();
261 let arrow_stream = stream.map(|i| {
262 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
263 Ok(batch)
264 });
265
266 let flight_data_stream = FlightDataEncoderBuilder::new()
267 .with_schema(arrow_schema)
268 .build(arrow_stream)
269 .map_err(flight_error_to_status)
270 .boxed();
271
272 Ok(Response::new(flight_data_stream))
273 }
274 sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
275 plan,
276 ..
277 }) => {
278 let substrait_bytes = &plan
279 .ok_or(Status::invalid_argument(
280 "Expected substrait plan, found None",
281 ))?
282 .plan;
283
284 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
285
286 let state = ctx.inner.state();
287 let df = DataFrame::new(state, plan);
288
289 let stream = df.execute_stream().await.map_err(df_error_to_status)?;
290 let arrow_schema = stream.schema();
291 let arrow_stream = stream.map(|i| {
292 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
293 Ok(batch)
294 });
295
296 let flight_data_stream = FlightDataEncoderBuilder::new()
297 .with_schema(arrow_schema)
298 .build(arrow_stream)
299 .map_err(flight_error_to_status)
300 .boxed();
301
302 Ok(Response::new(flight_data_stream))
303 }
304 _ => {
305 return Err(Status::internal(format!(
306 "statement handle not found: {:?}",
307 ticket.command
308 )));
309 }
310 }
311 }
312
313 async fn get_flight_info_statement(
314 &self,
315 query: CommandStatementQuery,
316 request: Request<FlightDescriptor>,
317 ) -> Result<Response<FlightInfo>> {
318 let (request, ctx) = self.new_context(request).await?;
319
320 let sql = &query.query;
321 info!("get_flight_info_statement with query={sql}");
322
323 let flight_descriptor = request.into_inner();
324
325 let plan = ctx
326 .sql_to_logical_plan(sql)
327 .await
328 .map_err(df_error_to_status)?;
329
330 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
331
332 let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
334 .try_encode()
335 .map_err(flight_error_to_status)?;
336
337 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
338
339 let flight_info = FlightInfo::new()
340 .with_endpoint(endpoint)
341 .with_descriptor(flight_descriptor)
343 .try_with_schema(dataset_schema.as_ref())
344 .map_err(arrow_error_to_status)?;
345
346 Ok(Response::new(flight_info))
347 }
348
349 async fn get_flight_info_substrait_plan(
350 &self,
351 query: CommandStatementSubstraitPlan,
352 request: Request<FlightDescriptor>,
353 ) -> Result<Response<FlightInfo>> {
354 info!("get_flight_info_substrait_plan");
355 let (request, ctx) = self.new_context(request).await?;
356
357 let substrait_bytes = &query
358 .plan
359 .as_ref()
360 .ok_or(Status::invalid_argument(
361 "Expected substrait plan, found None",
362 ))?
363 .plan;
364
365 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
366
367 let flight_descriptor = request.into_inner();
368
369 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
370
371 let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
373 .try_encode()
374 .map_err(flight_error_to_status)?;
375
376 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
377
378 let flight_info = FlightInfo::new()
379 .with_endpoint(endpoint)
380 .with_descriptor(flight_descriptor)
382 .try_with_schema(dataset_schema.as_ref())
383 .map_err(arrow_error_to_status)?;
384
385 Ok(Response::new(flight_info))
386 }
387
388 async fn get_flight_info_prepared_statement(
389 &self,
390 cmd: CommandPreparedStatementQuery,
391 request: Request<FlightDescriptor>,
392 ) -> Result<Response<FlightInfo>> {
393 let (request, ctx) = self.new_context(request).await?;
394
395 let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
396 .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
397
398 info!("get_flight_info_prepared_statement with handle={handle}");
399
400 let flight_descriptor = request.into_inner();
401
402 let sql = handle.query();
403 let plan = ctx
404 .sql_to_logical_plan(sql)
405 .await
406 .map_err(df_error_to_status)?;
407
408 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
409
410 let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
412 .try_encode()
413 .map_err(flight_error_to_status)?;
414
415 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
416
417 let flight_info = FlightInfo::new()
418 .with_endpoint(endpoint)
419 .with_descriptor(flight_descriptor)
421 .try_with_schema(dataset_schema.as_ref())
422 .map_err(arrow_error_to_status)?;
423
424 Ok(Response::new(flight_info))
425 }
426
427 async fn get_flight_info_catalogs(
428 &self,
429 query: CommandGetCatalogs,
430 request: Request<FlightDescriptor>,
431 ) -> Result<Response<FlightInfo>> {
432 info!("get_flight_info_catalogs");
433 let (request, _ctx) = self.new_context(request).await?;
434
435 let flight_descriptor = request.into_inner();
436 let ticket = Ticket {
437 ticket: query.as_any().encode_to_vec().into(),
438 };
439 let endpoint = FlightEndpoint::new().with_ticket(ticket);
440
441 let flight_info = FlightInfo::new()
442 .try_with_schema(&query.into_builder().schema())
443 .map_err(arrow_error_to_status)?
444 .with_endpoint(endpoint)
445 .with_descriptor(flight_descriptor);
446
447 Ok(Response::new(flight_info))
448 }
449
450 async fn get_flight_info_schemas(
451 &self,
452 query: CommandGetDbSchemas,
453 request: Request<FlightDescriptor>,
454 ) -> Result<Response<FlightInfo>> {
455 info!("get_flight_info_schemas");
456 let (request, _ctx) = self.new_context(request).await?;
457 let flight_descriptor = request.into_inner();
458 let ticket = Ticket {
459 ticket: query.as_any().encode_to_vec().into(),
460 };
461 let endpoint = FlightEndpoint::new().with_ticket(ticket);
462
463 let flight_info = FlightInfo::new()
464 .try_with_schema(&query.into_builder().schema())
465 .map_err(arrow_error_to_status)?
466 .with_endpoint(endpoint)
467 .with_descriptor(flight_descriptor);
468
469 Ok(Response::new(flight_info))
470 }
471
472 async fn get_flight_info_tables(
473 &self,
474 query: CommandGetTables,
475 request: Request<FlightDescriptor>,
476 ) -> Result<Response<FlightInfo>> {
477 info!("get_flight_info_tables");
478 let (request, _ctx) = self.new_context(request).await?;
479
480 let flight_descriptor = request.into_inner();
481 let ticket = Ticket {
482 ticket: query.as_any().encode_to_vec().into(),
483 };
484 let endpoint = FlightEndpoint::new().with_ticket(ticket);
485
486 let flight_info = FlightInfo::new()
487 .try_with_schema(&query.into_builder().schema())
488 .map_err(arrow_error_to_status)?
489 .with_endpoint(endpoint)
490 .with_descriptor(flight_descriptor);
491
492 Ok(Response::new(flight_info))
493 }
494
495 async fn get_flight_info_table_types(
496 &self,
497 query: CommandGetTableTypes,
498 request: Request<FlightDescriptor>,
499 ) -> Result<Response<FlightInfo>> {
500 info!("get_flight_info_table_types");
501 let (request, _ctx) = self.new_context(request).await?;
502
503 let flight_descriptor = request.into_inner();
504 let ticket = Ticket {
505 ticket: query.as_any().encode_to_vec().into(),
506 };
507 let endpoint = FlightEndpoint::new().with_ticket(ticket);
508
509 let flight_info = FlightInfo::new()
510 .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
511 .map_err(arrow_error_to_status)?
512 .with_endpoint(endpoint)
513 .with_descriptor(flight_descriptor);
514
515 Ok(Response::new(flight_info))
516 }
517
518 async fn get_flight_info_sql_info(
519 &self,
520 _query: CommandGetSqlInfo,
521 request: Request<FlightDescriptor>,
522 ) -> Result<Response<FlightInfo>> {
523 info!("get_flight_info_sql_info");
524 let (_, _) = self.new_context(request).await?;
525
526 Err(Status::unimplemented("Implement CommandGetSqlInfo"))
527 }
528
529 async fn get_flight_info_primary_keys(
530 &self,
531 _query: CommandGetPrimaryKeys,
532 request: Request<FlightDescriptor>,
533 ) -> Result<Response<FlightInfo>> {
534 info!("get_flight_info_primary_keys");
535 let (_, _) = self.new_context(request).await?;
536
537 Err(Status::unimplemented(
538 "Implement get_flight_info_primary_keys",
539 ))
540 }
541
542 async fn get_flight_info_exported_keys(
543 &self,
544 _query: CommandGetExportedKeys,
545 request: Request<FlightDescriptor>,
546 ) -> Result<Response<FlightInfo>> {
547 info!("get_flight_info_exported_keys");
548 let (_, _) = self.new_context(request).await?;
549
550 Err(Status::unimplemented(
551 "Implement get_flight_info_exported_keys",
552 ))
553 }
554
555 async fn get_flight_info_imported_keys(
556 &self,
557 _query: CommandGetImportedKeys,
558 request: Request<FlightDescriptor>,
559 ) -> Result<Response<FlightInfo>> {
560 info!("get_flight_info_imported_keys");
561 let (_, _) = self.new_context(request).await?;
562
563 Err(Status::unimplemented(
564 "Implement get_flight_info_imported_keys",
565 ))
566 }
567
568 async fn get_flight_info_cross_reference(
569 &self,
570 _query: CommandGetCrossReference,
571 request: Request<FlightDescriptor>,
572 ) -> Result<Response<FlightInfo>> {
573 info!("get_flight_info_cross_reference");
574 let (_, _) = self.new_context(request).await?;
575
576 Err(Status::unimplemented(
577 "Implement get_flight_info_cross_reference",
578 ))
579 }
580
581 async fn get_flight_info_xdbc_type_info(
582 &self,
583 _query: CommandGetXdbcTypeInfo,
584 request: Request<FlightDescriptor>,
585 ) -> Result<Response<FlightInfo>> {
586 info!("get_flight_info_xdbc_type_info");
587 let (_, _) = self.new_context(request).await?;
588
589 Err(Status::unimplemented(
590 "Implement get_flight_info_xdbc_type_info",
591 ))
592 }
593
594 async fn do_get_statement(
595 &self,
596 _ticket: TicketStatementQuery,
597 request: Request<Ticket>,
598 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
599 info!("do_get_statement");
600 let (_, _) = self.new_context(request).await?;
601
602 Err(Status::unimplemented("Implement do_get_statement"))
603 }
604
605 async fn do_get_prepared_statement(
606 &self,
607 _query: CommandPreparedStatementQuery,
608 request: Request<Ticket>,
609 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
610 info!("do_get_prepared_statement");
611 let (_, _) = self.new_context(request).await?;
612
613 Err(Status::unimplemented("Implement do_get_prepared_statement"))
614 }
615
616 async fn do_get_catalogs(
617 &self,
618 query: CommandGetCatalogs,
619 request: Request<Ticket>,
620 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
621 info!("do_get_catalogs");
622 let (_request, ctx) = self.new_context(request).await?;
623 let catalog_names = ctx.inner.catalog_names();
624
625 let mut builder = query.into_builder();
626 for catalog_name in &catalog_names {
627 builder.append(catalog_name);
628 }
629 let schema = builder.schema();
630 let batch = builder.build();
631 let stream = FlightDataEncoderBuilder::new()
632 .with_schema(schema)
633 .build(futures::stream::once(async { batch }))
634 .map_err(Status::from);
635 Ok(Response::new(Box::pin(stream)))
636 }
637
638 async fn do_get_schemas(
639 &self,
640 query: CommandGetDbSchemas,
641 request: Request<Ticket>,
642 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
643 info!("do_get_schemas");
644 let (_request, ctx) = self.new_context(request).await?;
645 let catalog_name = query.catalog.clone();
646 let mut builder = query.into_builder();
648 if let Some(catalog_name) = &catalog_name {
649 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
650 for schema_name in &catalog.schema_names() {
651 builder.append(catalog_name, schema_name);
652 }
653 }
654 };
655
656 let schema = builder.schema();
657 let batch = builder.build();
658 let stream = FlightDataEncoderBuilder::new()
659 .with_schema(schema)
660 .build(futures::stream::once(async { batch }))
661 .map_err(Status::from);
662 Ok(Response::new(Box::pin(stream)))
663 }
664
665 async fn do_get_tables(
666 &self,
667 query: CommandGetTables,
668 request: Request<Ticket>,
669 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
670 info!("do_get_tables");
671 let (_request, ctx) = self.new_context(request).await?;
672 let catalog_name = query.catalog.clone();
673 let mut builder = query.into_builder();
674 if let Some(catalog_name) = &catalog_name {
676 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
677 for schema_name in &catalog.schema_names() {
678 if let Some(schema) = catalog.schema(schema_name) {
679 for table_name in &schema.table_names() {
680 if let Some(table) =
681 schema.table(table_name).await.map_err(df_error_to_status)?
682 {
683 builder
684 .append(
685 catalog_name,
686 schema_name,
687 table_name,
688 table.table_type().to_string(),
689 &table.schema(),
690 )
691 .map_err(flight_error_to_status)?;
692 }
693 }
694 }
695 }
696 }
697 };
698
699 let schema = builder.schema();
700 let batch = builder.build();
701 let stream = FlightDataEncoderBuilder::new()
702 .with_schema(schema)
703 .build(futures::stream::once(async { batch }))
704 .map_err(Status::from);
705 Ok(Response::new(Box::pin(stream)))
706 }
707
708 async fn do_get_table_types(
709 &self,
710 _query: CommandGetTableTypes,
711 request: Request<Ticket>,
712 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
713 info!("do_get_table_types");
714 let (_, _) = self.new_context(request).await?;
715
716 let table_types: ArrayRef = Arc::new(StringArray::from(
718 vec![TableType::Base, TableType::View, TableType::Temporary]
719 .into_iter()
720 .map(|tt| tt.to_string())
721 .collect::<Vec<String>>(),
722 ));
723
724 let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
725
726 let stream = FlightDataEncoderBuilder::new()
727 .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
728 .build(futures::stream::once(async { Ok(batch) }))
729 .map_err(Status::from);
730 Ok(Response::new(Box::pin(stream)))
731 }
732
733 async fn do_get_sql_info(
734 &self,
735 _query: CommandGetSqlInfo,
736 request: Request<Ticket>,
737 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
738 info!("do_get_sql_info");
739 let (_, _) = self.new_context(request).await?;
740
741 Err(Status::unimplemented("Implement do_get_sql_info"))
742 }
743
744 async fn do_get_primary_keys(
745 &self,
746 _query: CommandGetPrimaryKeys,
747 request: Request<Ticket>,
748 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
749 info!("do_get_primary_keys");
750 let (_, _) = self.new_context(request).await?;
751
752 Err(Status::unimplemented("Implement do_get_primary_keys"))
753 }
754
755 async fn do_get_exported_keys(
756 &self,
757 _query: CommandGetExportedKeys,
758 request: Request<Ticket>,
759 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
760 info!("do_get_exported_keys");
761 let (_, _) = self.new_context(request).await?;
762
763 Err(Status::unimplemented("Implement do_get_exported_keys"))
764 }
765
766 async fn do_get_imported_keys(
767 &self,
768 _query: CommandGetImportedKeys,
769 request: Request<Ticket>,
770 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
771 info!("do_get_imported_keys");
772 let (_, _) = self.new_context(request).await?;
773
774 Err(Status::unimplemented("Implement do_get_imported_keys"))
775 }
776
777 async fn do_get_cross_reference(
778 &self,
779 _query: CommandGetCrossReference,
780 request: Request<Ticket>,
781 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
782 info!("do_get_cross_reference");
783 let (_, _) = self.new_context(request).await?;
784
785 Err(Status::unimplemented("Implement do_get_cross_reference"))
786 }
787
788 async fn do_get_xdbc_type_info(
789 &self,
790 _query: CommandGetXdbcTypeInfo,
791 request: Request<Ticket>,
792 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
793 info!("do_get_xdbc_type_info");
794 let (_, _) = self.new_context(request).await?;
795
796 Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
797 }
798
799 async fn do_put_statement_update(
800 &self,
801 _ticket: CommandStatementUpdate,
802 request: Request<PeekableFlightDataStream>,
803 ) -> Result<i64, Status> {
804 info!("do_put_statement_update");
805 let (_, _) = self.new_context(request).await?;
806
807 Err(Status::unimplemented("Implement do_put_statement_update"))
808 }
809
810 async fn do_put_prepared_statement_query(
811 &self,
812 query: CommandPreparedStatementQuery,
813 request: Request<PeekableFlightDataStream>,
814 ) -> Result<DoPutPreparedStatementResult, Status> {
815 info!("do_put_prepared_statement_query");
816 let (request, _) = self.new_context(request).await?;
817
818 let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
819
820 info!(
821 "do_action_create_prepared_statement query={:?}",
822 handle.query()
823 );
824 let mut decoder =
827 FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
828 let schema = decode_schema(&mut decoder).await?;
829 let mut parameters = Vec::new();
830 let mut encoder =
831 StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
832 let mut total_rows = 0;
833 while let Some(msg) = decoder.try_next().await? {
834 match msg.payload {
835 DecodedPayload::None => {}
836 DecodedPayload::Schema(_) => {
837 return Err(Status::invalid_argument(
838 "parameter flight data must contain a single schema",
839 ));
840 }
841 DecodedPayload::RecordBatch(record_batch) => {
842 total_rows += record_batch.num_rows();
843 encoder
844 .write(&record_batch)
845 .map_err(arrow_error_to_status)?;
846 }
847 }
848 }
849 if total_rows > 1 {
850 return Err(Status::invalid_argument(
851 "parameters should contain a single row",
852 ));
853 }
854
855 handle.set_parameters(Some(parameters.into()));
856
857 let res = DoPutPreparedStatementResult {
858 prepared_statement_handle: Some(Bytes::from(handle)),
859 };
860
861 Ok(res)
862 }
863
864 async fn do_put_prepared_statement_update(
865 &self,
866 _handle: CommandPreparedStatementUpdate,
867 request: Request<PeekableFlightDataStream>,
868 ) -> Result<i64, Status> {
869 info!("do_put_prepared_statement_update");
870 let (_, _) = self.new_context(request).await?;
871
872 Ok(-1)
875 }
876
877 async fn do_put_substrait_plan(
878 &self,
879 _query: CommandStatementSubstraitPlan,
880 request: Request<PeekableFlightDataStream>,
881 ) -> Result<i64, Status> {
882 info!("do_put_prepared_statement_update");
883 let (_, _) = self.new_context(request).await?;
884
885 Err(Status::unimplemented(
886 "Implement do_put_prepared_statement_update",
887 ))
888 }
889
890 async fn do_action_create_prepared_statement(
891 &self,
892 query: ActionCreatePreparedStatementRequest,
893 request: Request<Action>,
894 ) -> Result<ActionCreatePreparedStatementResult, Status> {
895 let (_, ctx) = self.new_context(request).await?;
896
897 let sql = query.query.clone();
898 info!(
899 "do_action_create_prepared_statement query={:?}",
900 query.query
901 );
902
903 let plan = ctx
904 .sql_to_logical_plan(sql.as_str())
905 .await
906 .map_err(df_error_to_status)?;
907
908 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
909 let parameter_schema = parameter_schema_for_plan(&plan).map_err(|e| e.as_ref().clone())?;
910
911 let dataset_schema =
912 encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
913 let parameter_schema =
914 encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
915
916 let handle = QueryHandle::new(sql, None);
917
918 let res = ActionCreatePreparedStatementResult {
919 prepared_statement_handle: Bytes::from(handle),
920 dataset_schema,
921 parameter_schema,
922 };
923
924 Ok(res)
925 }
926
927 async fn do_action_close_prepared_statement(
928 &self,
929 query: ActionClosePreparedStatementRequest,
930 request: Request<Action>,
931 ) -> Result<(), Status> {
932 let (_, _) = self.new_context(request).await?;
933
934 let handle = query.prepared_statement_handle.as_ref();
935 if let Ok(handle) = std::str::from_utf8(handle) {
936 info!("do_action_close_prepared_statement with handle {handle:?}",);
937
938 }
940 Ok(())
941 }
942
943 async fn do_action_create_prepared_substrait_plan(
944 &self,
945 _query: ActionCreatePreparedSubstraitPlanRequest,
946 request: Request<Action>,
947 ) -> Result<ActionCreatePreparedStatementResult, Status> {
948 info!("do_action_create_prepared_substrait_plan");
949 let (_, _) = self.new_context(request).await?;
950
951 Err(Status::unimplemented(
952 "Implement do_action_create_prepared_substrait_plan",
953 ))
954 }
955
956 async fn do_action_begin_transaction(
957 &self,
958 _query: ActionBeginTransactionRequest,
959 request: Request<Action>,
960 ) -> Result<ActionBeginTransactionResult, Status> {
961 let (_, _) = self.new_context(request).await?;
962
963 info!("do_action_begin_transaction");
964 Err(Status::unimplemented(
965 "Implement do_action_begin_transaction",
966 ))
967 }
968
969 async fn do_action_end_transaction(
970 &self,
971 _query: ActionEndTransactionRequest,
972 request: Request<Action>,
973 ) -> Result<(), Status> {
974 info!("do_action_end_transaction");
975 let (_, _) = self.new_context(request).await?;
976
977 Err(Status::unimplemented("Implement do_action_end_transaction"))
978 }
979
980 async fn do_action_begin_savepoint(
981 &self,
982 _query: ActionBeginSavepointRequest,
983 request: Request<Action>,
984 ) -> Result<ActionBeginSavepointResult, Status> {
985 info!("do_action_begin_savepoint");
986 let (_, _) = self.new_context(request).await?;
987
988 Err(Status::unimplemented("Implement do_action_begin_savepoint"))
989 }
990
991 async fn do_action_end_savepoint(
992 &self,
993 _query: ActionEndSavepointRequest,
994 request: Request<Action>,
995 ) -> Result<(), Status> {
996 info!("do_action_end_savepoint");
997 let (_, _) = self.new_context(request).await?;
998
999 Err(Status::unimplemented("Implement do_action_end_savepoint"))
1000 }
1001
1002 async fn do_action_cancel_query(
1003 &self,
1004 _query: ActionCancelQueryRequest,
1005 request: Request<Action>,
1006 ) -> Result<ActionCancelQueryResult, Status> {
1007 info!("do_action_cancel_query");
1008 let (_, _) = self.new_context(request).await?;
1009
1010 Err(Status::unimplemented("Implement do_action_cancel_query"))
1011 }
1012
1013 async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
1014}
1015
1016async fn parse_substrait_bytes(
1019 ctx: &FlightSqlSessionContext,
1020 substrait: &Bytes,
1021) -> Result<LogicalPlan> {
1022 let substrait_plan = deserialize_bytes(substrait.to_vec())
1023 .await
1024 .map_err(df_error_to_status)?;
1025
1026 from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1027 .await
1028 .map_err(df_error_to_status)
1029}
1030
1031fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1033 let options = IpcWriteOptions::default();
1034
1035 let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1037
1038 let IpcMessage(schema) = message?;
1039
1040 Ok(schema)
1041}
1042
1043fn get_schema_for_plan(logical_plan: &LogicalPlan, with_metadata: bool) -> SchemaRef {
1045 let schema: SchemaRef = if with_metadata {
1046 let df_schema = logical_plan.schema();
1048
1049 let fields_with_metadata: Vec<_> = df_schema
1051 .iter()
1052 .map(|(qualifier, field)| {
1053 if let Some(table_ref) = qualifier {
1055 let mut metadata = field.metadata().clone();
1056 metadata.insert("table_name".to_string(), table_ref.to_string());
1057 field.as_ref().clone().with_metadata(metadata)
1058 } else {
1059 field.as_ref().clone()
1060 }
1061 })
1062 .collect();
1063
1064 Arc::new(Schema::new_with_metadata(
1065 fields_with_metadata,
1066 df_schema.as_ref().metadata().clone(),
1067 ))
1068 } else {
1069 Arc::new(logical_plan.schema().as_arrow().clone())
1070 };
1071
1072 let flight_data_stream = FlightDataEncoderBuilder::new()
1075 .with_schema(schema)
1077 .build(futures::stream::iter([]));
1078
1079 flight_data_stream
1081 .known_schema()
1082 .expect("flight data schema should be known when explicitly provided via `with_schema`")
1083}
1084
1085fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Box<Status>> {
1086 let parameters = plan
1087 .get_parameter_types()
1088 .map_err(df_error_to_status)?
1089 .into_iter()
1090 .map(|(name, dt)| {
1091 dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1092 Status::internal(format!(
1093 "unable to determine type of query parameter {name}"
1094 ))
1095 })
1096 })
1097 .collect::<Result<BTreeMap<_, _>, Status>>()?;
1099
1100 let mut builder = SchemaBuilder::new();
1101 parameters
1102 .into_iter()
1103 .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1104 Ok(builder.finish().into())
1105}
1106
1107fn arrow_error_to_status(err: ArrowError) -> Status {
1108 Status::internal(format!("{err:?}"))
1109}
1110
1111fn flight_error_to_status(err: FlightError) -> Status {
1112 Status::internal(format!("{err:?}"))
1113}
1114
1115fn df_error_to_status(err: DataFusionError) -> Status {
1116 Status::internal(format!("{err:?}"))
1117}
1118
1119fn status_to_flight_error(status: Status) -> FlightError {
1120 FlightError::Tonic(Box::new(status))
1121}
1122
1123async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1124 while let Some(msg) = decoder.try_next().await? {
1125 match msg.payload {
1126 DecodedPayload::None => {}
1127 DecodedPayload::Schema(schema) => {
1128 return Ok(schema);
1129 }
1130 DecodedPayload::RecordBatch(_) => {
1131 return Err(Status::invalid_argument(
1132 "parameter flight data must have a known schema",
1133 ));
1134 }
1135 }
1136 }
1137
1138 Err(Status::invalid_argument(
1139 "parameter flight data must have a schema",
1140 ))
1141}
1142
1143fn decode_param_values(parameters: Option<&[u8]>) -> Result<Option<ParamValues>, ArrowError> {
1145 parameters
1146 .map(|parameters| {
1147 let decoder = StreamReader::try_new(parameters, None)?;
1148 let schema = decoder.schema();
1149 let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1150 let batch = concat_batches(&schema, batches.iter())?;
1151 Ok(record_to_param_values(&batch)?)
1152 })
1153 .transpose()
1154}
1155
1156fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1158 let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1159
1160 let mut is_list = true;
1161 for col_index in 0..batch.num_columns() {
1162 let array = batch.column(col_index);
1163 let scalar = ScalarValue::try_from_array(array, 0)?;
1164 let name = batch
1165 .schema_ref()
1166 .field(col_index)
1167 .name()
1168 .trim_start_matches('$')
1169 .to_string();
1170 let index = name.parse().ok();
1171 is_list &= index.is_some();
1172 param_values.push((name, index, scalar));
1173 }
1174 if is_list {
1175 let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1176 .into_iter()
1177 .map(|(_name, index, value)| (index, value))
1178 .collect();
1179 values.sort_by_key(|(index, _value)| *index);
1180 Ok(values
1181 .into_iter()
1182 .map(|(_index, value)| value)
1183 .collect::<Vec<ScalarValue>>()
1184 .into())
1185 } else {
1186 Ok(param_values
1187 .into_iter()
1188 .map(|(name, _index, value)| (name, value))
1189 .collect::<Vec<(String, ScalarValue)>>()
1190 .into())
1191 }
1192}