-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
User lookup #30
base: master
Are you sure you want to change the base?
User lookup #30
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,6 +193,34 @@ def _twitter_api(self,id=None,screen_name=None): | |
return self.process_twitter(r.json()) | ||
|
||
|
||
def _twitter_api_lookup(self, ids=None, batch_size=16, num_workers=4): | ||
if self.twitter_session==None: | ||
logger.fatal("You must call twitter_init(...) before using this method. Please see https://github.com/euagendas/m3inference/blob/master/README.md for details.") | ||
return None | ||
|
||
#if screen_name!=None: | ||
# logger.info("GET /users/show.json?screen_name={}".format(screen_name)) | ||
# try: | ||
# r=self.twitter_session.get("users/show.json",params={"screen_name":screen_name}) | ||
# except: | ||
# logger.warning("Invalid response from Twitter") | ||
|
||
# return None | ||
if ids!=None: | ||
logger.info("GET /users/lookup.json?ids={}".format(','.join(ids))) | ||
try: | ||
r=self.twitter_session.get("users/lookup.json",params={"user_id":','.join(ids)}) | ||
except: | ||
logger.warning("Invalid response from Twitter") | ||
return None | ||
else: | ||
logger.fatal("No id or screen_name") | ||
return None | ||
|
||
r = r.json() | ||
return self.process_twitter_batch(r, batch_size, num_workers) | ||
#return self.process_twitter(r.json()) | ||
|
||
def infer_id(self, id, skip_cache=False): | ||
""" | ||
Collect data for a numeric Twitter user id from the Twitter website and predict attributes with m3 | ||
|
@@ -216,13 +244,73 @@ def infer_id(self, id, skip_cache=False): | |
json.dump(output, fh) | ||
return output | ||
|
||
|
||
def _get_twitter_attrib(self,key,data): | ||
if key in data: | ||
return data[key] | ||
else: | ||
logger.warning("Could not retreive {}".format(key)) | ||
return "" | ||
|
||
|
||
def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False): | ||
""" | ||
Collect data for a list of numeric Twitter user ids from the Twitter website and predict attributes with m3 | ||
:param id_list: A list of Twitter numeric user ids | ||
:param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3). | ||
:return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3. | ||
""" | ||
|
||
outputs = [] | ||
cached_ids = [] | ||
|
||
if not skip_cache: | ||
for id in id_list: | ||
# If a json file exists, we'll use that. Otherwise go get the data. | ||
try: | ||
with open("{}/{}.json".format(self.cache_dir, id), "r") as fh: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
logger.info("Results from cache for id {}.".format(id)) | ||
outputs.append(json.load(fh)) | ||
cached_ids.append(id) | ||
except: | ||
logger.info("Results not in cache. Fetching data from Twitter for id {}.".format(id)) | ||
else: | ||
logger.info("skip_cache is True. Fetching data from Twitter for id {}.".format(id)) | ||
|
||
id_list = set(id_list) | ||
cached_ids = set(cached_ids) | ||
id_list = list(id_list - cached_ids) | ||
|
||
if len(id_list) > 0: | ||
# the twitter API handles a maximum of 100 user IDs per request. Chunk up the user | ||
# id list into batches of max 100 IDs and run them through the pipeline sequentially | ||
API_batch_size = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be defined as a global variable for easy tracking? |
||
if len(id_list) > 100: | ||
N_batches = int(len(id_list) / API_batch_size) | ||
id_batches = [id_list[i * API_batch_size : (i + 1) * API_batch_size] for i in range(N_batches + 1)] | ||
else: | ||
id_batches = [id_list] | ||
|
||
new_outputs = [] | ||
for id_batch in id_batches: | ||
new_outputs.extend(self._twitter_api_lookup( | ||
ids=id_batch, | ||
batch_size=batch_size, | ||
num_workers=num_workers) | ||
) | ||
else: | ||
new_outputs = [] | ||
|
||
# write any new outputs to the cache | ||
for output in new_outputs: | ||
id = output["input"]["id"] | ||
with open("{}/{}.json".format(self.cache_dir, id), "w") as fh: | ||
json.dump(output, fh) | ||
|
||
outputs.extend(new_outputs) | ||
return outputs | ||
|
||
|
||
def process_twitter(self, data): | ||
|
||
screen_name=self._get_twitter_attrib("screen_name",data) | ||
|
@@ -265,3 +353,54 @@ def process_twitter(self, data): | |
"output": pred[id] | ||
} | ||
return output | ||
|
||
|
||
def process_twitter_batch(self, data_list, batch_size, num_workers): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems largely overlaps with the non-batched version. Could you call the non-batched version and then aggregate the result? |
||
|
||
new_data_list = [] | ||
|
||
for data in data_list: | ||
|
||
screen_name=self._get_twitter_attrib("screen_name",data) | ||
id=self._get_twitter_attrib("id_str",data) | ||
bio=self._get_twitter_attrib("description",data) | ||
name=self._get_twitter_attrib("name",data) | ||
img_path=self._get_twitter_attrib("profile_image_url",data) | ||
if id=="": | ||
id="dummy" #Can be anything since batch is of size 1 | ||
|
||
if bio == "": | ||
lang = UNKNOWN_LANG | ||
else: | ||
lang = get_lang(bio) | ||
|
||
if img_path=="" or "default_profile" in img_path: | ||
logger.warning("Unable to extract image from Twitter. Using default image.") | ||
img_file_resize = TW_DEFAULT_PROFILE_IMG | ||
else: | ||
img_path = img_path.replace("_200x200", "_400x400").replace("_normal", "_400x400") | ||
img_file_full = f"{self.cache_dir}/{id}" + (f".{img_path[img_path.rfind('.') + 1:]}" if '.' in img_path.split('/')[-1] else '') | ||
img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, id, get_extension(img_path)) | ||
# img_file_full = "{}/{}.{}".format(self.cache_dir, screen_name, img[dotpos + 1:]) | ||
# img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, screen_name, get_extension(img)) | ||
download_resize_img(img_path, img_file_resize, img_file_full) | ||
|
||
data = [{ | ||
"description": bio, | ||
"id": id, | ||
"img_path": img_file_resize, | ||
"lang": lang, | ||
"name": name, | ||
"screen_name": screen_name, | ||
}] | ||
|
||
new_data_list.extend(data) | ||
|
||
pred = self.infer(new_data_list, batch_size=batch_size, num_workers=num_workers) | ||
|
||
outputs = [{ | ||
"input":new_data_list[i], | ||
"output":pred[new_data_list[i]["id"]] | ||
} for i in range(len(new_data_list))] | ||
|
||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably could reuse/extend the
twitter_api
function by extendingid
toids
?