diff --git a/importer/datasetImporter.py b/importer/datasetImporter.py index d9380fd..9ddb8de 100644 --- a/importer/datasetImporter.py +++ b/importer/datasetImporter.py @@ -6,12 +6,18 @@ class DatasetImporter: - def __init__(self, filename): - content = open(filename, 'r').readlines() - self.repos, self.target = zip(*[line.strip().split(',') for line in content]) + def __init__(self, filename, complete_set=False): + if complete_set: + df = pd.read_csv(filename) + self.repos = df['repo'] + self.target = df['y'] + self.data = df.iloc[:,3:] + else: + content = open(filename, 'r').readlines() + self.repos, self.target = zip(*[line.strip().split(',') for line in content]) - self.target = np.array(self.target) - self.data = self.get_data(self.repos) + self.target = np.array(self.target) + self.data = self.get_data(self.repos) @staticmethod def get_data(repo_links): diff --git a/main.py b/main.py index 7204d93..d209fef 100644 --- a/main.py +++ b/main.py @@ -161,10 +161,12 @@ def trainAndPredict(repos): ]) #train the classifier - importer = DatasetImporter('data/testset.csv') + #importer = DatasetImporter('data/testset.csv') + importer = DatasetImporter('enriched_data.csv', complete_set=True) classifier.fit(logarithmitize(importer.data), importer.target) # predict gives repositories + repos = [repo.strip() for repo in repos if repo.strip() != ''] prediction = classifier.predict(logarithmitize(DatasetImporter.get_data(repos))) for repo, category in zip(repos, prediction): print(repo + ', ' + category) diff --git a/requirements.txt b/requirements.txt index 101501e..367ff82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ pygithub numpy pandas +scipy scikit-learn entropy -scikit-learn