1use std::{collections::BTreeMap, pin::Pin, sync::Arc};
2
3use arrow::{
4 array::{ArrayRef, RecordBatch, StringArray},
5 compute::concat_batches,
6 datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
7 error::ArrowError,
8 ipc::{
9 reader::StreamReader,
10 writer::{IpcWriteOptions, StreamWriter},
11 },
12};
13use arrow_flight::{
14 decode::{DecodedPayload, FlightDataDecoder},
15 sql::{
16 self,
17 server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
18 ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
19 ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
20 ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
21 ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
22 ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
23 CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
24 CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
25 CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
26 CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
27 CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
28 TicketStatementQuery,
29 },
30};
31use arrow_flight::{
32 encode::FlightDataEncoderBuilder,
33 error::FlightError,
34 flight_service_server::{FlightService, FlightServiceServer},
35 Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
36 IpcMessage, SchemaAsIpc, Ticket,
37};
38use datafusion::{
39 common::{arrow::datatypes::Schema, ParamValues},
40 dataframe::DataFrame,
41 datasource::TableType,
42 error::{DataFusionError, Result as DataFusionResult},
43 execution::context::{SQLOptions, SessionContext, SessionState},
44 logical_expr::LogicalPlan,
45 physical_plan::SendableRecordBatchStream,
46 scalar::ScalarValue,
47};
48use datafusion_substrait::{
49 logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
50};
51
52use futures::{Stream, StreamExt, TryStreamExt};
53use log::info;
54use once_cell::sync::Lazy;
55use prost::bytes::Bytes;
56use prost::Message;
57use tonic::transport::Server;
58use tonic::{Request, Response, Status, Streaming};
59
60use super::config::FlightSqlServiceConfig;
61use super::session::{SessionStateProvider, StaticSessionStateProvider};
62use super::state::{CommandTicket, QueryHandle};
63
64type Result<T, E = Status> = std::result::Result<T, E>;
65
66pub struct FlightSqlService {
68 provider: Box<dyn SessionStateProvider>,
69 sql_options: Option<SQLOptions>,
70 config: FlightSqlServiceConfig,
71}
72
73impl FlightSqlService {
74 pub fn new(state: SessionState) -> Self {
76 Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state)))
77 }
78
79 pub fn new_with_provider(provider: Box<dyn SessionStateProvider>) -> Self {
81 Self {
82 provider,
83 sql_options: None,
84 config: FlightSqlServiceConfig::default(),
85 }
86 }
87
88 pub fn with_config(self, config: FlightSqlServiceConfig) -> Self {
90 Self { config, ..self }
91 }
92
93 pub fn with_sql_options(self, sql_options: SQLOptions) -> Self {
97 Self {
98 sql_options: Some(sql_options),
99 ..self
100 }
101 }
102
103 pub async fn serve(self, addr: String) -> Result<(), Box<dyn std::error::Error>> {
110 let addr = addr.parse()?;
111 info!("Listening on {addr:?}");
112
113 let svc = FlightServiceServer::new(self);
114
115 Ok(Server::builder().add_service(svc).serve(addr).await?)
116 }
117
118 async fn new_context<T>(
119 &self,
120 request: Request<T>,
121 ) -> Result<(Request<T>, FlightSqlSessionContext)> {
122 let (metadata, extensions, msg) = request.into_parts();
123 let inspect_request = Request::from_parts(metadata, extensions, ());
124
125 let state = self.provider.new_context(&inspect_request).await?;
126 let ctx = SessionContext::new_with_state(state);
127
128 let (metadata, extensions, _) = inspect_request.into_parts();
129 Ok((
130 Request::from_parts(metadata, extensions, msg),
131 FlightSqlSessionContext {
132 inner: ctx,
133 sql_options: self.sql_options,
134 },
135 ))
136 }
137}
138
139static GET_TABLE_TYPES_SCHEMA: Lazy<SchemaRef> = Lazy::new(|| {
141 Arc::new(Schema::new(vec![Field::new(
143 "table_type",
144 DataType::Utf8,
145 false,
146 )]))
147});
148
149struct FlightSqlSessionContext {
150 inner: SessionContext,
151 sql_options: Option<SQLOptions>,
152}
153
154impl FlightSqlSessionContext {
155 async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult<LogicalPlan> {
156 let plan = self.inner.state().create_logical_plan(sql).await?;
157 let verifier = self.sql_options.unwrap_or_default();
158 verifier.verify_plan(&plan)?;
159 Ok(plan)
160 }
161
162 async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
163 let plan = self.sql_to_logical_plan(sql).await?;
164 self.execute_logical_plan(plan).await
165 }
166
167 async fn execute_logical_plan(
168 &self,
169 plan: LogicalPlan,
170 ) -> DataFusionResult<SendableRecordBatchStream> {
171 self.inner
172 .execute_logical_plan(plan)
173 .await?
174 .execute_stream()
175 .await
176 }
177}
178
179#[tonic::async_trait]
180impl ArrowFlightSqlService for FlightSqlService {
181 type FlightService = FlightSqlService;
182
183 async fn do_handshake(
184 &self,
185 _request: Request<Streaming<HandshakeRequest>>,
186 ) -> Result<Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse>> + Send>>>> {
187 info!("do_handshake");
188 Err(Status::unimplemented("handshake is not supported"))
192 }
193
194 async fn do_get_fallback(
195 &self,
196 request: Request<Ticket>,
197 _message: Any,
198 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
199 let (request, ctx) = self.new_context(request).await?;
200
201 let ticket = CommandTicket::try_decode(request.into_inner().ticket)
202 .map_err(flight_error_to_status)?;
203
204 match ticket.command {
205 sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
206 let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
209 let arrow_schema = stream.schema();
210 let arrow_stream = stream.map(|i| {
211 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
212 Ok(batch)
213 });
214
215 let flight_data_stream = FlightDataEncoderBuilder::new()
216 .with_schema(arrow_schema)
217 .build(arrow_stream)
218 .map_err(flight_error_to_status)
219 .boxed();
220
221 Ok(Response::new(flight_data_stream))
222 }
223 sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
224 prepared_statement_handle,
225 }) => {
226 let handle = QueryHandle::try_decode(prepared_statement_handle)?;
227
228 let mut plan = ctx
229 .sql_to_logical_plan(handle.query())
230 .await
231 .map_err(df_error_to_status)?;
232
233 if let Some(param_values) =
234 decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
235 {
236 plan = plan
237 .with_param_values(param_values)
238 .map_err(df_error_to_status)?;
239 }
240
241 let stream = ctx
242 .execute_logical_plan(plan)
243 .await
244 .map_err(df_error_to_status)?;
245 let arrow_schema = stream.schema();
246 let arrow_stream = stream.map(|i| {
247 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
248 Ok(batch)
249 });
250
251 let flight_data_stream = FlightDataEncoderBuilder::new()
252 .with_schema(arrow_schema)
253 .build(arrow_stream)
254 .map_err(flight_error_to_status)
255 .boxed();
256
257 Ok(Response::new(flight_data_stream))
258 }
259 sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan {
260 plan,
261 ..
262 }) => {
263 let substrait_bytes = &plan
264 .ok_or(Status::invalid_argument(
265 "Expected substrait plan, found None",
266 ))?
267 .plan;
268
269 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
270
271 let state = ctx.inner.state();
272 let df = DataFrame::new(state, plan);
273
274 let stream = df.execute_stream().await.map_err(df_error_to_status)?;
275 let arrow_schema = stream.schema();
276 let arrow_stream = stream.map(|i| {
277 let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
278 Ok(batch)
279 });
280
281 let flight_data_stream = FlightDataEncoderBuilder::new()
282 .with_schema(arrow_schema)
283 .build(arrow_stream)
284 .map_err(flight_error_to_status)
285 .boxed();
286
287 Ok(Response::new(flight_data_stream))
288 }
289 _ => {
290 return Err(Status::internal(format!(
291 "statement handle not found: {:?}",
292 ticket.command
293 )));
294 }
295 }
296 }
297
298 async fn get_flight_info_statement(
299 &self,
300 query: CommandStatementQuery,
301 request: Request<FlightDescriptor>,
302 ) -> Result<Response<FlightInfo>> {
303 let (request, ctx) = self.new_context(request).await?;
304
305 let sql = &query.query;
306 info!("get_flight_info_statement with query={sql}");
307
308 let flight_descriptor = request.into_inner();
309
310 let plan = ctx
311 .sql_to_logical_plan(sql)
312 .await
313 .map_err(df_error_to_status)?;
314
315 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
316
317 let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
319 .try_encode()
320 .map_err(flight_error_to_status)?;
321
322 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
323
324 let flight_info = FlightInfo::new()
325 .with_endpoint(endpoint)
326 .with_descriptor(flight_descriptor)
328 .try_with_schema(dataset_schema.as_ref())
329 .map_err(arrow_error_to_status)?;
330
331 Ok(Response::new(flight_info))
332 }
333
334 async fn get_flight_info_substrait_plan(
335 &self,
336 query: CommandStatementSubstraitPlan,
337 request: Request<FlightDescriptor>,
338 ) -> Result<Response<FlightInfo>> {
339 info!("get_flight_info_substrait_plan");
340 let (request, ctx) = self.new_context(request).await?;
341
342 let substrait_bytes = &query
343 .plan
344 .as_ref()
345 .ok_or(Status::invalid_argument(
346 "Expected substrait plan, found None",
347 ))?
348 .plan;
349
350 let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;
351
352 let flight_descriptor = request.into_inner();
353
354 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
355
356 let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
358 .try_encode()
359 .map_err(flight_error_to_status)?;
360
361 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
362
363 let flight_info = FlightInfo::new()
364 .with_endpoint(endpoint)
365 .with_descriptor(flight_descriptor)
367 .try_with_schema(dataset_schema.as_ref())
368 .map_err(arrow_error_to_status)?;
369
370 Ok(Response::new(flight_info))
371 }
372
373 async fn get_flight_info_prepared_statement(
374 &self,
375 cmd: CommandPreparedStatementQuery,
376 request: Request<FlightDescriptor>,
377 ) -> Result<Response<FlightInfo>> {
378 let (request, ctx) = self.new_context(request).await?;
379
380 let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone())
381 .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?;
382
383 info!("get_flight_info_prepared_statement with handle={handle}");
384
385 let flight_descriptor = request.into_inner();
386
387 let sql = handle.query();
388 let plan = ctx
389 .sql_to_logical_plan(sql)
390 .await
391 .map_err(df_error_to_status)?;
392
393 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
394
395 let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
397 .try_encode()
398 .map_err(flight_error_to_status)?;
399
400 let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket });
401
402 let flight_info = FlightInfo::new()
403 .with_endpoint(endpoint)
404 .with_descriptor(flight_descriptor)
406 .try_with_schema(dataset_schema.as_ref())
407 .map_err(arrow_error_to_status)?;
408
409 Ok(Response::new(flight_info))
410 }
411
412 async fn get_flight_info_catalogs(
413 &self,
414 query: CommandGetCatalogs,
415 request: Request<FlightDescriptor>,
416 ) -> Result<Response<FlightInfo>> {
417 info!("get_flight_info_catalogs");
418 let (request, _ctx) = self.new_context(request).await?;
419
420 let flight_descriptor = request.into_inner();
421 let ticket = Ticket {
422 ticket: query.as_any().encode_to_vec().into(),
423 };
424 let endpoint = FlightEndpoint::new().with_ticket(ticket);
425
426 let flight_info = FlightInfo::new()
427 .try_with_schema(&query.into_builder().schema())
428 .map_err(arrow_error_to_status)?
429 .with_endpoint(endpoint)
430 .with_descriptor(flight_descriptor);
431
432 Ok(Response::new(flight_info))
433 }
434
435 async fn get_flight_info_schemas(
436 &self,
437 query: CommandGetDbSchemas,
438 request: Request<FlightDescriptor>,
439 ) -> Result<Response<FlightInfo>> {
440 info!("get_flight_info_schemas");
441 let (request, _ctx) = self.new_context(request).await?;
442 let flight_descriptor = request.into_inner();
443 let ticket = Ticket {
444 ticket: query.as_any().encode_to_vec().into(),
445 };
446 let endpoint = FlightEndpoint::new().with_ticket(ticket);
447
448 let flight_info = FlightInfo::new()
449 .try_with_schema(&query.into_builder().schema())
450 .map_err(arrow_error_to_status)?
451 .with_endpoint(endpoint)
452 .with_descriptor(flight_descriptor);
453
454 Ok(Response::new(flight_info))
455 }
456
457 async fn get_flight_info_tables(
458 &self,
459 query: CommandGetTables,
460 request: Request<FlightDescriptor>,
461 ) -> Result<Response<FlightInfo>> {
462 info!("get_flight_info_tables");
463 let (request, _ctx) = self.new_context(request).await?;
464
465 let flight_descriptor = request.into_inner();
466 let ticket = Ticket {
467 ticket: query.as_any().encode_to_vec().into(),
468 };
469 let endpoint = FlightEndpoint::new().with_ticket(ticket);
470
471 let flight_info = FlightInfo::new()
472 .try_with_schema(&query.into_builder().schema())
473 .map_err(arrow_error_to_status)?
474 .with_endpoint(endpoint)
475 .with_descriptor(flight_descriptor);
476
477 Ok(Response::new(flight_info))
478 }
479
480 async fn get_flight_info_table_types(
481 &self,
482 query: CommandGetTableTypes,
483 request: Request<FlightDescriptor>,
484 ) -> Result<Response<FlightInfo>> {
485 info!("get_flight_info_table_types");
486 let (request, _ctx) = self.new_context(request).await?;
487
488 let flight_descriptor = request.into_inner();
489 let ticket = Ticket {
490 ticket: query.as_any().encode_to_vec().into(),
491 };
492 let endpoint = FlightEndpoint::new().with_ticket(ticket);
493
494 let flight_info = FlightInfo::new()
495 .try_with_schema(&GET_TABLE_TYPES_SCHEMA)
496 .map_err(arrow_error_to_status)?
497 .with_endpoint(endpoint)
498 .with_descriptor(flight_descriptor);
499
500 Ok(Response::new(flight_info))
501 }
502
503 async fn get_flight_info_sql_info(
504 &self,
505 _query: CommandGetSqlInfo,
506 request: Request<FlightDescriptor>,
507 ) -> Result<Response<FlightInfo>> {
508 info!("get_flight_info_sql_info");
509 let (_, _) = self.new_context(request).await?;
510
511 Err(Status::unimplemented("Implement CommandGetSqlInfo"))
512 }
513
514 async fn get_flight_info_primary_keys(
515 &self,
516 _query: CommandGetPrimaryKeys,
517 request: Request<FlightDescriptor>,
518 ) -> Result<Response<FlightInfo>> {
519 info!("get_flight_info_primary_keys");
520 let (_, _) = self.new_context(request).await?;
521
522 Err(Status::unimplemented(
523 "Implement get_flight_info_primary_keys",
524 ))
525 }
526
527 async fn get_flight_info_exported_keys(
528 &self,
529 _query: CommandGetExportedKeys,
530 request: Request<FlightDescriptor>,
531 ) -> Result<Response<FlightInfo>> {
532 info!("get_flight_info_exported_keys");
533 let (_, _) = self.new_context(request).await?;
534
535 Err(Status::unimplemented(
536 "Implement get_flight_info_exported_keys",
537 ))
538 }
539
540 async fn get_flight_info_imported_keys(
541 &self,
542 _query: CommandGetImportedKeys,
543 request: Request<FlightDescriptor>,
544 ) -> Result<Response<FlightInfo>> {
545 info!("get_flight_info_imported_keys");
546 let (_, _) = self.new_context(request).await?;
547
548 Err(Status::unimplemented(
549 "Implement get_flight_info_imported_keys",
550 ))
551 }
552
553 async fn get_flight_info_cross_reference(
554 &self,
555 _query: CommandGetCrossReference,
556 request: Request<FlightDescriptor>,
557 ) -> Result<Response<FlightInfo>> {
558 info!("get_flight_info_cross_reference");
559 let (_, _) = self.new_context(request).await?;
560
561 Err(Status::unimplemented(
562 "Implement get_flight_info_cross_reference",
563 ))
564 }
565
566 async fn get_flight_info_xdbc_type_info(
567 &self,
568 _query: CommandGetXdbcTypeInfo,
569 request: Request<FlightDescriptor>,
570 ) -> Result<Response<FlightInfo>> {
571 info!("get_flight_info_xdbc_type_info");
572 let (_, _) = self.new_context(request).await?;
573
574 Err(Status::unimplemented(
575 "Implement get_flight_info_xdbc_type_info",
576 ))
577 }
578
579 async fn do_get_statement(
580 &self,
581 _ticket: TicketStatementQuery,
582 request: Request<Ticket>,
583 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
584 info!("do_get_statement");
585 let (_, _) = self.new_context(request).await?;
586
587 Err(Status::unimplemented("Implement do_get_statement"))
588 }
589
590 async fn do_get_prepared_statement(
591 &self,
592 _query: CommandPreparedStatementQuery,
593 request: Request<Ticket>,
594 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
595 info!("do_get_prepared_statement");
596 let (_, _) = self.new_context(request).await?;
597
598 Err(Status::unimplemented("Implement do_get_prepared_statement"))
599 }
600
601 async fn do_get_catalogs(
602 &self,
603 query: CommandGetCatalogs,
604 request: Request<Ticket>,
605 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
606 info!("do_get_catalogs");
607 let (_request, ctx) = self.new_context(request).await?;
608 let catalog_names = ctx.inner.catalog_names();
609
610 let mut builder = query.into_builder();
611 for catalog_name in &catalog_names {
612 builder.append(catalog_name);
613 }
614 let schema = builder.schema();
615 let batch = builder.build();
616 let stream = FlightDataEncoderBuilder::new()
617 .with_schema(schema)
618 .build(futures::stream::once(async { batch }))
619 .map_err(Status::from);
620 Ok(Response::new(Box::pin(stream)))
621 }
622
623 async fn do_get_schemas(
624 &self,
625 query: CommandGetDbSchemas,
626 request: Request<Ticket>,
627 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
628 info!("do_get_schemas");
629 let (_request, ctx) = self.new_context(request).await?;
630 let catalog_name = query.catalog.clone();
631 let mut builder = query.into_builder();
633 if let Some(catalog_name) = &catalog_name {
634 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
635 for schema_name in &catalog.schema_names() {
636 builder.append(catalog_name, schema_name);
637 }
638 }
639 };
640
641 let schema = builder.schema();
642 let batch = builder.build();
643 let stream = FlightDataEncoderBuilder::new()
644 .with_schema(schema)
645 .build(futures::stream::once(async { batch }))
646 .map_err(Status::from);
647 Ok(Response::new(Box::pin(stream)))
648 }
649
650 async fn do_get_tables(
651 &self,
652 query: CommandGetTables,
653 request: Request<Ticket>,
654 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
655 info!("do_get_tables");
656 let (_request, ctx) = self.new_context(request).await?;
657 let catalog_name = query.catalog.clone();
658 let mut builder = query.into_builder();
659 if let Some(catalog_name) = &catalog_name {
661 if let Some(catalog) = ctx.inner.catalog(catalog_name) {
662 for schema_name in &catalog.schema_names() {
663 if let Some(schema) = catalog.schema(schema_name) {
664 for table_name in &schema.table_names() {
665 if let Some(table) =
666 schema.table(table_name).await.map_err(df_error_to_status)?
667 {
668 builder
669 .append(
670 catalog_name,
671 schema_name,
672 table_name,
673 table.table_type().to_string(),
674 &table.schema(),
675 )
676 .map_err(flight_error_to_status)?;
677 }
678 }
679 }
680 }
681 }
682 };
683
684 let schema = builder.schema();
685 let batch = builder.build();
686 let stream = FlightDataEncoderBuilder::new()
687 .with_schema(schema)
688 .build(futures::stream::once(async { batch }))
689 .map_err(Status::from);
690 Ok(Response::new(Box::pin(stream)))
691 }
692
693 async fn do_get_table_types(
694 &self,
695 _query: CommandGetTableTypes,
696 request: Request<Ticket>,
697 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
698 info!("do_get_table_types");
699 let (_, _) = self.new_context(request).await?;
700
701 let table_types: ArrayRef = Arc::new(StringArray::from(
703 vec![TableType::Base, TableType::View, TableType::Temporary]
704 .into_iter()
705 .map(|tt| tt.to_string())
706 .collect::<Vec<String>>(),
707 ));
708
709 let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap();
710
711 let stream = FlightDataEncoderBuilder::new()
712 .with_schema(GET_TABLE_TYPES_SCHEMA.clone())
713 .build(futures::stream::once(async { Ok(batch) }))
714 .map_err(Status::from);
715 Ok(Response::new(Box::pin(stream)))
716 }
717
718 async fn do_get_sql_info(
719 &self,
720 _query: CommandGetSqlInfo,
721 request: Request<Ticket>,
722 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
723 info!("do_get_sql_info");
724 let (_, _) = self.new_context(request).await?;
725
726 Err(Status::unimplemented("Implement do_get_sql_info"))
727 }
728
729 async fn do_get_primary_keys(
730 &self,
731 _query: CommandGetPrimaryKeys,
732 request: Request<Ticket>,
733 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
734 info!("do_get_primary_keys");
735 let (_, _) = self.new_context(request).await?;
736
737 Err(Status::unimplemented("Implement do_get_primary_keys"))
738 }
739
740 async fn do_get_exported_keys(
741 &self,
742 _query: CommandGetExportedKeys,
743 request: Request<Ticket>,
744 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
745 info!("do_get_exported_keys");
746 let (_, _) = self.new_context(request).await?;
747
748 Err(Status::unimplemented("Implement do_get_exported_keys"))
749 }
750
751 async fn do_get_imported_keys(
752 &self,
753 _query: CommandGetImportedKeys,
754 request: Request<Ticket>,
755 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
756 info!("do_get_imported_keys");
757 let (_, _) = self.new_context(request).await?;
758
759 Err(Status::unimplemented("Implement do_get_imported_keys"))
760 }
761
762 async fn do_get_cross_reference(
763 &self,
764 _query: CommandGetCrossReference,
765 request: Request<Ticket>,
766 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
767 info!("do_get_cross_reference");
768 let (_, _) = self.new_context(request).await?;
769
770 Err(Status::unimplemented("Implement do_get_cross_reference"))
771 }
772
773 async fn do_get_xdbc_type_info(
774 &self,
775 _query: CommandGetXdbcTypeInfo,
776 request: Request<Ticket>,
777 ) -> Result<Response<<Self as FlightService>::DoGetStream>> {
778 info!("do_get_xdbc_type_info");
779 let (_, _) = self.new_context(request).await?;
780
781 Err(Status::unimplemented("Implement do_get_xdbc_type_info"))
782 }
783
784 async fn do_put_statement_update(
785 &self,
786 _ticket: CommandStatementUpdate,
787 request: Request<PeekableFlightDataStream>,
788 ) -> Result<i64, Status> {
789 info!("do_put_statement_update");
790 let (_, _) = self.new_context(request).await?;
791
792 Err(Status::unimplemented("Implement do_put_statement_update"))
793 }
794
795 async fn do_put_prepared_statement_query(
796 &self,
797 query: CommandPreparedStatementQuery,
798 request: Request<PeekableFlightDataStream>,
799 ) -> Result<DoPutPreparedStatementResult, Status> {
800 info!("do_put_prepared_statement_query");
801 let (request, _) = self.new_context(request).await?;
802
803 let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;
804
805 info!(
806 "do_action_create_prepared_statement query={:?}",
807 handle.query()
808 );
809 let mut decoder =
812 FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
813 let schema = decode_schema(&mut decoder).await?;
814 let mut parameters = Vec::new();
815 let mut encoder =
816 StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
817 let mut total_rows = 0;
818 while let Some(msg) = decoder.try_next().await? {
819 match msg.payload {
820 DecodedPayload::None => {}
821 DecodedPayload::Schema(_) => {
822 return Err(Status::invalid_argument(
823 "parameter flight data must contain a single schema",
824 ));
825 }
826 DecodedPayload::RecordBatch(record_batch) => {
827 total_rows += record_batch.num_rows();
828 encoder
829 .write(&record_batch)
830 .map_err(arrow_error_to_status)?;
831 }
832 }
833 }
834 if total_rows > 1 {
835 return Err(Status::invalid_argument(
836 "parameters should contain a single row",
837 ));
838 }
839
840 handle.set_parameters(Some(parameters.into()));
841
842 let res = DoPutPreparedStatementResult {
843 prepared_statement_handle: Some(Bytes::from(handle)),
844 };
845
846 Ok(res)
847 }
848
849 async fn do_put_prepared_statement_update(
850 &self,
851 _handle: CommandPreparedStatementUpdate,
852 request: Request<PeekableFlightDataStream>,
853 ) -> Result<i64, Status> {
854 info!("do_put_prepared_statement_update");
855 let (_, _) = self.new_context(request).await?;
856
857 Ok(-1)
860 }
861
862 async fn do_put_substrait_plan(
863 &self,
864 _query: CommandStatementSubstraitPlan,
865 request: Request<PeekableFlightDataStream>,
866 ) -> Result<i64, Status> {
867 info!("do_put_prepared_statement_update");
868 let (_, _) = self.new_context(request).await?;
869
870 Err(Status::unimplemented(
871 "Implement do_put_prepared_statement_update",
872 ))
873 }
874
875 async fn do_action_create_prepared_statement(
876 &self,
877 query: ActionCreatePreparedStatementRequest,
878 request: Request<Action>,
879 ) -> Result<ActionCreatePreparedStatementResult, Status> {
880 let (_, ctx) = self.new_context(request).await?;
881
882 let sql = query.query.clone();
883 info!(
884 "do_action_create_prepared_statement query={:?}",
885 query.query
886 );
887
888 let plan = ctx
889 .sql_to_logical_plan(sql.as_str())
890 .await
891 .map_err(df_error_to_status)?;
892
893 let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
894 let parameter_schema = parameter_schema_for_plan(&plan).map_err(|e| e.as_ref().clone())?;
895
896 let dataset_schema =
897 encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
898 let parameter_schema =
899 encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;
900
901 let handle = QueryHandle::new(sql, None);
902
903 let res = ActionCreatePreparedStatementResult {
904 prepared_statement_handle: Bytes::from(handle),
905 dataset_schema,
906 parameter_schema,
907 };
908
909 Ok(res)
910 }
911
912 async fn do_action_close_prepared_statement(
913 &self,
914 query: ActionClosePreparedStatementRequest,
915 request: Request<Action>,
916 ) -> Result<(), Status> {
917 let (_, _) = self.new_context(request).await?;
918
919 let handle = query.prepared_statement_handle.as_ref();
920 if let Ok(handle) = std::str::from_utf8(handle) {
921 info!("do_action_close_prepared_statement with handle {handle:?}",);
922
923 }
925 Ok(())
926 }
927
928 async fn do_action_create_prepared_substrait_plan(
929 &self,
930 _query: ActionCreatePreparedSubstraitPlanRequest,
931 request: Request<Action>,
932 ) -> Result<ActionCreatePreparedStatementResult, Status> {
933 info!("do_action_create_prepared_substrait_plan");
934 let (_, _) = self.new_context(request).await?;
935
936 Err(Status::unimplemented(
937 "Implement do_action_create_prepared_substrait_plan",
938 ))
939 }
940
941 async fn do_action_begin_transaction(
942 &self,
943 _query: ActionBeginTransactionRequest,
944 request: Request<Action>,
945 ) -> Result<ActionBeginTransactionResult, Status> {
946 let (_, _) = self.new_context(request).await?;
947
948 info!("do_action_begin_transaction");
949 Err(Status::unimplemented(
950 "Implement do_action_begin_transaction",
951 ))
952 }
953
954 async fn do_action_end_transaction(
955 &self,
956 _query: ActionEndTransactionRequest,
957 request: Request<Action>,
958 ) -> Result<(), Status> {
959 info!("do_action_end_transaction");
960 let (_, _) = self.new_context(request).await?;
961
962 Err(Status::unimplemented("Implement do_action_end_transaction"))
963 }
964
965 async fn do_action_begin_savepoint(
966 &self,
967 _query: ActionBeginSavepointRequest,
968 request: Request<Action>,
969 ) -> Result<ActionBeginSavepointResult, Status> {
970 info!("do_action_begin_savepoint");
971 let (_, _) = self.new_context(request).await?;
972
973 Err(Status::unimplemented("Implement do_action_begin_savepoint"))
974 }
975
976 async fn do_action_end_savepoint(
977 &self,
978 _query: ActionEndSavepointRequest,
979 request: Request<Action>,
980 ) -> Result<(), Status> {
981 info!("do_action_end_savepoint");
982 let (_, _) = self.new_context(request).await?;
983
984 Err(Status::unimplemented("Implement do_action_end_savepoint"))
985 }
986
987 async fn do_action_cancel_query(
988 &self,
989 _query: ActionCancelQueryRequest,
990 request: Request<Action>,
991 ) -> Result<ActionCancelQueryResult, Status> {
992 info!("do_action_cancel_query");
993 let (_, _) = self.new_context(request).await?;
994
995 Err(Status::unimplemented("Implement do_action_cancel_query"))
996 }
997
998 async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
999}
1000
1001async fn parse_substrait_bytes(
1004 ctx: &FlightSqlSessionContext,
1005 substrait: &Bytes,
1006) -> Result<LogicalPlan> {
1007 let substrait_plan = deserialize_bytes(substrait.to_vec())
1008 .await
1009 .map_err(df_error_to_status)?;
1010
1011 from_substrait_plan(&ctx.inner.state(), &substrait_plan)
1012 .await
1013 .map_err(df_error_to_status)
1014}
1015
1016fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
1018 let options = IpcWriteOptions::default();
1019
1020 let message: Result<IpcMessage, ArrowError> = SchemaAsIpc::new(schema, &options).try_into();
1022
1023 let IpcMessage(schema) = message?;
1024
1025 Ok(schema)
1026}
1027
1028fn get_schema_for_plan(logical_plan: &LogicalPlan, with_metadata: bool) -> SchemaRef {
1030 let schema: SchemaRef = if with_metadata {
1031 let df_schema = logical_plan.schema();
1033
1034 let fields_with_metadata: Vec<_> = df_schema
1036 .iter()
1037 .map(|(qualifier, field)| {
1038 if let Some(table_ref) = qualifier {
1040 let mut metadata = field.metadata().clone();
1041 metadata.insert("table_name".to_string(), table_ref.to_string());
1042 field.as_ref().clone().with_metadata(metadata)
1043 } else {
1044 field.as_ref().clone()
1045 }
1046 })
1047 .collect();
1048
1049 Arc::new(Schema::new_with_metadata(
1050 fields_with_metadata,
1051 df_schema.as_ref().metadata().clone(),
1052 ))
1053 } else {
1054 Arc::new(Schema::from(logical_plan.schema().as_ref()))
1055 };
1056
1057 let flight_data_stream = FlightDataEncoderBuilder::new()
1060 .with_schema(schema)
1062 .build(futures::stream::iter([]));
1063
1064 flight_data_stream
1066 .known_schema()
1067 .expect("flight data schema should be known when explicitly provided via `with_schema`")
1068}
1069
1070fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Box<Status>> {
1071 let parameters = plan
1072 .get_parameter_types()
1073 .map_err(df_error_to_status)?
1074 .into_iter()
1075 .map(|(name, dt)| {
1076 dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
1077 Status::internal(format!(
1078 "unable to determine type of query parameter {name}"
1079 ))
1080 })
1081 })
1082 .collect::<Result<BTreeMap<_, _>, Status>>()?;
1084
1085 let mut builder = SchemaBuilder::new();
1086 parameters
1087 .into_iter()
1088 .for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
1089 Ok(builder.finish().into())
1090}
1091
1092fn arrow_error_to_status(err: ArrowError) -> Status {
1093 Status::internal(format!("{err:?}"))
1094}
1095
1096fn flight_error_to_status(err: FlightError) -> Status {
1097 Status::internal(format!("{err:?}"))
1098}
1099
1100fn df_error_to_status(err: DataFusionError) -> Status {
1101 Status::internal(format!("{err:?}"))
1102}
1103
1104fn status_to_flight_error(status: Status) -> FlightError {
1105 FlightError::Tonic(Box::new(status))
1106}
1107
1108async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
1109 while let Some(msg) = decoder.try_next().await? {
1110 match msg.payload {
1111 DecodedPayload::None => {}
1112 DecodedPayload::Schema(schema) => {
1113 return Ok(schema);
1114 }
1115 DecodedPayload::RecordBatch(_) => {
1116 return Err(Status::invalid_argument(
1117 "parameter flight data must have a known schema",
1118 ));
1119 }
1120 }
1121 }
1122
1123 Err(Status::invalid_argument(
1124 "parameter flight data must have a schema",
1125 ))
1126}
1127
1128fn decode_param_values(
1130 parameters: Option<&[u8]>,
1131) -> Result<Option<ParamValues>, arrow::error::ArrowError> {
1132 parameters
1133 .map(|parameters| {
1134 let decoder = StreamReader::try_new(parameters, None)?;
1135 let schema = decoder.schema();
1136 let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
1137 let batch = concat_batches(&schema, batches.iter())?;
1138 Ok(record_to_param_values(&batch)?)
1139 })
1140 .transpose()
1141}
1142
1143fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
1145 let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();
1146
1147 let mut is_list = true;
1148 for col_index in 0..batch.num_columns() {
1149 let array = batch.column(col_index);
1150 let scalar = ScalarValue::try_from_array(array, 0)?;
1151 let name = batch
1152 .schema_ref()
1153 .field(col_index)
1154 .name()
1155 .trim_start_matches('$')
1156 .to_string();
1157 let index = name.parse().ok();
1158 is_list &= index.is_some();
1159 param_values.push((name, index, scalar));
1160 }
1161 if is_list {
1162 let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
1163 .into_iter()
1164 .map(|(_name, index, value)| (index, value))
1165 .collect();
1166 values.sort_by_key(|(index, _value)| *index);
1167 Ok(values
1168 .into_iter()
1169 .map(|(_index, value)| value)
1170 .collect::<Vec<ScalarValue>>()
1171 .into())
1172 } else {
1173 Ok(param_values
1174 .into_iter()
1175 .map(|(name, _index, value)| (name, value))
1176 .collect::<Vec<(String, ScalarValue)>>()
1177 .into())
1178 }
1179}