diff --git a/simple/stats/config.py b/simple/stats/config.py index 8e8d034e..d8aa2084 100644 --- a/simple/stats/config.py +++ b/simple/stats/config.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re - from stats import constants from stats.data import AggregationConfig from stats.data import EntityType @@ -23,6 +21,8 @@ from stats.data import Provenance from stats.data import Source from stats.data import StatVar +from util.file_match import match +from util.filesystem import File _INPUT_FILES_FIELD = "inputFiles" _IMPORT_TYPE_FIELD = "importType" @@ -55,6 +55,7 @@ _GROUP_STAT_VARS_BY_PROPERTY = "groupStatVarsByProperty" _GENERATE_TOPICS = "generateTopics" _OBSERVATION_PROPERTIES = "observationProperties" +_INCLUDE_INPUT_SUBDIRS_PROPERTY = "includeInputSubdirs" class Config: @@ -67,11 +68,12 @@ def __init__(self, data: dict) -> None: self.data = data self._input_files_config: dict[str, dict] = self.data.get( _INPUT_FILES_FIELD, {}) - # If input file names are specified with wildcards - e.g. "foo*.csv", - # this dict maintains a mapping from actual file name to the wildcard key + # If input file paths are specified with wildcards - e.g. "gs://bucket/foo*.csv", + # this dict maintains a mapping from actual file path to the wildcard key # for fast lookup. - # e.g. "foo1.csv" -> "foo*.csv", "foo2.csv" -> "foo*.csv", etc. - self._input_file_name_keys: dict[str, str] = {} + # e.g. "foo1.csv" -> "foo*.csv", "foo2.csv" -> "foo*.csv", + # "path/to/foo.csv" -> "**/foo.csv, etc. + self._config_key_by_full_path: dict[str, str] = {} # dict from provenance name to Provenance self.provenances: dict[str, Provenance] = {} # dict from provenance name to Source @@ -87,28 +89,29 @@ def data_download_urls(self) -> list[str]: raise ValueError( f"{_DATA_DOWNLOAD_URL_FIELD} can only be a list, found: {cfg}") - def import_type(self, input_file_name: str) -> ImportType: - import_type_str = self._input_file(input_file_name).get(_IMPORT_TYPE_FIELD) + def import_type(self, input_file: File) -> ImportType: + import_type_str = self._per_file_config(input_file).get(_IMPORT_TYPE_FIELD) if not import_type_str: return ImportType.OBSERVATIONS if import_type_str not in iter(ImportType): raise ValueError( - f"Unsupported import type: {import_type_str} ({input_file_name})") + f"Unsupported import type: {import_type_str} ({input_file.full_path()})" + ) return ImportType(import_type_str) - def format(self, input_file_name: str) -> ImportType | None: - format_str = self._input_file(input_file_name).get(_FORMAT_FIELD) + def format(self, input_file: File) -> ImportType | None: + format_str = self._per_file_config(input_file).get(_FORMAT_FIELD) if not format_str: return None if format_str not in iter(InputFileFormat): - raise ValueError(f"Unsupported format: {format_str} ({input_file_name})") + raise ValueError(f"Unsupported format: {format_str} ({input_file})") return InputFileFormat(format_str) - def column_mappings(self, input_file_name: str) -> dict[str, str]: - return self._input_file(input_file_name).get(_COLUMN_MAPPINGS_FIELD, {}) + def column_mappings(self, input_file: File) -> dict[str, str]: + return self._per_file_config(input_file).get(_COLUMN_MAPPINGS_FIELD, {}) - def computed_variables(self, input_file_name: str) -> list[str]: - return self._input_file(input_file_name).get(_COMPUTED_VARIABLES_FIELD, []) + def computed_variables(self, input_file: File) -> list[str]: + return self._per_file_config(input_file).get(_COMPUTED_VARIABLES_FIELD, []) def variable(self, variable_name: str) -> StatVar: var_cfg = self.data.get(_VARIABLES_FIELD, {}).get(variable_name, {}) @@ -130,8 +133,8 @@ def aggregation(self, variable_name: str) -> AggregationConfig: .get(_AGGREGATION_FIELD, {}) return AggregationConfig(**aggregation_cfg) - def event_type(self, input_file_name: str) -> str: - return self._input_file(input_file_name).get(_EVENT_TYPE_FIELD, "") + def event_type(self, input_file: File) -> str: + return self._per_file_config(input_file).get(_EVENT_TYPE_FIELD, "") def event(self, event_type_name: str) -> EventType: event_type_cfg = self.data.get(_EVENTS_FIELD, {}).get(event_type_name, {}) @@ -146,27 +149,27 @@ def entity(self, entity_type_name: str) -> EntityType: entity_type_cfg.get(_NAME_FIELD, entity_type_name), description=entity_type_cfg.get(_DESCRIPTION_FIELD, "")) - def id_column(self, input_file_name: str) -> str: - return self._input_file(input_file_name).get(_ID_COLUMN_FIELD, "") + def id_column(self, input_file: File) -> str: + return self._per_file_config(input_file).get(_ID_COLUMN_FIELD, "") - def entity_type(self, input_file_name: str) -> str: - return self._input_file(input_file_name).get(_ENTITY_TYPE_FIELD, "") + def entity_type(self, input_file: File) -> str: + return self._per_file_config(input_file).get(_ENTITY_TYPE_FIELD, "") - def ignore_columns(self, input_file_name: str) -> list[str]: - return self._input_file(input_file_name).get(_IGNORE_COLUMNS_FIELD, []) + def ignore_columns(self, input_file: File) -> list[str]: + return self._per_file_config(input_file).get(_IGNORE_COLUMNS_FIELD, []) - def provenance_name(self, input_file_name: str) -> str: - return self._input_file(input_file_name).get(_PROVENANCE_FIELD, - input_file_name) + def provenance_name(self, input_file: File) -> str: + return self._per_file_config(input_file).get(_PROVENANCE_FIELD, + input_file.path) - def row_entity_type(self, input_file_name: str) -> str: - return self._input_file(input_file_name).get(_ROW_ENTITY_TYPE_FIELD, "") + def row_entity_type(self, input_file: File) -> str: + return self._per_file_config(input_file).get(_ROW_ENTITY_TYPE_FIELD, "") - def entity_columns(self, input_file_name: str) -> list[str]: - return self._input_file(input_file_name).get(_ENTITY_COLUMNS, []) + def entity_columns(self, input_file: File) -> list[str]: + return self._per_file_config(input_file).get(_ENTITY_COLUMNS, []) - def observation_properties(self, input_file_name: str) -> dict[str, str]: - return self._input_file(input_file_name).get(_OBSERVATION_PROPERTIES, {}) + def observation_properties(self, input_file: File) -> dict[str, str]: + return self._per_file_config(input_file).get(_OBSERVATION_PROPERTIES, {}) def database(self) -> dict: return self.data.get(_DATABASE_FIELD) @@ -174,45 +177,54 @@ def database(self) -> dict: def generate_hierarchy(self) -> bool: return self.data.get(_GROUP_STAT_VARS_BY_PROPERTY) or False + def include_input_subdirs(self) -> bool: + return self.data.get(_INCLUDE_INPUT_SUBDIRS_PROPERTY) or False + def special_files(self) -> dict[str, str]: special_files: dict[str, str] = {} for special_file_type in constants.SPECIAL_FILE_TYPES: - special_file = self.data.get(special_file_type, "") - if special_file: - special_files[special_file] = special_file_type + special_file_name = self.data.get(special_file_type, "") + if special_file_name: + special_files[special_file_type] = special_file_name return special_files def generate_topics(self) -> bool: return self.data.get(_GENERATE_TOPICS) or False - def _input_file(self, input_file_name: str) -> dict: - # Exact match. - input_file_config = self._input_files_config.get(input_file_name, {}) - if input_file_config: - return input_file_config + def _per_file_config(self, input_file: File) -> dict: + """ Looks up the config for a given file. + + The lookup process is as follows: + - If the file name exactly matches a config key, the config for that key + is returned. + - Else if the file's path relative to the input directory exactly matches + a config key, the config for that key is returned. + - Else, we attempt to match the file with each config key in order, + returning the first matching result. + Matches are checked with the match function in simple/util/file_match.py. + """ + for_exact_name = self._input_files_config.get(input_file.name(), {}) + if for_exact_name: + return for_exact_name + + for_exact_full_path = self._input_files_config.get(input_file.full_path(), + {}) + if for_exact_full_path: + return for_exact_full_path + + if input_file.full_path() not in self._config_key_by_full_path.keys(): + self._config_key_by_full_path[input_file.full_path( + )] = self._find_first_matching_config_key(input_file) - # Wildcard match - if input_file_name not in self._input_file_name_keys.keys(): - self._input_file_name_keys[input_file_name] = self._input_file_name_match( - input_file_name) return self._input_files_config.get( - self._input_file_name_keys[input_file_name], {}) + self._config_key_by_full_path[input_file.full_path()], {}) - def _input_file_name_match(self, input_file_name: str) -> str | None: + def _find_first_matching_config_key(self, input_file: File) -> str | None: for input_file_pattern in self._input_files_config.keys(): - if "*" not in input_file_pattern: - continue - regex = self._input_file_pattern_to_regex(input_file_pattern) - if re.match(regex, input_file_name) is not None: + if match(input_file, input_file_pattern): return input_file_pattern return None - def _input_file_pattern_to_regex(self, input_file_pattern: str) -> str: - """ - Transforms a string of the form "a*b.c" to the regex "a.*b\.c". - """ - return input_file_pattern.replace(".", r"\.").replace("*", ".*") - def _parse_provenances_and_sources(self): sources_cfg = self.data.get(_SOURCES_FIELD, {}) for source_name, source_cfg in sources_cfg.items(): diff --git a/simple/stats/db.py b/simple/stats/db.py index fc40b4e0..06b4c712 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -20,7 +20,6 @@ import logging import os import sqlite3 -import tempfile from typing import Any from google.cloud.sql.connector.connector import Connector @@ -32,8 +31,9 @@ from stats.data import STAT_VAR_GROUP from stats.data import STATISTICAL_VARIABLE from stats.data import Triple -from util.filehandler import create_file_handler -from util.filehandler import is_gcs_path +from util.filesystem import create_store +from util.filesystem import Dir +from util.filesystem import File FIELD_DB_TYPE = "type" FIELD_DB_PARAMS = "params" @@ -41,7 +41,7 @@ TYPE_SQLITE = "sqlite" TYPE_MAIN_DC = "maindc" -SQLITE_DB_FILE_PATH = "dbFilePath" +SQLITE_DB_FILE = "dbFile" CLOUD_MY_SQL_INSTANCE = "instance" CLOUD_MY_SQL_USER = "user" @@ -205,7 +205,7 @@ def insert_triples(self, triples: list[Triple]): pass def insert_observations(self, observations: list[Observation], - input_file_name: str): + input_file: File): pass def insert_key_value(self, key: str, value: str): @@ -236,8 +236,7 @@ def __init__(self, db_params: dict) -> None: assert db_params assert MAIN_DC_OUTPUT_DIR in db_params - self.output_dir_fh = create_file_handler(db_params[MAIN_DC_OUTPUT_DIR], - is_dir=True) + self.output_dir = db_params[MAIN_DC_OUTPUT_DIR] # dcid to node dict self.nodes: dict[str, McfNode] = {} @@ -246,14 +245,19 @@ def insert_triples(self, triples: list[Triple]): self._add_triple(triple) def insert_observations(self, observations: list[Observation], - input_file_name: str): + input_file: File): df = pd.DataFrame(observations) # Drop the provenance and properties columns. # Provenance is specified differently for main dc. # TODO: Include obs properties in main DC output. df = df.drop(columns=["provenance", "properties"]) - self.output_dir_fh.make_file(input_file_name).write_string( - df.to_csv(index=False)) + # Right now, this overwrites any file with the same name, + # so if different input sources have files with the same relative path, + # they will clobber each others output. Treating this as an edge case + # for now since it only affects the main DC case, but we could resolve + # it in the future by allowing input sources to be mapped to output + # locations. + self.output_dir.open_file(input_file.path).write(df.to_csv(index=False)) def insert_import_info(self, status: ImportStatus): # No-op for now. @@ -264,10 +268,10 @@ def commit_and_close(self): filtered = filter(lambda node: node.node_type in MCF_NODE_TYPES_ALLOWLIST, self.nodes.values()) mcf = "\n\n".join(map(lambda node: node.to_mcf(), filtered)) - self.output_dir_fh.make_file(SCHEMA_MCF_FILE_NAME).write_string(mcf) + self.output_dir.open_file(SCHEMA_MCF_FILE_NAME).write(mcf) # TMCF - self.output_dir_fh.make_file(OBSERVATIONS_TMCF_FILE_NAME).write_string( + self.output_dir.open_file(OBSERVATIONS_TMCF_FILE_NAME).write( OBSERVATIONS_TMCF) # Not supported for main DC at this time. @@ -306,7 +310,7 @@ def insert_triples(self, triples: list[Triple]): [triple.db_tuple() for triple in triples]) def insert_observations(self, observations: list[Observation], - input_file_name: str): + input_file: File): logging.info("Writing %s observations to [%s]", len(observations), self.engine) self.num_observations += len(observations) @@ -382,18 +386,25 @@ class SqliteDbEngine(DbEngine): def __init__(self, db_params: dict) -> None: assert db_params - assert SQLITE_DB_FILE_PATH in db_params - - self.db_file_path = db_params[SQLITE_DB_FILE_PATH] - # If file path is a GCS path, we create the DB in a local temp file - # and upload to GCS on commit. - self.local_db_file_path: str = self.db_file_path - if is_gcs_path(self.db_file_path): - self.local_db_file_path = tempfile.NamedTemporaryFile().name - - logging.info("Connecting to SQLite: %s", self.local_db_file_path) - self.connection = sqlite3.connect(self.local_db_file_path) - logging.info("Connected to SQLite: %s", self.local_db_file_path) + assert SQLITE_DB_FILE in db_params + + self.db_temp_store = None + self.db_final_file = None + self.db_file = db_params[SQLITE_DB_FILE] + if self.db_file.syspath() is None: + logging.info("Copying DB to local storage from %s", + self.db_file.full_path()) + # Copy to disk + self.db_temp_store = create_store("temp://") + self.db_final_file = self.db_file + self.db_file = self.db_temp_store.as_dir().open_file( + "local.db", create_if_missing=True) + self.db_final_file.copy_to(self.db_file) + logging.info("Local copy of DB is %s", self.db_file.syspath()) + + logging.info("Connecting to SQLite: %s", self.db_file.full_path()) + self.connection = sqlite3.connect(self.db_file.syspath()) + logging.info("Connected to SQLite: %s", self.db_file.full_path()) self.cursor = self.connection.cursor() @@ -424,7 +435,7 @@ def _create_indexes(self) -> None: logging.info("Index created: %s", index.index_name) def __str__(self) -> str: - return f"{TYPE_SQLITE}: {self.db_file_path}" + return f"{TYPE_SQLITE}: {self.db_file.full_path()}" def init_or_update_tables(self): for statement in _INIT_TABLE_STATEMENTS: @@ -458,13 +469,11 @@ def commit_and_close(self): self._create_indexes() self.connection.commit() self.connection.close() - # Copy file if local and actual DB file paths are different. - if self.local_db_file_path != self.db_file_path: - local_db = create_file_handler(self.local_db_file_path, - is_dir=False).read_bytes() - logging.info("Writing to sqlite db: %s (%s bytes)", - self.local_db_file_path, len(local_db)) - create_file_handler(self.db_file_path, is_dir=False).write_bytes(local_db) + + if self.db_temp_store: + logging.info("Copying temp DB back to permanent storage") + self.db_file.copy_to(self.db_final_file) + self.db_temp_store.close() # Parameters needed to connect to a Cloud SQL instance. @@ -633,24 +642,21 @@ def create_and_update_db(config: dict) -> Db: return SqlDb(config) -def create_sqlite_config(sqlite_db_file_path: str) -> dict: +def create_sqlite_config(sqlite_db_file: File) -> dict: return { FIELD_DB_TYPE: TYPE_SQLITE, FIELD_DB_PARAMS: { - SQLITE_DB_FILE_PATH: sqlite_db_file_path + SQLITE_DB_FILE: sqlite_db_file } } -def create_main_dc_config(output_dir: str) -> dict: +def create_main_dc_config(output_dir: Dir) -> dict: return {FIELD_DB_TYPE: TYPE_MAIN_DC, MAIN_DC_OUTPUT_DIR: output_dir} -def get_sqlite_config_from_env() -> dict | None: - sqlite_db_file_path = os.getenv(ENV_SQLITE_PATH) - if not sqlite_db_file_path: - return None - return create_sqlite_config(sqlite_db_file_path) +def get_sqlite_path_from_env() -> str | None: + return os.getenv(ENV_SQLITE_PATH) def get_cloud_sql_config_from_env() -> dict | None: diff --git a/simple/stats/entities_importer.py b/simple/stats/entities_importer.py index 03142b64..06eb833d 100644 --- a/simple/stats/entities_importer.py +++ b/simple/stats/entities_importer.py @@ -22,36 +22,35 @@ from stats.importer import Importer from stats.nodes import Nodes from stats.reporter import FileImportReporter -from util.filehandler import FileHandler +from util.filesystem import File class EntitiesImporter(Importer): """Imports a single entities input file. - Key behaviors at this time: + Key behaviors at this time: - All un-ignored columns will be encoded as property triples. - If an id column was configured, it will be used as the entity dcid. Else a new dcid will be generated for each entity. - Columns specified as entity columns will be encoded as object_id in the triples tables. Others will be encoded as object_value. + Currently this importer does not resolve any entities and all entities are assumed to be pre-resolved into dcids. """ - def __init__(self, input_fh: FileHandler, db: Db, - reporter: FileImportReporter, nodes: Nodes) -> None: - self.input_fh = input_fh + def __init__(self, input_file: File, db: Db, reporter: FileImportReporter, + nodes: Nodes) -> None: + self.input_file = input_file self.db = db self.reporter = reporter self.nodes = nodes - self.input_file_name = self.input_fh.basename() self.config = nodes.config - self.ignore_columns = self.config.ignore_columns(self.input_file_name) - self.provenance = self.nodes.provenance(self.input_file_name).id + self.ignore_columns = self.config.ignore_columns(self.input_file) + self.provenance = self.nodes.provenance(self.input_file).id - self.row_entity_type = self.config.row_entity_type(self.input_file_name) - assert self.row_entity_type, f"Row entity type must be specified: {self.input_file_name}" + self.row_entity_type = self.config.row_entity_type(self.input_file) + assert self.row_entity_type, f"Row entity type must be specified: {self.input_file.full_path()}" - self.id_column = self.config.id_column(self.input_file_name) + self.id_column = self.config.id_column(self.input_file) # Reassigned when renaming columns. - self.entity_columns = set(self.config.entity_columns(self.input_file_name)) + self.entity_columns = set(self.config.entity_columns(self.input_file)) self.df = pd.DataFrame() @@ -72,7 +71,7 @@ def _read_csv(self) -> None: # Read CSVs with the following behaviors: # - Strip leading whitespaces # - Treat comma as a thousands separator - self.df = pd.read_csv(self.input_fh.read_string_io(), + self.df = pd.read_csv(self.input_file.read_string_io(), skipinitialspace=True, thousands=",") logging.info("Read %s rows.", self.df.index.size) @@ -110,7 +109,7 @@ def _write_row_entity_triples(self) -> None: # Add event type node - it will be written to DB later. # This is to avoid duplicate entity types in scenarios where entities of the same type # are spread across files. - self.nodes.entity_type(self.row_entity_type, self.input_file_name) + self.nodes.entity_type(self.row_entity_type, self.input_file) # All property columns would've been renamed to their dcids by now. # So use the id column's dcid as the id column name. diff --git a/simple/stats/events_importer.py b/simple/stats/events_importer.py index 636142fa..71ba4edc 100644 --- a/simple/stats/events_importer.py +++ b/simple/stats/events_importer.py @@ -29,7 +29,7 @@ from stats.importer import Importer from stats.nodes import Nodes from stats.reporter import FileImportReporter -from util.filehandler import FileHandler +from util.filesystem import File from util import dc_client as dc @@ -42,26 +42,24 @@ class EventsImporter(Importer): """Imports a single events input file. """ - def __init__(self, input_fh: FileHandler, db: Db, - debug_resolve_fh: FileHandler, reporter: FileImportReporter, - nodes: Nodes) -> None: - self.input_fh = input_fh + def __init__(self, input_file: File, db: Db, debug_resolve_file: File, + reporter: FileImportReporter, nodes: Nodes) -> None: + self.input_file = input_file self.db = db - self.debug_resolve_fh = debug_resolve_fh + self.debug_resolve_file = debug_resolve_file self.reporter = reporter self.nodes = nodes - self.input_file_name = self.input_fh.basename() self.config = nodes.config - self.entity_type = self.config.entity_type(self.input_file_name) - self.ignore_columns = self.config.ignore_columns(self.input_file_name) - self.provenance = self.nodes.provenance(self.input_file_name).id + self.entity_type = self.config.entity_type(self.input_file) + self.ignore_columns = self.config.ignore_columns(self.input_file) + self.provenance = self.nodes.provenance(self.input_file).id # Reassign after reading CSV. self.entity_column_name = constants.COLUMN_DCID - self.event_type = self.config.event_type(self.input_file_name) - assert self.event_type, f"Event type must be specified: {self.input_file_name}" + self.event_type = self.config.event_type(self.input_file) + assert self.event_type, f"Event type must be specified: {self.input_file.full_path()}" - self.id_column = self.config.id_column(self.input_file_name) + self.id_column = self.config.id_column(self.input_file) self.df = pd.DataFrame() self.debug_resolve_df = None @@ -88,7 +86,7 @@ def _read_csv(self) -> None: # - Set 1st column (i.e. the entity column) to type str (so that geoIds like "01" are not treated as ints and converted to 1) # - Strip leading whitespaces # - Treat comma as a thousands separator - self.df = pd.read_csv(self.input_fh.read_string_io(), + self.df = pd.read_csv(self.input_file.read_string_io(), dtype={0: str}, skipinitialspace=True, thousands=",") @@ -125,17 +123,17 @@ def _rename_columns(self) -> None: self.df = self.df.rename(columns=renamed) def _write_observations(self) -> None: - sv_names = self.config.computed_variables(self.input_file_name) + sv_names = self.config.computed_variables(self.input_file) if not sv_names: logging.warning("No computed variables specified: %s", - self.input_file_name) + self.input_file.full_path()) return for sv_name in sv_names: - sv_dcid = self.nodes.variable(sv_name, self.input_file_name).id + sv_dcid = self.nodes.variable(sv_name, self.input_file).id aggr_cfg = self.config.aggregation(sv_name) observations = self._compute_sv_observations(sv_dcid, aggr_cfg) - self.db.insert_observations(observations, self.input_file_name) + self.db.insert_observations(observations, self.input_file) def _compute_sv_observations( self, sv_dcid: str, aggr_cfg: AggregationConfig = AggregationConfig() @@ -173,7 +171,7 @@ def _write_event_triples(self) -> None: # Add event type node - it will be written to DB later. # This is to avoid duplicate event types in scenarios where events of the same type # are spread across files. - self.nodes.event_type(self.event_type, self.input_file_name) + self.nodes.event_type(self.event_type, self.input_file) # All property columns would've been renamed to their dcids by now. # So use the id column's dcid as the id column name. @@ -305,9 +303,8 @@ def _create_debug_resolve_dataframe( def _write_debug_csvs(self) -> None: if self.debug_resolve_df is not None: logging.info("Writing resolutions (for debugging) to: %s", - self.debug_resolve_fh) - self.debug_resolve_fh.write_string( - self.debug_resolve_df.to_csv(index=False)) + self.debug_resolve_file) + self.debug_resolve_file.write(self.debug_resolve_df.to_csv(index=False)) # Utility methods diff --git a/simple/stats/main.py b/simple/stats/main.py index 54132d06..30588061 100644 --- a/simple/stats/main.py +++ b/simple/stats/main.py @@ -61,9 +61,9 @@ def _init_logger(): def _run(): - Runner(config_file=FLAGS.config_file, - input_dir=FLAGS.input_dir, - output_dir=FLAGS.output_dir, + Runner(config_file_path=FLAGS.config_file, + input_dir_path=FLAGS.input_dir, + output_dir_path=FLAGS.output_dir, mode=FLAGS.mode).run() diff --git a/simple/stats/mcf_importer.py b/simple/stats/mcf_importer.py index d7c6fd52..ca44ac28 100644 --- a/simple/stats/mcf_importer.py +++ b/simple/stats/mcf_importer.py @@ -23,7 +23,7 @@ from stats.importer import Importer from stats.nodes import Nodes from stats.reporter import FileImportReporter -from util.filehandler import FileHandler +from util.filesystem import File _ID = 'ID' _DCID = 'dcid' @@ -36,13 +36,12 @@ class McfImporter(Importer): For custom DC, the MCF nodes are inserted as triples in the DB. """ - def __init__(self, input_fh: FileHandler, output_fh: FileHandler, db: Db, + def __init__(self, input_file: File, output_file: File, db: Db, reporter: FileImportReporter, is_main_dc: bool) -> None: - self.input_fh = input_fh - self.output_fh = output_fh + self.input_file = input_file + self.output_file = output_file self.db = db self.reporter = reporter - self.input_file_name = self.input_fh.basename() self.is_main_dc = is_main_dc def do_import(self) -> None: @@ -50,11 +49,11 @@ def do_import(self) -> None: try: # For main DC, simply copy the file over. if self.is_main_dc: - self.output_fh.write_string(self.input_fh.read_string()) + self.output_file.write(self.input_file.read()) else: triples = self._mcf_to_triples() logging.info("Inserting %s triples from %s", len(triples), - self.input_file_name) + self.input_file.full_path()) self.db.insert_triples(triples) self.reporter.report_success() @@ -66,7 +65,7 @@ def _mcf_to_triples(self) -> list[Triple]: parser_triples: list[list[str]] = [] # DCID references local2dcid: dict[str, str] = {} - for parser_triple in mcf_to_triples(self.input_fh.read_string_io()): + for parser_triple in mcf_to_triples(self.input_file.read_string_io()): [subject_id, predicate, value, _] = parser_triple if predicate == _DCID: local2dcid[subject_id] = value diff --git a/simple/stats/nl.py b/simple/stats/nl.py index 63739ee4..cedb4ae0 100644 --- a/simple/stats/nl.py +++ b/simple/stats/nl.py @@ -23,7 +23,8 @@ from stats.nl_constants import CUSTOM_MODEL from stats.nl_constants import CUSTOM_MODEL_PATH import stats.schema_constants as sc -from util.filehandler import FileHandler +from util.filesystem import Dir +from util.filesystem import File import yaml _DCID_COL = "dcid" @@ -37,16 +38,16 @@ _TOPIC_CACHE_JSON_FILE = "custom_dc_topic_cache.json" -def generate_nl_sentences(triples: list[Triple], nl_dir_fh: FileHandler): +def generate_nl_sentences(triples: list[Triple], nl_dir: Dir): """Generates NL sentences based on name and searchDescription triples. This method should only be called for triples of types for which NL sentences should be generated. Currently it is StatisticalVariable and Topic. - This method does not do the type checks itself and the onus is on the caller + This method does not do the type checks itself and the onus is on the caller to filter triples. - The dcids and sentences are written to a CSV using the specified FileHandler + The dcids and sentences are written to a CSV using the specified File. """ dcid2candidates: dict[str, SentenceCandidates] = {} @@ -64,29 +65,27 @@ def generate_nl_sentences(triples: list[Triple], nl_dir_fh: FileHandler): dataframe = pd.DataFrame(rows) - sentences_fh = nl_dir_fh.make_file(_SENTENCES_FILE) - logging.info("Writing %s NL sentences to: %s", dataframe.size, sentences_fh) - sentences_fh.write_string(dataframe.to_csv(index=False)) + sentences_file = nl_dir.open_file(_SENTENCES_FILE) + logging.info("Writing %s NL sentences to: %s", dataframe.size, sentences_file) + sentences_file.write(dataframe.to_csv(index=False)) - # The trailing "/" is used by the file handler to create a directory. - embeddings_dir_fh = nl_dir_fh.make_file(f"{_EMBEDDINGS_DIR}/") - embeddings_dir_fh.make_dirs() - embeddings_fh = embeddings_dir_fh.make_file(_EMBEDDINGS_FILE) - catalog_fh = embeddings_dir_fh.make_file(_CUSTOM_CATALOG_YAML) - catalog_dict = _catalog_dict(nl_dir_fh.path, embeddings_fh.path) + embeddings_dir = nl_dir.open_dir(_EMBEDDINGS_DIR) + embeddings_file = embeddings_dir.open_file(_EMBEDDINGS_FILE) + catalog_file = embeddings_dir.open_file(_CUSTOM_CATALOG_YAML) + catalog_dict = _catalog_dict(nl_dir, embeddings_file) catalog_yaml = yaml.safe_dump(catalog_dict) - logging.info("Writing custom catalog to path %s:\n%s", catalog_fh, + logging.info("Writing custom catalog to path %s:\n%s", catalog_file, catalog_yaml) - catalog_fh.write_string(catalog_yaml) + catalog_file.write(catalog_yaml) -def generate_topic_cache(triples: list[Triple], nl_dir_fh: FileHandler): +def generate_topic_cache(triples: list[Triple], nl_dir: Dir): """Generates topic cache based on Topic and StatVarPeerGroup triples. This method should only be called for triples of types for which topic cache should be generated (Topic and StatVarPeerGroup). - This method does not do the type checks itself and the onus is on the caller + This method does not do the type checks itself and the onus is on the caller to filter triples. The topic cache is written to a custom_dc_topic_cache.json file in the specified directory. @@ -102,20 +101,20 @@ def generate_topic_cache(triples: list[Triple], nl_dir_fh: FileHandler): nodes.append(node.json()) result = {"nodes": nodes} - topic_cache_fh = nl_dir_fh.make_file(_TOPIC_CACHE_JSON_FILE) + topic_cache_file = nl_dir.open_file(_TOPIC_CACHE_JSON_FILE) logging.info("Writing %s topic cache nodes to: %s", len(nodes), - topic_cache_fh) - topic_cache_fh.write_string(json.dumps(result, indent=1)) + topic_cache_file) + topic_cache_file.write(json.dumps(result, indent=1)) -def _catalog_dict(nl_dir: str, embeddings_path: str) -> dict: +def _catalog_dict(nl_dir: Dir, embeddings_file: File) -> dict: return { "version": "1", "indexes": { CUSTOM_EMBEDDINGS_INDEX: { "store_type": "MEMORY", - "source_path": nl_dir, - "embeddings_path": embeddings_path, + "source_path": nl_dir.full_path(), + "embeddings_path": embeddings_file.full_path(), "model": CUSTOM_MODEL }, }, diff --git a/simple/stats/nodes.py b/simple/stats/nodes.py index 67d50480..850c0de5 100644 --- a/simple/stats/nodes.py +++ b/simple/stats/nodes.py @@ -28,7 +28,7 @@ from stats.data import StatVarGroup from stats.data import Triple import stats.schema_constants as sc -from util.filehandler import FileHandler +from util.filesystem import File _CUSTOM_SV_ID_PREFIX = "custom/statvar_" _CUSTOM_GROUP_ID_PREFIX = "custom/g/group_" @@ -123,11 +123,11 @@ def _source_id(self, source_cfg: Source | None) -> str: return source.id - def provenance(self, input_file_name: str) -> Provenance: - prov_name = self.config.provenance_name(input_file_name) + def provenance(self, input_file: File) -> Provenance: + prov_name = self.config.provenance_name(input_file) return self.provenances.get(prov_name, _DEFAULT_PROVENANCE) - def variable(self, sv_column_name: str, input_file_name: str) -> StatVar: + def variable(self, sv_column_name: str, input_file: File) -> StatVar: if not sv_column_name in self.variables: var_cfg = self.config.variable(sv_column_name) group = self.group(var_cfg.group_path) @@ -141,7 +141,7 @@ def variable(self, sv_column_name: str, input_file_name: str) -> StatVar: properties=var_cfg.properties) return self._add_provenance(self.variables[sv_column_name], - self.provenance(input_file_name)) + self.provenance(input_file)) def property(self, property_column_name: str) -> Property: if not property_column_name in self.properties: @@ -150,7 +150,7 @@ def property(self, property_column_name: str) -> Property: return self.properties[property_column_name] - def event_type(self, event_type_name: str, input_file_name: str) -> EventType: + def event_type(self, event_type_name: str, input_file: File) -> EventType: if not event_type_name in self.event_types: event_type_cfg = self.config.event(event_type_name) self.event_types[event_type_name] = EventType( @@ -159,10 +159,9 @@ def event_type(self, event_type_name: str, input_file_name: str) -> EventType: description=event_type_cfg.description) return self.event_types[event_type_name].add_provenance( - self.provenance(input_file_name)) + self.provenance(input_file)) - def entity_type(self, entity_type_name: str, - input_file_name: str) -> EntityType: + def entity_type(self, entity_type_name: str, input_file: File) -> EntityType: if not entity_type_name in self.entity_types: entity_type_cfg = self.config.entity(entity_type_name) self.entity_types[entity_type_name] = EntityType( @@ -171,7 +170,7 @@ def entity_type(self, entity_type_name: str, description=entity_type_cfg.description) return self.entity_types[entity_type_name].add_provenance( - self.provenance(input_file_name)) + self.provenance(input_file)) def _add_provenance(self, sv: StatVar, provenance: Provenance) -> StatVar: sv.add_provenance(provenance) @@ -270,7 +269,7 @@ def entities_with_types(self, dcid2type: dict[str, str]): for entity_dcid, entity_type in dcid2type.items(): self.entity_with_type(entity_dcid, entity_type) - def triples(self, triples_fh: FileHandler | None = None) -> list[Triple]: + def triples(self, triples_file: File | None = None) -> list[Triple]: triples: list[Triple] = [] for source in self.sources.values(): triples.extend(source.triples()) @@ -289,8 +288,8 @@ def triples(self, triples_fh: FileHandler | None = None) -> list[Triple]: for entities in self.entities.values(): triples.extend(entities.triples()) - if triples_fh: - logging.info("Writing %s triples to: %s", len(triples), str(triples_fh)) - triples_fh.write_string(pd.DataFrame(triples).to_csv(index=False)) + if triples_file: + logging.info("Writing %s triples to: %s", len(triples), triples_file) + triples_file.write(pd.DataFrame(triples).to_csv(index=False)) return triples diff --git a/simple/stats/observations_importer.py b/simple/stats/observations_importer.py index e937f7c3..5c17ffa3 100644 --- a/simple/stats/observations_importer.py +++ b/simple/stats/observations_importer.py @@ -23,7 +23,7 @@ from stats.importer import Importer from stats.nodes import Nodes from stats.reporter import FileImportReporter -from util.filehandler import FileHandler +from util.filesystem import File from util import dc_client as dc @@ -32,18 +32,16 @@ class ObservationsImporter(Importer): """Imports a single observations input file. """ - def __init__(self, input_fh: FileHandler, db: Db, - debug_resolve_fh: FileHandler, reporter: FileImportReporter, - nodes: Nodes) -> None: - self.input_fh = input_fh + def __init__(self, input_file: File, db: Db, debug_resolve_file: File, + reporter: FileImportReporter, nodes: Nodes) -> None: + self.input_file = input_file self.db = db - self.debug_resolve_fh = debug_resolve_fh + self.debug_resolve_file = debug_resolve_file self.reporter = reporter self.nodes = nodes - self.input_file_name = self.input_fh.basename() self.config = nodes.config - self.entity_type = self.config.entity_type(self.input_file_name) - self.ignore_columns = self.config.ignore_columns(self.input_file_name) + self.entity_type = self.config.entity_type(self.input_file) + self.ignore_columns = self.config.ignore_columns(self.input_file) # Reassign after reading CSV. self.entity_column_name = constants.COLUMN_DCID self.df = pd.DataFrame() @@ -71,7 +69,7 @@ def _read_csv(self) -> None: # - Set 1st column (i.e. the entity column) to type str (so that geoIds like "01" are not treated as ints and converted to 1) # - Strip leading whitespaces # - Treat comma as a thousands separator - self.df = pd.read_csv(self.input_fh.read_string_io(), + self.df = pd.read_csv(self.input_file.read_string_io(), dtype={0: str}, skipinitialspace=True, thousands=",") @@ -98,7 +96,7 @@ def _rename_columns(self) -> None: # Rename SV columns to their IDs sv_column_names = self.df.columns[2:] sv_ids = [ - self.nodes.variable(sv_column_name, self.input_file_name).id + self.nodes.variable(sv_column_name, self.input_file).id for sv_column_name in sv_column_names ] renamed.update({col: id for col, id in zip(sv_column_names, sv_ids)}) @@ -115,9 +113,9 @@ def _write_observations(self) -> None: value_name=constants.COLUMN_VALUE, ) - provenance = self.nodes.provenance(self.input_file_name).id + provenance = self.nodes.provenance(self.input_file).id obs_props = ObservationProperties.new( - self.config.observation_properties(self.input_file_name)) + self.config.observation_properties(self.input_file)) observations: list[Observation] = [] for _, row in observations_df.iterrows(): @@ -129,7 +127,7 @@ def _write_observations(self) -> None: properties=obs_props) if observation.value and observation.value != "": observations.append(observation) - self.db.insert_observations(observations, self.input_file_name) + self.db.insert_observations(observations, self.input_file) def _add_entity_nodes(self) -> None: # Convert entity dcids to dict. @@ -262,6 +260,5 @@ def _create_debug_resolve_dataframe( def _write_debug_csvs(self) -> None: if self.debug_resolve_df is not None: logging.info("Writing resolutions (for debugging) to: %s", - self.debug_resolve_fh) - self.debug_resolve_fh.write_string( - self.debug_resolve_df.to_csv(index=False)) + self.debug_resolve_file.path) + self.debug_resolve_file.write(self.debug_resolve_df.to_csv(index=False)) diff --git a/simple/stats/reporter.py b/simple/stats/reporter.py index 1b90546c..51232dd3 100644 --- a/simple/stats/reporter.py +++ b/simple/stats/reporter.py @@ -19,7 +19,7 @@ import json import time -from util.filehandler import FileHandler +from util.filesystem import File # Minimum interval before a report should be saved to disk or cloud. # This keeps it from reporting too frequently and running into GCS rate limit issues. @@ -43,14 +43,14 @@ class ImportReporter: The report is written to report.json in the process directory. """ - def __init__(self, report_fh: FileHandler) -> None: + def __init__(self, report_file: File) -> None: self.status = Status.NOT_STARTED self.start_time = None self.last_update = datetime.now() self.last_reported: float | None = None - self.report_fh = report_fh + self.report_file = report_file self.data = {} - self.import_files: dict[str, FileImportReporter] = {} + self.file_reporters_by_full_path: dict[str, FileImportReporter] = {} # Functions decorated with @_report will result in the report being saved # upon function execution. @@ -65,11 +65,13 @@ def wrapper(self, *args, **kwargs): return wrapper @_report - def report_started(self, import_files: list[str]): + def report_started(self, import_files: list[File]): self.status = Status.STARTED self.start_time = datetime.now() for import_file in import_files: - self.import_files[import_file] = FileImportReporter(import_file, self) + full_path = import_file.full_path() + self.file_reporters_by_full_path[full_path] = FileImportReporter( + full_path, self) @_report def report_done(self): @@ -81,10 +83,10 @@ def report_failure(self, error: str): self.status = Status.FAILURE self.data["error"] = error - def import_file(self, import_file: str): - return self.import_files[import_file] + def get_file_reporter(self, import_file: File): + return self.file_reporters_by_full_path[import_file.full_path()] - def import_file_update(self, import_file: str): + def recompute_progress(self): self._compute_all_done() self.save() @@ -95,8 +97,8 @@ def _compute_all_done(self): self.status = Status.FAILURE def _all_file_imports(self, status: Status) -> bool: - return all( - reporter.status == status for reporter in self.import_files.values()) + return all(reporter.status == status + for reporter in self.file_reporters_by_full_path.values()) def json(self) -> dict: report = {} @@ -113,11 +115,11 @@ def _maybe_report(field: str, func=None): report["startTime"] = str(self.start_time) report["lastUpdate"] = str(self.last_update) - import_files = {} - for import_file, import_file_reporter in self.import_files.items(): - import_files[import_file] = import_file_reporter.json() + import_files_output = {} + for full_path, file_reporter in self.file_reporters_by_full_path.items(): + import_files_output[full_path] = file_reporter.json() - report["importFiles"] = import_files + report["importFiles"] = import_files_output return report @@ -128,18 +130,19 @@ def save(self) -> None: self.status) if should_report: self.last_reported = time.time() - self.report_fh.write_string(json.dumps(self.json(), indent=2)) + self.report_file.write(json.dumps(self.json(), indent=2)) class FileImportReporter: """Generates a report on every reported change for a single file import. """ - def __init__(self, import_file: str, parent: ImportReporter) -> None: + def __init__(self, import_file_full_path: str, + parent: ImportReporter) -> None: self.status = Status.NOT_STARTED self.start_time = None self.last_update = datetime.now() - self.import_file = import_file + self.import_file_full_path = import_file_full_path self.parent = parent self.data = {} @@ -188,4 +191,4 @@ def _maybe_report(field: str, func=None): def report(self) -> None: self.last_update = datetime.now() - self.parent.import_file_update(self.import_file) + self.parent.recompute_progress() diff --git a/simple/stats/runner.py b/simple/stats/runner.py index 73ccfb8a..dfe3e306 100644 --- a/simple/stats/runner.py +++ b/simple/stats/runner.py @@ -1,21 +1,9 @@ -# Copyright 2023 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from enum import StrEnum import json import logging +import os +import fs.path as fspath from stats import constants from stats import schema from stats import stat_var_hierarchy_generator @@ -30,7 +18,7 @@ from stats.db import create_main_dc_config from stats.db import create_sqlite_config from stats.db import get_cloud_sql_config_from_env -from stats.db import get_sqlite_config_from_env +from stats.db import get_sqlite_path_from_env from stats.db import ImportStatus from stats.entities_importer import EntitiesImporter from stats.events_importer import EventsImporter @@ -42,8 +30,11 @@ from stats.reporter import ImportReporter import stats.schema_constants as sc from stats.variable_per_row_importer import VariablePerRowImporter -from util.filehandler import create_file_handler -from util.filehandler import FileHandler +from util.file_match import match +from util.filesystem import create_store +from util.filesystem import Dir +from util.filesystem import File +from util.filesystem import Store class RunMode(StrEnum): @@ -57,62 +48,69 @@ class Runner: """ def __init__(self, - config_file: str, - input_dir: str, - output_dir: str, + config_file_path: str, + input_dir_path: str, + output_dir_path: str, mode: RunMode = RunMode.CUSTOM_DC) -> None: - assert config_file or input_dir, "One of config_file or input_dir must be specified" - assert output_dir, "output_dir must be specified" + assert config_file_path or input_dir_path, "One of config_file or input_dir must be specified" + assert output_dir_path, "output_dir must be specified" self.mode = mode - self.input_handlers: list[FileHandler] = [] + + # File systems, both input and output. Must be closed when run finishes. + self.all_stores: list[Store] = [] + # Input-only stores + self.input_stores: list[Store] = [] + # "Special" file handlers. # i.e. if files of these types are present, they are handled in specific ways. - self.special_handlers: dict[str, FileHandler] = {} + self.special_files: dict[str, File] = {} self.svg_specialized_names: ParentSVG2ChildSpecializedNames = {} - # Config file driven. - if config_file: - config_fh = create_file_handler(config_file, is_dir=False) - if not config_fh.exists(): - raise FileNotFoundError("Config file must be provided.") - self.config = Config(data=json.loads(config_fh.read_string())) + # Config file driven (input paths pulled from config) + if config_file_path: + with create_store(config_file_path) as config_store: + config_data = config_store.as_file().read() + self.config = Config(data=json.loads(config_data)) input_urls = self.config.data_download_urls() if not input_urls: raise ValueError("Data Download URLs not found in config.") for input_url in input_urls: - self.input_handlers.append(create_file_handler(input_url, is_dir=True)) + input_store = create_store(input_url) + self.all_stores.append(input_store) + self.input_stores.append(input_store) - #Input dir driven. + # Input dir driven (config file found in input dir) else: - input_dir_fh = create_file_handler(input_dir, is_dir=True) - if not input_dir_fh.isdir: - raise NotADirectoryError( - f"Input path must be a directory: {input_dir}. If it is a GCS path, ensure it ends with a '/'." - ) - self.input_handlers.append(input_dir_fh) + input_store = create_store(input_dir_path) + self.all_stores.append(input_store) + self.input_stores.append(input_store) - config_fh = input_dir_fh.make_file(constants.CONFIG_JSON_FILE_NAME) - if not config_fh.exists(): - raise FileNotFoundError("Config file must be provided.") - self.config = Config(data=json.loads(config_fh.read_string())) + config_file = input_store.as_dir().open_file( + constants.CONFIG_JSON_FILE_NAME, create_if_missing=False) + self.config = Config(data=json.loads(config_file.read())) - self.special_file_to_type = self.config.special_files() + # Get dict of special file type string to special file name. + # Example entry: verticalSpecsFile -> vertical_specs.json + self.special_file_names_by_type = self.config.special_files() - # Output directories - self.output_dir_fh = create_file_handler(output_dir, is_dir=True) - self.nl_dir_fh = self.output_dir_fh.make_file(f"{constants.NL_DIR_NAME}/") - self.process_dir_fh = self.output_dir_fh.make_file( - f"{constants.PROCESS_DIR_NAME}/") + # New option to traverse subdirs of input dir(s). Defaults to false. + self.include_input_subdirs = self.config.include_input_subdirs() - self.output_dir_fh.make_dirs() - self.nl_dir_fh.make_dirs() - self.process_dir_fh.make_dirs() + # Output directories + output_store = create_store(output_dir_path, create_if_missing=True) + if self.include_input_subdirs: + for input_store in self.input_stores: + _check_not_overlapping(input_store, output_store) + self.all_stores.append(output_store) + self.output_dir = output_store.as_dir() + self.nl_dir = self.output_dir.open_dir(constants.NL_DIR_NAME) + self.process_dir = self.output_dir.open_dir(constants.PROCESS_DIR_NAME) # Reporter. - self.reporter = ImportReporter(report_fh=self.process_dir_fh.make_file( - constants.REPORT_JSON_FILE_NAME)) + self.reporter = ImportReporter( + report_file=self.process_dir.open_file(constants.REPORT_JSON_FILE_NAME)) self.nodes = Nodes(self.config) self.db = None @@ -136,6 +134,12 @@ def run(self): # Report done. self.reporter.report_done() + + # Close all file storage. + for store in self.all_stores: + store.close() + logging.info("File storage closed.") + except Exception as e: logging.exception("Error updating stats") self.reporter.report_failure(error=str(e)) @@ -143,20 +147,25 @@ def run(self): def _get_db_config(self) -> dict: if self.mode == RunMode.MAIN_DC: logging.info("Using Main DC config.") - return create_main_dc_config(self.output_dir_fh.path) + return create_main_dc_config(self.output_dir.path) # Attempt to get from env (cloud sql, then sqlite), # then config file, then default. db_cfg = get_cloud_sql_config_from_env() if db_cfg: logging.info("Using Cloud SQL settings from env.") return db_cfg - db_cfg = get_sqlite_config_from_env() - if db_cfg: + sqlite_path_from_env = get_sqlite_path_from_env() + if sqlite_path_from_env: logging.info("Using SQLite settings from env.") - return db_cfg - logging.info("Using default DB settings.") - return create_sqlite_config( - self.output_dir_fh.make_file(constants.DB_FILE_NAME).path) + sqlite_env_store = create_store(sqlite_path_from_env, + create_if_missing=True, + treat_as_file=True) + self.all_stores.append(sqlite_env_store) + sqlite_file = sqlite_env_store.as_file() + else: + logging.info("Using default SQLite settings.") + sqlite_file = self.output_dir.open_file(constants.DB_FILE_NAME) + return create_sqlite_config(sqlite_file) def _run_imports_and_do_post_import_work(self): # (SQL only) Drop data in existing tables (except import metadata). @@ -194,13 +203,13 @@ def _generate_nl_artifacts(self): sc.TYPE_STATISTICAL_VARIABLE) # Generate sentences. - nl.generate_nl_sentences(triples, self.nl_dir_fh) + nl.generate_nl_sentences(triples, self.nl_dir) # If generating topics, fetch svpg triples as well and generate topic cache if generate_topics: triples = triples + self.db.select_triples_by_subject_type( sc.TYPE_STAT_VAR_PEER_GROUP) - nl.generate_topic_cache(triples, self.nl_dir_fh) + nl.generate_topic_cache(triples, self.nl_dir) def _generate_svg_hierarchy(self): if self.mode == RunMode.MAIN_DC: @@ -218,13 +227,13 @@ def _generate_svg_hierarchy(self): logging.info("Generating SVG hierarchy for %s SV triples.", len(sv_triples)) vertical_specs: list[VerticalSpec] = [] - vertical_specs_fh = self.special_handlers.get( + vertical_specs_file = self.special_files.get( constants.VERTICAL_SPECS_FILE_TYPE) - if vertical_specs_fh: + if vertical_specs_file: logging.info("Loading vertical specs from: %s", - vertical_specs_fh.basename()) + vertical_specs_file.name()) vertical_specs = stat_var_hierarchy_generator.load_vertical_specs( - vertical_specs_fh.read_string()) + vertical_specs_file.read()) # Collect all dcids that can be used to generate SVG names and get their schema names. schema_dcids = list( @@ -261,100 +270,115 @@ def _vertical_specs_dcids(self, def _generate_svg_cache(self): generate_svg_cache(self.db, self.svg_specialized_names) - # If the fh is a "special" file, append it to the self.special_handlers dict. - # Returns true if it is, otherwise false. - def _maybe_set_special_fh(self, fh: FileHandler) -> bool: - file_name = fh.basename() - file_type = self.special_file_to_type.get(file_name) - if file_type: - self.special_handlers[file_type] = fh - return True + def _check_if_special_file(self, file: File) -> bool: + for file_type in self.special_file_names_by_type.keys(): + if file_type in self.special_files: + # Already found this special file. + continue + file_name = self.special_file_names_by_type[file_type] + if match(file, file_name): + self.special_files[file_type] = file + return True return False def _run_all_data_imports(self): - input_fhs: list[FileHandler] = [] - input_mcf_fhs: list[FileHandler] = [] - for input_handler in self.input_handlers: - if not input_handler.isdir: - if self._maybe_set_special_fh(input_handler): - continue - input_file_name = input_handler.basename() - if input_file_name.endswith(".mcf"): - input_mcf_fhs.append(input_handler) - else: - input_fhs.append(input_handler) + input_files: list[File] = [] + input_csv_files: list[File] = [] + input_mcf_files: list[File] = [] + + for input_store in self.input_stores: + if input_store.isdir(): + input_files.extend(input_store.as_dir().all_files( + self.include_input_subdirs)) else: - for input_file in sorted(input_handler.list_files(extension=".csv")): - fh = input_handler.make_file(input_file) - if not self._maybe_set_special_fh(fh): - input_fhs.append(fh) - for input_file in sorted(input_handler.list_files(extension=".mcf")): - fh = input_handler.make_file(input_file) - if not self._maybe_set_special_fh(fh): - input_mcf_fhs.append(fh) - for input_file in sorted(input_handler.list_files(extension=".json")): - fh = input_handler.make_file(input_file) - self._maybe_set_special_fh(fh) - - self.reporter.report_started(import_files=list( - map(lambda fh: fh.basename(), input_fhs + input_mcf_fhs))) - for input_fh in input_fhs: - self._run_single_import(input_fh) - for input_mcf_fh in input_mcf_fhs: - self._run_single_mcf_import(input_mcf_fh) - - def _run_single_import(self, input_fh: FileHandler): - logging.info("Importing file: %s", input_fh.basename()) - self._create_importer(input_fh).do_import() - - def _run_single_mcf_import(self, input_mcf_fh: FileHandler): - logging.info("Importing MCF file: %s", input_mcf_fh.basename()) - self._create_mcf_importer(input_mcf_fh, self.output_dir_fh, + input_files.append(input_store.as_file()) + + for input_file in input_files: + if self._check_if_special_file(input_file): + continue + if match(input_file, "*.csv"): + input_csv_files.append(input_file) + if match(input_file, "*.mcf"): + input_mcf_files.append(input_file) + + # Sort input files alphabetically. + input_csv_files.sort(key=lambda f: f.full_path()) + input_mcf_files.sort(key=lambda f: f.full_path()) + + self.reporter.report_started(import_files=list(input_csv_files + + input_mcf_files)) + for input_csv_file in input_csv_files: + self._run_single_import(input_csv_file) + for input_mcf_file in input_mcf_files: + self._run_single_mcf_import(input_mcf_file) + + def _run_single_import(self, input_file: File): + logging.info("Importing file: %s", input_file) + self._create_importer(input_file).do_import() + + def _run_single_mcf_import(self, input_mcf_file: File): + logging.info("Importing MCF file: %s", input_mcf_file) + self._create_mcf_importer(input_mcf_file, self.output_dir, self.mode == RunMode.MAIN_DC).do_import() - def _create_mcf_importer(self, input_fh: FileHandler, - output_dir_fh: FileHandler, + def _create_mcf_importer(self, input_file: File, output_dir: Dir, is_main_dc: bool) -> Importer: - mcf_file_name = input_fh.basename() - output_fh = output_dir_fh.make_file(mcf_file_name) - reporter = self.reporter.import_file(mcf_file_name) - return McfImporter(input_fh=input_fh, - output_fh=output_fh, + # Right now, this overwrites any file with the same name, + # so if different input sources have files with the same relative path, + # they will clobber each others output. Treating this as an edge case + # for now since it only affects the main DC case, but we could resolve + # it in the future by allowing input sources to be mapped to output + # locations. + output_file = output_dir.open_file(input_file.path) + reporter = self.reporter.get_file_reporter(input_file) + return McfImporter(input_file=input_file, + output_file=output_file, db=self.db, reporter=reporter, is_main_dc=is_main_dc) - def _create_importer(self, input_fh: FileHandler) -> Importer: - input_file = input_fh.basename() + def _create_importer(self, input_file: File) -> Importer: import_type = self.config.import_type(input_file) - debug_resolve_fh = self.process_dir_fh.make_file( - f"{constants.DEBUG_RESOLVE_FILE_NAME_PREFIX}_{input_file}") - reporter = self.reporter.import_file(input_file) + sanitized_path = input_file.full_path().replace("://", + "_").replace("/", "_") + debug_resolve_file = self.process_dir.open_file( + f"{constants.DEBUG_RESOLVE_FILE_NAME_PREFIX}_{sanitized_path}") + reporter = self.reporter.get_file_reporter(input_file) if import_type == ImportType.OBSERVATIONS: input_file_format = self.config.format(input_file) if input_file_format == InputFileFormat.VARIABLE_PER_ROW: - return VariablePerRowImporter(input_fh=input_fh, + return VariablePerRowImporter(input_file=input_file, db=self.db, reporter=reporter, nodes=self.nodes) - return ObservationsImporter(input_fh=input_fh, + return ObservationsImporter(input_file=input_file, db=self.db, - debug_resolve_fh=debug_resolve_fh, + debug_resolve_file=debug_resolve_file, reporter=reporter, nodes=self.nodes) if import_type == ImportType.EVENTS: - return EventsImporter(input_fh=input_fh, + return EventsImporter(input_file=input_file, db=self.db, - debug_resolve_fh=debug_resolve_fh, + debug_resolve_file=debug_resolve_file, reporter=reporter, nodes=self.nodes) if import_type == ImportType.ENTITIES: - return EntitiesImporter(input_fh=input_fh, + return EntitiesImporter(input_file=input_file, db=self.db, reporter=reporter, nodes=self.nodes) - raise ValueError(f"Unsupported import type: {import_type} ({input_file})") + raise ValueError( + f"Unsupported import type: {import_type} ({input_file.full_path()})") + + +def _check_not_overlapping(input_store: Store, output_store: Store): + input_path = input_store.full_path() + output_path = output_store.full_path() + if fspath.issamedir(input_path, output_path) or fspath.isparent( + input_path, output_path) or fspath.isparent(output_path, input_path): + raise ValueError( + f"Input path (${input_path}) overlaps with output dir ({output_path})") diff --git a/simple/stats/variable_per_row_importer.py b/simple/stats/variable_per_row_importer.py index b7490cef..d58c727b 100644 --- a/simple/stats/variable_per_row_importer.py +++ b/simple/stats/variable_per_row_importer.py @@ -23,7 +23,7 @@ from stats.importer import Importer from stats.nodes import Nodes from stats.reporter import FileImportReporter -from util.filehandler import FileHandler +from util.filesystem import File from util import dc_client as dc @@ -49,15 +49,14 @@ class VariablePerRowImporter(Importer): This is in contrast to the ObservationsImporter where variables are specified in columns. Currently this importer only writes observations and no entities. - It also does not resolve any entities and expects all entities to be pre-resolved. + It also does not resolve any entities and expects all entities to be pre-resolved. """ - def __init__(self, input_fh: FileHandler, db: Db, - reporter: FileImportReporter, nodes: Nodes) -> None: - self.input_fh = input_fh + def __init__(self, input_file: File, db: Db, reporter: FileImportReporter, + nodes: Nodes) -> None: + self.input_file = input_file self.db = db self.reporter = reporter - self.input_file_name = self.input_fh.basename() self.nodes = nodes self.config = nodes.config # Reassign after reading CSV. @@ -80,10 +79,10 @@ def do_import(self) -> None: raise e def _read_csv(self) -> None: - self.reader = DictReader(self.input_fh.read_string_io()) + self.reader = DictReader(self.input_file.read_string_io()) def _map_columns(self): - config_mappings = self.config.column_mappings(self.input_file_name) + config_mappings = self.config.column_mappings(self.input_file) # Required columns. for key in self.column_mappings.keys(): @@ -107,9 +106,9 @@ def _map_columns(self): ) def _write_observations(self) -> None: - provenance = self.nodes.provenance(self.input_file_name).id + provenance = self.nodes.provenance(self.input_file).id obs_props = ObservationProperties.new( - self.config.observation_properties(self.input_file_name)) + self.config.observation_properties(self.input_file)) observations: list[Observation] = [] for row in self.reader: @@ -125,7 +124,7 @@ def _write_observations(self) -> None: properties=row_obs_props) observations.append(observation) self.entity_dcids[entity_dcid] = True - self.db.insert_observations(observations, self.input_file_name) + self.db.insert_observations(observations, self.input_file) def _get_row_obs(self, row: dict[str, str]) -> dict[str, str]: properties: dict[str, str] = {} diff --git a/simple/tests/stats/config_test.py b/simple/tests/stats/config_test.py index b82e865a..f25e1b26 100644 --- a/simple/tests/stats/config_test.py +++ b/simple/tests/stats/config_test.py @@ -23,6 +23,8 @@ from stats.data import Source from stats.data import StatVar from stats.data import TimePeriod +from util.filesystem import create_store +from util.filesystem import File CONFIG_DATA = { "inputFiles": { @@ -124,6 +126,15 @@ def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) self.maxDiff = None + def make_file(self, path: str) -> File: + return self.store.as_dir().open_file(path) + + def setUp(self): + self.store = create_store("mem://") + + def tearDown(self): + self.store.close() + def test_variable(self): config = Config(CONFIG_DATA) self.assertEqual( @@ -151,15 +162,18 @@ def test_variable(self): def test_entity_type(self): config = Config(CONFIG_DATA) - self.assertEqual(config.entity_type("a.csv"), "Country") - self.assertEqual(config.entity_type("b.csv"), "") - self.assertEqual(config.entity_type("not-in-config.csv"), "") + self.assertEqual(config.entity_type(self.make_file("a.csv")), "Country") + self.assertEqual(config.entity_type(self.make_file("b.csv")), "") + self.assertEqual(config.entity_type(self.make_file("not_in_config.csv")), + "") def test_ignore_columns(self): config = Config(CONFIG_DATA) - self.assertEqual(config.ignore_columns("a.csv"), []) - self.assertEqual(config.ignore_columns("b.csv"), ["ignore1", "ignore2"]) - self.assertEqual(config.ignore_columns("not-in-config.csv"), []) + self.assertEqual(config.ignore_columns(self.make_file("a.csv")), []) + self.assertEqual(config.ignore_columns(self.make_file("b.csv")), + ["ignore1", "ignore2"]) + self.assertEqual(config.ignore_columns(self.make_file("not_in_config.csv")), + []) def test_provenances_and_sources(self): config = Config(CONFIG_DATA) @@ -168,19 +182,20 @@ def test_provenances_and_sources(self): def test_provenance_name(self): config = Config(CONFIG_DATA) - self.assertEqual(config.provenance_name("a.csv"), "Provenance21 Name") - self.assertEqual(config.provenance_name("b.csv"), "b.csv") + self.assertEqual(config.provenance_name(self.make_file("a.csv")), + "Provenance21 Name") + self.assertEqual(config.provenance_name(self.make_file("b.csv")), "b.csv") def test_import_type(self): config = Config(CONFIG_DATA) - self.assertEqual(config.import_type("a.csv"), ImportType.OBSERVATIONS, - "default import type") - self.assertEqual(config.import_type("observations.csv"), + self.assertEqual(config.import_type(self.make_file("a.csv")), + ImportType.OBSERVATIONS, "default import type") + self.assertEqual(config.import_type(self.make_file("observations.csv")), ImportType.OBSERVATIONS, "observations import type") - self.assertEqual(config.import_type("events.csv"), ImportType.EVENTS, - "events import type") + self.assertEqual(config.import_type(self.make_file("events.csv")), + ImportType.EVENTS, "events import type") with self.assertRaisesRegex(ValueError, "Unsupported import type"): - config.import_type("invalid_import_type.csv") + config.import_type(self.make_file("invalid_import_type.csv")) def test_aggregation(self): config = Config(CONFIG_DATA) @@ -198,11 +213,11 @@ def test_aggregation(self): def test_empty_config(self): config = Config({}) self.assertEqual(config.variable("Variable 1"), StatVar("", "Variable 1")) - self.assertEqual(config.entity_type("a.csv"), "") - self.assertEqual(config.ignore_columns("a.csv"), []) + self.assertEqual(config.entity_type(self.make_file("a.csv")), "") + self.assertEqual(config.ignore_columns(self.make_file("a.csv")), []) self.assertDictEqual(config.provenances, {}) self.assertDictEqual(config.provenance_sources, {}) - self.assertEqual(config.provenance_name("a.csv"), "a.csv") + self.assertEqual(config.provenance_name(self.make_file("a.csv")), "a.csv") def test_data_download_urls(self): self.assertListEqual(Config({}).data_download_urls(), [], "empty") @@ -216,7 +231,8 @@ def test_data_download_urls(self): }).data_download_urls(), ["foo", "bar"], "two urls") def test_input_file(self): - self.assertDictEqual(Config({})._input_file("foo.csv"), {}, "empty") + self.assertDictEqual( + Config({})._per_file_config(self.make_file("foo.csv")), {}, "empty") self.assertDictEqual( Config({ "inputFiles": { @@ -224,7 +240,43 @@ def test_input_file(self): "x": "y" } } - })._input_file("foo.csv"), {"x": "y"}, "exact match") + })._per_file_config(self.make_file("foo.csv")), {"x": "y"}, + "exact match") + self.assertDictEqual( + Config({ + "inputFiles": { + "foo.csv": { + "x": "y" + } + } + })._per_file_config(self.make_file("path/to/foo.csv")), {"x": "y"}, + "subdir match") + self.assertDictEqual( + Config({ + "inputFiles": { + "path/to/*.csv": { + "x": "y" + } + } + })._per_file_config(self.make_file("path/to/foo.csv")), {"x": "y"}, + "subdir match with path") + self.assertDictEqual( + Config({ + "inputFiles": { + "path/to/*.csv": { + "x": "y" + } + } + })._per_file_config(self.make_file("foo.csv")), {}, "different subdir") + self.assertDictEqual( + Config({ + "inputFiles": { + "//to/*.csv": { + "x": "y" + } + } + })._per_file_config(self.make_file("path/to/foo.csv")), {}, + "wrong subdir") self.assertDictEqual( Config({ "inputFiles": { @@ -232,7 +284,8 @@ def test_input_file(self): "x": "y" } } - })._input_file("foo1.csv"), {"x": "y"}, "wildcard match") + })._per_file_config(self.make_file("foo1.csv")), {"x": "y"}, + "wildcard match") self.assertDictEqual( Config({ "inputFiles": { @@ -240,7 +293,7 @@ def test_input_file(self): "x": "y" } } - })._input_file("foo.csv"), {}, "no exact match") + })._per_file_config(self.make_file("foo.csv")), {}, "no exact match") self.assertDictEqual( Config({ "inputFiles": { @@ -248,14 +301,16 @@ def test_input_file(self): "x": "y" } } - })._input_file("foo1.csv"), {}, "no wildcard match") + })._per_file_config(self.make_file("foo1.csv")), {}, + "no wildcard match") def test_input_file_format(self): config = Config({}) - self.assertEqual(config.format("foo.csv"), None, "empty") + self.assertEqual(config.format(self.make_file("foo.csv")), None, "empty") config = Config({"inputFiles": {"foo.csv": {"format": "variablePerRow"}}}) - self.assertEqual(config.format("foo.csv"), InputFileFormat.VARIABLE_PER_ROW) + self.assertEqual(config.format(self.make_file("foo.csv")), + InputFileFormat.VARIABLE_PER_ROW) config = Config( {"inputFiles": { @@ -263,33 +318,38 @@ def test_input_file_format(self): "format": "variablePerColumn" } }}) - self.assertEqual(config.format("foo.csv"), + self.assertEqual(config.format(self.make_file("foo.csv")), InputFileFormat.VARIABLE_PER_COLUMN) config = Config({"inputFiles": {"foo.csv": {"format": "INVALID"}}}) with self.assertRaisesRegex(ValueError, "Unsupported format"): - config.format("foo.csv") + config.format(self.make_file("foo.csv")) def test_column_mappings(self): config = Config({}) - self.assertDictEqual(config.column_mappings("foo.csv"), {}, "empty") + self.assertDictEqual(config.column_mappings(self.make_file("foo.csv")), {}, + "empty") config = Config({"inputFiles": {"foo.csv": {"columnMappings": {"x": "y"}}}}) - self.assertDictEqual(config.column_mappings("foo.csv"), {"x": "y"}) + self.assertDictEqual(config.column_mappings(self.make_file("foo.csv")), + {"x": "y"}) def test_row_entity_type(self): config = Config({}) - self.assertEqual(config.row_entity_type("foo.csv"), "", "empty") + self.assertEqual(config.row_entity_type(self.make_file("foo.csv")), "", + "empty") config = Config({"inputFiles": {"foo.csv": {"rowEntityType": "Foo"}}}) - self.assertEqual(config.row_entity_type("foo.csv"), "Foo") + self.assertEqual(config.row_entity_type(self.make_file("foo.csv")), "Foo") config = Config({"inputFiles": {"foo.csv": {}}}) - self.assertEqual(config.row_entity_type("foo.csv"), "", "unspecified") + self.assertEqual(config.row_entity_type(self.make_file("foo.csv")), "", + "unspecified") def test_entity_columns(self): config = Config({}) - self.assertListEqual(config.entity_columns("foo.csv"), [], "empty") + self.assertListEqual(config.entity_columns(self.make_file("foo.csv")), [], + "empty") config = Config( {"inputFiles": { @@ -297,7 +357,19 @@ def test_entity_columns(self): "entityColumns": ["foo", "bar"] } }}) - self.assertListEqual(config.entity_columns("foo.csv"), ["foo", "bar"]) + self.assertListEqual(config.entity_columns(self.make_file("foo.csv")), + ["foo", "bar"]) config = Config({"inputFiles": {"foo.csv": {}}}) - self.assertListEqual(config.entity_columns("foo.csv"), [], "unspecified") + self.assertListEqual(config.entity_columns(self.make_file("foo.csv")), [], + "unspecified") + + def test_include_input_subdirs(self): + config = Config({}) + self.assertFalse(config.include_input_subdirs()) + + config = Config({"includeInputSubdirs": False}) + self.assertFalse(config.include_input_subdirs()) + + config = Config({"includeInputSubdirs": True}) + self.assertTrue(config.include_input_subdirs()) diff --git a/simple/tests/stats/db_test.py b/simple/tests/stats/db_test.py index b44e2ceb..d317249b 100644 --- a/simple/tests/stats/db_test.py +++ b/simple/tests/stats/db_test.py @@ -27,11 +27,12 @@ from stats.db import create_main_dc_config from stats.db import create_sqlite_config from stats.db import get_cloud_sql_config_from_env -from stats.db import get_sqlite_config_from_env +from stats.db import get_sqlite_path_from_env from stats.db import ImportStatus from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from tests.stats.test_util import read_full_db_from_file +from util.filesystem import create_store _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "db") @@ -118,10 +119,14 @@ def test_sql_db(self): Compares resulting DB contents with expected values. """ with tempfile.TemporaryDirectory() as temp_dir: - db_file_path = os.path.join(temp_dir, "datacommons.db") - db = create_and_update_db(create_sqlite_config(db_file_path)) + temp_store = create_store(temp_dir) + db_file_name = "datacommons.db" + db_file_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) + db = create_and_update_db(create_sqlite_config(db_file)) db.insert_triples(_TRIPLES) - db.insert_observations(_OBSERVATIONS, "foo.csv") + foo_file = temp_store.as_dir().open_file("foo.csv") + db.insert_observations(_OBSERVATIONS, foo_file) db.insert_key_value(_KEY_VALUE[0], _KEY_VALUE[1]) db.insert_import_info(status=ImportStatus.SUCCESS) @@ -151,10 +156,13 @@ def test_sql_db_schema_update(self): without modifying existing data. """ with tempfile.TemporaryDirectory() as temp_dir: - db_file_path = os.path.join(temp_dir, "datacommons.db") + temp_store = create_store(temp_dir) + db_file_name = "datacommons.db" + db_file_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) self._seed_db_from_input(db_file_path, "sqlite_old_schema_populated.sql") - db = create_and_update_db(create_sqlite_config(db_file_path)) + db = create_and_update_db(create_sqlite_config(db_file)) db.commit_and_close() self._verify_db_contents( @@ -172,15 +180,19 @@ def test_sql_db_reimport_with_schema_update(self): database has an old schema. """ with tempfile.TemporaryDirectory() as temp_dir: - db_file_path = os.path.join(temp_dir, "datacommons.db") + temp_store = create_store(temp_dir) + db_file_name = "datacommons.db" + db_file_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) self._seed_db_from_input(db_file_path, "sqlite_old_schema_populated.sql") - db = create_and_update_db(create_sqlite_config(db_file_path)) + db = create_and_update_db(create_sqlite_config(db_file)) db.maybe_clear_before_import() db.insert_triples(_TRIPLES) - db.insert_observations(_OBSERVATIONS, "foo.csv") + foo_file = temp_store.as_dir().open_file("foo.csv") + db.insert_observations(_OBSERVATIONS, foo_file) db.insert_key_value(_KEY_VALUE[0], _KEY_VALUE[1]) db.insert_import_info(status=ImportStatus.SUCCESS) @@ -201,16 +213,20 @@ def test_sql_db_reimport_without_schema_update(self): all tables except the imports table. """ with tempfile.TemporaryDirectory() as temp_dir: - db_file_path = os.path.join(temp_dir, "datacommons.db") + temp_store = create_store(temp_dir) + db_file_name = "datacommons.db" + db_file_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) self._seed_db_from_input(db_file_path, "sqlite_current_schema_populated.sql") - db = create_and_update_db(create_sqlite_config(db_file_path)) + db = create_and_update_db(create_sqlite_config(db_file)) db.maybe_clear_before_import() db.insert_triples(_TRIPLES) - db.insert_observations(_OBSERVATIONS, "foo.csv") + foo_file = temp_store.as_dir().open_file("foo.csv") + db.insert_observations(_OBSERVATIONS, foo_file) db.insert_key_value(_KEY_VALUE[0], _KEY_VALUE[1]) db.insert_import_info(status=ImportStatus.SUCCESS) @@ -233,7 +249,9 @@ def test_main_dc_db(self): In write mode, replaces the goldens instead. """ with tempfile.TemporaryDirectory() as temp_dir: - observations_file = os.path.join(temp_dir, "observations.csv") + temp_store = create_store(temp_dir) + observations_file_name = "observations.csv" + observations_file_path = os.path.join(temp_dir, observations_file_name) expected_observations_file = os.path.join(_EXPECTED_DIR, "observations.csv") tmcf_file = os.path.join(temp_dir, "observations.tmcf") @@ -241,19 +259,20 @@ def test_main_dc_db(self): mcf_file = os.path.join(temp_dir, "schema.mcf") expected_mcf_file = os.path.join(_EXPECTED_DIR, "schema.mcf") - db = create_and_update_db(create_main_dc_config(temp_dir)) + db = create_and_update_db(create_main_dc_config(temp_store.as_dir())) db.insert_triples(_TRIPLES) - db.insert_observations(_OBSERVATIONS, "observations.csv") + observations_file = temp_store.as_dir().open_file(observations_file_name) + db.insert_observations(_OBSERVATIONS, observations_file) db.insert_import_info(status=ImportStatus.SUCCESS) db.commit_and_close() if is_write_mode(): - shutil.copy(observations_file, expected_observations_file) + shutil.copy(observations_file_path, expected_observations_file) shutil.copy(tmcf_file, expected_tmcf_file) shutil.copy(mcf_file, expected_mcf_file) return - compare_files(self, observations_file, expected_observations_file) + compare_files(self, observations_file_path, expected_observations_file) compare_files(self, tmcf_file, expected_tmcf_file) compare_files(self, mcf_file, expected_mcf_file) @@ -291,14 +310,9 @@ def test_get_cloud_sql_config_from_env_invalid(self): get_cloud_sql_config_from_env() @mock.patch.dict(os.environ, {}) - def test_get_sqlite_config_from_env_empty(self): - self.assertIsNone(get_sqlite_config_from_env()) + def test_get_sqlite_path_from_env_empty(self): + self.assertIsNone(get_sqlite_path_from_env()) @mock.patch.dict(os.environ, {"SQLITE_PATH": "/path/datacommons.db"}) - def test_get_sqlite_config_from_env(self): - self.assertDictEqual(get_sqlite_config_from_env(), { - "type": "sqlite", - "params": { - "dbFilePath": "/path/datacommons.db" - } - }) + def test_get_sqlite_path_from_env(self): + self.assertEqual(get_sqlite_path_from_env(), "/path/datacommons.db") diff --git a/simple/tests/stats/entities_importer_test.py b/simple/tests/stats/entities_importer_test.py index 454a9b42..b4a567ff 100644 --- a/simple/tests/stats/entities_importer_test.py +++ b/simple/tests/stats/entities_importer_test.py @@ -32,7 +32,7 @@ from tests.stats.test_util import is_write_mode from tests.stats.test_util import use_fake_gzip_time from tests.stats.test_util import write_triples -from util.filehandler import LocalFileHandler +from util.filesystem import create_store _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "entities_importer") @@ -46,26 +46,33 @@ def _test_import(test: unittest.TestCase, test_name: str): test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: - input_file = f"{test_name}.csv" - input_path = os.path.join(_INPUT_DIR, input_file) - input_config_path = os.path.join(_INPUT_DIR, "config.json") - db_path = os.path.join(temp_dir, f"{test_name}.db") + input_store = create_store(_INPUT_DIR) + temp_store = create_store(temp_dir) + + input_file_name = f"{test_name}.csv" + input_file = input_store.as_dir().open_file(input_file_name, + create_if_missing=False) + input_config_file = input_store.as_dir().open_file("config.json", + create_if_missing=False) + db_file_name = f"{test_name}.db" + db_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) output_triples_path = os.path.join(temp_dir, f"{test_name}.triples.db.csv") expected_triples_path = os.path.join(_EXPECTED_DIR, f"{test_name}.triples.db.csv") - input_fh = LocalFileHandler(input_path) - - input_config_fh = LocalFileHandler(input_config_path) - config = Config(data=json.loads(input_config_fh.read_string())) + config = Config(data=json.loads(input_config_file.read())) nodes = Nodes(config) - db = create_and_update_db(create_sqlite_config(db_path)) - report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) - reporter = FileImportReporter(input_path, ImportReporter(report_fh)) + db = create_and_update_db(create_sqlite_config(db_file)) + report_file = temp_store.as_dir().open_file("report.json") + reporter = FileImportReporter(input_file.full_path(), + ImportReporter(report_file)) - EntitiesImporter(input_fh=input_fh, db=db, reporter=reporter, + EntitiesImporter(input_file=input_file, + db=db, + reporter=reporter, nodes=nodes).do_import() db.insert_triples(nodes.triples()) db.commit_and_close() @@ -78,6 +85,9 @@ def _test_import(test: unittest.TestCase, test_name: str): compare_files(test, output_triples_path, expected_triples_path) + input_store.close() + temp_store.close() + class TestEntitiesImporter(unittest.TestCase): diff --git a/simple/tests/stats/events_importer_test.py b/simple/tests/stats/events_importer_test.py index 2451df3e..aa0b4585 100644 --- a/simple/tests/stats/events_importer_test.py +++ b/simple/tests/stats/events_importer_test.py @@ -15,14 +15,10 @@ import json import os import shutil -import sqlite3 import tempfile import unittest -import pandas as pd from stats.config import Config -from stats.data import Observation -from stats.data import Triple from stats.db import create_and_update_db from stats.db import create_sqlite_config from stats.events_importer import EventsImporter @@ -33,7 +29,7 @@ from tests.stats.test_util import is_write_mode from tests.stats.test_util import write_observations from tests.stats.test_util import write_triples -from util.filehandler import LocalFileHandler +from util.filesystem import create_store _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "events_importer") @@ -45,10 +41,17 @@ def _test_import(test: unittest.TestCase, test_name: str): test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: - input_file = f"{test_name}.csv" - input_path = os.path.join(_INPUT_DIR, input_file) - input_config_path = os.path.join(_INPUT_DIR, "config.json") - db_path = os.path.join(temp_dir, f"{test_name}.db") + input_store = create_store(_INPUT_DIR) + temp_store = create_store(temp_dir) + + input_file_name = f"{test_name}.csv" + input_file = input_store.as_dir().open_file(input_file_name, + create_if_missing=False) + input_config_file = input_store.as_dir().open_file("config.json", + create_if_missing=False) + db_file_name = f"{test_name}.db" + db_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) output_triples_path = os.path.join(temp_dir, f"{test_name}.triples.db.csv") expected_triples_path = os.path.join(_EXPECTED_DIR, @@ -58,20 +61,18 @@ def _test_import(test: unittest.TestCase, test_name: str): expected_observations_path = os.path.join( _EXPECTED_DIR, f"{test_name}.observations.db.csv") - input_fh = LocalFileHandler(input_path) - - input_config_fh = LocalFileHandler(input_config_path) - config = Config(data=json.loads(input_config_fh.read_string())) + config = Config(data=json.loads(input_config_file.read())) nodes = Nodes(config) - db = create_and_update_db(create_sqlite_config(db_path)) - debug_resolve_fh = LocalFileHandler(os.path.join(temp_dir, "debug.csv")) - report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) - reporter = FileImportReporter(input_path, ImportReporter(report_fh)) + db = create_and_update_db(create_sqlite_config(db_file)) + debug_resolve_file = temp_store.as_dir().open_file("debug.csv") + report_file = temp_store.as_dir().open_file("report.json") + reporter = FileImportReporter(input_file.full_path(), + ImportReporter(report_file)) - EventsImporter(input_fh=input_fh, + EventsImporter(input_file=input_file, db=db, - debug_resolve_fh=debug_resolve_fh, + debug_resolve_file=debug_resolve_file, reporter=reporter, nodes=nodes).do_import() db.insert_triples(nodes.triples()) @@ -88,6 +89,9 @@ def _test_import(test: unittest.TestCase, test_name: str): compare_files(test, output_triples_path, expected_triples_path) compare_files(test, output_observations_path, expected_observations_path) + input_store.close() + temp_store.close() + class TestEventsImporter(unittest.TestCase): diff --git a/simple/tests/stats/mcf_importer_test.py b/simple/tests/stats/mcf_importer_test.py index 309f1fd8..236e36ac 100644 --- a/simple/tests/stats/mcf_importer_test.py +++ b/simple/tests/stats/mcf_importer_test.py @@ -12,26 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import shutil -import sqlite3 import tempfile import unittest -import pandas as pd -from stats.config import Config -from stats.data import Triple from stats.db import create_and_update_db from stats.db import create_sqlite_config from stats.mcf_importer import McfImporter -from stats.nodes import Nodes from stats.reporter import FileImportReporter from stats.reporter import ImportReporter from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from tests.stats.test_util import write_triples -from util.filehandler import LocalFileHandler +from util.filesystem import create_store _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "mcf_importer") @@ -46,9 +40,16 @@ def _test_import(test: unittest.TestCase, test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: - input_file = f"{test_name}.mcf" - input_mcf_path = os.path.join(_INPUT_DIR, input_file) - db_path = os.path.join(temp_dir, f"{test_name}.db") + input_store = create_store(_INPUT_DIR) + temp_store = create_store(temp_dir) + + input_file_name = f"{test_name}.mcf" + input_file = input_store.as_dir().open_file(input_file_name, + create_if_missing=False) + + db_file_name = f"{test_name}.db" + db_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) output_mcf_path = os.path.join(temp_dir, f"{test_name}.mcf") output_triples_path = os.path.join(temp_dir, f"{test_name}.triples.db.csv") @@ -56,15 +57,17 @@ def _test_import(test: unittest.TestCase, expected_triples_path = os.path.join(_EXPECTED_DIR, f"{test_name}.triples.db.csv") - input_fh = LocalFileHandler(input_mcf_path) - output_fh = LocalFileHandler(output_mcf_path) - - db = create_and_update_db(create_sqlite_config(db_path)) - report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) - reporter = FileImportReporter(input_mcf_path, ImportReporter(report_fh)) - - importer = McfImporter(input_fh=input_fh, - output_fh=output_fh, + db = create_and_update_db(create_sqlite_config(db_file)) + report_file = temp_store.as_dir().open_file("report.json") + reporter = FileImportReporter(input_file.full_path(), + ImportReporter(report_file)) + + output_store = create_store(output_mcf_path, + create_if_missing=True, + treat_as_file=True) + output_file = output_store.as_file() + importer = McfImporter(input_file=input_file, + output_file=output_file, db=db, reporter=reporter, is_main_dc=is_main_dc) @@ -92,6 +95,10 @@ def _test_import(test: unittest.TestCase, compare_files(test, output_mcf_path, expected_mcf_path) + input_store.close() + output_store.close() + temp_store.close() + class TestMcfImporter(unittest.TestCase): diff --git a/simple/tests/stats/nl_test.py b/simple/tests/stats/nl_test.py index 2497abbf..6fd29454 100644 --- a/simple/tests/stats/nl_test.py +++ b/simple/tests/stats/nl_test.py @@ -23,7 +23,7 @@ from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from tests.stats.test_util import read_triples_csv -from util.filehandler import LocalFileHandler +from util.filesystem import create_store _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "nl") @@ -41,10 +41,11 @@ def _rewrite_catalog_for_testing(catalog_yaml_path: str, temp_dir: str) -> None: To consistently test the catalog out against a golden file, we replace the temp paths with a constant fake path. """ - catalog_fh = LocalFileHandler(catalog_yaml_path) - content = catalog_fh.read_string() - content = content.replace(temp_dir, _FAKE_PATH) - catalog_fh.write_string(content) + with create_store(catalog_yaml_path) as store: + catalog_file = store.as_file() + content = catalog_file.read() + content = content.replace(temp_dir, _FAKE_PATH) + catalog_file.write(content) def _test_generate_nl_sentences(test: unittest.TestCase, @@ -53,6 +54,7 @@ def _test_generate_nl_sentences(test: unittest.TestCase, test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: + temp_store = create_store(temp_dir) input_triples_path = os.path.join(_INPUT_DIR, f"{test_name}.csv") input_triples = read_triples_csv(input_triples_path) @@ -70,15 +72,14 @@ def _test_generate_nl_sentences(test: unittest.TestCase, expected_topic_cache_json_path = os.path.join(_EXPECTED_DIR, test_name, "custom_dc_topic_cache.json") - nl_dir_fh = LocalFileHandler(temp_dir) - # Sentences are not generated for StatVarPeerGroup triples. # So remove them first. - nl.generate_nl_sentences(_without_svpg_triples(input_triples), nl_dir_fh) + nl.generate_nl_sentences(_without_svpg_triples(input_triples), + nl_dir=temp_store.as_dir()) _rewrite_catalog_for_testing(output_catalog_yaml_path, temp_dir) if generate_topics: - nl.generate_topic_cache(input_triples, nl_dir_fh) + nl.generate_topic_cache(input_triples, nl_dir=temp_store.as_dir()) if is_write_mode(): shutil.copy(output_sentences_csv_path, expected_sentences_csv_path) @@ -94,6 +95,8 @@ def _test_generate_nl_sentences(test: unittest.TestCase, compare_files(test, output_topic_cache_json_path, expected_topic_cache_json_path) + temp_store.close() + def _without_svpg_triples(triples: list[Triple]) -> list[Triple]: svpg_dcids = set() diff --git a/simple/tests/stats/nodes_test.py b/simple/tests/stats/nodes_test.py index 580db182..d79d6c03 100644 --- a/simple/tests/stats/nodes_test.py +++ b/simple/tests/stats/nodes_test.py @@ -25,7 +25,8 @@ from stats.nodes import Nodes from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode -from util.filehandler import LocalFileHandler +from util.filesystem import create_store +from util.filesystem import File _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "nodes") @@ -113,19 +114,40 @@ class TestNodes(unittest.TestCase): + def make_file(self, path: str) -> File: + return self.store.as_dir().open_file(path) + + def setUp(self): + self.store = create_store("mem://") + self.a = self.make_file("a.csv") + self.b = self.make_file("b.csv") + self.events = self.make_file("events.csv") + self.entities = self.make_file("entities.csv") + self.x = self.make_file("x.csv") + self.map = { + "a.csv": self.a, + "b.csv": self.b, + "events.csv": self.events, + "entities.csv": self.entities, + "x.csv": self.x, + } + + def tearDown(self): + self.store.close() + def test_triples(self): nodes = Nodes(CONFIG) for sv_column_name, input_file_name in TEST_SV_COLUMN_AND_INPUT_FILE_NAMES: - nodes.variable(sv_column_name, input_file_name) + nodes.variable(sv_column_name, self.map[input_file_name]) nodes.entities_with_type(TEST_ENTITY_DCIDS_1, TEST_ENTITY_TYPE_1) nodes.entities_with_type(TEST_ENTITY_DCIDS_2, TEST_ENTITY_TYPE_2) for event_type_name, input_file_name in TEST_EVENT_TYPE_INPUT_FILE_NAMES: - nodes.event_type(event_type_name, input_file_name) + nodes.event_type(event_type_name, self.map[input_file_name]) for entity_type_name, input_file_name in TEST_ENTITY_TYPE_INPUT_FILE_NAMES: - nodes.entity_type(entity_type_name, input_file_name) + nodes.entity_type(entity_type_name, self.map[input_file_name]) for property_column_name in TEST_PROPERTY_COLUMNS: nodes.property(property_column_name) @@ -133,7 +155,9 @@ def test_triples(self): with tempfile.TemporaryDirectory() as temp_dir: output_path = os.path.join(temp_dir, f"triples.csv") expected_path = os.path.join(_EXPECTED_DIR, f"triples.csv") - nodes.triples(LocalFileHandler(output_path)) + nodes.triples( + create_store(output_path, create_if_missing=True, + treat_as_file=True).as_file()) if is_write_mode(): shutil.copy(output_path, expected_path) @@ -143,7 +167,7 @@ def test_triples(self): def test_variable_with_no_config(self): nodes = Nodes(CONFIG) - sv = nodes.variable("Variable with no config", "a.csv") + sv = nodes.variable("Variable with no config", self.a) self.assertEqual( sv, StatVar( @@ -157,7 +181,7 @@ def test_variable_with_no_config(self): def test_variable_with_config(self): nodes = Nodes(CONFIG) - sv = nodes.variable("var3", "a.csv") + sv = nodes.variable("var3", self.a) self.assertEqual( sv, StatVar( @@ -173,7 +197,7 @@ def test_variable_with_config(self): def test_variable_with_group(self): nodes = Nodes(CONFIG) - sv = nodes.variable("Variable 1", "a.csv") + sv = nodes.variable("Variable 1", self.a) self.assertEqual( sv, StatVar( @@ -206,7 +230,7 @@ def test_variable_with_group(self): def test_multiple_variables_in_same_group(self): nodes = Nodes(CONFIG) - sv = nodes.variable("Variable 1", "a.csv") + sv = nodes.variable("Variable 1", self.a) self.assertEqual( sv, StatVar( @@ -217,7 +241,7 @@ def test_multiple_variables_in_same_group(self): source_ids=["c/s/1"], ), ) - sv = nodes.variable("Variable 2", "a.csv") + sv = nodes.variable("Variable 2", self.a) self.assertEqual( sv, StatVar( @@ -250,17 +274,17 @@ def test_multiple_variables_in_same_group(self): def test_provenance(self): nodes = Nodes(CONFIG) - nodes.variable("Variable 1", "a.csv") - nodes.variable("Variable X", "x.csv") + nodes.variable("Variable 1", self.a) + nodes.variable("Variable X", self.x) self.assertEqual( - nodes.provenance("a.csv"), + nodes.provenance(self.a), Provenance(id="c/p/1", source_id="c/s/1", name="Provenance1", url="http://source1.com/provenance1")) self.assertEqual( - nodes.provenance("x.csv"), + nodes.provenance(self.x), Provenance(id="c/p/default", source_id="c/s/default", name="Custom Import", @@ -268,7 +292,7 @@ def test_provenance(self): def test_multiple_parent_groups(self): """This is to test a bug fix related to groups. - + The bug was that if there are multiple custom parent groups and if a variable is inserted inbetween, the second parent is put under custom/g/Root instead of dc/g/Root. @@ -277,7 +301,7 @@ def test_multiple_parent_groups(self): """ nodes = Nodes(Config({})) nodes.group("Parent 1/Child 1") - nodes.variable("foo", "x.csv") + nodes.variable("foo", self.x) nodes.group("Parent 2/Child 1") self.assertEqual(nodes.groups["Parent 1"].parent_id, "dc/g/Root") diff --git a/simple/tests/stats/observations_importer_test.py b/simple/tests/stats/observations_importer_test.py index c40207c6..037af453 100644 --- a/simple/tests/stats/observations_importer_test.py +++ b/simple/tests/stats/observations_importer_test.py @@ -31,7 +31,7 @@ from tests.stats.test_util import is_write_mode from tests.stats.test_util import use_fake_gzip_time from tests.stats.test_util import write_observations -from util.filehandler import LocalFileHandler +from util.filesystem import create_store from util import dc_client @@ -47,34 +47,41 @@ def _test_import(test: unittest.TestCase, test_name: str): test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: + input_store = create_store(_INPUT_DIR) + temp_store = create_store(temp_dir) + input_dir = os.path.join(_INPUT_DIR, test_name) expected_dir = os.path.join(_EXPECTED_DIR, test_name) - input_path = os.path.join(input_dir, "input.csv") + input_file_name = "input.csv" + input_path = os.path.join(input_dir, input_file_name) config_path = os.path.join(input_dir, "config.json") - db_path = os.path.join(temp_dir, f"{test_name}.db") + db_file_name = f"{test_name}.db" + db_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) output_path = os.path.join(temp_dir, f"{test_name}.db.csv") expected_path = os.path.join(_EXPECTED_DIR, f"{test_name}.db.csv") output_path = os.path.join(temp_dir, "observations.db.csv") expected_path = os.path.join(expected_dir, "observations.db.csv") - input_fh = LocalFileHandler(input_path) + input_file = input_store.as_dir().open_dir(test_name).open_file( + input_file_name) with open(config_path) as config_file: config = Config(json.load(config_file)) - db = create_and_update_db(create_sqlite_config(db_path)) - debug_resolve_fh = LocalFileHandler(os.path.join(temp_dir, "debug.csv")) - report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) - reporter = FileImportReporter(input_path, ImportReporter(report_fh)) + db = create_and_update_db(create_sqlite_config(db_file)) + debug_resolve_file = temp_store.as_dir().open_file("debug.csv") + report_file = temp_store.as_dir().open_file("report.json") + reporter = FileImportReporter(input_path, ImportReporter(report_file)) nodes = Nodes(config) dc_client.get_property_of_entities = MagicMock(return_value={}) - ObservationsImporter(input_fh=input_fh, + ObservationsImporter(input_file=input_file, db=db, - debug_resolve_fh=debug_resolve_fh, + debug_resolve_file=debug_resolve_file, reporter=reporter, nodes=nodes).do_import() db.commit_and_close() @@ -87,6 +94,9 @@ def _test_import(test: unittest.TestCase, test_name: str): compare_files(test, output_path, expected_path) + input_store.close() + temp_store.close() + class TestObservationsImporter(unittest.TestCase): diff --git a/simple/tests/stats/runner_test.py b/simple/tests/stats/runner_test.py index b39f4ccf..a4c521c2 100644 --- a/simple/tests/stats/runner_test.py +++ b/simple/tests/stats/runner_test.py @@ -16,15 +16,11 @@ import os from pathlib import Path import shutil -import sqlite3 import tempfile import unittest -from unittest.mock import MagicMock +from unittest import mock -import pandas as pd from stats import constants -from stats.data import Observation -from stats.data import Triple from stats.runner import RunMode from stats.runner import Runner from tests.stats.test_util import compare_files @@ -46,18 +42,17 @@ def _test_runner(test: unittest.TestCase, test_name: str, - is_config_driven: bool = True, + config_path: str = None, + output_dir_name: str = None, run_mode: RunMode = RunMode.CUSTOM_DC, input_db_file_name: str = None): test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: - if is_config_driven: - config_path = os.path.join(_CONFIG_DIR, f"{test_name}.json") + if config_path: input_dir = None remote_entity_types_path = None else: - config_path = None input_dir = os.path.join(_INPUT_DIR, test_name) remote_entity_types_path = os.path.join(input_dir, "remote_entity_types.json") @@ -67,7 +62,8 @@ def _test_runner(test: unittest.TestCase, input_db_file = os.path.join(input_dir, input_db_file_name) read_full_db_from_file(db_path, input_db_file) - expected_dir = os.path.join(_EXPECTED_DIR, test_name) + output_dir_name = output_dir_name if output_dir_name else test_name + expected_dir = os.path.join(_EXPECTED_DIR, output_dir_name) expected_nl_dir = os.path.join(expected_dir, constants.NL_DIR_NAME) Path(expected_nl_dir).mkdir(parents=True, exist_ok=True) @@ -90,15 +86,15 @@ def _test_runner(test: unittest.TestCase, expected_topic_cache_json_path = os.path.join( expected_dir, constants.NL_DIR_NAME, constants.TOPIC_CACHE_FILE_NAME) - dc_client.get_property_of_entities = MagicMock(return_value={}) + dc_client.get_property_of_entities = mock.MagicMock(return_value={}) if remote_entity_types_path and os.path.exists(remote_entity_types_path): with open(remote_entity_types_path, "r") as f: - dc_client.get_property_of_entities = MagicMock( + dc_client.get_property_of_entities = mock.MagicMock( return_value=json.load(f)) - Runner(config_file=config_path, - input_dir=input_dir, - output_dir=temp_dir, + Runner(config_file_path=config_path, + input_dir_path=input_dir, + output_dir_path=temp_dir, mode=run_mode).run() write_triples(db_path, output_triples_path) @@ -138,35 +134,52 @@ def __init__(self, methodName: str = "runTest") -> None: use_fake_gzip_time() def test_config_driven(self): - _test_runner(self, "config_driven") + _test_runner(self, + "config_driven", + config_path=os.path.join(_CONFIG_DIR, "config_driven.json")) def test_config_with_wildcards(self): - _test_runner(self, "config_with_wildcards") + _test_runner(self, + "config_with_wildcards", + config_path=os.path.join(_CONFIG_DIR, + "config_with_wildcards.json")) def test_input_dir_driven(self): - _test_runner(self, "input_dir_driven", is_config_driven=False) + _test_runner(self, "input_dir_driven") def test_input_dir_driven_with_existing_old_schema_data(self): _test_runner(self, "input_dir_driven_with_existing_old_schema_data", - is_config_driven=False, input_db_file_name="sqlite_old_schema_populated.sql") def test_generate_svg_hierarchy(self): - _test_runner(self, "generate_svg_hierarchy", is_config_driven=False) + _test_runner(self, "generate_svg_hierarchy") def test_sv_nl_sentences(self): - _test_runner(self, "sv_nl_sentences", is_config_driven=False) + _test_runner(self, "sv_nl_sentences") def test_topic_nl_sentences(self): - _test_runner(self, "topic_nl_sentences", is_config_driven=False) + _test_runner(self, "topic_nl_sentences") def test_remote_entity_types(self): - _test_runner(self, "remote_entity_types", is_config_driven=False) + _test_runner(self, "remote_entity_types") def test_schema_update_only(self): _test_runner(self, "schema_update_only", - is_config_driven=False, run_mode=RunMode.SCHEMA_UPDATE, input_db_file_name="sqlite_old_schema_populated.sql") + + def test_with_subdirs_excluded(self): + _test_runner(self, + "with_subdirs", + config_path=os.path.join(_CONFIG_DIR, + "config_exclude_subdirs.json"), + output_dir_name="with_subdirs_excluded") + + def test_with_subdirs_included(self): + _test_runner(self, + "with_subdirs", + config_path=os.path.join(_CONFIG_DIR, + "config_include_subdirs.json"), + output_dir_name="with_subdirs_included") diff --git a/simple/tests/stats/schema_test.py b/simple/tests/stats/schema_test.py index 45d3060f..c0e9009c 100644 --- a/simple/tests/stats/schema_test.py +++ b/simple/tests/stats/schema_test.py @@ -23,6 +23,7 @@ from stats.data import Triple from stats.db import create_and_update_db from stats.db import create_sqlite_config +from util.filesystem import create_store def _to_triples(dcid2name: dict[str, str]) -> list[Triple]: @@ -91,8 +92,8 @@ def test_get_schema_names(self, desc: str, db_names: dict[str, str], str], input_dcids: list[str], output_names: dict[str, str], mock_dc_client): with tempfile.TemporaryDirectory() as temp_dir: - db_file_path = os.path.join(temp_dir, "datacommons.db") - db = create_and_update_db(create_sqlite_config(db_file_path)) + db_file = create_store(temp_dir).as_dir().open_file("datacommons.db") + db = create_and_update_db(create_sqlite_config(db_file)) db.insert_triples(_to_triples(db_names)) mock_dc_client.return_value = remote_names diff --git a/simple/tests/stats/test_data/runner/config/config_exclude_subdirs.json b/simple/tests/stats/test_data/runner/config/config_exclude_subdirs.json new file mode 100644 index 00000000..555a8fc8 --- /dev/null +++ b/simple/tests/stats/test_data/runner/config/config_exclude_subdirs.json @@ -0,0 +1,48 @@ +{ + "includeInputSubdirs": false, + "dataDownloadUrl": [ + "tests/stats/test_data/runner/input/with_subdirs" + ], + "inputFiles": { + "countries.csv": { + "importType": "observations", + "format": "variablePerColumn", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "wikidataids.csv": { + "importType": "observations", + "format": "variablePerColumn", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "variable_per_row.csv": { + "importType": "observations", + "format": "variablePerRow", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "author_entities.csv": { + "importType": "entities", + "rowEntityType": "Author", + "idColumn": "author_id", + "entityColumns": ["author_country"], + "provenance": "Provenance1 Name" + }, + "article_entities.csv": { + "importType": "entities", + "rowEntityType": "Article", + "idColumn": "article_id", + "entityColumns": ["article_author"], + "provenance": "Provenance1 Name" + } + }, + "sources": { + "Source1 Name": { + "url": "http://source1.com", + "provenances": { + "Provenance1 Name": "http://source1.com/provenance1" + } + } + } +} diff --git a/simple/tests/stats/test_data/runner/config/config_include_subdirs.json b/simple/tests/stats/test_data/runner/config/config_include_subdirs.json new file mode 100644 index 00000000..966b9ed2 --- /dev/null +++ b/simple/tests/stats/test_data/runner/config/config_include_subdirs.json @@ -0,0 +1,48 @@ +{ + "includeInputSubdirs": true, + "dataDownloadUrl": [ + "tests/stats/test_data/runner/input/with_subdirs" + ], + "inputFiles": { + "countries.csv": { + "importType": "observations", + "format": "variablePerColumn", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "wikidataids.csv": { + "importType": "observations", + "format": "variablePerColumn", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "variable_per_row.csv": { + "importType": "observations", + "format": "variablePerRow", + "entityType": "Country", + "provenance": "Provenance1 Name" + }, + "author_entities.csv": { + "importType": "entities", + "rowEntityType": "Author", + "idColumn": "author_id", + "entityColumns": ["author_country"], + "provenance": "Provenance1 Name" + }, + "article_entities.csv": { + "importType": "entities", + "rowEntityType": "Article", + "idColumn": "article_id", + "entityColumns": ["article_author"], + "provenance": "Provenance1 Name" + } + }, + "sources": { + "Source1 Name": { + "url": "http://source1.com", + "provenances": { + "Provenance1 Name": "http://source1.com/provenance1" + } + } + } +} diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/key_value_store.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/key_value_store.db.csv new file mode 100644 index 00000000..23bb062c --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/key_value_store.db.csv @@ -0,0 +1,2 @@ +lookup_key,value +StatVarGroups,H4sIAAAAAAAC/+OS5OJMSdZP1w/Kzy8R4pHi4uKA8bjcEGwhKy4B59LikvxchbDEoszEpJzUYiEhLpayxCJDKTCpBCahYkZgMSOwmBEAlEss8mMAAAA= diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/nl/sentences.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/nl/sentences.csv new file mode 100644 index 00000000..8e9cf978 --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/nl/sentences.csv @@ -0,0 +1,4 @@ +dcid,sentence +some_var1,Some Variable 1 Name +var1,var1 +var2,var2 diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/observations.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/observations.db.csv new file mode 100644 index 00000000..ac0a5c28 --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/observations.db.csv @@ -0,0 +1,9 @@ +entity,variable,date,value,provenance,unit,scaling_factor,measurement_method,observation_period,properties +country/IND,var1,2020,0.16,c/p/1,,,,, +country/IND,var2,2020,53,c/p/1,,,,, +country/CHN,var1,2020,0.23,c/p/1,,,,, +country/CHN,var2,2020,67,c/p/1,,,,, +country/USA,var1,2021,555,c/p/1,,,,, +country/IND,var1,2022,321,c/p/1,,,,, +country/USA,var2,2021,666,c/p/1,,,,, +country/IND,var2,2022,123,c/p/1,,,,, diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/triples.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/triples.db.csv new file mode 100644 index 00000000..46431fee --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_excluded/triples.db.csv @@ -0,0 +1,65 @@ +subject_id,predicate,object_id,object_value +author1,typeOf,Author, +author1,includedIn,c/p/1, +author1,author_id,,author1 +author1,author_name,,Jane Doe +author1,author_country,country/USA, +author2,typeOf,Author, +author2,includedIn,c/p/1, +author2,author_id,,author2 +author2,author_name,,Joe Smith +author2,author_country,country/CAN, +author3,typeOf,Author, +author3,includedIn,c/p/1, +author3,author_id,,author3 +author3,author_name,,Jane Smith +author3,author_country,country/USA, +some_var1,typeOf,StatisticalVariable, +some_var1,measuredProperty,value, +some_var1,name,,Some Variable 1 Name +some_var1,description,,Some Variable 1 Description +c/s/default,typeOf,Source, +c/s/default,name,,Custom Data Commons +c/s/1,typeOf,Source, +c/s/1,name,,Source1 Name +c/s/1,url,,http://source1.com +c/s/1,domain,,source1.com +c/p/default,typeOf,Provenance, +c/p/default,name,,Custom Import +c/p/default,source,c/s/default, +c/p/default,url,,custom-import +c/p/1,typeOf,Provenance, +c/p/1,name,,Provenance1 Name +c/p/1,source,c/s/1, +c/p/1,url,,http://source1.com/provenance1 +c/g/Root,typeOf,StatVarGroup, +c/g/Root,name,,Custom Variables +c/g/Root,specializationOf,dc/g/Root, +var1,typeOf,StatisticalVariable, +var1,name,,var1 +var1,memberOf,c/g/Root, +var1,includedIn,c/p/1, +var1,includedIn,c/s/1, +var1,populationType,Thing, +var1,statType,measuredValue, +var1,measuredProperty,var1, +var2,typeOf,StatisticalVariable, +var2,name,,var2 +var2,memberOf,c/g/Root, +var2,includedIn,c/p/1, +var2,includedIn,c/s/1, +var2,populationType,Thing, +var2,statType,measuredValue, +var2,measuredProperty,var2, +Author,typeOf,Class, +Author,name,,Author +Author,includedIn,c/p/1, +Author,includedIn,c/s/1, +author_id,typeOf,Property, +author_id,name,,author_id +author_name,typeOf,Property, +author_name,name,,author_name +author_country,typeOf,Property, +author_country,name,,author_country +country/USA,typeOf,Country, +country/IND,typeOf,Country, diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_included/key_value_store.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/key_value_store.db.csv new file mode 100644 index 00000000..23bb062c --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/key_value_store.db.csv @@ -0,0 +1,2 @@ +lookup_key,value +StatVarGroups,H4sIAAAAAAAC/+OS5OJMSdZP1w/Kzy8R4pHi4uKA8bjcEGwhKy4B59LikvxchbDEoszEpJzUYiEhLpayxCJDKTCpBCahYkZgMSOwmBEAlEss8mMAAAA= diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_included/nl/sentences.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/nl/sentences.csv new file mode 100644 index 00000000..9e8746c2 --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/nl/sentences.csv @@ -0,0 +1,5 @@ +dcid,sentence +some_var2,Some Variable 2 Name +some_var1,Some Variable 1 Name +var1,var1 +var2,var2 diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_included/observations.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/observations.db.csv new file mode 100644 index 00000000..bea81291 --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/observations.db.csv @@ -0,0 +1,31 @@ +entity,variable,date,value,provenance,unit,scaling_factor,measurement_method,observation_period,properties +country/AFG,var1,2023,0.19,c/p/1,,,,, +country/YEM,var1,2023,0.21,c/p/1,,,,, +country/AGO,var1,2023,0.29,c/p/1,,,,, +country/ZMB,var1,2023,0.31,c/p/1,,,,, +country/ZWE,var1,2023,0.37,c/p/1,,,,, +country/ALB,var1,2023,0.5,c/p/1,,,,, +wikidataId/Q22062741,var1,2023,0.5,c/p/1,,,,, +country/DZA,var1,2023,0.52,c/p/1,,,,, +country/AND,var1,2023,0.76,c/p/1,,,,, +country/AFG,var2,2023,6,c/p/1,,,,, +country/YEM,var2,2023,56,c/p/1,,,,, +country/AGO,var2,2023,6,c/p/1,,,,, +country/ZMB,var2,2023,34,c/p/1,,,,, +country/ZWE,var2,2023,76,c/p/1,,,,, +country/ALB,var2,2023,34,c/p/1,,,,, +wikidataId/Q22062741,var2,2023,97,c/p/1,,,,, +country/DZA,var2,2023,92,c/p/1,,,,, +country/AND,var2,2023,9,c/p/1,,,,, +country/ASM,var2,2023,34,c/p/1,,,,, +country/AIA,var2,2023,42,c/p/1,,,,, +country/WLF,var2,2023,75,c/p/1,,,,, +country/ESH,var2,2023,65,c/p/1,,,,, +country/IND,var1,2020,0.16,c/p/1,,,,, +country/IND,var2,2020,53,c/p/1,,,,, +country/CHN,var1,2020,0.23,c/p/1,,,,, +country/CHN,var2,2020,67,c/p/1,,,,, +country/USA,var1,2021,555,c/p/1,,,,, +country/IND,var1,2022,321,c/p/1,,,,, +country/USA,var2,2021,666,c/p/1,,,,, +country/IND,var2,2022,123,c/p/1,,,,, diff --git a/simple/tests/stats/test_data/runner/expected/with_subdirs_included/triples.db.csv b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/triples.db.csv new file mode 100644 index 00000000..ae7b6dc1 --- /dev/null +++ b/simple/tests/stats/test_data/runner/expected/with_subdirs_included/triples.db.csv @@ -0,0 +1,109 @@ +subject_id,predicate,object_id,object_value +author1,typeOf,Author, +author1,includedIn,c/p/1, +author1,author_id,,author1 +author1,author_name,,Jane Doe +author1,author_country,country/USA, +author2,typeOf,Author, +author2,includedIn,c/p/1, +author2,author_id,,author2 +author2,author_name,,Joe Smith +author2,author_country,country/CAN, +author3,typeOf,Author, +author3,includedIn,c/p/1, +author3,author_id,,author3 +author3,author_name,,Jane Smith +author3,author_country,country/USA, +article1,typeOf,Article, +article1,includedIn,c/p/1, +article1,article_id,,article1 +article1,article_title,,Article 1 +article1,article_author,author1, +article2,typeOf,Article, +article2,includedIn,c/p/1, +article2,article_id,,article2 +article2,article_title,,Article 2 +article2,article_author,author1, +article2,article_author,author2, +article3,typeOf,Article, +article3,includedIn,c/p/1, +article3,article_id,,article3 +article3,article_title,,Article 3 +article3,article_author,author2, +article3,article_author,author3, +some_var2,typeOf,StatisticalVariable, +some_var2,measuredProperty,value, +some_var2,name,,Some Variable 2 Name +some_var2,description,,Some Variable 2 Description +some_var1,typeOf,StatisticalVariable, +some_var1,measuredProperty,value, +some_var1,name,,Some Variable 1 Name +some_var1,description,,Some Variable 1 Description +c/s/default,typeOf,Source, +c/s/default,name,,Custom Data Commons +c/s/1,typeOf,Source, +c/s/1,name,,Source1 Name +c/s/1,url,,http://source1.com +c/s/1,domain,,source1.com +c/p/default,typeOf,Provenance, +c/p/default,name,,Custom Import +c/p/default,source,c/s/default, +c/p/default,url,,custom-import +c/p/1,typeOf,Provenance, +c/p/1,name,,Provenance1 Name +c/p/1,source,c/s/1, +c/p/1,url,,http://source1.com/provenance1 +c/g/Root,typeOf,StatVarGroup, +c/g/Root,name,,Custom Variables +c/g/Root,specializationOf,dc/g/Root, +var1,typeOf,StatisticalVariable, +var1,name,,var1 +var1,memberOf,c/g/Root, +var1,includedIn,c/p/1, +var1,includedIn,c/s/1, +var1,populationType,Thing, +var1,statType,measuredValue, +var1,measuredProperty,var1, +var2,typeOf,StatisticalVariable, +var2,name,,var2 +var2,memberOf,c/g/Root, +var2,includedIn,c/p/1, +var2,includedIn,c/s/1, +var2,populationType,Thing, +var2,statType,measuredValue, +var2,measuredProperty,var2, +Author,typeOf,Class, +Author,name,,Author +Author,includedIn,c/p/1, +Author,includedIn,c/s/1, +Article,typeOf,Class, +Article,name,,Article +Article,includedIn,c/p/1, +Article,includedIn,c/s/1, +author_id,typeOf,Property, +author_id,name,,author_id +author_name,typeOf,Property, +author_name,name,,author_name +author_country,typeOf,Property, +author_country,name,,author_country +article_id,typeOf,Property, +article_id,name,,article_id +article_title,typeOf,Property, +article_title,name,,article_title +article_author,typeOf,Property, +article_author,name,,article_author +country/AFG,typeOf,Country, +country/YEM,typeOf,Country, +country/AGO,typeOf,Country, +country/ZMB,typeOf,Country, +country/ZWE,typeOf,Country, +country/ALB,typeOf,Country, +wikidataId/Q22062741,typeOf,Country, +country/DZA,typeOf,Country, +country/AND,typeOf,Country, +country/ASM,typeOf,Country, +country/AIA,typeOf,Country, +country/WLF,typeOf,Country, +country/ESH,typeOf,Country, +country/USA,typeOf,Country, +country/IND,typeOf,Country, diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/author_entities.csv b/simple/tests/stats/test_data/runner/input/with_subdirs/author_entities.csv new file mode 100644 index 00000000..404fb9cd --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/author_entities.csv @@ -0,0 +1,4 @@ +author_id,author_name,author_country +author1,"Jane Doe",country/USA +author2,"Joe Smith",country/CAN +author3,"Jane Smith",country/USA diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/article_entities.csv b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/article_entities.csv new file mode 100644 index 00000000..61cbba4a --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/article_entities.csv @@ -0,0 +1,4 @@ +article_id,article_title,article_author +article1,"Article 1",author1 +article2,"Article 2","author1,author2" +article3,"Article 3","author2,author3" \ No newline at end of file diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/variables.mcf b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/variables.mcf new file mode 100644 index 00000000..76df83a7 --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir1/variables.mcf @@ -0,0 +1,5 @@ +Node: dcid:some_var2 +typeOf: dcs:StatisticalVariable +measuredProperty: dcs:value +name: "Some Variable 2 Name" +description: "Some Variable 2 Description" diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/subdir2/subsubdir/countries.csv b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir2/subsubdir/countries.csv new file mode 100644 index 00000000..e62275b3 --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/subdir2/subsubdir/countries.csv @@ -0,0 +1,15 @@ +place,year,var1,var2 +Afghanistan,2023,0.19,6 +Yemen,2023,0.21,56 +Angola,2023,0.29,6 +Zambia,2023,0.31,34 +Zimbabwe,2023,0.37,76 +Albania,2023,0.50,34 +dcid: wikidataId/Q22062741,2023,0.50,97 +Algeria,2023,0.52,92 +West Bank and Gaza,2023,0.53,64 +Andorra,2023,0.76,9 +American Samoa,2023,#N/A,34 +Anguilla,2023,#N/A,42 +Wallis and Futuna Islands,2023,#N/A,75 +Western Sahara,2023,#N/A,65 diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/variable_per_row.csv b/simple/tests/stats/test_data/runner/input/with_subdirs/variable_per_row.csv new file mode 100644 index 00000000..d5b70ebc --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/variable_per_row.csv @@ -0,0 +1,5 @@ +entity,variable,date,value +country/IND,var1,2020,0.16 +country/IND,var2,2020,53 +country/CHN,var1,2020,0.23 +country/CHN,var2,2020,67 \ No newline at end of file diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/variables.mcf b/simple/tests/stats/test_data/runner/input/with_subdirs/variables.mcf new file mode 100644 index 00000000..9120930b --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/variables.mcf @@ -0,0 +1,6 @@ +Node: v1 +dcid:"some_var1" +typeOf: dcs:StatisticalVariable +measuredProperty: dcs:value +name: "Some Variable 1 Name" +description: "Some Variable 1 Description" diff --git a/simple/tests/stats/test_data/runner/input/with_subdirs/wikidataids.csv b/simple/tests/stats/test_data/runner/input/with_subdirs/wikidataids.csv new file mode 100644 index 00000000..a9870983 --- /dev/null +++ b/simple/tests/stats/test_data/runner/input/with_subdirs/wikidataids.csv @@ -0,0 +1,3 @@ +wikidataid,year,var1,var2 +Q30,2021,555,666 +Q668,2022,321,123 diff --git a/simple/tests/stats/variable_per_row_importer_test.py b/simple/tests/stats/variable_per_row_importer_test.py index 32a5b22e..d5d01331 100644 --- a/simple/tests/stats/variable_per_row_importer_test.py +++ b/simple/tests/stats/variable_per_row_importer_test.py @@ -32,7 +32,7 @@ from tests.stats.test_util import is_write_mode from tests.stats.test_util import use_fake_gzip_time from tests.stats.test_util import write_observations -from util.filehandler import LocalFileHandler +from util.filesystem import create_store from util import dc_client @@ -50,29 +50,32 @@ def _test_import(test: unittest.TestCase, test_name: str): with tempfile.TemporaryDirectory() as temp_dir: input_dir = os.path.join(_INPUT_DIR, test_name) expected_dir = os.path.join(_EXPECTED_DIR, test_name) + temp_store = create_store(temp_dir) input_path = os.path.join(input_dir, "input.csv") config_path = os.path.join(input_dir, "config.json") - db_path = os.path.join(temp_dir, f"{test_name}.db") + db_file_name = f"{test_name}.db" + db_path = os.path.join(temp_dir, db_file_name) + db_file = temp_store.as_dir().open_file(db_file_name) output_path = os.path.join(temp_dir, f"{test_name}.db.csv") expected_path = os.path.join(_EXPECTED_DIR, f"{test_name}.db.csv") output_path = os.path.join(temp_dir, "observations.db.csv") expected_path = os.path.join(expected_dir, "observations.db.csv") - input_fh = LocalFileHandler(input_path) + input_file = create_store(input_path).as_file() with open(config_path) as config_file: config = Config(json.load(config_file)) - db = create_and_update_db(create_sqlite_config(db_path)) - report_fh = LocalFileHandler(os.path.join(temp_dir, "report.json")) - reporter = FileImportReporter(input_path, ImportReporter(report_fh)) + db = create_and_update_db(create_sqlite_config(db_file)) + report_file = create_store(temp_dir).as_dir().open_file("report.json") + reporter = FileImportReporter(input_path, ImportReporter(report_file)) nodes = Nodes(config) dc_client.get_property_of_entities = MagicMock(return_value={}) - VariablePerRowImporter(input_fh=input_fh, + VariablePerRowImporter(input_file=input_file, db=db, reporter=reporter, nodes=nodes).do_import() diff --git a/simple/tests/util/filesystem_test.py b/simple/tests/util/filesystem_test.py index d993424d..73b9d018 100644 --- a/simple/tests/util/filesystem_test.py +++ b/simple/tests/util/filesystem_test.py @@ -93,11 +93,17 @@ def test_dir(self): self.assertEqual(file.full_path(), "mem://dir1/dir2/dir3/foo.txt") dir.open_file("bar.txt") subdir.open_file("baz.txt") - all_file_paths = [file.full_path() for file in dir.all_files()] - self.assertListEqual(all_file_paths, [ + all_file_paths_including_subdirs = [ + file.full_path() for file in dir.all_files(include_subdirs=True) + ] + self.assertListEqual(all_file_paths_including_subdirs, [ "mem://bar.txt", "mem://dir1/dir2/baz.txt", "mem://dir1/dir2/dir3/foo.txt" ]) + all_file_paths_excluding_subdirs = [ + file.full_path() for file in dir.all_files() + ] + self.assertListEqual(all_file_paths_excluding_subdirs, ["mem://bar.txt"]) # Test copy_to method on File def test_copy_to(self): diff --git a/simple/util/file_match.py b/simple/util/file_match.py index fc389f91..e413a037 100644 --- a/simple/util/file_match.py +++ b/simple/util/file_match.py @@ -55,7 +55,7 @@ def match(f: File, pattern: str) -> bool: return False match_from_beginning = True else: - if "://" in full_path: + if "://" in full_path and f.syspath() is not None: # Switch to syspath abs_path = fspath.relpath(f.syspath()) else: diff --git a/simple/util/filehandler.py b/simple/util/filehandler.py deleted file mode 100644 index 0e14262e..00000000 --- a/simple/util/filehandler.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2023 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A generic FileHandler abstraction that allows clients to work seamlessly with -local and GCS files and directories. -""" - -import io -import logging -import os - -from google.cloud import storage - -_GCS_PATH_PREFIX = "gs://" - - -class FileHandler: - """(Abstract) base class that should be extended by concrete implementations.""" - - def __init__(self, path: str, isdir: bool) -> None: - self.path = path - self.isdir = isdir - - def __str__(self) -> str: - return self.path - - def read_string(self) -> str: - pass - - def read_string_io(self) -> io.StringIO: - return io.StringIO(self.read_string()) - - def write_string(self, content: str) -> None: - pass - - def read_bytes(self) -> bytes: - pass - - def write_bytes(self, content: bytes) -> None: - pass - - def make_file(self, file_name: str) -> "FileHandler": - pass - - def make_dirs(self) -> None: - pass - - def basename(self) -> str: - pass - - def exists(self) -> bool: - pass - - def list_files(self, extension: str = None) -> list[str]: - pass - - -class LocalFileHandler(FileHandler): - - def __init__(self, path: str) -> None: - isdir = os.path.isdir(path) - super().__init__(path, isdir) - - def read_string(self) -> str: - with open(self.path, "r") as f: - return f.read() - - def write_string(self, content: str) -> None: - with open(self.path, "w") as f: - f.write(content) - - def read_bytes(self) -> bytes: - with open(self.path, "rb") as f: - return f.read() - - def write_bytes(self, content: bytes) -> None: - with open(self.path, "wb") as f: - f.write(content) - - def make_file(self, file_name: str) -> FileHandler: - return LocalFileHandler(os.path.join(self.path, file_name)) - - def make_dirs(self) -> None: - return os.makedirs(self.path, exist_ok=True) - - def basename(self) -> str: - path = self.path.rstrip(self.path[-1]) if self.path.endswith( - os.sep) else self.path - return path.split(os.sep)[-1] - - def exists(self) -> bool: - return os.path.exists(self.path) - - def list_files(self, extension: str = None) -> list[str]: - all_files = os.listdir(self.path) - if not extension: - return all_files - return filter(lambda name: name.lower().endswith(extension.lower()), - all_files) - - -class GcsMeta(type): - - @property - def gcs_client(cls) -> storage.Client: - if getattr(cls, "_GCS_CLIENT", None) is None: - gcs_client = storage.Client() - logging.info("Using GCS project: %s", gcs_client.project) - cls._GCS_CLIENT = gcs_client - return cls._GCS_CLIENT - - -class GcsFileHandler(FileHandler, metaclass=GcsMeta): - - def __init__(self, path: str, is_dir: bool = None) -> None: - if not path.startswith(_GCS_PATH_PREFIX): - raise ValueError(f"Expected {_GCS_PATH_PREFIX} prefix, got {path}") - - # If is_dir is specified, use that to set the isdir property. - if is_dir is not None: - isdir = is_dir - # If it is a dir, suffix with "/" if needed. - if isdir: - if not path.endswith("/"): - path = f"{path}/" - else: - isdir = path.endswith("/") - - bucket_name, blob_name = path[len(_GCS_PATH_PREFIX):].split('/', 1) - self.bucket = GcsFileHandler.gcs_client.bucket(bucket_name) - self.blob = self.bucket.blob(blob_name) - - super().__init__(path, isdir) - - def read_string(self) -> str: - return self.blob.download_as_string().decode("utf-8") - - def write_string(self, content: str) -> None: - self.blob.upload_from_string(content) - - def read_bytes(self) -> bytes: - return self.blob.download_as_bytes() - - def write_bytes(self, content: bytes) -> None: - self.blob.upload_from_string(content, - content_type="application/octet-stream") - - def make_file(self, file_name: str) -> FileHandler: - return GcsFileHandler(f"{self.path}{'' if self.isdir else '/'}{file_name}") - - def basename(self) -> str: - path = self.path.rstrip( - self.path[-1]) if self.path.endswith("/") else self.path - return path.split("/")[-1] - - def exists(self) -> bool: - return self.blob.exists() - - def list_files(self, extension: str = None) -> list[str]: - prefix = self.blob.name if self.path.endswith("/") else f"{self.blob.name}/" - all_files = [ - blob.name[len(prefix):] - for blob in self.bucket.list_blobs(prefix=prefix, delimiter="/") - ] - if not extension: - return all_files - return filter(lambda name: name.lower().endswith(extension.lower()), - all_files) - - -def is_gcs_path(path: str) -> bool: - return path.startswith(_GCS_PATH_PREFIX) - - -def create_file_handler(path: str, is_dir: bool = None) -> FileHandler: - if is_gcs_path(path): - return GcsFileHandler(path, is_dir) - return LocalFileHandler(path) diff --git a/simple/util/filesystem.py b/simple/util/filesystem.py index ad482d16..6ddec363 100644 --- a/simple/util/filesystem.py +++ b/simple/util/filesystem.py @@ -165,10 +165,11 @@ def open_file(self, path: str, create_if_missing: bool = True) -> "File": f"{self.full_path(path)} exists and is a directory, not a file") return File(self._store, new_path, create_if_missing) - def all_files(self): + def all_files(self, include_subdirs: bool = False): files = [] subfs = self.fs().opendir(self.path) - for abspath in subfs.walk.files(self.path): + max_depth = None if include_subdirs else 1 + for abspath in subfs.walk.files(self.path, max_depth=max_depth): files.append(self.open_file(abspath)) return files