Skip to content

Commit

Permalink
Streamline index_elements()
Browse files Browse the repository at this point in the history
  • Loading branch information
dustalov committed Jul 10, 2024
1 parent c93a1ff commit 8e3caa9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 25 deletions.
25 changes: 10 additions & 15 deletions python/evalica/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import dataclasses
from collections import OrderedDict
from collections.abc import Hashable, Iterable
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar
Expand Down Expand Up @@ -35,16 +34,6 @@
T = TypeVar("T", bound=Hashable)


def enumerate_elements(xs: Iterable[T], *yss: Iterable[T]) -> dict[T, int]:
index: dict[T, int] = OrderedDict()

for ys in (xs, *yss):
for y in ys:
index[y] = index.get(y, len(index))

return index


@dataclass
class IndexedElements(Generic[T]):
index: pd.Index[T] # type: ignore[type-var]
Expand All @@ -53,10 +42,16 @@ class IndexedElements(Generic[T]):


def index_elements(xs: Iterable[T], ys: Iterable[T]) -> IndexedElements[T]:
xy_index = enumerate_elements(xs, ys)
xy_index: dict[T, int] = {}

def get_index(x: T) -> int:
if (index := xy_index.get(x)) is None:
index = xy_index[x] = len(xy_index)

return index

xs_indexed = [xy_index[x] for x in xs]
ys_indexed = [xy_index[y] for y in ys]
xs_indexed = [get_index(x) for x in xs]
ys_indexed = [get_index(y) for y in ys]

return IndexedElements(
index=pd.Index(xy_index),
Expand Down Expand Up @@ -335,7 +330,7 @@ def pairwise_frame(scores: pd.Series[T]) -> pd.DataFrame: # type: ignore[type-v
"counting",
"eigen",
"elo",
"enumerate_elements",
"index_elements",
"matrices",
"newman",
"pagerank",
Expand Down
18 changes: 8 additions & 10 deletions python/evalica/test_evalica.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pytest
from hypothesis import given
from pandas._testing import assert_series_equal

import evalica
from conftest import Example, elements

if TYPE_CHECKING:
import pandas as pd


def test_version() -> None:
assert isinstance(evalica.__version__, str)
Expand All @@ -23,14 +19,16 @@ def test_exports() -> None:


@given(example=elements())
def test_enumerate_elements(example: Example) -> None: # type: ignore[type-var]
def test_index_elements(example: Example) -> None: # type: ignore[type-var]
xs, ys, ws = example

index = evalica.enumerate_elements(xs, ys)
indexed = evalica.index_elements(xs, ys)

assert isinstance(index, dict)
assert len(index) == len(set(xs) | set(ys))
assert not xs or max(index.values()) == len(index) - 1
assert len(indexed.xs) == len(xs)
assert len(indexed.ys) == len(ys)
assert isinstance(indexed.index, pd.Index)
assert len(indexed.index) == len(set(xs) | set(ys))
assert set(indexed.index.values) == (set(xs) | set(ys))


@given(example=elements())
Expand Down

0 comments on commit 8e3caa9

Please sign in to comment.