-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_pipeline.py
66 lines (56 loc) · 2.2 KB
/
run_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os.path
from datetime import datetime, timedelta
from sml.pipeline import Pipeline
from sml.secrets import Secrets
import yaml
import click
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("azure").setLevel(logging.WARNING)
logging.getLogger("requests_oauthlib").setLevel(logging.WARNING)
@click.command()
@click.option("--country", type=str, required=True, help="Country ISO3")
@click.option("--source", type=str, required=True, help="Data source")
@click.option("--channels", type=str, required=True, help="Channels to track")
@click.option("--days", type=int, default=14, help="How many days in the past")
def run_sml_pipeline(country, source, channels, days):
start_date = datetime.today() - timedelta(days=days)
end_date = datetime.today()
country = country.upper()
# load secrets from .env
pipe = Pipeline(secrets=Secrets("credentials/.env"))
logging.info(f"scraping messages")
pipe.extract.set_source(source)
messages = pipe.extract.get_data(
start_date=start_date,
country=country,
channels=channels.split(","),
store_temp=False,
)
messages = [
message for message in messages if len(message.text) >= 20
] # filter messages by length
logging.info(f"found {len(messages)} messages!")
pipe.transform.set_translator(
model="Microsoft",
from_lang="", # empty string means auto-detect language
to_lang="en",
)
pipe.transform.set_classifier(
type="setfit", model="rodekruis/sml-ukr-message-classifier-2", lang="en"
)
messages = pipe.transform.process_messages(messages, translate=True, classify=True)
logging.info(f"processed {len(messages)} messages!")
pipe.load.set_storage("Azure Cosmos DB")
pipe.load.save_messages(messages)
pipe.load.save_to_argilla(
messages=messages,
dataset_name=f"{country}-{start_date.strftime('%Y-%m-%d')}-{end_date.strftime('%Y-%m-%d')}",
workspace=country,
)
logging.info(f"saved {len(messages)} messages!")
if __name__ == "__main__":
run_sml_pipeline()