Skip to content

Commit

Permalink
SONARPY-2415 Add stubs for pymsql, mysql and pgdb
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-dequenne-sonarsource committed Dec 6, 2024
1 parent 63e9286 commit e738b08
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ public class DbNoPasswordCheck extends PythonSubscriptionCheck {
private static final List<String> CONNECT_METHODS = Arrays.asList(
"mysql.connector.connect",
"mysql.connector.connection.MySQLConnection",
"pymysql.connect",
"pymysql.connections.connect",
"psycopg2.connect",
"pgdb.connect",
"pgdb.connect.connect",
"pg.DB",
"pg.connect"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ private Map<String, Integer> sensitiveArgumentByFQN() {
sensitiveArgumentByFQN.put("mysql.connector.connect", 2);
sensitiveArgumentByFQN.put("mysql.connector.connection.MySQLConnection", 2);
sensitiveArgumentByFQN.put("pymysql.connect", 2);
sensitiveArgumentByFQN.put("pymysql.connections.connect", 2);
sensitiveArgumentByFQN.put("pymysql.connections.Connection", 2);
sensitiveArgumentByFQN.put("psycopg2.connect", 2);
sensitiveArgumentByFQN.put("pgdb.connect", 2);
sensitiveArgumentByFQN.put("pgdb.connect.connect", 2);
sensitiveArgumentByFQN.put("pg.DB", 5);
sensitiveArgumentByFQN.put("pg.connect", 5);
sensitiveArgumentByFQN = Collections.unmodifiableMap(sensitiveArgumentByFQN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ void shouldResolvePackages() {
assertThat(provider.descriptorsForModule("kazoo")).isEmpty();
}

@Test
void customDbStubs() {
var provider = typeshedDescriptorsProvider();
var pgdb = provider.descriptorsForModule("pgdb");
assertThat(pgdb.get("connect")).isInstanceOf(FunctionDescriptor.class);

var mysql = provider.descriptorsForModule("mysql.connector");
assertThat(mysql.get("connect")).isInstanceOf(FunctionDescriptor.class);

var pymysql = provider.descriptorsForModule("pymysql");
assertThat(pymysql.get("connect")).isInstanceOf(FunctionDescriptor.class);
}

@Test
void unknownModule() {
var provider = typeshedDescriptorsProvider();
Expand Down
128 changes: 70 additions & 58 deletions python-frontend/src/test/java/org/sonar/python/types/TypeShedTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ void package_lxml_reexported_symbol_fqn() {
@Test
void package_sqlite3_connect_type_in_ambiguous_symbol() {
Map<String, Symbol> sqlite3Symbols = symbolsForModule("sqlite3");
ClassSymbol connectionSymbol = (ClassSymbol) sqlite3Symbols.get("Connection");
ClassSymbol connectionSymbol = (ClassSymbol) sqlite3Symbols.get("Connection");
AmbiguousSymbol cursorFunction = connectionSymbol.declaredMembers().stream().filter(m -> "cursor".equals(m.name())).findFirst().map(AmbiguousSymbol.class::cast).get();
Set<Symbol> alternatives = cursorFunction.alternatives();
assertThat(alternatives)
Expand Down Expand Up @@ -378,21 +378,21 @@ void deserialize_nonexistent_or_incorrect_protobuf() {
void class_symbols_from_protobuf() throws TextFormat.ParseException {
SymbolsProtos.ModuleSymbol moduleSymbol = moduleSymbol(
"fully_qualified_name: \"mod\"\n" +
"classes {\n" +
" name: \"Base\"\n" +
" fully_qualified_name: \"mod.Base\"\n" +
" super_classes: \"builtins.object\"\n" +
"}\n" +
"classes {\n" +
" name: \"C\"\n" +
" fully_qualified_name: \"mod.C\"\n" +
" super_classes: \"builtins.str\"\n" +
"}\n" +
"classes {\n" +
" name: \"D\"\n" +
" fully_qualified_name: \"mod.D\"\n" +
" super_classes: \"NOT_EXISTENT\"\n" +
"}");
"classes {\n" +
" name: \"Base\"\n" +
" fully_qualified_name: \"mod.Base\"\n" +
" super_classes: \"builtins.object\"\n" +
"}\n" +
"classes {\n" +
" name: \"C\"\n" +
" fully_qualified_name: \"mod.C\"\n" +
" super_classes: \"builtins.str\"\n" +
"}\n" +
"classes {\n" +
" name: \"D\"\n" +
" fully_qualified_name: \"mod.D\"\n" +
" super_classes: \"NOT_EXISTENT\"\n" +
"}");
Map<String, Symbol> symbols = TypeShed.getSymbolsFromProtobufModule(moduleSymbol);
assertThat(symbols.values()).extracting(Symbol::kind, Symbol::fullyQualifiedName)
.containsExactlyInAnyOrder(tuple(Kind.CLASS, "mod.Base"), tuple(Kind.CLASS, "mod.C"), tuple(Kind.CLASS, "mod.D"));
Expand All @@ -408,44 +408,44 @@ void class_symbols_from_protobuf() throws TextFormat.ParseException {
void function_symbols_from_protobuf() throws TextFormat.ParseException {
SymbolsProtos.ModuleSymbol moduleSymbol = moduleSymbol(
"fully_qualified_name: \"mod\"\n" +
"functions {\n" +
" name: \"foo\"\n" +
" fully_qualified_name: \"mod.foo\"\n" +
" parameters {\n" +
" name: \"p\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
"}\n" +
"overloaded_functions {\n" +
" name: \"bar\"\n" +
" fullname: \"mod.bar\"\n" +
" definitions {\n" +
" name: \"bar\"\n" +
" fully_qualified_name: \"mod.bar\"\n" +
" parameters {\n" +
" name: \"x\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" has_decorators: true\n" +
" resolved_decorator_names: \"typing.overload\"\n" +
" is_overload: true\n" +
" }\n" +
" definitions {\n" +
" name: \"bar\"\n" +
" fully_qualified_name: \"mod.bar\"\n" +
" parameters {\n" +
" name: \"x\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" parameters {\n" +
" name: \"y\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" has_decorators: true\n" +
" resolved_decorator_names: \"typing.overload\"\n" +
" is_overload: true\n" +
" }\n" +
"}\n");
"functions {\n" +
" name: \"foo\"\n" +
" fully_qualified_name: \"mod.foo\"\n" +
" parameters {\n" +
" name: \"p\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
"}\n" +
"overloaded_functions {\n" +
" name: \"bar\"\n" +
" fullname: \"mod.bar\"\n" +
" definitions {\n" +
" name: \"bar\"\n" +
" fully_qualified_name: \"mod.bar\"\n" +
" parameters {\n" +
" name: \"x\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" has_decorators: true\n" +
" resolved_decorator_names: \"typing.overload\"\n" +
" is_overload: true\n" +
" }\n" +
" definitions {\n" +
" name: \"bar\"\n" +
" fully_qualified_name: \"mod.bar\"\n" +
" parameters {\n" +
" name: \"x\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" parameters {\n" +
" name: \"y\"\n" +
" kind: POSITIONAL_OR_KEYWORD\n" +
" }\n" +
" has_decorators: true\n" +
" resolved_decorator_names: \"typing.overload\"\n" +
" is_overload: true\n" +
" }\n" +
"}\n");
Map<String, Symbol> symbols = TypeShed.getSymbolsFromProtobufModule(moduleSymbol);
assertThat(symbols.values()).extracting(Symbol::kind, Symbol::fullyQualifiedName)
.containsExactlyInAnyOrder(tuple(Kind.FUNCTION, "mod.foo"), tuple(Kind.AMBIGUOUS, "mod.bar"));
Expand All @@ -457,7 +457,7 @@ void function_symbols_from_protobuf() throws TextFormat.ParseException {
@Test
void pythonVersions() {
Symbol range = TypeShed.builtinSymbols().get("range");
assertThat(((SymbolImpl) range).validForPythonVersions()).containsExactlyInAnyOrder( "38", "39", "310", "311", "312", "313");
assertThat(((SymbolImpl) range).validForPythonVersions()).containsExactlyInAnyOrder("38", "39", "310", "311", "312", "313");
assertThat(range.kind()).isEqualTo(Kind.CLASS);

// python 2
Expand Down Expand Up @@ -524,7 +524,7 @@ private static SymbolsProtos.ModuleSymbol moduleSymbol(String protobuf) throws T
@Test
void variables_from_protobuf() throws TextFormat.ParseException {
SymbolsProtos.ModuleSymbol moduleSymbol = moduleSymbol(
"fully_qualified_name: \"mod\"\n" +
"fully_qualified_name: \"mod\"\n" +
"vars {\n" +
" name: \"foo\"\n" +
" fully_qualified_name: \"mod.foo\"\n" +
Expand Down Expand Up @@ -564,15 +564,15 @@ void symbol_from_submodule_access() {
}

@Test
void typeshed_private_modules_should_not_affect_fqn() {
void typeshed_private_modules_should_not_affect_fqn() {
Map<String, Symbol> socketModule = symbolsForModule("socket");
ClassSymbol socket = (ClassSymbol) socketModule.get("socket");
assertThat(socket.declaredMembers()).extracting(Symbol::name, Symbol::fullyQualifiedName).contains(tuple("connect", "socket.socket.connect"));
assertThat(socket.superClasses()).extracting(Symbol::fullyQualifiedName).containsExactly("object");
}

@Test
void overloaded_function_alias_has_function_annotated_type() {
void overloaded_function_alias_has_function_annotated_type() {
Map<String, Symbol> gettextModule = symbolsForModule("gettext");
Symbol translation = gettextModule.get("translation");
Symbol catalog = gettextModule.get("Catalog");
Expand All @@ -587,4 +587,16 @@ void stubFilesSymbols_third_party_symbols_should_not_be_null() {
symbolsForModule("six");
assertThat(TypeShed.stubFilesSymbols()).doesNotContainNull();
}

@Test
void customDbStubs() {
var pgdb = symbolsForModule("pgdb");
assertThat(pgdb.get("connect")).isInstanceOf(FunctionSymbol.class);

var mysql = symbolsForModule("mysql.connector");
assertThat(mysql.get("connect")).isInstanceOf(FunctionSymbol.class);

var pymysql = symbolsForModule("pymysql");
assertThat(pymysql.get("connect")).isInstanceOf(FunctionSymbol.class);
}
}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .connection import MySQLConnection
from .cursor import MySQLCursor

def connect(dsn: str | None = None,
user: str | None = None, password: str | None = None,
host: str | None = None, database: str | None = None,
**kwargs: Any) -> MySQLConnection:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cursor import MySQLCursor
class MySQLConnection:
def cursor(self) -> MySQLCursor:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Sequence, Union
class MySQLCursor:
def execute(self, operation: str, parameters: Union[Sequence, None] = None
): # should return Cursor
...

def executemany(self, operation: str,
seq_of_parameters: Sequence[Union[Sequence, None]]): # should return Cursor
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .connect import connect
from .connection import Connection
from .cursor import Cursor
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .connection import Connection


def connect(dsn: str | None = None,
user: str | None = None, password: str | None = None,
host: str | None = None, database: str | None = None,
**kwargs: Any) -> Connection:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cursor import Cursor
class Connection:
def cursor(self) -> Cursor:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Sequence, Union
class Cursor:
def execute(self, operation: str, parameters: Union[Sequence, None] = None
): # should return Cursor
...

def executemany(self, operation: str,
seq_of_parameters: Sequence[Union[Sequence, None]]): # should return Cursor
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .connections import Connection, connect
from .cursors import Cursor
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .cursors import Cursor

class Connection:
def cursor(self) -> Cursor:
...

def connect(dsn: str | None = None,
user: str | None = None, password: str | None = None,
host: str | None = None, database: str | None = None,
**kwargs: Any) -> Connection:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Sequence, Union
class Cursor:
def execute(self, operation: str, parameters: Union[Sequence, None] = None
): # should return Cursor
...

def executemany(self, operation: str,
seq_of_parameters: Sequence[Union[Sequence, None]]): # should return Cursor
...
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_custom_stubs_serializer(typeshed_custom_stubs):
custom_stubs_serializer.serialize()
assert custom_stubs_serializer.get_build_result.call_count == 1
# Not every files from "typeshed_custom_stubs" build are serialized, as some are builtins
assert symbols.save_module.call_count == 146
assert symbols.save_module.call_count == 157


def test_importer_serializer():
Expand Down

0 comments on commit e738b08

Please sign in to comment.