-
Notifications
You must be signed in to change notification settings - Fork 0
/
csv_script.py
95 lines (72 loc) · 2.99 KB
/
csv_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import sys
import os
import pandas as pd
import numpy as np
import mysql.connector
from dotenv import load_dotenv
class DB:
def __init__(self):
# load config for database
load_dotenv()
try:
# initialize mysql client
self.conn = mysql.connector.connect(
host=os.environ.get('MYSQL_HOST'),
port=int(os.environ.get('MYSQL_PORT')),
user=os.environ.get('MYSQL_USER'),
password=os.environ.get('MYSQL_PASSWORD'),
database=os.environ.get('MYSQL_DB'),
)
self.database_name = os.environ.get('MYSQL_DB')
self.db_schema = {}
self.get_tables_columns()
except Exception as e:
print(e)
def get_tables_columns(self):
# get table names
query = 'SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = "BASE TABLE" AND TABLE_SCHEMA="' + self.database_name + '";'
cursor = self.conn.cursor()
cursor.execute(query)
table_names = [tbl[0] for tbl in list(cursor.fetchall())]
# get column names
for table in table_names:
query = 'SELECT COLUMN_NAME,DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="' + self.database_name + '" AND TABLE_NAME = "' + table + '";'
cursor = self.conn.cursor()
cursor.execute(query)
self.db_schema[table] = list(cursor.fetchall())
cursor.close()
def read_file(self,path):
if path.endswith('.csv'):
df = pd.read_csv(path)
elif path.endswith('.xlsx'):
df = pd.read_excel(path)
else:
raise Exception('File format not supported')
table_name = ""
# match column names
for table,columns in self.db_schema.items():
column_names = [tpl[0] for tpl in columns]
matches = [True if col_name in column_names else False for col_name in df.columns]
if all(matches):
table_name = table
break
return df,table_name
def convert_datatypes(self, ls):
for i in range(len(ls)):
if type(ls[i]) == np.int64:
ls[i] = int(ls[i])
return ls
def insert_csv(self,path):
df,table_name = self.read_file(path)
records = [tuple(self.convert_datatypes(list(x))) for x in df.to_records(index=False)]
columns_str = '(' + ','.join(df.columns) + ')'
identifiers_str = '(' + ','.join(['%s' for _ in df.columns]) + ')'
query = 'INSERT IGNORE INTO ' + self.database_name + '.' + table_name + ' ' + columns_str + ' VALUES ' + identifiers_str
cursor = self.conn.cursor()
cursor.executemany(query,records)
self.conn.commit()
cursor.close()
print('INSERTED SUCCESSFULLY!')
if __name__ == "__main__":
db = DB()
db.insert_csv(sys.argv[1])