Skip to content

Commit

Permalink
moved MongoClient to fw_config.py; pylint & flakes
Browse files Browse the repository at this point in the history
  • Loading branch information
ikondov committed Feb 20, 2024
1 parent a4af47c commit c56dad0
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 36 deletions.
19 changes: 4 additions & 15 deletions fireworks/core/launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,17 @@
from bson import ObjectId
from monty.os.path import zpath
from monty.serialization import loadfn
import pymongo
from pymongo import ASCENDING, DESCENDING
from pymongo.errors import DocumentTooLarge
import mongomock
import mongomock.gridfs
from tqdm import tqdm

from fireworks.core.firework import Firework, FWAction, Launch, Tracker, Workflow
from fireworks.fw_config import MongoClient
from fireworks.fw_config import (
GRIDFS_FALLBACK_COLLECTION,
LAUNCHPAD_LOC,
MAINTAIN_INTERVAL,
MONGO_SOCKET_TIMEOUT_MS,
MONGOMOCK_SERVERSTORE_FILE,
RESERVATION_EXPIRATION_SECS,
RUN_EXPIRATION_SECS,
SORT_FWS,
Expand Down Expand Up @@ -199,22 +196,14 @@ def __init__(
self.user_indices = user_indices if user_indices else []
self.wf_user_indices = wf_user_indices if wf_user_indices else []

if MONGOMOCK_SERVERSTORE_FILE:
os.environ['MONGOMOCK_SERVERSTORE_FILE'] = MONGOMOCK_SERVERSTORE_FILE
mongoclient_cls = getattr(mongomock, 'MongoClient')
if GRIDFS_FALLBACK_COLLECTION:
mongomock.gridfs.enable_gridfs_integration()
else:
mongoclient_cls = getattr(pymongo, 'MongoClient')

# get connection
if uri_mode:
self.connection = mongoclient_cls(host, **self.mongoclient_kwargs)
self.connection = MongoClient(host, **self.mongoclient_kwargs)
if self.name is None:
raise ValueError("Must specify a database name when using a MongoDB URI string.")
self.db = self.connection[self.name]
else:
self.connection = mongoclient_cls(
self.connection = MongoClient(
self.host,
self.port,
socketTimeoutMS=MONGO_SOCKET_TIMEOUT_MS,
Expand Down Expand Up @@ -423,7 +412,7 @@ def bulk_add_wfs(self, wfs):
"""
# Make all fireworks workflows
wfs = [Workflow.from_firework(wf) if isinstance(wf, Firework) else wf for wf in wfs]
wfs = [Workflow.from_Firework(wf) if isinstance(wf, Firework) else wf for wf in wfs]

# Initialize new firework counter, starting from the next fw id
total_num_fws = sum(len(wf) for wf in wfs)
Expand Down
3 changes: 1 addition & 2 deletions fireworks/core/tests/test_launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import pytest
from monty.os import cd
from mongomock import MongoClient
from pymongo import __version__ as PYMONGO_VERSION
from pymongo.errors import OperationFailure

Expand Down Expand Up @@ -43,7 +42,7 @@ class AuthenticationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
try:
client = MongoClient()
client = fireworks.fw_config.MongoClient()
client.not_the_admin_db.command("createUser", "myuser", pwd="mypassword", roles=["dbOwner"])
except Exception:
raise unittest.SkipTest("MongoDB is not running in localhost:27017! Skipping tests.")
Expand Down
15 changes: 15 additions & 0 deletions fireworks/fw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from monty.design_patterns import singleton
from monty.serialization import dumpfn, loadfn
import pymongo
import mongomock
import mongomock.gridfs

__author__ = "Anubhav Jain"
__copyright__ = "Copyright 2012, The Materials Project"
Expand Down Expand Up @@ -105,6 +108,9 @@
# path to a database file to use with mongomock, do not use mongomock if None
MONGOMOCK_SERVERSTORE_FILE = None

# default mongoclient class
MongoClient = pymongo.MongoClient


def override_user_settings() -> None:
module_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -156,6 +162,15 @@ def override_user_settings() -> None:
if len(m_paths) > 0:
globals()[k] = m_paths[0]

if 'MONGOMOCK_SERVERSTORE_FILE' in os.environ:
globals()['MONGOMOCK_SERVERSTORE_FILE'] = os.environ['MONGOMOCK_SERVERSTORE_FILE']
if globals()['MONGOMOCK_SERVERSTORE_FILE']:
if not os.environ.get('MONGOMOCK_SERVERSTORE_FILE'):
os.environ['MONGOMOCK_SERVERSTORE_FILE'] = globals()['MONGOMOCK_SERVERSTORE_FILE']
globals()['MongoClient'] = getattr(mongomock, 'MongoClient')
if globals()['GRIDFS_FALLBACK_COLLECTION']:
mongomock.gridfs.enable_gridfs_integration()


override_user_settings()

Expand Down
16 changes: 6 additions & 10 deletions fireworks/user_objects/firetasks/filepad_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os

import json
from glob import glob
from pymongo import DESCENDING
from ruamel.yaml import YAML
from fireworks.core.firework import FiretaskBase
from fireworks.utilities.filepad import FilePad
from fireworks.utilities.dict_mods import arrow_to_dot

__author__ = "Kiran Mathew, Johannes Hoermann"
__email__ = "[email protected], [email protected]"
Expand All @@ -28,7 +32,6 @@ class AddFilesTask(FiretaskBase):
optional_params = ["identifiers", "directory", "filepad_file", "compress", "metadata"]

def run_task(self, fw_spec):
from glob import glob

directory = os.path.abspath(self.get("directory", "."))

Expand Down Expand Up @@ -143,19 +146,12 @@ class GetFilesByQueryTask(FiretaskBase):
]

def run_task(self, fw_spec):
import json

import pymongo
from ruamel.yaml import YAML

from fireworks.utilities.dict_mods import arrow_to_dot

fpad = get_fpad(self.get("filepad_file", None))
dest_dir = self.get("dest_dir", os.path.abspath("."))
new_file_names = self.get("new_file_names", [])
query = self.get("query", {})
sort_key = self.get("sort_key", None)
sort_direction = self.get("sort_direction", pymongo.DESCENDING)
sort_direction = self.get("sort_direction", DESCENDING)
limit = self.get("limit", None)
fizzle_empty_result = self.get("fizzle_empty_result", True)
fizzle_degenerate_file_name = self.get("fizzle_degenerate_file_name", True)
Expand Down
13 changes: 4 additions & 9 deletions fireworks/utilities/filepad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@
import zlib

import gridfs
import pymongo
from pymongo import DESCENDING
from monty.json import MSONable
from monty.serialization import loadfn
from bson.objectid import ObjectId

import mongomock.gridfs
from mongomock import MongoClient

from fireworks.fw_config import MongoClient
from fireworks.fw_config import LAUNCHPAD_LOC, MONGO_SOCKET_TIMEOUT_MS
from fireworks.utilities.fw_utilities import get_fw_logger

__author__ = "Kiran Mathew"
__email__ = "[email protected]"
__credits__ = "Anubhav Jain"

mongomock.gridfs.enable_gridfs_integration()


class FilePad(MSONable):
def __init__(
Expand Down Expand Up @@ -180,7 +177,7 @@ def get_file_by_id(self, gfs_id):
doc = self.filepad.find_one({"gfs_id": gfs_id})
return self._get_file_contents(doc)

def get_file_by_query(self, query, sort_key=None, sort_direction=pymongo.DESCENDING):
def get_file_by_query(self, query, sort_key=None, sort_direction=DESCENDING):
"""
Args:
Expand Down Expand Up @@ -293,8 +290,6 @@ def _get_file_contents(self, doc):
Returns:
(str, dict): the file content as a string, document dictionary
"""
from bson.objectid import ObjectId

if doc:
gfs_id = doc["gfs_id"]
file_contents = self.gridfs.get(ObjectId(gfs_id)).read()
Expand Down

0 comments on commit c56dad0

Please sign in to comment.