Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User lookup #30

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions m3inference/m3twitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,34 @@ def _twitter_api(self,id=None,screen_name=None):
return self.process_twitter(r.json())


def _twitter_api_lookup(self, ids=None, batch_size=16, num_workers=4):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably could reuse/extend the twitter_api function by extending id to ids?

if self.twitter_session==None:
logger.fatal("You must call twitter_init(...) before using this method. Please see https://github.com/euagendas/m3inference/blob/master/README.md for details.")
return None

#if screen_name!=None:
# logger.info("GET /users/show.json?screen_name={}".format(screen_name))
# try:
# r=self.twitter_session.get("users/show.json",params={"screen_name":screen_name})
# except:
# logger.warning("Invalid response from Twitter")

# return None
if ids!=None:
logger.info("GET /users/lookup.json?ids={}".format(','.join(ids)))
try:
r=self.twitter_session.get("users/lookup.json",params={"user_id":','.join(ids)})
except:
logger.warning("Invalid response from Twitter")
return None
else:
logger.fatal("No id or screen_name")
return None

r = r.json()
return self.process_twitter_batch(r, batch_size, num_workers)
#return self.process_twitter(r.json())

def infer_id(self, id, skip_cache=False):
"""
Collect data for a numeric Twitter user id from the Twitter website and predict attributes with m3
Expand All @@ -216,13 +244,73 @@ def infer_id(self, id, skip_cache=False):
json.dump(output, fh)
return output


def _get_twitter_attrib(self,key,data):
if key in data:
return data[key]
else:
logger.warning("Could not retreive {}".format(key))
return ""


def infer_ids(self, id_list, batch_size=16, num_workers=4, skip_cache=False):
"""
Collect data for a list of numeric Twitter user ids from the Twitter website and predict attributes with m3
:param id_list: A list of Twitter numeric user ids
:param skip_cache: If output for this screen name already exists in self.cache_dir, the results will be reused (i.e., the function will not contact the Twitter website and will not run m3).
:return: a dictionary object with two keys. "input" contains the data from the Twitter website. "output" contains the m3 output in the `output_format` format described for m3.
"""

outputs = []
cached_ids = []

if not skip_cache:
for id in id_list:
# If a json file exists, we'll use that. Otherwise go get the data.
try:
with open("{}/{}.json".format(self.cache_dir, id), "r") as fh:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.path.join might be better?

logger.info("Results from cache for id {}.".format(id))
outputs.append(json.load(fh))
cached_ids.append(id)
except:
logger.info("Results not in cache. Fetching data from Twitter for id {}.".format(id))
else:
logger.info("skip_cache is True. Fetching data from Twitter for id {}.".format(id))

id_list = set(id_list)
cached_ids = set(cached_ids)
id_list = list(id_list - cached_ids)

if len(id_list) > 0:
# the twitter API handles a maximum of 100 user IDs per request. Chunk up the user
# id list into batches of max 100 IDs and run them through the pipeline sequentially
API_batch_size = 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be defined as a global variable for easy tracking?

if len(id_list) > 100:
N_batches = int(len(id_list) / API_batch_size)
id_batches = [id_list[i * API_batch_size : (i + 1) * API_batch_size] for i in range(N_batches + 1)]
else:
id_batches = [id_list]

new_outputs = []
for id_batch in id_batches:
new_outputs.extend(self._twitter_api_lookup(
ids=id_batch,
batch_size=batch_size,
num_workers=num_workers)
)
else:
new_outputs = []

# write any new outputs to the cache
for output in new_outputs:
id = output["input"]["id"]
with open("{}/{}.json".format(self.cache_dir, id), "w") as fh:
json.dump(output, fh)

outputs.extend(new_outputs)
return outputs


def process_twitter(self, data):

screen_name=self._get_twitter_attrib("screen_name",data)
Expand Down Expand Up @@ -265,3 +353,54 @@ def process_twitter(self, data):
"output": pred[id]
}
return output


def process_twitter_batch(self, data_list, batch_size, num_workers):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems largely overlaps with the non-batched version. Could you call the non-batched version and then aggregate the result?


new_data_list = []

for data in data_list:

screen_name=self._get_twitter_attrib("screen_name",data)
id=self._get_twitter_attrib("id_str",data)
bio=self._get_twitter_attrib("description",data)
name=self._get_twitter_attrib("name",data)
img_path=self._get_twitter_attrib("profile_image_url",data)
if id=="":
id="dummy" #Can be anything since batch is of size 1

if bio == "":
lang = UNKNOWN_LANG
else:
lang = get_lang(bio)

if img_path=="" or "default_profile" in img_path:
logger.warning("Unable to extract image from Twitter. Using default image.")
img_file_resize = TW_DEFAULT_PROFILE_IMG
else:
img_path = img_path.replace("_200x200", "_400x400").replace("_normal", "_400x400")
img_file_full = f"{self.cache_dir}/{id}" + (f".{img_path[img_path.rfind('.') + 1:]}" if '.' in img_path.split('/')[-1] else '')
img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, id, get_extension(img_path))
# img_file_full = "{}/{}.{}".format(self.cache_dir, screen_name, img[dotpos + 1:])
# img_file_resize = "{}/{}_224x224.{}".format(self.cache_dir, screen_name, get_extension(img))
download_resize_img(img_path, img_file_resize, img_file_full)

data = [{
"description": bio,
"id": id,
"img_path": img_file_resize,
"lang": lang,
"name": name,
"screen_name": screen_name,
}]

new_data_list.extend(data)

pred = self.infer(new_data_list, batch_size=batch_size, num_workers=num_workers)

outputs = [{
"input":new_data_list[i],
"output":pred[new_data_list[i]["id"]]
} for i in range(len(new_data_list))]

return outputs