diff --git a/Cargo.lock b/Cargo.lock index 6e7bd09af8b9d..3dbe9b24b70e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -215,9 +215,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" dependencies = [ "backtrace", ] @@ -11411,8 +11411,10 @@ dependencies = [ "itertools 0.12.1", "matches", "serde", + "thiserror", "tracing", "tracing-subscriber", + "winnow 0.6.8 (git+https://github.com/TennyZhuang/winnow.git?rev=a6b1f04)", "workspace-hack", ] @@ -14110,9 +14112,9 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] @@ -14141,9 +14143,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", @@ -14540,7 +14542,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.5", + "winnow 0.6.8 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -16220,9 +16222,17 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.5" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d" +dependencies = [ + "memchr", +] + +[[package]] +name = "winnow" +version = "0.6.8" +source = "git+https://github.com/TennyZhuang/winnow.git?rev=a6b1f04#a6b1f04cbe9b39d218da5a121e610144ebf961b0" dependencies = [ "memchr", ] diff --git a/src/sqlparser/Cargo.toml b/src/sqlparser/Cargo.toml index ef57da9aa62f8..8c20eb9bf2f29 100644 --- a/src/sqlparser/Cargo.toml +++ b/src/sqlparser/Cargo.toml @@ -1,10 +1,7 @@ [package] name = "risingwave_sqlparser" license = "Apache-2.0" -include = [ - "src/**/*.rs", - "Cargo.toml", -] +include = ["src/**/*.rs", "Cargo.toml"] version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } @@ -27,8 +24,10 @@ normal = ["workspace-hack"] [dependencies] itertools = { workspace = true } serde = { version = "1.0", features = ["derive"], optional = true } +thiserror = "1.0.61" tracing = "0.1" tracing-subscriber = "0.3" +winnow = { version = "0.6.8", git = "https://github.com/TennyZhuang/winnow.git", rev = "a6b1f04" } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/sqlparser/src/lib.rs b/src/sqlparser/src/lib.rs index 612b11078eac4..a102e5428edae 100644 --- a/src/sqlparser/src/lib.rs +++ b/src/sqlparser/src/lib.rs @@ -45,6 +45,7 @@ extern crate alloc; pub mod ast; pub mod keywords; pub mod parser; +pub mod parser_v2; pub mod tokenizer; #[doc(hidden)] diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 9b86b098b14ba..b70fc1b627abe 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -27,6 +27,7 @@ use tracing::{debug, instrument}; use crate::ast::*; use crate::keywords::{self, Keyword}; +use crate::parser_v2; use crate::tokenizer::*; pub(crate) const UPSTREAM_SOURCE_KEY: &str = "connector"; @@ -172,19 +173,51 @@ pub struct Parser { tokens: Vec, /// The index of the first unprocessed token in `self.tokens` index: usize, - /// Since we cannot distinguish `>>` and double `>`, so use `angle_brackets_num` to store the - /// number of `<` to match `>` in sql like `struct>`. - angle_brackets_num: i32, } impl Parser { /// Parse the specified tokens pub fn new(tokens: Vec) -> Self { - Parser { - tokens, - index: 0, - angle_brackets_num: 0, - } + Parser { tokens, index: 0 } + } + + /// Adaptor for [`parser_v2`]. + /// + /// You can call a v2 parser from original parser by using this method. + pub(crate) fn parse_v2<'a, O>( + &'a mut self, + mut parse_next: impl winnow::Parser< + winnow::Located>, + O, + winnow::error::ContextError, + >, + ) -> Result { + use winnow::stream::Location; + + let mut token_stream = winnow::Located::new(parser_v2::TokenStreamWrapper { + tokens: &self.tokens[self.index..], + }); + let output = parse_next.parse_next(&mut token_stream).map_err(|e| { + let msg = if let Some(e) = e.into_inner() + && let Some(cause) = e.cause() + { + format!(": {}", cause) + } else { + "".to_string() + }; + ParserError::ParserError(format!( + "Unexpected {}{}", + if self.index + token_stream.location() >= self.tokens.len() { + &"EOF" as &dyn std::fmt::Display + } else { + &self.tokens[self.index + token_stream.location()] as &dyn std::fmt::Display + }, + msg + )) + }); + let offset = token_stream.location(); + self.index += offset; + output } /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) @@ -3806,136 +3839,7 @@ impl Parser { /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) and convert /// into an array of that datatype if needed pub fn parse_data_type(&mut self) -> Result { - let mut data_type = self.parse_data_type_inner()?; - while self.consume_token(&Token::LBracket) { - self.expect_token(&Token::RBracket)?; - data_type = DataType::Array(Box::new(data_type)); - } - Ok(data_type) - } - - /// Parse struct `data_type` e.g.`>`. - pub fn parse_struct_data_type(&mut self) -> Result, ParserError> { - let mut columns = vec![]; - if !self.consume_token(&Token::Lt) { - return self.expected("'<' after struct", self.peek_token()); - } - self.angle_brackets_num += 1; - - loop { - if let Token::Word(_) = self.peek_token().token { - let name = self.parse_identifier_non_reserved()?; - let data_type = self.parse_data_type()?; - columns.push(StructField { name, data_type }) - } else { - return self.expected("struct field name", self.peek_token()); - } - if self.angle_brackets_num == 0 { - break; - } else if self.consume_token(&Token::Gt) { - self.angle_brackets_num -= 1; - break; - } else if self.consume_token(&Token::ShiftRight) { - if self.angle_brackets_num >= 1 { - self.angle_brackets_num -= 2; - break; - } else { - return parser_err!("too much '>'"); - } - } else if !self.consume_token(&Token::Comma) { - return self.expected("',' or '>' after column definition", self.peek_token()); - } - } - - Ok(columns) - } - - /// Parse a SQL datatype - pub fn parse_data_type_inner(&mut self) -> Result { - let token = self.next_token(); - match token.token { - Token::Word(w) => match w.keyword { - Keyword::BOOLEAN | Keyword::BOOL => Ok(DataType::Boolean), - Keyword::FLOAT => { - let precision = self.parse_optional_precision()?; - match precision { - Some(0) => Err(ParserError::ParserError( - "precision for type float must be at least 1 bit".to_string(), - )), - Some(54..) => Err(ParserError::ParserError( - "precision for type float must be less than 54 bits".to_string(), - )), - _ => Ok(DataType::Float(precision)), - } - } - Keyword::REAL => Ok(DataType::Real), - Keyword::DOUBLE => { - let _ = self.parse_keyword(Keyword::PRECISION); - Ok(DataType::Double) - } - Keyword::SMALLINT => Ok(DataType::SmallInt), - Keyword::INT | Keyword::INTEGER => Ok(DataType::Int), - Keyword::BIGINT => Ok(DataType::BigInt), - Keyword::STRING | Keyword::VARCHAR => Ok(DataType::Varchar), - Keyword::CHAR | Keyword::CHARACTER => { - if self.parse_keyword(Keyword::VARYING) { - Ok(DataType::Varchar) - } else { - Ok(DataType::Char(self.parse_optional_precision()?)) - } - } - Keyword::UUID => Ok(DataType::Uuid), - Keyword::DATE => Ok(DataType::Date), - Keyword::TIMESTAMP => { - let with_time_zone = self.parse_keyword(Keyword::WITH); - if with_time_zone || self.parse_keyword(Keyword::WITHOUT) { - self.expect_keywords(&[Keyword::TIME, Keyword::ZONE])?; - } - Ok(DataType::Timestamp(with_time_zone)) - } - Keyword::TIME => { - let with_time_zone = self.parse_keyword(Keyword::WITH); - if with_time_zone || self.parse_keyword(Keyword::WITHOUT) { - self.expect_keywords(&[Keyword::TIME, Keyword::ZONE])?; - } - Ok(DataType::Time(with_time_zone)) - } - // Interval types can be followed by a complicated interval - // qualifier that we don't currently support. See - // parse_interval_literal for a taste. - Keyword::INTERVAL => Ok(DataType::Interval), - Keyword::REGCLASS => Ok(DataType::Regclass), - Keyword::REGPROC => Ok(DataType::Regproc), - Keyword::TEXT => { - if self.consume_token(&Token::LBracket) { - // Note: this is postgresql-specific - self.expect_token(&Token::RBracket)?; - Ok(DataType::Array(Box::new(DataType::Text))) - } else { - Ok(DataType::Text) - } - } - Keyword::STRUCT => Ok(DataType::Struct(self.parse_struct_data_type()?)), - Keyword::BYTEA => Ok(DataType::Bytea), - Keyword::NUMERIC | Keyword::DECIMAL | Keyword::DEC => { - let (precision, scale) = self.parse_optional_precision_scale()?; - Ok(DataType::Decimal(precision, scale)) - } - _ => { - self.prev_token(); - let type_name = self.parse_object_name()?; - // JSONB is not a keyword - if type_name.to_string().eq_ignore_ascii_case("jsonb") { - Ok(DataType::Jsonb) - } else { - Ok(DataType::Custom(type_name)) - } - } - }, - unexpected => { - self.expected("a data type name", unexpected.with_location(token.location)) - } - } + self.parse_v2(parser_v2::data_type) } /// Parse `AS identifier` (or simply `identifier` if it's not a reserved keyword) diff --git a/src/sqlparser/src/parser_v2/data_type.rs b/src/sqlparser/src/parser_v2/data_type.rs new file mode 100644 index 0000000000000..a8e9ea20f3ba4 --- /dev/null +++ b/src/sqlparser/src/parser_v2/data_type.rs @@ -0,0 +1,231 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Parsers for data types. +//! +//! This module contains parsers for data types. To handle the anbiguity of `>>` and `> >` in struct definition, +//! we need to use a stateful parser here. See [`with_state`] for more information. + +use core::cell::RefCell; +use std::rc::Rc; + +use winnow::combinator::{ + alt, cut_err, delimited, dispatch, empty, fail, opt, preceded, repeat, separated, seq, + terminated, trace, +}; +use winnow::error::{ContextError, ErrMode, ErrorKind, FromExternalError, StrContext}; +use winnow::{PResult, Parser, Stateful}; + +use super::{ + identifier_non_reserved, keyword, literal_uint, object_name, precision_in_range, with_state, + TokenStream, +}; +use crate::ast::{DataType, StructField}; +use crate::keywords::Keyword; +use crate::tokenizer::Token; + +#[derive(Default, Debug)] +struct DataTypeParsingState { + /// Since we can't distinguish between `>>` and `> >` in tokenizer, we need to handle this case in the parser. + /// When we want a [`>`][Token::Gt] but actually consumed a [`>>`][Token::ShiftRight], we set this to true. + /// When the value was true and we want a [`>`][Token::Gt], we just set this to false instead of really consume it. + remaining_close: Rc>, +} + +type StatefulStream = Stateful; + +/// Consume struct type definitions +fn struct_data_type(input: &mut StatefulStream) -> PResult> +where + S: TokenStream, +{ + let remaining_close1 = input.state.remaining_close.clone(); + let remaining_close2 = input.state.remaining_close.clone(); + + // Consume an abstract `>`, it may be the `remaining_close1` flag set by previous `>>`. + let consume_close = trace( + "consume_struct_close", + alt(( + trace( + "consume_remaining_close", + move |input: &mut StatefulStream| -> PResult<()> { + if *remaining_close1.borrow() { + *remaining_close1.borrow_mut() = false; + Ok(()) + } else { + fail(input) + } + }, + ) + .void(), + trace( + "produce_remaining_close", + ( + Token::ShiftRight, + move |_input: &mut StatefulStream| -> PResult<()> { + *remaining_close2.borrow_mut() = true; + Ok(()) + }, + ) + .void(), + ), + Token::Gt.void(), + )), + ); + + // If there is an `over-consumed' `>`, we shouldn't handle `,`. + let sep = |input: &mut StatefulStream| -> PResult<()> { + if *input.state.remaining_close.borrow() { + fail(input) + } else { + Token::Comma.void().parse_next(input) + } + }; + + delimited( + Token::Lt, + cut_err(separated( + 1.., + trace( + "struct_field", + seq! { + StructField { + name: identifier_non_reserved, + data_type: data_type_stateful, + } + }, + ), + sep, + )), + cut_err(consume_close), + ) + .context(StrContext::Label("struct_data_type")) + .parse_next(input) +} + +/// Consume a data type definition. +/// +/// The parser is the main entry point for data type parsing. +pub fn data_type(input: &mut S) -> PResult +where + S: TokenStream, +{ + #[derive(Debug, thiserror::Error)] + #[error("Unconsumed `>>`")] + struct UnconsumedShiftRight; + + with_state::(terminated( + data_type_stateful, + trace("data_type_verify_state", |input: &mut StatefulStream| { + // If there is remaining `>`, we should fail. + if *input.state.remaining_close.borrow() { + Err(ErrMode::Cut(ContextError::from_external_error( + input, + ErrorKind::Fail, + UnconsumedShiftRight, + ))) + } else { + Ok(()) + } + }), + )) + .context(StrContext::Label("data_type")) + .parse_next(input) +} + +/// Data type parsing with stateful stream. +fn data_type_stateful(input: &mut StatefulStream) -> PResult +where + S: TokenStream, +{ + repeat(0.., (Token::LBracket, cut_err(Token::RBracket))) + .fold1(data_type_stateful_inner, |mut acc, _| { + acc = DataType::Array(Box::new(acc)); + acc + }) + .parse_next(input) +} + +/// Consume a data type except [`DataType::Array`]. +fn data_type_stateful_inner(input: &mut StatefulStream) -> PResult +where + S: TokenStream, +{ + let with_time_zone = || { + opt(alt(( + (Keyword::WITH, Keyword::TIME, Keyword::ZONE).value(true), + (Keyword::WITHOUT, Keyword::TIME, Keyword::ZONE).value(false), + ))) + .map(|x| x.unwrap_or(false)) + }; + + let precision_and_scale = || { + opt(delimited( + Token::LParen, + (literal_uint, opt(preceded(Token::Comma, literal_uint))), + Token::RParen, + )) + .map(|p| match p { + Some((x, y)) => (Some(x), y), + None => (None, None), + }) + }; + + let keywords = dispatch! {keyword; + Keyword::BOOLEAN | Keyword::BOOL => empty.value(DataType::Boolean), + Keyword::FLOAT => opt(precision_in_range(1..54)).map(DataType::Float), + Keyword::REAL => empty.value(DataType::Real), + Keyword::DOUBLE => opt(Keyword::PRECISION).value(DataType::Double), + Keyword::SMALLINT => empty.value(DataType::SmallInt), + Keyword::INT | Keyword::INTEGER => empty.value(DataType::Int), + Keyword::BIGINT => empty.value(DataType::BigInt), + Keyword::STRING | Keyword::VARCHAR => empty.value(DataType::Varchar), + Keyword::CHAR | Keyword::CHARACTER => dispatch! {keyword; + Keyword::VARYING => empty.value(DataType::Varchar), + _ => opt(precision_in_range(..)).map(DataType::Char), + }, + Keyword::UUID => empty.value(DataType::Uuid), + Keyword::DATE => empty.value(DataType::Date), + Keyword::TIMESTAMP => with_time_zone().map(DataType::Timestamp), + Keyword::TIME => with_time_zone().map(DataType::Time), + // TODO: Support complex interval type parsing. + Keyword::INTERVAL => empty.value(DataType::Interval), + Keyword::REGCLASS => empty.value(DataType::Regclass), + Keyword::REGPROC => empty.value(DataType::Regproc), + Keyword::TEXT => empty.value(DataType::Text), + Keyword::STRUCT => cut_err(struct_data_type).map(DataType::Struct), + Keyword::BYTEA => empty.value(DataType::Bytea), + Keyword::NUMERIC | Keyword::DECIMAL | Keyword::DEC => cut_err(precision_and_scale()).map(|(precision, scale)| { + DataType::Decimal(precision, scale) + }), + _ => fail, + }; + + trace( + "data_type_inner", + alt(( + keywords, + trace( + "non_keyword_data_type", + object_name.map(|name| { + if name.to_string().eq_ignore_ascii_case("jsonb") { + // JSONB is not a keyword + DataType::Jsonb + } else { + DataType::Custom(name) + } + }), + ), + )), + ) + .parse_next(input) +} diff --git a/src/sqlparser/src/parser_v2/impl_.rs b/src/sqlparser/src/parser_v2/impl_.rs new file mode 100644 index 0000000000000..fc09a02b7b591 --- /dev/null +++ b/src/sqlparser/src/parser_v2/impl_.rs @@ -0,0 +1,153 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use winnow::stream::{Checkpoint, Offset, SliceLen, Stream, StreamIsPartial, UpdateSlice}; + +use crate::tokenizer::{Token, TokenWithLocation, Whitespace}; + +#[derive(Copy, Clone, Debug)] +pub struct CheckpointWrapper<'a>(Checkpoint<&'a [TokenWithLocation], &'a [TokenWithLocation]>); + +impl<'a> Offset> for CheckpointWrapper<'a> { + #[inline(always)] + fn offset_from(&self, start: &Self) -> usize { + self.0.offset_from(&start.0) + } +} + +/// Customized wrapper that implements [`TokenStream`][super::TokenStream], override [`Debug`] implementation for better diagnostics. +#[derive(Default, Copy, Clone)] +pub struct TokenStreamWrapper<'a> { + pub tokens: &'a [TokenWithLocation], +} + +impl<'a> std::fmt::Debug for TokenStreamWrapper<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for tok in self.tokens { + let tok = &tok.token; + if matches!(tok, Token::Whitespace(Whitespace::Newline)) { + write!(f, "\\n")?; + } else { + write!(f, "{}", tok)?; + } + } + Ok(()) + } +} + +impl<'a> Offset> for TokenStreamWrapper<'a> { + #[inline(always)] + fn offset_from(&self, start: &Self) -> usize { + self.tokens.offset_from(&start.tokens) + } +} + +impl<'a> Offset> for TokenStreamWrapper<'a> { + #[inline(always)] + fn offset_from(&self, start: &CheckpointWrapper<'a>) -> usize { + self.tokens.offset_from(&start.0) + } +} + +impl<'a> SliceLen for TokenStreamWrapper<'a> { + #[inline(always)] + fn slice_len(&self) -> usize { + self.tokens.len() + } +} + +impl<'a> StreamIsPartial for TokenStreamWrapper<'a> { + type PartialState = <&'a [TokenWithLocation] as StreamIsPartial>::PartialState; + + #[must_use] + #[inline(always)] + fn complete(&mut self) -> Self::PartialState { + self.tokens.complete() + } + + #[inline(always)] + fn restore_partial(&mut self, state: Self::PartialState) { + self.tokens.restore_partial(state) + } + + #[inline(always)] + fn is_partial_supported() -> bool { + <&'a [TokenWithLocation] as StreamIsPartial>::is_partial_supported() + } +} + +impl<'a> Stream for TokenStreamWrapper<'a> { + type Checkpoint = CheckpointWrapper<'a>; + type IterOffsets = <&'a [TokenWithLocation] as Stream>::IterOffsets; + type Slice = TokenStreamWrapper<'a>; + type Token = <&'a [TokenWithLocation] as Stream>::Token; + + #[inline(always)] + fn iter_offsets(&self) -> Self::IterOffsets { + self.tokens.iter_offsets() + } + + #[inline(always)] + fn eof_offset(&self) -> usize { + self.tokens.eof_offset() + } + + #[inline(always)] + fn next_token(&mut self) -> Option { + self.tokens.next_token() + } + + #[inline(always)] + fn offset_for

(&self, predicate: P) -> Option + where + P: Fn(Self::Token) -> bool, + { + self.tokens.offset_for(predicate) + } + + #[inline(always)] + fn offset_at(&self, tokens: usize) -> Result { + self.tokens.offset_at(tokens) + } + + #[inline(always)] + fn next_slice(&mut self, offset: usize) -> Self::Slice { + TokenStreamWrapper { + tokens: self.tokens.next_slice(offset), + } + } + + #[inline(always)] + fn checkpoint(&self) -> Self::Checkpoint { + CheckpointWrapper(self.tokens.checkpoint()) + } + + #[inline(always)] + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.tokens.reset(&checkpoint.0) + } + + #[inline(always)] + fn raw(&self) -> &dyn std::fmt::Debug { + // We customized the `Debug` implementation in the wrapper, so don't return `self.tokens` here. + self + } +} + +impl<'a> UpdateSlice for TokenStreamWrapper<'a> { + #[inline(always)] + fn update_slice(self, inner: Self::Slice) -> Self { + TokenStreamWrapper { + tokens: self.tokens.update_slice(inner.tokens), + } + } +} diff --git a/src/sqlparser/src/parser_v2/mod.rs b/src/sqlparser/src/parser_v2/mod.rs new file mode 100644 index 0000000000000..03ff0af218f35 --- /dev/null +++ b/src/sqlparser/src/parser_v2/mod.rs @@ -0,0 +1,190 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use winnow::combinator::{preceded, separated, trace}; +use winnow::error::{ContextError, StrContext}; +use winnow::stream::{Location, Stream, StreamIsPartial}; +use winnow::token::{any, take_while}; +use winnow::{PResult, Parser, Stateful}; + +use crate::ast::{Ident, ObjectName}; +use crate::keywords::{self, Keyword}; +use crate::tokenizer::{Token, TokenWithLocation}; + +mod data_type; +mod impl_; +mod number; + +pub(crate) use data_type::*; +pub(crate) use impl_::TokenStreamWrapper; +pub(crate) use number::*; + +/// Bundle trait requirements from winnow, so that we don't need to write them everywhere. +/// +/// All combinators should accept a generic `S` that implements `TokenStream`. +pub trait TokenStream: + Stream + StreamIsPartial + Location + Default +{ +} + +impl TokenStream for S where + S: Stream + StreamIsPartial + Location + Default +{ +} + +/// Consume any token, including whitespaces. In almost all cases, you should use [`token`] instead. +fn any_token(input: &mut S) -> PResult +where + S: TokenStream, +{ + any(input) +} + +/// Consume any non-whitespace token. +/// +/// If you need to consume a specific token, use [`Token::?`][Token] directly, which already implements [`Parser`]. +fn token(input: &mut S) -> PResult +where + S: TokenStream, +{ + preceded( + take_while(0.., |token: TokenWithLocation| { + matches!(token.token, Token::Whitespace(_)) + }), + any_token, + ) + .parse_next(input) +} + +/// Consume a keyword. +/// +/// If you need to consume a specific keyword, use [`Keyword::?`][Keyword] directly, which already implements [`Parser`]. +fn keyword(input: &mut S) -> PResult +where + S: TokenStream, +{ + token + .verify_map(|t| match &t.token { + Token::Word(w) if w.keyword != Keyword::NoKeyword => Some(w.keyword), + _ => None, + }) + .context(StrContext::Label("keyword")) + .parse_next(input) +} + +impl Parser for Token +where + I: TokenStream, +{ + fn parse_next(&mut self, input: &mut I) -> PResult { + trace( + format!("token {}", self), + token.verify(move |t: &TokenWithLocation| t.token == *self), + ) + .parse_next(input) + } +} + +impl Parser for Keyword +where + I: TokenStream, +{ + fn parse_next(&mut self, input: &mut I) -> PResult { + token + .verify_map(move |t| match &t.token { + Token::Word(w) if *self == w.keyword => Some(w.keyword), + _ => None, + }) + .parse_next(input) + } +} + +/// Consume an identifier that is not a reserved keyword. +fn identifier_non_reserved(input: &mut S) -> PResult +where + S: TokenStream, +{ + // FIXME: Reporting error correctly. + token + .verify_map(|t| match &t.token { + Token::Word(w) if !keywords::RESERVED_FOR_COLUMN_OR_TABLE_NAME.contains(&w.keyword) => { + w.to_ident().ok() + } + _ => None, + }) + .parse_next(input) +} + +/// Consume an object name. +/// +/// FIXME: Object name is extremely complex, we only handle a subset here. +fn object_name(input: &mut S) -> PResult +where + S: TokenStream, +{ + separated(1.., identifier_non_reserved, Token::Period) + .map(ObjectName) + .parse_next(input) +} + +/// Accept a subparser contains a given state. +/// +/// The state will be constructed using [`Default::default()`]. +fn with_state(mut parse_next: ParseNext) -> impl Parser +where + S: TokenStream, + State: Default, + ParseNext: Parser, O, ContextError>, +{ + move |input: &mut S| -> PResult { + let state = State::default(); + let input2 = std::mem::take(input); + let mut stateful = Stateful { + input: input2, + state, + }; + let output = parse_next.parse_next(&mut stateful); + *input = stateful.input; + output + } +} + +#[cfg(test)] +mod tests { + use winnow::Located; + + use super::*; + use crate::tokenizer::Tokenizer; + + #[test] + fn test_basic() { + let input = "SELECT 1"; + let tokens = Tokenizer::new(input).tokenize_with_location().unwrap(); + let mut token_stream = Located::new(&*tokens); + Token::make_keyword("SELECT") + .parse_next(&mut token_stream) + .unwrap(); + } + + #[test] + fn test_stateful() { + let input = "SELECT 1"; + let tokens = Tokenizer::new(input).tokenize_with_location().unwrap(); + let mut token_stream = Located::new(&*tokens); + with_state(|input: &mut Stateful<_, usize>| -> PResult<()> { + input.state += 1; + Token::make_keyword("SELECT").void().parse_next(input) + }) + .parse_next(&mut token_stream) + .unwrap(); + } +} diff --git a/src/sqlparser/src/parser_v2/number.rs b/src/sqlparser/src/parser_v2/number.rs new file mode 100644 index 0000000000000..e466d9f3019c3 --- /dev/null +++ b/src/sqlparser/src/parser_v2/number.rs @@ -0,0 +1,70 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::ops::RangeBounds; + +use winnow::combinator::{cut_err, delimited}; +use winnow::error::ContextError; +use winnow::{PResult, Parser}; + +use super::{token, TokenStream}; +use crate::tokenizer::Token; + +/// Consume a [number][Token::Number] from token. +pub fn token_number(input: &mut S) -> PResult +where + S: TokenStream, +{ + token + .verify_map(|t| { + if let Token::Number(number) = t.token { + Some(number) + } else { + None + } + }) + .parse_next(input) +} + +/// Consume an unsigned literal integer/long +pub fn literal_uint(input: &mut S) -> PResult +where + S: TokenStream, +{ + token_number.try_map(|s| s.parse::()).parse_next(input) +} + +/// Consume a precision definition in some types, e.g. `FLOAT(32)`. +/// +/// The precision must be in the given range. +pub fn precision_in_range( + range: impl RangeBounds + std::fmt::Debug, +) -> impl Parser +where + S: TokenStream, +{ + #[derive(Debug, thiserror::Error)] + #[error("Precision must be in range {0}")] + struct OutOfRange(String); + + delimited( + Token::LParen, + cut_err(literal_uint.try_map(move |v| { + if range.contains(&v) { + Ok(v) + } else { + Err(OutOfRange(format!("{:?}", range))) + } + })), + cut_err(Token::RParen), + ) +} diff --git a/src/sqlparser/tests/testdata/array.yaml b/src/sqlparser/tests/testdata/array.yaml index 9af94c041fdcb..24d1ff8031658 100644 --- a/src/sqlparser/tests/testdata/array.yaml +++ b/src/sqlparser/tests/testdata/array.yaml @@ -6,13 +6,9 @@ - input: CREATE TABLE t(a int[][][]); formatted_sql: CREATE TABLE t (a INT[][][]) - input: CREATE TABLE t(a int[); - error_msg: |- - sql parser error: Expected ], found: ) at line:1, column:23 - Near "CREATE TABLE t(a int[" + error_msg: 'sql parser error: Unexpected ) at line:1, column:23' - input: CREATE TABLE t(a int[[]); - error_msg: |- - sql parser error: Expected ], found: [ at line:1, column:23 - Near "CREATE TABLE t(a int[" + error_msg: 'sql parser error: Unexpected [ at line:1, column:23' - input: CREATE TABLE t(a int]); error_msg: |- sql parser error: Expected ',' or ')' after column definition, found: ] at line:1, column:22 diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index b8bcfe18e87e4..c2edc6391367a 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -120,9 +120,9 @@ - input: SELECT 0x error_msg: 'sql parser error: incomplete integer literal at Line: 1, Column 8' - input: SELECT 1::float(0) - error_msg: 'sql parser error: precision for type float must be at least 1 bit' + error_msg: 'sql parser error: Unexpected 0 at line:1, column:17: Precision must be in range 1..54' - input: SELECT 1::float(54) - error_msg: 'sql parser error: precision for type float must be less than 54 bits' + error_msg: 'sql parser error: Unexpected 54 at line:1, column:18: Precision must be in range 1..54' - input: SELECT 1::int(2) error_msg: |- sql parser error: Expected end of statement, found: ( at line:1, column:14 diff --git a/src/sqlparser/tests/testdata/struct.yaml b/src/sqlparser/tests/testdata/struct.yaml index 4898714fc5858..92b7e533b2428 100644 --- a/src/sqlparser/tests/testdata/struct.yaml +++ b/src/sqlparser/tests/testdata/struct.yaml @@ -3,3 +3,9 @@ formatted_sql: SELECT CAST(ROW(1 * 2, 1.0) AS foo) - input: SELECT ROW(1 * 2, 1.0)::foo; formatted_sql: SELECT CAST(ROW(1 * 2, 1.0) AS foo) +- input: SELECT NULL::STRUCT + formatted_sql: SELECT CAST(NULL AS STRUCT) +- input: create table st (v1 int, v2 struct>, v3 struct>) + formatted_sql: CREATE TABLE st (v1 INT, v2 STRUCT>, v3 STRUCT>) +- input: SELECT NULL::STRUCT> + error_msg: 'sql parser error: Unexpected EOF: Unconsumed `>>`'