From d57b0aafb516ac260a5842b986ab1951e1e594b1 Mon Sep 17 00:00:00 2001 From: Jana Date: Fri, 25 Mar 2022 17:48:08 +0100 Subject: [PATCH 1/2] new function for batch ID processing and corresponding helper functions --- m3inference/m3twitter.py | 122 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/m3inference/m3twitter.py b/m3inference/m3twitter.py index 046c571..acdfcf9 100644 --- a/m3inference/m3twitter.py +++ b/m3inference/m3twitter.py @@ -193,6 +193,77 @@ def _twitter_api(self,id=None,screen_name=None): return self.process_twitter(r.json()) + def _twitter_api_lookup(self, ids=None, batch_size=16, num_workers=4): + if self.twitter_session==None: + logger.fatal("You must call twitter_init(...) before using this method. Please see https://github.com/euagendas/m3inference/blob/master/README.md for details.") + return None + + #if screen_name!=None: + # logger.info("GET /users/show.json?screen_name={}".format(screen_name)) + # try: + # r=self.twitter_session.get("users/show.json",params={"screen_name":screen_name}) + # except: + # logger.warning("Invalid response from Twitter") + + # return None + if ids!=None: + logger.info("GET /users/lookup.json?ids={}".format(','.join(ids))) + try: + r=self.twitter_session.get("users/lookup.json",params={"user_id":','.join(ids)}) + except: + logger.warning("Invalid response from Twitter") + return None + else: + logger.fatal("No id or screen_name") + return None + + r = r.json() + return self.process_twitter_batch(r, batch_size, num_workers) + #return self.process_twitter(r.json()) + + + def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): + """ + Collect data for a list of numeric Twitter user ids from the Twitter website and predict attributes with m3 + :param id_list: A list of Twitter numeric user ids + :param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3). + :return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3. + """ + outputs = [] + cached_ids = [] + + if not skip_cache: + for id in id_list: + # If a json file exists, we'll use that. Otherwise go get the data. + try: + with open("{}/{}.json".format(self.cache_dir, id), "r") as fh: + logger.info("Results from cache for id {}.".format(id)) + outputs.append(json.load(fh)) + cached_ids.append(id) + except: + logger.info("Results not in cache. Fetching data from Twitter for id {}.".format(id)) + else: + logger.info("skip_cache is True. Fetching data from Twitter for id {}.".format(id)) + + id_list = set(id_list) + cached_ids = set(cached_ids) + id_list = id_list - cached_ids + + if len(id_list) > 0: + new_outputs=self._twitter_api_lookup(ids=id_list, batch_size=batch_size, num_workers=num_workers) + else: + new_outputs = [] + + # write any new outputs to the cache + for output in new_outputs: + id = output["input"]["id"] + with open("{}/{}.json".format(self.cache_dir, id), "w") as fh: + json.dump(output, fh) + + outputs.extend(new_outputs) + return outputs + + def infer_id(self, id, skip_cache=False): """ Collect data for a numeric Twitter user id from the Twitter website and predict attributes with m3 @@ -265,3 +336,54 @@ def process_twitter(self, data): "output": pred[id] } return output + + + def process_twitter_batch(self, data_list, batch_size, num_workers): + + new_data_list = [] + + for data in data_list: + + screen_name=self._get_twitter_attrib("screen_name",data) + id=self._get_twitter_attrib("id_str",data) + bio=self._get_twitter_attrib("description",data) + name=self._get_twitter_attrib("name",data) + img_path=self._get_twitter_attrib("profile_image_url",data) + if id=="": + id="dummy" #Can be anything since batch is of size 1 + + if bio == "": + lang = UNKNOWN_LANG + else: + lang = get_lang(bio) + + if img_path=="" or "default_profile" in img_path: + logger.warning("Unable to extract image from Twitter. Using default image.") + img_file_resize = TW_DEFAULT_PROFILE_IMG + else: + img_path = img_path.replace("_200x200", "_400x400").replace("_normal", "_400x400") + img_file_full = f"{self.cache_dir}/{id}" + (f".{img_path[img_path.rfind('.') + 1:]}" if '.' in img_path.split('/')[-1] else '') + img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, id, get_extension(img_path)) + # img_file_full = "{}/{}.{}".format(self.cache_dir, screen_name, img[dotpos + 1:]) + # img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, screen_name, get_extension(img)) + download_resize_img(img_path, img_file_resize, img_file_full) + + data = [{ + "description": bio, + "id": id, + "img_path": img_file_resize, + "lang": lang, + "name": name, + "screen_name": screen_name, + }] + + new_data_list.extend(data) + + pred = self.infer(new_data_list, batch_size=batch_size, num_workers=num_workers) + + outputs = [{ + "input":new_data_list[i], + "output":pred[new_data_list[i]["id"]] + } for i in range(len(new_data_list))] + + return outputs \ No newline at end of file From 3f7446f3713eb0b398dc7a2436e659d44f260c9c Mon Sep 17 00:00:00 2001 From: Jana Date: Fri, 25 Mar 2022 18:01:34 +0100 Subject: [PATCH 2/2] added functionality to chunk up ID lists with > 100 IDs --- m3inference/m3twitter.py | 81 ++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/m3inference/m3twitter.py b/m3inference/m3twitter.py index acdfcf9..b18d3c9 100644 --- a/m3inference/m3twitter.py +++ b/m3inference/m3twitter.py @@ -221,6 +221,37 @@ def _twitter_api_lookup(self, ids=None, batch_size=16, num_workers=4): return self.process_twitter_batch(r, batch_size, num_workers) #return self.process_twitter(r.json()) + def infer_id(self, id, skip_cache=False): + """ + Collect data for a numeric Twitter user id from the Twitter website and predict attributes with m3 + :param id: A Twitter numeric user id + :param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3). + :return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3. + """ + if not skip_cache: + # If a json file exists, we'll use that. Otherwise go get the data. + try: + with open("{}/{}.json".format(self.cache_dir, id), "r") as fh: + logger.info("Results from cache for id {}.".format(id)) + return json.load(fh) + except: + logger.info("Results not in cache. Fetching data from Twitter for id {}.".format(id)) + else: + logger.info("skip_cache is True. Fetching data from Twitter for id {}.".format(id)) + + output=self._twitter_api(id=id) + with open("{}/{}.json".format(self.cache_dir, id), "w") as fh: + json.dump(output, fh) + return output + + + def _get_twitter_attrib(self,key,data): + if key in data: + return data[key] + else: + logger.warning("Could not retreive {}".format(key)) + return "" + def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): """ @@ -229,6 +260,7 @@ def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): :param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3). :return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3. """ + outputs = [] cached_ids = [] @@ -247,10 +279,25 @@ def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): id_list = set(id_list) cached_ids = set(cached_ids) - id_list = id_list - cached_ids + id_list = list(id_list - cached_ids) if len(id_list) > 0: - new_outputs=self._twitter_api_lookup(ids=id_list, batch_size=batch_size, num_workers=num_workers) + # the twitter API handles a maximum of 100 user IDs per request. Chunk up the user + # id list into batches of max 100 IDs and run them through the pipeline sequentially + API_batch_size = 100 + if len(id_list) > 100: + N_batches = int(len(id_list) / API_batch_size) + id_batches = [id_list[i * API_batch_size : (i + 1) * API_batch_size] for i in range(N_batches + 1)] + else: + id_batches = [id_list] + + new_outputs = [] + for id_batch in id_batches: + new_outputs.extend(self._twitter_api_lookup( + ids=id_batch, + batch_size=batch_size, + num_workers=num_workers) + ) else: new_outputs = [] @@ -264,36 +311,6 @@ def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): return outputs - def infer_id(self, id, skip_cache=False): - """ - Collect data for a numeric Twitter user id from the Twitter website and predict attributes with m3 - :param id: A Twitter numeric user id - :param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3). - :return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3. - """ - if not skip_cache: - # If a json file exists, we'll use that. Otherwise go get the data. - try: - with open("{}/{}.json".format(self.cache_dir, id), "r") as fh: - logger.info("Results from cache for id {}.".format(id)) - return json.load(fh) - except: - logger.info("Results not in cache. Fetching data from Twitter for id {}.".format(id)) - else: - logger.info("skip_cache is True. Fetching data from Twitter for id {}.".format(id)) - - output=self._twitter_api(id=id) - with open("{}/{}.json".format(self.cache_dir, id), "w") as fh: - json.dump(output, fh) - return output - - def _get_twitter_attrib(self,key,data): - if key in data: - return data[key] - else: - logger.warning("Could not retreive {}".format(key)) - return "" - def process_twitter(self, data): screen_name=self._get_twitter_attrib("screen_name",data)