-
Notifications
You must be signed in to change notification settings - Fork 21
/
database.py
150 lines (102 loc) · 3.83 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sqlite3
import numpy as np
from config import DB_PATH, DATA_TYPES
class Base_DB(object):
def __init__(self, sql) -> None:
self.db_path = DB_PATH
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute(sql)
conn.commit()
conn.close()
def insert_data(self, data_list, data_type):
if data_list is None:
return
assert data_type in DATA_TYPES
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
inserted_ids = []
for data in data_list:
data['embedding'] = data['embedding'].tobytes()
columns = ', '.join(data.keys())
placeholders = ', '.join(['?' for _ in data])
values = tuple(data.values())
query = f'INSERT INTO {data_type}_data ({columns}) VALUES ({placeholders})'
c.execute(query, values)
conn.commit()
last_inserted_id = c.lastrowid
inserted_ids.append(last_inserted_id)
data['embedding'] = np.frombuffer(data['embedding'], dtype=np.float32)
conn.close()
def delete_data(self, path, data_type, is_directory=False):
assert data_type in DATA_TYPES
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
if is_directory:
if not path.endswith('/'):
path = path + '/'
query = f'DELETE FROM {data_type}_data WHERE file_path LIKE ?'
match_value = path + '%'
else:
query = f'DELETE FROM {data_type}_data WHERE file_path=?'
match_value = path
c.execute(query, (match_value,))
conn.commit()
c.execute(f'SELECT id, embedding FROM {data_type}_data')
remaining_data = c.fetchall()
remaining_ids = [row[0] for row in remaining_data]
remaining_embeddings = [np.frombuffer(row[1], dtype=np.float32) for row in remaining_data]
conn.close()
return remaining_embeddings, remaining_ids
def get_existing_file_paths(self, data_type):
assert data_type in DATA_TYPES
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute(f'SELECT file_path FROM {data_type}_data')
file_paths = c.fetchall()
file_path_set = set(file_path[0] for file_path in file_paths)
conn.close()
return file_path_set
def retrieve_data(self, data_type, indices=None, query=None):
assert data_type in DATA_TYPES
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
if query:
c.execute(query)
else:
indices = indices.tolist()
placeholders = ','.join('?' * len(indices))
query = f'SELECT * FROM {data_type}_data WHERE id IN ({placeholders})'
c.execute(query, tuple(indices))
rows = c.fetchall()
column_names = [desc[0] for desc in c.description]
conn.close()
return column_names, rows
def close(self):
pass
class Text_DB(Base_DB):
def __init__(self) -> None:
sql = '''
CREATE TABLE IF NOT EXISTS text_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT,
author TEXT,
page INTEGER,
file_path TEXT,
subject TEXT,
content TEXT,
embedding BLOB
)
'''
super().__init__(sql)
class Image_DB(Base_DB):
def __init__(self) -> None:
sql = '''
CREATE TABLE IF NOT EXISTS image_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT,
content TEXT,
embedding BLOB
)
'''
super().__init__(sql)