forked from microsoft/promptflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_connection.py
132 lines (122 loc) · 5.83 KB
/
test_connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import uuid
import pydash
import pytest
from _constants import PROMPTFLOW_ROOT
from mock import mock
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._pf_client import PFClient
from promptflow._sdk.entities import AzureOpenAIConnection, CustomConnection, OpenAIConnection
from promptflow.constants import ConnectionDefaultApiVersion
TEST_ROOT = PROMPTFLOW_ROOT / "tests"
CONNECTION_ROOT = TEST_ROOT / "test_configs/connections"
_client = PFClient()
@pytest.mark.cli_test
@pytest.mark.e2etest
class TestConnection:
def test_connection_operations(self):
name = f"Connection_{str(uuid.uuid4())[:4]}"
conn = AzureOpenAIConnection(name=name, api_key="test", api_base="test")
# Create
_client.connections.create_or_update(conn)
# Get
result = _client.connections.get(name)
assert pydash.omit(result._to_dict(), ["created_date", "last_modified_date", "name"]) == {
"module": "promptflow.connections",
"type": "azure_open_ai",
"api_key": "test", # get return real key now
"auth_mode": "key",
"api_base": "test",
"api_type": "azure",
"api_version": ConnectionDefaultApiVersion.AZURE_OPEN_AI,
}
# Update
conn.api_base = "test2"
result = _client.connections.create_or_update(conn)
assert pydash.omit(result._to_dict(), ["created_date", "last_modified_date", "name"]) == {
"module": "promptflow.connections",
"type": "azure_open_ai",
"api_key": "test", # get return real key now
"auth_mode": "key",
"api_base": "test2",
"api_type": "azure",
"api_version": ConnectionDefaultApiVersion.AZURE_OPEN_AI,
}
# List
result = _client.connections.list()
assert len(result) > 0
# Delete
_client.connections.delete(name)
with pytest.raises(Exception) as e:
_client.connections.get(name)
assert "is not found." in str(e.value)
def test_connection_get_and_update(self):
# Test api key not updated
name = f"Connection_{str(uuid.uuid4())[:4]}"
conn = AzureOpenAIConnection(name=name, api_key="test_key", api_base="test")
result = _client.connections.create_or_update(conn)
assert result.api_key == "test_key"
assert "test_key" not in str(result) # Assert key scrubbed when print
# Update api_base only Assert no exception
result.api_base = "test2"
result = _client.connections.create_or_update(result)
assert result._to_dict()["api_base"] == "test2"
# Assert value not scrubbed
assert result._secrets["api_key"] == "test_key"
_client.connections.delete(name)
# Invalid update
with pytest.raises(Exception) as e:
result._secrets = {}
result.secrets["api_key"] = "****"
_client.connections.create_or_update(result)
assert "secrets ['api_key'] value invalid, please fill them" in str(e.value)
def test_custom_connection_get_and_update(self):
# Test api key not updated
name = f"Connection_{str(uuid.uuid4())[:4]}"
conn = CustomConnection(name=name, secrets={"api_key": "test_key"}, configs={"api_base": "test"})
result = _client.connections.create_or_update(conn)
assert "test_key" not in str(result) # Assert key scrubbed when print
assert result.secrets["api_key"] == "test_key"
# Update api_base only Assert no exception
result.configs["api_base"] = "test2"
result = _client.connections.create_or_update(result)
assert result._to_dict()["configs"]["api_base"] == "test2"
# Assert value not scrubbed
assert result._secrets["api_key"] == "test_key"
_client.connections.delete(name)
# Invalid update
with pytest.raises(Exception) as e:
result._secrets = {}
result.secrets["api_key"] = "****"
_client.connections.create_or_update(result)
assert "secrets ['api_key'] value invalid, please fill them" in str(e.value)
@pytest.mark.parametrize(
"file_name, expected_updated_item, expected_secret_item",
[
("azure_openai_connection.yaml", ("api_base", "new_value"), ("api_key", "<to-be-replaced>")),
("custom_connection.yaml", ("key1", "new_value"), ("key2", "test2")),
],
)
def test_upsert_connection_from_file(self, file_name, expected_updated_item, expected_secret_item):
from promptflow._cli._pf._connection import _upsert_connection_from_file
name = f"Connection_{str(uuid.uuid4())[:4]}"
result = _upsert_connection_from_file(file=CONNECTION_ROOT / file_name, params_override=[{"name": name}])
assert result is not None
update_file_name = f"update_{file_name}"
result = _upsert_connection_from_file(file=CONNECTION_ROOT / update_file_name, params_override=[{"name": name}])
# Test secrets not updated, and configs updated
assert (
result.configs[expected_updated_item[0]] == expected_updated_item[1]
), "Assert configs updated failed, expected: {}, actual: {}".format(
expected_updated_item[1], result.configs[expected_updated_item[0]]
)
assert (
result._secrets[expected_secret_item[0]] == expected_secret_item[1]
), "Assert secrets not updated failed, expected: {}, actual: {}".format(
expected_secret_item[1], result._secrets[expected_secret_item[0]]
)
def test_create_connection_no_name(self):
with mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}):
connection = OpenAIConnection.from_env()
with pytest.raises(ConnectionNameNotSetError):
_client.connections.create_or_update(connection)