From cd2d51c4ffc0d008b5a1a05a396c71335bc2b2c3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 14:01:22 +0900 Subject: [PATCH 01/72] Migrate to GitHub Action --- .github/workflows/unit-tests.yml | 29 +++++++++++++++++++++++++++++ .travis.yml | 16 ---------------- README.md | 26 +++++++++++--------------- tests/test_hello.py | 2 ++ 4 files changed, 42 insertions(+), 31 deletions(-) create mode 100644 .github/workflows/unit-tests.yml delete mode 100644 .travis.yml create mode 100644 tests/test_hello.py diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 0000000..c281fce --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,29 @@ +name: unit-tests + +on: [push] + +jobs: + build: + + name: Test on Python ${{ matrix.python-version }} and ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + max-parallel: 4 + matrix: + os: [ubuntu-latest, macOS-latest, windows-latest] + python-version: [3.5, 3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - name: Test with pytest + run: | + pip install pytest + pytest tests diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 80b98c2..0000000 --- a/.travis.yml +++ /dev/null @@ -1,16 +0,0 @@ -language: python -sudo: required -dist: xenial -cache: - - pip - - $HOME/data -git: - depth: 3 - quiet: true -python: 3.7 -install: - - python -m pip install pytest pytest-cov codecov hypothesis -script: - - python -m pytest tests --cov-report term --cov aku -after_success: - - codecov \ No newline at end of file diff --git a/README.md b/README.md index 9b6ffb3..ce7e94d 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,20 @@ # aku [![PyPI Version](https://badge.fury.io/py/aku.svg)](https://pypi.org/project/aku/) -[![Build Status](https://travis-ci.org/speedcell4/aku.svg?branch=master)](https://travis-ci.org/speedcell4/aku) -[![Code Coverage](https://codecov.io/gh/speedcell4/aku/branch/master/graph/badge.svg)](https://codecov.io/gh/speedcell4/aku) +[![Actions Status](https://github.com/speedcell4/aku/workflows/unit-tests/badge.svg)](https://github.com/speedcell4/aku/actions) -setup your argument parser speedily +An Annotation-driven ArgumentParser Generator -## Installation +## Install ```bash -python3.6 -m pip install aku --upgrade +python -m pip install aku --upgrade ``` ## Usage ```python -# file test_single_function.py +# tests/test_single_function.py import aku app = aku.Aku() @@ -29,9 +28,9 @@ def add(a: int, b: int = 2): app.run() ``` -then `aku` will automatically add argument option according to your function signature. +`aku` will automatically add argument options according to your function signature. -```shell +```bash ~ python tests/test_single_function.py --help usage: aku [-h] --a A [--b B] @@ -42,13 +41,13 @@ optional arguments: ``` -if you registered more than one functions, then sub-parser will be utilized. +Registering more than one function will make `aku` add them to sub-parser respectively (and lazily). ```python # file test_multi_functions.py import aku -app = aku.App() +app = aku.Aku() @app.register @@ -64,9 +63,9 @@ def say_hello(name: str): app.run() ``` -your argument parser interface will looks like, +Similarly, your argument parser interface looks like, -```shell +```bash ~ python tests/test_multi_functions.py --help usage: aku [-h] {add,say_hello} ... @@ -82,7 +81,4 @@ usage: aku say_hello [-h] --name NAME optional arguments: -h, --help show this help message and exit --name NAME name (default: None) - -~ python tests/test_multi_functions.py say_hello --name aku -hello aku ``` diff --git a/tests/test_hello.py b/tests/test_hello.py new file mode 100644 index 0000000..3983ffa --- /dev/null +++ b/tests/test_hello.py @@ -0,0 +1,2 @@ +def test_hello(): + assert 1 + 1 == 2 From aa9d58f6bb0045f19196c56d4d15d8f4d626fd00 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 14:06:15 +0900 Subject: [PATCH 02/72] Remove install requirements.txt --- .github/workflows/unit-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index c281fce..bf51be0 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -22,7 +22,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r requirements.txt - name: Test with pytest run: | pip install pytest From d4896b246555cf884352a24ae50619b7e9e8fc90 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 14:16:55 +0900 Subject: [PATCH 03/72] Update(document) README.md --- README.md | 14 +++++++++----- VERSION | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ce7e94d..145cd87 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,21 @@ -# aku +# Aku [![PyPI Version](https://badge.fury.io/py/aku.svg)](https://pypi.org/project/aku/) [![Actions Status](https://github.com/speedcell4/aku/workflows/unit-tests/badge.svg)](https://github.com/speedcell4/aku/actions) An Annotation-driven ArgumentParser Generator +## Requirements + +Python 3.5 or higher + ## Install ```bash python -m pip install aku --upgrade ``` -## Usage +## Quick Start ```python # tests/test_single_function.py @@ -30,7 +34,7 @@ app.run() `aku` will automatically add argument options according to your function signature. -```bash +``` ~ python tests/test_single_function.py --help usage: aku [-h] --a A [--b B] @@ -63,9 +67,9 @@ def say_hello(name: str): app.run() ``` -Similarly, your argument parser interface looks like, +Similarly, your command line interface will look like, -```bash +``` ~ python tests/test_multi_functions.py --help usage: aku [-h] {add,say_hello} ... diff --git a/VERSION b/VERSION index 341cf11..7dff5b8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.0 \ No newline at end of file +0.2.1 \ No newline at end of file From 658479e9a472efef4e753fd172eca43d5c34b6a9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 15:04:39 +0900 Subject: [PATCH 04/72] Add text-classification example [WIP] --- .github/workflows/unit-tests.yml | 4 +- aku/__init__.py | 23 +++++- examples/naive.py | 24 ------- examples/text_classification.py | 120 +++++++++++++++++++++++++++++++ requirements.txt | 3 + 5 files changed, 147 insertions(+), 27 deletions(-) delete mode 100644 examples/naive.py create mode 100644 examples/text_classification.py create mode 100644 requirements.txt diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index bf51be0..aa8a205 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -24,5 +24,5 @@ jobs: python -m pip install --upgrade pip - name: Test with pytest run: | - pip install pytest - pytest tests + python -m pip install -r requirements.txt + python -m pytest tests diff --git a/aku/__init__.py b/aku/__init__.py index 0a51ec0..f931470 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,3 +1,24 @@ +import sys +from typing import List, Tuple, Set, FrozenSet +from typing import Optional +from typing import Type +from typing import Union + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing import _SpecialForm + + Literal = _SpecialForm('Literal', doc='Literal') + from .aku import Aku -__version__ = '0.2.0' +__version__ = '0.2.1' + +__all__ = [ + 'Type', + 'List', 'Tuple', 'Set', 'FrozenSet', + 'Union', 'Optional', 'Literal', + + 'Aku', +] diff --git a/examples/naive.py b/examples/naive.py deleted file mode 100644 index 8d07604..0000000 --- a/examples/naive.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import TypeVar - -import aku - -app = aku.Aku() - - -@app.register -def add(x: int = 1, y: int = 2): - return x + y - - -@app.register -def sub(x: int = 3, y: int = 4): - return x - y - - -@app.register -def mul(encoder: TypeVar('enc', add, sub) = sub): - print(f'op() => {encoder()}') - - -if __name__ == '__main__': - app.run() diff --git a/examples/text_classification.py b/examples/text_classification.py new file mode 100644 index 0000000..10772c8 --- /dev/null +++ b/examples/text_classification.py @@ -0,0 +1,120 @@ +import torch +from torch import Tensor +from torch import nn +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence +from torchtext.data import Batch +from torchtext.vocab import Vocab + +from aku import Literal, Type, Union, Tuple + + +class WordEmbedding(nn.Embedding): + def __init__(self, freeze: bool = False, *, word_vocab: Vocab) -> None: + num_embeddings, embedding_dim = word_vocab.vectors.size() + super(WordEmbedding, self).__init__( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, + padding_idx=word_vocab.stoi.get('', None), + _weight=word_vocab.vectors, + ) + self.weight.requires_grad = not freeze + + def forward(self, x): + if torch.is_tensor(x): + return super(WordEmbedding, self).forward(x) + elif not isinstance(x, PackedSequence): + x = pack_padded_sequence(*x, batch_first=True, enforce_sorted=False) + return x._replace(data=super(WordEmbedding, self).forward(x.data)) + + +class LstmEncoder(nn.Module): + def __init__(self, hidden_dim: int = 300, bias: bool = True, + num_layers: int = 2, dropout: float = 0.2, *, input_dim: int): + super(LstmEncoder, self).__init__() + + self.rnn = nn.LSTM( + input_size=input_dim, hidden_size=hidden_dim, bias=bias, + num_layers=num_layers, dropout=dropout, + batch_first=True, bidirectional=True, + ) + + self.output_dim = hidden_dim + + def forward(self, embedding: PackedSequence) -> Tensor: + _, (encoding, _) = self.rnn(embedding) + return encoding + + +class ConvEncoder(nn.Module): + def __init__(self, hidden_dim: int = 300, bias: bool = True, + kernel_sizes: Tuple[int, ...] = (3, 5, 7), dropout: float = 0.2, *, input_dim: int): + super(ConvEncoder, self).__init__() + + self.conv_layers = nn.ModuleList([ + nn.Conv2d( + in_channels=input_dim, out_channels=input_dim, bias=bias, + kernel_size=kernel_size, padding=kernel_size // 2, + ) + for kernel_size in kernel_sizes + ]) + self.fc = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(input_dim * len(kernel_sizes), input_dim * len(kernel_sizes)), + nn.ReLU(), + nn.Linear(input_dim * len(kernel_sizes), hidden_dim), + ) + + self.output_dim = hidden_dim + + def forward(self, embedding: Tensor) -> Tensor: + encoding = torch.cat([ + layer(embedding).max(dim=1) + for layer in self.conv_layers + ], dim=-1) + return self.fc(encoding) + + +class Classifier(nn.Sequential): + def __init__(self, input_dim: int, bias: bool = True, *, target_vocab: Vocab): + super(Classifier, self).__init__( + nn.Linear(input_dim, input_dim, bias=bias), + nn.ReLU(), + nn.Linear(input_dim, len(target_vocab), bias=bias), + ) + + +class TextClassification(nn.Module): + def __init__(self, + embedding: Type[WordEmbedding], + encoder: Union[Type[LstmEncoder], Type[ConvEncoder]] = Type[LstmEncoder], + decoder: Type[Classifier] = Type[Classifier], + reduction: Literal['sum', 'mean'] = 'mean', + *, + word_vocab: Vocab, target_vocab: Vocab, + ): + super(TextClassification, self).__init__() + + self.embedding = embedding(word_vocab=word_vocab) + self.encoder = encoder(input_dim=self.embedding.embedding_dim) + self.decoder = decoder(input_dim=self.encoder.output_dim, target_vocab=target_vocab) + + self.criterion = nn.CrossEntropyLoss( + ignore_index=target_vocab.stoi.get('', -100), + reduction=reduction, + ) + + def forward(self, batch: Batch) -> Tensor: + if isinstance(self.encoder, LstmEncoder): + embedding = self.embedding(batch.word) + encoding = self.encoder(embedding) + else: + embedding = self.embedding(batch.word[0]) + encoding = self.encoder(embedding) + return self.decoder(encoding) + + def fit(self, batch: Batch) -> Tensor: + logits = self(batch) + return self.criterion(logits, batch.target) + + def inference(self, batch: Batch) -> float: + prediction = self(batch).argmax(dim=-1) + return (prediction == batch.target).float().mean().item() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ce3263d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +pytest +torch +torchtext From d2f23271c579a605a61e3780f3acca709e1e1eba Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 15:08:31 +0900 Subject: [PATCH 05/72] Remove pytorch from requirements.txt --- aku/__init__.py | 10 +++------- requirements.txt | 4 +--- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/aku/__init__.py b/aku/__init__.py index f931470..0b02732 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,8 +1,6 @@ import sys from typing import List, Tuple, Set, FrozenSet -from typing import Optional -from typing import Type -from typing import Union +from typing import Optional, Type, Union if sys.version_info >= (3, 8): from typing import Literal @@ -11,14 +9,12 @@ Literal = _SpecialForm('Literal', doc='Literal') -from .aku import Aku +from aku.aku import Aku __version__ = '0.2.1' __all__ = [ - 'Type', + 'Type', 'Union', 'Optional', 'Literal', 'List', 'Tuple', 'Set', 'FrozenSet', - 'Union', 'Optional', 'Literal', - 'Aku', ] diff --git a/requirements.txt b/requirements.txt index ce3263d..55b033e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1 @@ -pytest -torch -torchtext +pytest \ No newline at end of file From 33ebd5764aa84e125cbb7c6a335d4a8b03106bd2 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 15:46:44 +0900 Subject: [PATCH 06/72] Fix bugs of text-classification example --- examples/text_classification.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/text_classification.py b/examples/text_classification.py index 10772c8..6d9f25e 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -37,7 +37,7 @@ def __init__(self, hidden_dim: int = 300, bias: bool = True, batch_first=True, bidirectional=True, ) - self.output_dim = hidden_dim + self.encoding_dim = hidden_dim * (2 if self.rnn.bidirectional else 1) def forward(self, embedding: PackedSequence) -> Tensor: _, (encoding, _) = self.rnn(embedding) @@ -58,12 +58,10 @@ def __init__(self, hidden_dim: int = 300, bias: bool = True, ]) self.fc = nn.Sequential( nn.Dropout(dropout), - nn.Linear(input_dim * len(kernel_sizes), input_dim * len(kernel_sizes)), - nn.ReLU(), - nn.Linear(input_dim * len(kernel_sizes), hidden_dim), + nn.Linear(input_dim * len(kernel_sizes), hidden_dim * 2), ) - self.output_dim = hidden_dim + self.encoding_dim = self.fc[-1].out_features def forward(self, embedding: Tensor) -> Tensor: encoding = torch.cat([ @@ -84,18 +82,18 @@ def __init__(self, input_dim: int, bias: bool = True, *, target_vocab: Vocab): class TextClassification(nn.Module): def __init__(self, - embedding: Type[WordEmbedding], - encoder: Union[Type[LstmEncoder], Type[ConvEncoder]] = Type[LstmEncoder], - decoder: Type[Classifier] = Type[Classifier], + Embedding: Type[WordEmbedding], + Encoder: Union[Type[LstmEncoder], Type[ConvEncoder]] = Type[LstmEncoder], + Decoder: Type[Classifier] = Type[Classifier], reduction: Literal['sum', 'mean'] = 'mean', *, word_vocab: Vocab, target_vocab: Vocab, ): super(TextClassification, self).__init__() - self.embedding = embedding(word_vocab=word_vocab) - self.encoder = encoder(input_dim=self.embedding.embedding_dim) - self.decoder = decoder(input_dim=self.encoder.output_dim, target_vocab=target_vocab) + self.embedding = Embedding(word_vocab=word_vocab) + self.encoder = Encoder(input_dim=self.embedding.embedding_dim) + self.decoder = Decoder(input_dim=self.encoder.encoding_dim, target_vocab=target_vocab) self.criterion = nn.CrossEntropyLoss( ignore_index=target_vocab.stoi.get('', -100), @@ -105,10 +103,9 @@ def __init__(self, def forward(self, batch: Batch) -> Tensor: if isinstance(self.encoder, LstmEncoder): embedding = self.embedding(batch.word) - encoding = self.encoder(embedding) else: embedding = self.embedding(batch.word[0]) - encoding = self.encoder(embedding) + encoding = self.encoder(embedding) return self.decoder(encoding) def fit(self, batch: Batch) -> Tensor: @@ -118,3 +115,9 @@ def fit(self, batch: Batch) -> Tensor: def inference(self, batch: Batch) -> float: prediction = self(batch).argmax(dim=-1) return (prediction == batch.target).float().mean().item() + + +def train_text_classification( + model: Type[TextClassification], +): + raise NotImplementedError From 59a4cefaad116286332ef34826e4231e4dd9ca3a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 15:55:54 +0900 Subject: [PATCH 07/72] Add optimizers and schedulers to text-classification example --- aku/__init__.py | 9 +++++++-- examples/text_classification.py | 34 +++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/aku/__init__.py b/aku/__init__.py index 0b02732..eadb0a7 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,6 +1,11 @@ import sys -from typing import List, Tuple, Set, FrozenSet -from typing import Optional, Type, Union +from typing import FrozenSet +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union if sys.version_info >= (3, 8): from typing import Literal diff --git a/examples/text_classification.py b/examples/text_classification.py index 6d9f25e..af93fc5 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch import nn +from torch import nn, optim from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence from torchtext.data import Batch from torchtext.vocab import Vocab @@ -117,7 +117,37 @@ def inference(self, batch: Batch) -> float: return (prediction == batch.target).float().mean().item() +def sgd(lr: float = 1e-3, momentum: float = 0.0, + weight_decay: float = 0.0, *, model: nn.Module): + return optim.SGD( + model.parameters(), lr=lr, + momentum=momentum, weight_decay=weight_decay, + ) + + +def adam(lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, + weight_decay: float = 0.0, *, model: nn.Module): + return optim.Adam( + model.parameters(), lr=lr, + betas=(beta1, beta2), weight_decay=weight_decay, + ) + + +def exponential(gamma: float = 0.98, *, optimizer: optim.Optimizer): + return optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=gamma, + ) + + +def half_life(half_life_epoch: int, *, optimizer: optim.Optimizer): + return optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=0.5 ** (1 / half_life_epoch), + ) + + def train_text_classification( - model: Type[TextClassification], + Model: Type[TextClassification], + Optimizer: Union[Type[sgd], Type[adam]] = Type[adam], + Scheduler: Union[Type[exponential], Type[half_life]] = Type[half_life], ): raise NotImplementedError From 0c3ad4943d0969cad9d75f62ffa330e6fdfb8afb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 16:03:16 +0900 Subject: [PATCH 08/72] Reformat text-classification example --- examples/text_classification.py | 63 ++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/examples/text_classification.py b/examples/text_classification.py index af93fc5..d361378 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -9,7 +9,11 @@ class WordEmbedding(nn.Embedding): - def __init__(self, freeze: bool = False, *, word_vocab: Vocab) -> None: + def __init__(self, + freeze: bool = False, + *, + word_vocab: Vocab, + ): num_embeddings, embedding_dim = word_vocab.vectors.size() super(WordEmbedding, self).__init__( num_embeddings=num_embeddings, embedding_dim=embedding_dim, @@ -27,8 +31,14 @@ def forward(self, x): class LstmEncoder(nn.Module): - def __init__(self, hidden_dim: int = 300, bias: bool = True, - num_layers: int = 2, dropout: float = 0.2, *, input_dim: int): + def __init__(self, + hidden_dim: int = 300, + bias: bool = True, + num_layers: int = 2, + dropout: float = 0.2, + *, + input_dim: int, + ): super(LstmEncoder, self).__init__() self.rnn = nn.LSTM( @@ -45,8 +55,14 @@ def forward(self, embedding: PackedSequence) -> Tensor: class ConvEncoder(nn.Module): - def __init__(self, hidden_dim: int = 300, bias: bool = True, - kernel_sizes: Tuple[int, ...] = (3, 5, 7), dropout: float = 0.2, *, input_dim: int): + def __init__(self, + hidden_dim: int = 300, + bias: bool = True, + kernel_sizes: Tuple[int, ...] = (3, 5, 7), + dropout: float = 0.2, + *, + input_dim: int, + ): super(ConvEncoder, self).__init__() self.conv_layers = nn.ModuleList([ @@ -72,7 +88,12 @@ def forward(self, embedding: Tensor) -> Tensor: class Classifier(nn.Sequential): - def __init__(self, input_dim: int, bias: bool = True, *, target_vocab: Vocab): + def __init__(self, + input_dim: int, + bias: bool = True, + *, + target_vocab: Vocab, + ): super(Classifier, self).__init__( nn.Linear(input_dim, input_dim, bias=bias), nn.ReLU(), @@ -117,29 +138,45 @@ def inference(self, batch: Batch) -> float: return (prediction == batch.target).float().mean().item() -def sgd(lr: float = 1e-3, momentum: float = 0.0, - weight_decay: float = 0.0, *, model: nn.Module): +def sgd(lr: float = 1e-3, + momentum: float = 0.0, + weight_decay: float = 0.0, + *, + model: nn.Module, + ): return optim.SGD( model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, ) -def adam(lr: float = 1e-3, beta1: float = 0.9, beta2: float = 0.999, - weight_decay: float = 0.0, *, model: nn.Module): +def adam(lr: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.999, + weight_decay: + float = 0.0, + *, + model: nn.Module, + ): return optim.Adam( model.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=weight_decay, ) -def exponential(gamma: float = 0.98, *, optimizer: optim.Optimizer): +def exponential(gamma: float = 0.98, + *, + optimizer: optim.Optimizer, + ): return optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=gamma, ) -def half_life(half_life_epoch: int, *, optimizer: optim.Optimizer): +def half_life(half_life_epoch: int, + *, + optimizer: optim.Optimizer, + ): return optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=0.5 ** (1 / half_life_epoch), ) @@ -148,6 +185,6 @@ def half_life(half_life_epoch: int, *, optimizer: optim.Optimizer): def train_text_classification( Model: Type[TextClassification], Optimizer: Union[Type[sgd], Type[adam]] = Type[adam], - Scheduler: Union[Type[exponential], Type[half_life]] = Type[half_life], + Scheduler: Union[Type[exponential], Type[half_life]] = Type[half_life] ): raise NotImplementedError From 761ca3f40ac3d3bc5aa20ffe491ff7cf78772d02 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 15:35:37 +0900 Subject: [PATCH 09/72] Init parsing-fn module --- aku/parsing_fn.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 aku/parsing_fn.py diff --git a/aku/parsing_fn.py b/aku/parsing_fn.py new file mode 100644 index 0000000..51c8377 --- /dev/null +++ b/aku/parsing_fn.py @@ -0,0 +1,35 @@ +from typing import get_type_hints, Type, Callable, Any, Dict + +ParsingFn = Callable[[str], Any] +parsing_fn_registry: Dict[Any, ParsingFn] = {} + + +def register_parsing_fn(fn: ParsingFn) -> ParsingFn: + retype = get_type_hints(fn)['return'] + assert retype not in parsing_fn_registry, \ + f'the parsing function of {retype} is already registered ({parsing_fn_registry[retype]})' + + parsing_fn_registry[retype] = fn + return fn + + +def get_parsing_fn(retype: Type) -> ParsingFn: + return parsing_fn_registry.get(retype, retype) + + +@register_parsing_fn +def str2bool(option_string: str) -> bool: + option_string = option_string.strip().lower() + if option_string in ('1', 't', 'true', 'y', 'yes'): + return True + if option_string in ('0', 'f', 'false', 'n', 'no'): + return False + raise ValueError(f'{option_string} is not a boolean value.') + + +@register_parsing_fn +def str2none(option_string: str) -> type(None): + option_string = option_string.strip().lower() + if option_string in ('nil', 'null', 'none'): + return True + raise ValueError(f'{option_string} is not a null value.') From c471489098fdaa91a3c8af7abb33aaaad7050b89 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 16:14:46 +0900 Subject: [PATCH 10/72] Update text-classification example --- examples/text_classification.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/text_classification.py b/examples/text_classification.py index d361378..dc699e3 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -36,6 +36,7 @@ def __init__(self, bias: bool = True, num_layers: int = 2, dropout: float = 0.2, + bidirectional: bool = True, *, input_dim: int, ): @@ -44,10 +45,10 @@ def __init__(self, self.rnn = nn.LSTM( input_size=input_dim, hidden_size=hidden_dim, bias=bias, num_layers=num_layers, dropout=dropout, - batch_first=True, bidirectional=True, + batch_first=True, bidirectional=bidirectional, ) - self.encoding_dim = hidden_dim * (2 if self.rnn.bidirectional else 1) + self.encoding_dim = self.rnn.hidden_dim * (2 if self.rnn.bidirectional else 1) def forward(self, embedding: PackedSequence) -> Tensor: _, (encoding, _) = self.rnn(embedding) @@ -173,7 +174,7 @@ def exponential(gamma: float = 0.98, ) -def half_life(half_life_epoch: int, +def half_life(half_life_epoch: int = 50, *, optimizer: optim.Optimizer, ): From c3785ade2b1529424ad85e9a06ec720f0f7b5731 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 16:16:25 +0900 Subject: [PATCH 11/72] Remove VERSION file --- VERSION | 1 - setup.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) delete mode 100644 VERSION diff --git a/VERSION b/VERSION deleted file mode 100644 index 7dff5b8..0000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -0.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 00694b0..7c24d72 100644 --- a/setup.py +++ b/setup.py @@ -7,10 +7,9 @@ name='aku', description='Annotation-driven ArgumentParser Generator', long_description=long_description, - version=open('VERSION', mode='r').read(), + version='0.2.1', packages=find_packages(), classifiers=[ - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3 :: Only', 'Development Status :: 2 - Pre-Alpha', 'License :: OSI Approved :: MIT License', @@ -22,5 +21,5 @@ license='MIT', author='Izen', author_email='speedcell4@gmail.com', - python_requires='>=3.7', + python_requires='>=3.5', ) From f24f64c3972b0629c125761e88361d421b38082d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 16:20:14 +0900 Subject: [PATCH 12/72] Init single-function and multiple-functions example --- README.md | 18 +++++++++--------- examples/multiple_functions.py | 16 ++++++++++++++++ examples/single_function.py | 11 +++++++++++ 3 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 examples/multiple_functions.py create mode 100644 examples/single_function.py diff --git a/README.md b/README.md index 145cd87..717b010 100644 --- a/README.md +++ b/README.md @@ -19,17 +19,17 @@ python -m pip install aku --upgrade ```python # tests/test_single_function.py -import aku +from aku import Aku -app = aku.Aku() +aku = Aku() -@app.register +@aku.register def add(a: int, b: int = 2): print(f'{a} + {b} => {a + b}') -app.run() +aku.run() ``` `aku` will automatically add argument options according to your function signature. @@ -49,22 +49,22 @@ Registering more than one function will make `aku` add them to sub-parser respec ```python # file test_multi_functions.py -import aku +from aku import Aku -app = aku.Aku() +aku = Aku() -@app.register +@aku.register def add(a: int, b: int = 2): print(f'{a} + {b} => {a + b}') -@app.register +@aku.register def say_hello(name: str): print(f'hello {name}') -app.run() +aku.run() ``` Similarly, your command line interface will look like, diff --git a/examples/multiple_functions.py b/examples/multiple_functions.py new file mode 100644 index 0000000..41f38bf --- /dev/null +++ b/examples/multiple_functions.py @@ -0,0 +1,16 @@ +from aku import Aku + +aku = Aku() + + +@aku.register +def add(a: int, b: int = 2): + print(f'{a} + {b} => {a + b}') + + +@aku.register +def say_hello(name: str): + print(f'hello {name}') + + +aku.run() diff --git a/examples/single_function.py b/examples/single_function.py new file mode 100644 index 0000000..355ee16 --- /dev/null +++ b/examples/single_function.py @@ -0,0 +1,11 @@ +from aku import Aku + +aku = Aku() + + +@aku.register +def add(a: int, b: int = 2): + print(f'{a} + {b} => {a + b}') + + +aku.run() From 4ea715dea2c4d0958b9f54fdc62bc18569178de6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 30 Nov 2019 17:52:48 +0900 Subject: [PATCH 13/72] Add str2set, str2list, and str2tuple --- aku/parsing_fn.py | 75 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/aku/parsing_fn.py b/aku/parsing_fn.py index 51c8377..37eacb0 100644 --- a/aku/parsing_fn.py +++ b/aku/parsing_fn.py @@ -1,20 +1,42 @@ -from typing import get_type_hints, Type, Callable, Any, Dict +from typing import get_type_hints, Type, Callable, Any, Dict, Set, Union, List, Tuple -ParsingFn = Callable[[str], Any] +ParsingFn = Union[Callable[[str], Any], Type] parsing_fn_registry: Dict[Any, ParsingFn] = {} +ParsingFnGen = Callable[[Any], ParsingFn] +parsing_fn_gen_registry: Dict[Any, ParsingFnGen] = {} + def register_parsing_fn(fn: ParsingFn) -> ParsingFn: retype = get_type_hints(fn)['return'] assert retype not in parsing_fn_registry, \ - f'the parsing function of {retype} is already registered ({parsing_fn_registry[retype]})' + f'the parsing function of {retype} is already registered ' \ + f'({parsing_fn_registry[retype]})' parsing_fn_registry[retype] = fn return fn +def register_parsing_fn_gen(origin: Any): + def _register_parsing_fn_gen(fn: ParsingFnGen) -> ParsingFnGen: + assert origin not in parsing_fn_gen_registry, \ + f'the parsing function generator of {origin} is already registered ' \ + f'({parsing_fn_gen_registry[origin]})' + + parsing_fn_gen_registry[origin] = fn + return fn + + return _register_parsing_fn_gen + + def get_parsing_fn(retype: Type) -> ParsingFn: - return parsing_fn_registry.get(retype, retype) + if retype in parsing_fn_registry: + return parsing_fn_registry[retype] + origin = getattr(retype, '__origin__', None) + args, *_ = getattr(retype, '__args__', (None,)) + if origin in parsing_fn_gen_registry: + return parsing_fn_gen_registry[origin](args) + return retype @register_parsing_fn @@ -24,7 +46,7 @@ def str2bool(option_string: str) -> bool: return True if option_string in ('0', 'f', 'false', 'n', 'no'): return False - raise ValueError(f'{option_string} is not a boolean value.') + raise ValueError(f'{option_string} is not a {bool} value.') @register_parsing_fn @@ -32,4 +54,45 @@ def str2none(option_string: str) -> type(None): option_string = option_string.strip().lower() if option_string in ('nil', 'null', 'none'): return True - raise ValueError(f'{option_string} is not a null value.') + raise ValueError(f'{option_string} is not a {type(None)} value.') + + +@register_parsing_fn_gen(set) +def str2set(retype: Any, sep: str = ',') -> ParsingFn: + @register_parsing_fn + def _str2set(option_string: str) -> Set[retype]: + option_string = option_string.strip() + if option_string.startswith('{') and option_string.endswith('}'): + return set(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) + raise ValueError(f'{option_string} is not a Set[{retype}] value.') + + return _str2set + + +@register_parsing_fn_gen(list) +def str2list(retype: Any, sep: str = ',') -> ParsingFn: + @register_parsing_fn + def _str2list(option_string: str) -> List[retype]: + option_string = option_string.strip() + if option_string.startswith('[') and option_string.endswith(']'): + return list(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) + raise ValueError(f'{option_string} is not a List[{retype}] value.') + + return _str2list + + +@register_parsing_fn_gen(tuple) +def str2tuple(retype: Any, sep: str = ',') -> ParsingFn: + @register_parsing_fn + def _str2tuple(option_string: str) -> Tuple[retype, ...]: + option_string = option_string.strip() + if option_string.startswith('(') and option_string.endswith(')'): + return tuple(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) + raise ValueError(f'{option_string} is not a Tuple[{retype}, ...] value.') + + return _str2tuple + + +if __name__ == '__main__': + print(get_parsing_fn(List[List[int]])('[[23,23],[34]]')) + # print(get_parsing_fn(List[Set[int]])('[{23,23},{34}]')) From c10391d6a91f8466cb72eedd4c5a63e76c049276 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Apr 2020 23:03:27 +0900 Subject: [PATCH 14/72] Init: Remove them all and start from scratch --- .github/workflows/unit-tests.yml | 17 +-- README.md | 3 +- aku/__init__.py | 25 ---- aku/add_option.py | 164 --------------------- aku/aku.py | 58 -------- aku/annotation.py | 237 ------------------------------- aku/parsing_fn.py | 98 ------------- examples/multiple_functions.py | 16 --- examples/single_function.py | 11 -- examples/text_classification.py | 191 ------------------------- requirements.txt | 1 - setup.py | 24 ++-- 12 files changed, 21 insertions(+), 824 deletions(-) delete mode 100644 aku/add_option.py delete mode 100644 aku/aku.py delete mode 100644 aku/annotation.py delete mode 100644 aku/parsing_fn.py delete mode 100644 examples/multiple_functions.py delete mode 100644 examples/single_function.py delete mode 100644 examples/text_classification.py delete mode 100644 requirements.txt diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index aa8a205..fbecce8 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -5,24 +5,19 @@ on: [push] jobs: build: - name: Test on Python ${{ matrix.python-version }} and ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - max-parallel: 4 - matrix: - os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: [3.5, 3.6, 3.7, 3.8] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} + - name: Set up Python 3.7 uses: actions/setup-python@v1 with: - python-version: ${{ matrix.python-version }} + python-version: 3.7 - name: Install dependencies run: | python -m pip install --upgrade pip + python -m pip install torch + python -m pip install -e '.[dev]' - name: Test with pytest run: | - python -m pip install -r requirements.txt - python -m pytest tests + python -m pytest tests \ No newline at end of file diff --git a/README.md b/README.md index 717b010..caca0de 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ # Aku -[![PyPI Version](https://badge.fury.io/py/aku.svg)](https://pypi.org/project/aku/) [![Actions Status](https://github.com/speedcell4/aku/workflows/unit-tests/badge.svg)](https://github.com/speedcell4/aku/actions) An Annotation-driven ArgumentParser Generator ## Requirements -Python 3.5 or higher +Python 3.7 or higher ## Install diff --git a/aku/__init__.py b/aku/__init__.py index eadb0a7..e69de29 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,25 +0,0 @@ -import sys -from typing import FrozenSet -from typing import List -from typing import Optional -from typing import Set -from typing import Tuple -from typing import Type -from typing import Union - -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing import _SpecialForm - - Literal = _SpecialForm('Literal', doc='Literal') - -from aku.aku import Aku - -__version__ = '0.2.1' - -__all__ = [ - 'Type', 'Union', 'Optional', 'Literal', - 'List', 'Tuple', 'Set', 'FrozenSet', - 'Aku', -] diff --git a/aku/add_option.py b/aku/add_option.py deleted file mode 100644 index 6d3217e..0000000 --- a/aku/add_option.py +++ /dev/null @@ -1,164 +0,0 @@ -from argparse import Action, ArgumentParser, Namespace -from typing import Any, Optional - -from aku import annotation as ty -from aku.annotation import annotation_iter - -DONE = '_DONE' - - -def get_name(name, prefix): - if prefix is None: - return name - return f'{prefix}_{name}' - - -def cat_prefix(prefix1, prefix2): - if prefix1 is None: - return prefix2 - if prefix2 is None: - return prefix1 - return f'{prefix1}_{prefix2}' - - -def add_primitive(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - parser.add_argument( - f'--{dest}', dest=dest, help=name, default=default, - type=ty.primitive_type(annotation), metavar=ty.primitive_metavar(annotation), - ) - return dest - - -class ListAction(Action): - def __call__(self, action_parser: ArgumentParser, namespace: Namespace, values, option_string) -> None: - setattr(namespace, self.dest, getattr(self, self.dest, []) + [values]) - - -def add_list(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - - parser.add_argument( - f'--{dest}', dest=dest, help=name, default=default, - type=ty.list_type(annotation), metavar=ty.list_metavar(annotation), - action=ListAction, required=True, - ) - - return dest - - -class TupleAction(Action): - def __call__(self, action_parser: ArgumentParser, namespace: Namespace, values, option_string) -> None: - setattr(namespace, self.dest, getattr(self, self.dest, ()) + (values,)) - - -def add_tuple(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - - parser.add_argument( - f'--{dest}', dest=dest, help=name, default=default, - type=ty.tuple_type(annotation), metavar=ty.tuple_metavar(annotation), - action=TupleAction, required=True, - ) - - return dest - - -def add_value_union(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - - parser.add_argument( - f'--{dest}', dest=dest, help=name, default=default, - type=ty.value_union_type(annotation), metavar=ty.value_union_metavar(annotation), - choices=annotation, - ) - - return dest - - -def add_type_union(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - obj_dest = f'@{dest}' - fn_mapping = { - a.__name__: a - for a in ty.type_union_args(annotation) - } - - class TypeUnionAction(Action): - def __call__(self, action_parser: ArgumentParser, namespace: Namespace, values, option_string) -> None: - if not getattr(self, DONE, False): - setattr(self, DONE, True) - setattr(namespace, obj_dest, values) - - func = fn_mapping[values] - add_function( - parser=parser, prefix=prefix, - name=dest, annotation=func, default=None, delays=delays, - ) - - action_parser.set_defaults(**{dest: func}) - action_parser.parse_known_args(namespace=namespace) - - parser.add_argument( - f'--{dest}', dest=obj_dest, help=name, default=default if isinstance(default, str) else default.__name__, - type=ty.type_union_type(annotation), metavar=ty.type_union_metavar(annotation), - action=TypeUnionAction, choices=tuple(fn_mapping.keys()), - ) - - return dest - - -def add_type_var(parser: ArgumentParser, prefix: str, name: str, annotation: Any, default: Any, delays): - dest = get_name(name, prefix) - obj_dest = f'@{dest}' - fn_mapping = { - a.__name__: a - for a in ty.type_var_args(annotation) - } - - class TypeVarAction(Action): - def __call__(self, action_parser: ArgumentParser, namespace: Namespace, values, option_string) -> None: - if not getattr(self, DONE, False): - setattr(self, DONE, True) - setattr(namespace, obj_dest, values) - - func = fn_mapping[values] - add_function( - parser=parser, prefix=cat_prefix(prefix, annotation.__name__), - name=dest, annotation=func, default=None, delays=delays, - ) - - action_parser.set_defaults(**{dest: func}) - action_parser.parse_known_args(namespace=namespace) - - parser.add_argument( - f'--{dest}', dest=obj_dest, help=name, default=default if isinstance(default, str) else default.__name__, - type=ty.type_union_type(annotation), metavar=ty.type_union_metavar(annotation), - action=TypeVarAction, choices=tuple(fn_mapping.keys()), - ) - - return dest - - -def add_function(parser: ArgumentParser, prefix: Optional[str], name: Optional[str], annotation: Any, default: Any, - delays): - for nn, aa, dd in annotation_iter(annotation): - if ty.is_list(aa): - add_op = add_list - elif ty.is_tuple(aa): - add_op = add_tuple - elif ty.is_value_union(aa): - add_op = add_value_union - elif ty.is_type_union(aa): - add_op = add_type_union - elif ty.is_type_var(aa): - add_op = add_type_var - else: - add_op = add_primitive - - key = add_op( - parser=parser, prefix=prefix, - delays=delays, name=nn, annotation=aa, default=dd, - ) - if name is not None: - delays.append((name, nn, key)) diff --git a/aku/aku.py b/aku/aku.py deleted file mode 100644 index 0969829..0000000 --- a/aku/aku.py +++ /dev/null @@ -1,58 +0,0 @@ -import sys -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from functools import partial - -from aku import add_option - - -class Aku(object): - def __init__(self, prog: str = None, usage: str = None, description: str = None): - super(Aku, self).__init__() - self.parser = ArgumentParser( - prog=prog, usage=usage, description=description, - formatter_class=ArgumentDefaultsHelpFormatter, - ) - self._funcs = {} - - def register(self, func): - self._funcs[func.__name__] = func - return func - - def run(self, args=None, namespace=None): - self.delays = [] - - if len(self._funcs) == 1: - fn = list(self._funcs.values())[0] - add_option.add_function( - parser=self.parser, prefix=None, name=None, - annotation=fn, default=None, delays=self.delays, - ) - else: - subparsers = self.parser.add_subparsers() - parsers = { - name: subparsers.add_parser(name) - for name, _ in self._funcs.items() - } - if sys.argv.__len__() > 1 and sys.argv[1] in parsers: - fn = self._funcs[sys.argv[1]] - add_option.add_function( - parser=parsers[sys.argv[1]], prefix=None, name=None, - annotation=fn, default=None, delays=self.delays, - ) - - args, _ = self.parser.parse_known_args(args, namespace) - self.raw_args = {k: v for k, v in vars(args).items()} - self.args = {k: v for k, v in vars(args).items()} - - for dest, name, key in reversed(self.delays): - self.args[dest] = partial(self.args[dest], **{name: self.args[key]}) - del self.args[key] - - obj_dest = f'@{dest}' - if obj_dest in self.args: - del self.args[obj_dest] - if dest in self.raw_args and obj_dest in self.raw_args: - self.raw_args[dest] = self.raw_args[obj_dest] - del self.raw_args[obj_dest] - - return fn(**self.args) diff --git a/aku/annotation.py b/aku/annotation.py deleted file mode 100644 index cef6058..0000000 --- a/aku/annotation.py +++ /dev/null @@ -1,237 +0,0 @@ -import inspect -from argparse import SUPPRESS -from itertools import zip_longest -from typing import Optional, TypeVar, Union - - -def add(x: int, y: int = None) -> int: - return x + y - - -class Po(object): - def __init__(self, x: int = 2, y: int = None): - self.x = x - self.y = y - - -class Rec(object): - def __init__(self, x: int, y: int, w: int, h: int): - self.x = x - self.y = y - self.w = w - self.h = h - - -def is_list(ty): - return getattr(ty, '_name', None) == 'List' - - -def is_tuple(ty): - if getattr(ty, '_name', None) == 'Tuple': - args = getattr(ty, '__args__') - if len(args) == 2 and args[-1] is ...: - return True - return False - - -def is_type_var(ty): - return isinstance(ty, TypeVar) - - -def is_value_union(ty): - if isinstance(ty, tuple): - if not any(callable(t) for t in ty): - return True - return False - - -def is_type_union(ty): - if getattr(ty, '__origin__', None) is Union: - args = getattr(ty, '__args__') - if type(None) in args: - return False - return True - return False - - -def is_optional(ty): - if getattr(ty, '__origin__', None) is Union: - args = getattr(ty, '__args__') - if type(None) in args: - return True - return False - - -def list_arg(ty): - return getattr(ty, '__args__')[0] - - -def tuple_arg(ty): - return getattr(ty, '__args__')[0] - - -def type_var_args(ty): - return getattr(ty, '__constraints__') - - -def value_union_arg(ty): - return type(ty[0]) - - -def type_union_args(ty): - return getattr(ty, '__args__') - - -def optional_args(ty): - ret = Union[getattr(ty, '__args__')[:-1]] - if is_type_union(ret): - return type_union_args(ret) - return ret - - -def annotation_iter(func): - spec = inspect.getfullargspec(func) - args = spec.args - if inspect.ismethod(func) or inspect.isclass(func): - args = spec.args[1:] - - names, annotations, defaults = [], [], [] - for name, default in zip_longest(reversed(args), reversed(spec.defaults or []), fillvalue=SUPPRESS): - names.append(name) - annotation = spec.annotations.get(name, str) - annotations.append(Optional[annotation] if default is None else annotation) - defaults.append(default) - - return zip(names[::-1], annotations[::-1], defaults[::-1]) - - -_types = {} - - -def register_type(func): - ty = inspect.getfullargspec(func).annotations.get('return') - if ty not in _types: - _types[ty] = func - return func - - -@register_type -def str2none(option_string: str) -> type(None): - option_string = option_string.lower() - if option_string in ('nil', 'none', 'null'): - return None - raise ValueError(f'{option_string} is not a null value') - - -@register_type -def str2bool(option_string: str) -> bool: - option_string = option_string.lower() - if option_string in ('1', 'y', 'yes', 't', 'true'): - return True - if option_string in ('0', 'n', 'no', 'f', 'false'): - return False - raise ValueError(f'{option_string} is not a boolean value') - - -def combine_types(*fns): - @register_type - def type_fn(option_string: str) -> Union[fns]: - for fn in fns: - try: - return fn(option_string) - except ValueError: - pass - raise ValueError - - return type_fn - - -def get_type(ty): - if is_list(ty): - return list_type(ty) - if is_tuple(ty): - return tuple_type(ty) - if is_value_union(ty): - return value_union_type(ty) - if is_type_union(ty): - return type_union_type(ty) - if is_type_var(ty): - return type_var_type(ty) - if is_optional(ty): - return optional_type(ty) - return primitive_type(ty) - - -def primitive_type(ty): - return _types.get(ty, ty) - - -def list_type(ty): - return get_type(list_arg(ty)) - - -def tuple_type(ty): - return get_type(tuple_arg(ty)) - - -def value_union_type(ty): - return get_type(value_union_arg(ty)) - - -def type_union_type(ty): - return str - - -def type_var_type(ty): - return str - - -def optional_type(ty): - return combine_types( - get_type(optional_args(ty)), - get_type(type(None)) - ) - - -def get_metavar(ty): - if is_list(ty): - return list_metavar(ty) - if is_tuple(ty): - return type_union_metavar(ty) - if is_value_union(ty): - return value_union_metavar(ty) - if is_type_union(ty): - return type_union_metavar(ty) - if is_type_var(ty): - return type_var_metavar(ty) - if is_optional(ty): - return optional_metavar(ty) - return primitive_metavar(ty) - - -def list_metavar(ty): - return f'[{get_metavar(list_arg(ty))}]' - - -def tuple_metavar(ty): - return f'({get_metavar(tuple_arg(ty))})' - - -def value_union_metavar(ty): - return None - - -def type_union_metavar(ty): - return None - - -def type_var_metavar(ty): - return None - - -def optional_metavar(ty): - return f'{get_metavar(optional_args(ty))}?' - - -def primitive_metavar(ty): - return ty.__name__ diff --git a/aku/parsing_fn.py b/aku/parsing_fn.py deleted file mode 100644 index 37eacb0..0000000 --- a/aku/parsing_fn.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import get_type_hints, Type, Callable, Any, Dict, Set, Union, List, Tuple - -ParsingFn = Union[Callable[[str], Any], Type] -parsing_fn_registry: Dict[Any, ParsingFn] = {} - -ParsingFnGen = Callable[[Any], ParsingFn] -parsing_fn_gen_registry: Dict[Any, ParsingFnGen] = {} - - -def register_parsing_fn(fn: ParsingFn) -> ParsingFn: - retype = get_type_hints(fn)['return'] - assert retype not in parsing_fn_registry, \ - f'the parsing function of {retype} is already registered ' \ - f'({parsing_fn_registry[retype]})' - - parsing_fn_registry[retype] = fn - return fn - - -def register_parsing_fn_gen(origin: Any): - def _register_parsing_fn_gen(fn: ParsingFnGen) -> ParsingFnGen: - assert origin not in parsing_fn_gen_registry, \ - f'the parsing function generator of {origin} is already registered ' \ - f'({parsing_fn_gen_registry[origin]})' - - parsing_fn_gen_registry[origin] = fn - return fn - - return _register_parsing_fn_gen - - -def get_parsing_fn(retype: Type) -> ParsingFn: - if retype in parsing_fn_registry: - return parsing_fn_registry[retype] - origin = getattr(retype, '__origin__', None) - args, *_ = getattr(retype, '__args__', (None,)) - if origin in parsing_fn_gen_registry: - return parsing_fn_gen_registry[origin](args) - return retype - - -@register_parsing_fn -def str2bool(option_string: str) -> bool: - option_string = option_string.strip().lower() - if option_string in ('1', 't', 'true', 'y', 'yes'): - return True - if option_string in ('0', 'f', 'false', 'n', 'no'): - return False - raise ValueError(f'{option_string} is not a {bool} value.') - - -@register_parsing_fn -def str2none(option_string: str) -> type(None): - option_string = option_string.strip().lower() - if option_string in ('nil', 'null', 'none'): - return True - raise ValueError(f'{option_string} is not a {type(None)} value.') - - -@register_parsing_fn_gen(set) -def str2set(retype: Any, sep: str = ',') -> ParsingFn: - @register_parsing_fn - def _str2set(option_string: str) -> Set[retype]: - option_string = option_string.strip() - if option_string.startswith('{') and option_string.endswith('}'): - return set(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) - raise ValueError(f'{option_string} is not a Set[{retype}] value.') - - return _str2set - - -@register_parsing_fn_gen(list) -def str2list(retype: Any, sep: str = ',') -> ParsingFn: - @register_parsing_fn - def _str2list(option_string: str) -> List[retype]: - option_string = option_string.strip() - if option_string.startswith('[') and option_string.endswith(']'): - return list(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) - raise ValueError(f'{option_string} is not a List[{retype}] value.') - - return _str2list - - -@register_parsing_fn_gen(tuple) -def str2tuple(retype: Any, sep: str = ',') -> ParsingFn: - @register_parsing_fn - def _str2tuple(option_string: str) -> Tuple[retype, ...]: - option_string = option_string.strip() - if option_string.startswith('(') and option_string.endswith(')'): - return tuple(map(get_parsing_fn(retype), option_string[1:-1].split(sep))) - raise ValueError(f'{option_string} is not a Tuple[{retype}, ...] value.') - - return _str2tuple - - -if __name__ == '__main__': - print(get_parsing_fn(List[List[int]])('[[23,23],[34]]')) - # print(get_parsing_fn(List[Set[int]])('[{23,23},{34}]')) diff --git a/examples/multiple_functions.py b/examples/multiple_functions.py deleted file mode 100644 index 41f38bf..0000000 --- a/examples/multiple_functions.py +++ /dev/null @@ -1,16 +0,0 @@ -from aku import Aku - -aku = Aku() - - -@aku.register -def add(a: int, b: int = 2): - print(f'{a} + {b} => {a + b}') - - -@aku.register -def say_hello(name: str): - print(f'hello {name}') - - -aku.run() diff --git a/examples/single_function.py b/examples/single_function.py deleted file mode 100644 index 355ee16..0000000 --- a/examples/single_function.py +++ /dev/null @@ -1,11 +0,0 @@ -from aku import Aku - -aku = Aku() - - -@aku.register -def add(a: int, b: int = 2): - print(f'{a} + {b} => {a + b}') - - -aku.run() diff --git a/examples/text_classification.py b/examples/text_classification.py deleted file mode 100644 index dc699e3..0000000 --- a/examples/text_classification.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -from torch import Tensor -from torch import nn, optim -from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence -from torchtext.data import Batch -from torchtext.vocab import Vocab - -from aku import Literal, Type, Union, Tuple - - -class WordEmbedding(nn.Embedding): - def __init__(self, - freeze: bool = False, - *, - word_vocab: Vocab, - ): - num_embeddings, embedding_dim = word_vocab.vectors.size() - super(WordEmbedding, self).__init__( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, - padding_idx=word_vocab.stoi.get('', None), - _weight=word_vocab.vectors, - ) - self.weight.requires_grad = not freeze - - def forward(self, x): - if torch.is_tensor(x): - return super(WordEmbedding, self).forward(x) - elif not isinstance(x, PackedSequence): - x = pack_padded_sequence(*x, batch_first=True, enforce_sorted=False) - return x._replace(data=super(WordEmbedding, self).forward(x.data)) - - -class LstmEncoder(nn.Module): - def __init__(self, - hidden_dim: int = 300, - bias: bool = True, - num_layers: int = 2, - dropout: float = 0.2, - bidirectional: bool = True, - *, - input_dim: int, - ): - super(LstmEncoder, self).__init__() - - self.rnn = nn.LSTM( - input_size=input_dim, hidden_size=hidden_dim, bias=bias, - num_layers=num_layers, dropout=dropout, - batch_first=True, bidirectional=bidirectional, - ) - - self.encoding_dim = self.rnn.hidden_dim * (2 if self.rnn.bidirectional else 1) - - def forward(self, embedding: PackedSequence) -> Tensor: - _, (encoding, _) = self.rnn(embedding) - return encoding - - -class ConvEncoder(nn.Module): - def __init__(self, - hidden_dim: int = 300, - bias: bool = True, - kernel_sizes: Tuple[int, ...] = (3, 5, 7), - dropout: float = 0.2, - *, - input_dim: int, - ): - super(ConvEncoder, self).__init__() - - self.conv_layers = nn.ModuleList([ - nn.Conv2d( - in_channels=input_dim, out_channels=input_dim, bias=bias, - kernel_size=kernel_size, padding=kernel_size // 2, - ) - for kernel_size in kernel_sizes - ]) - self.fc = nn.Sequential( - nn.Dropout(dropout), - nn.Linear(input_dim * len(kernel_sizes), hidden_dim * 2), - ) - - self.encoding_dim = self.fc[-1].out_features - - def forward(self, embedding: Tensor) -> Tensor: - encoding = torch.cat([ - layer(embedding).max(dim=1) - for layer in self.conv_layers - ], dim=-1) - return self.fc(encoding) - - -class Classifier(nn.Sequential): - def __init__(self, - input_dim: int, - bias: bool = True, - *, - target_vocab: Vocab, - ): - super(Classifier, self).__init__( - nn.Linear(input_dim, input_dim, bias=bias), - nn.ReLU(), - nn.Linear(input_dim, len(target_vocab), bias=bias), - ) - - -class TextClassification(nn.Module): - def __init__(self, - Embedding: Type[WordEmbedding], - Encoder: Union[Type[LstmEncoder], Type[ConvEncoder]] = Type[LstmEncoder], - Decoder: Type[Classifier] = Type[Classifier], - reduction: Literal['sum', 'mean'] = 'mean', - *, - word_vocab: Vocab, target_vocab: Vocab, - ): - super(TextClassification, self).__init__() - - self.embedding = Embedding(word_vocab=word_vocab) - self.encoder = Encoder(input_dim=self.embedding.embedding_dim) - self.decoder = Decoder(input_dim=self.encoder.encoding_dim, target_vocab=target_vocab) - - self.criterion = nn.CrossEntropyLoss( - ignore_index=target_vocab.stoi.get('', -100), - reduction=reduction, - ) - - def forward(self, batch: Batch) -> Tensor: - if isinstance(self.encoder, LstmEncoder): - embedding = self.embedding(batch.word) - else: - embedding = self.embedding(batch.word[0]) - encoding = self.encoder(embedding) - return self.decoder(encoding) - - def fit(self, batch: Batch) -> Tensor: - logits = self(batch) - return self.criterion(logits, batch.target) - - def inference(self, batch: Batch) -> float: - prediction = self(batch).argmax(dim=-1) - return (prediction == batch.target).float().mean().item() - - -def sgd(lr: float = 1e-3, - momentum: float = 0.0, - weight_decay: float = 0.0, - *, - model: nn.Module, - ): - return optim.SGD( - model.parameters(), lr=lr, - momentum=momentum, weight_decay=weight_decay, - ) - - -def adam(lr: float = 1e-3, - beta1: float = 0.9, - beta2: float = 0.999, - weight_decay: - float = 0.0, - *, - model: nn.Module, - ): - return optim.Adam( - model.parameters(), lr=lr, - betas=(beta1, beta2), weight_decay=weight_decay, - ) - - -def exponential(gamma: float = 0.98, - *, - optimizer: optim.Optimizer, - ): - return optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, gamma=gamma, - ) - - -def half_life(half_life_epoch: int = 50, - *, - optimizer: optim.Optimizer, - ): - return optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, gamma=0.5 ** (1 / half_life_epoch), - ) - - -def train_text_classification( - Model: Type[TextClassification], - Optimizer: Union[Type[sgd], Type[adam]] = Type[adam], - Scheduler: Union[Type[exponential], Type[half_life]] = Type[half_life] -): - raise NotImplementedError diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 55b033e..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -pytest \ No newline at end of file diff --git a/setup.py b/setup.py index 7c24d72..7ef8706 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,10 @@ -from setuptools import setup, find_packages - -with open('README.md', 'r', encoding='utf-8') as fp: - long_description = fp.read() +from setuptools import setup setup( name='aku', - description='Annotation-driven ArgumentParser Generator', - long_description=long_description, - version='0.2.1', - packages=find_packages(), + description='An Annotation-driven ArgumentParser Generator', + version='0.2.0', + packages=['aku'], classifiers=[ 'Programming Language :: Python :: 3 :: Only', 'Development Status :: 2 - Pre-Alpha', @@ -19,7 +15,15 @@ ], url='https://github.com/speedcell4/aku', license='MIT', - author='Izen', + author='speedcell4', author_email='speedcell4@gmail.com', - python_requires='>=3.5', + python_requires='>=3.7', + install_requires=[ + ], + extras_require={ + 'dev': [ + 'pytest', + 'hypothesis', + ], + } ) From 95a32604d77e1e481044a039dd64b5c7eaa1e734 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Apr 2020 23:42:49 +0900 Subject: [PATCH 15/72] Docs: Add components of named entity recognition task example --- examples/named_entity_recognition.py | 120 +++++++++++++++++++++++++++ setup.py | 2 + 2 files changed, 122 insertions(+) create mode 100644 examples/named_entity_recognition.py diff --git a/examples/named_entity_recognition.py b/examples/named_entity_recognition.py new file mode 100644 index 0000000..c2ff2cc --- /dev/null +++ b/examples/named_entity_recognition.py @@ -0,0 +1,120 @@ +from abc import ABCMeta, abstractmethod +from typing import List, Type, Union + +import torch +from einops import rearrange +from torch import Tensor +from torch import nn +from torch.nn.utils.rnn import PackedSequence +from torch.nn.utils.rnn import pad_packed_sequence +from torchglyph.vocab import Vocab + + +class WordEmbedding(nn.Embedding): + def __init__(self, *, word_vocab: Vocab) -> None: + super(WordEmbedding, self).__init__( + num_embeddings=len(word_vocab), + embedding_dim=word_vocab.vec_dim, + padding_idx=word_vocab.pad_idx, + _weight=word_vocab.vectors, + ) + + def forward(self, word: PackedSequence) -> PackedSequence: + data = super(WordEmbedding, self).forward(word.data) + return word._replace(data=data) + + +class Encoder(nn.Module, metaclass=ABCMeta): + encoding_dim: int + + def __init__(self, *, embedding_layer: WordEmbedding) -> None: + super(Encoder, self).__init__() + + @abstractmethod + def forward(self, embedding: PackedSequence) -> Tensor: + raise NotImplementedError + + +class LstmEncoder(Encoder): + def __init__(self, hidden_dim: int, num_layers: int, *, embedding_layer: WordEmbedding) -> None: + super(LstmEncoder, self).__init__(embedding_layer=embedding_layer) + self.rnn = nn.LSTM( + input_size=embedding_layer.embedding_dim, + hidden_size=hidden_dim, num_layers=num_layers, + bias=True, batch_first=True, bidirectional=True, + ) + + self.encoding_dim = self.rnn.hidden_size * (2 if self.rnn.bidirectional else 1) + + def forward(self, embedding: PackedSequence) -> Tensor: + _, (hidden, _) = self.rnn(embedding) + return rearrange(hidden, '(l d) b h -> l b (d h)', l=self.rnn.num_layers)[-1] + + +class ConvEncoder(Encoder): + def __init__(self, kernel_sizes: List[int], hidden_dim: int, *, embedding_layer: WordEmbedding): + super(ConvEncoder, self).__init__(embedding_layer=embedding_layer) + self.convs = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=hidden_dim, + kernel_size=(kernel_size, embedding_layer.embedding_dim), + padding=(kernel_size // 2, 0), bias=True, + ), + nn.AdaptiveMaxPool2d(output_size=(1, 1)), + nn.ReLU(), + ) + for kernel_size in kernel_sizes + ]) + + self.encoding_dim = hidden_dim * len(kernel_sizes) + + def forward(self, embedding: PackedSequence) -> Tensor: + data, _ = pad_packed_sequence(embedding, batch_first=True) + return torch.cat([conv(data[:, None, :, :])[:, :, 0, 0] for conv in self.convs], dim=-1) + + +class Projection(nn.Sequential): + def __init__(self, *, target_vocab: Vocab, encoding_layer: Encoder) -> None: + super(Projection, self).__init__( + nn.Linear(encoding_layer.encoding_dim, encoding_layer.encoding_dim), + nn.ReLU(), + nn.Linear(encoding_layer.encoding_dim, len(target_vocab)), + ) + + +class TextClassifier(nn.Module): + def __init__(self, + Emb: Type[WordEmbedding] = WordEmbedding, + Enc: Type[Union[LstmEncoder, ConvEncoder]] = LstmEncoder, + Proj: Type[Projection] = Projection, + reduction: str = 'mean', *, + word_vocab: Vocab, target_vocab: Vocab) -> None: + super(TextClassifier, self).__init__() + + self.embedding_layer = Emb(word_vocab=word_vocab) + self.encoding_layer = Enc(embedding_layer=self.embedding_layer) + self.projection_layer = Proj(target_vocab=target_vocab, encoding_layer=self.encoding_layer) + + self.criterion = nn.CrossEntropyLoss( + ignore_index=target_vocab.pad_idx, + reduction=reduction, + ) + + def forward(self, word: PackedSequence) -> Tensor: + embedding = self.embedding_layer(word) + encoding = self.encoding_layer(embedding) + return self.projection_layer(encoding) + + def fit(self, word: PackedSequence, target: PackedSequence) -> Tensor: + projection = self(word) + return self.criterion(projection.data, target.data) + + def inference(self, word: PackedSequence) -> List[List[int]]: + predictions, lengths = pad_packed_sequence(self(word), batch_first=True) + predictions = predictions.detach().cpu().tolist() + lengths = lengths.detach().cpu().tolist() + return [ + predictions[i][:l] + for i, l in enumerate(lengths) + ] diff --git a/setup.py b/setup.py index 7ef8706..d27f575 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,8 @@ 'dev': [ 'pytest', 'hypothesis', + 'torchglyph', + 'einops', ], } ) From 2d1d08aa4bacddcd1541317a948cce6d7f343645 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 14 Apr 2020 17:10:51 +0900 Subject: [PATCH 16/72] Feat: Support Literal for Python 3.7 --- aku/__init__.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/aku/__init__.py b/aku/__init__.py index e69de29..d9bc81e 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -0,0 +1,38 @@ +import sys +from typing import List, Tuple, Set, FrozenSet, Dict +from typing import Optional +from typing import Union, Type + +if sys.version_info < (3, 8): + from typing import _SpecialForm + + Literal = _SpecialForm('Literal', doc= + """Special typing form to define literal types (a.k.a. value types). + + This form can be used to indicate to type checkers that the corresponding + variable or function parameter has a value equivalent to the provided + literal (or one of several literals): + + def validate_simple(data: Any) -> Literal[True]: # always returns True + ... + + MODE = Literal['r', 'rb', 'w', 'wb'] + def open_helper(file: str, mode: MODE) -> str: + ... + + open_helper('/some/path', 'r') # Passes type check + open_helper('/other/path', 'typo') # Error in type checker + + Literal[...] cannot be subclassed. At runtime, an arbitrary value + is allowed as type argument to Literal[...], but type checkers may + impose restrictions. + """) +else: + from typing import Literal + +__all__ = [ + 'List', 'Tuple', 'Set', 'FrozenSet', 'Dict', + 'Optional', + 'Union', 'Type', + 'Literal', +] From a2bb9970117477e16e55eb8612d3843bae501421 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 14 Apr 2020 17:13:33 +0900 Subject: [PATCH 17/72] Feat: Utilize Literal in example --- ..._recognition.py => text_classification.py} | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) rename examples/{named_entity_recognition.py => text_classification.py} (83%) diff --git a/examples/named_entity_recognition.py b/examples/text_classification.py similarity index 83% rename from examples/named_entity_recognition.py rename to examples/text_classification.py index c2ff2cc..85658fb 100644 --- a/examples/named_entity_recognition.py +++ b/examples/text_classification.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import List, Type, Union +from typing import List, Type, Union, Tuple import torch from einops import rearrange @@ -9,6 +9,8 @@ from torch.nn.utils.rnn import pad_packed_sequence from torchglyph.vocab import Vocab +from aku import Literal + class WordEmbedding(nn.Embedding): def __init__(self, *, word_vocab: Vocab) -> None: @@ -36,7 +38,7 @@ def forward(self, embedding: PackedSequence) -> Tensor: class LstmEncoder(Encoder): - def __init__(self, hidden_dim: int, num_layers: int, *, embedding_layer: WordEmbedding) -> None: + def __init__(self, hidden_dim: int = 300, num_layers: int = 1, *, embedding_layer: WordEmbedding) -> None: super(LstmEncoder, self).__init__(embedding_layer=embedding_layer) self.rnn = nn.LSTM( input_size=embedding_layer.embedding_dim, @@ -52,7 +54,8 @@ def forward(self, embedding: PackedSequence) -> Tensor: class ConvEncoder(Encoder): - def __init__(self, kernel_sizes: List[int], hidden_dim: int, *, embedding_layer: WordEmbedding): + def __init__(self, kernel_sizes: Tuple[int, ...] = (3, 5, 7), hidden_dim: int = 200, *, + embedding_layer: WordEmbedding) -> None: super(ConvEncoder, self).__init__(embedding_layer=embedding_layer) self.convs = nn.ModuleList([ nn.Sequential( @@ -88,7 +91,7 @@ def __init__(self, Emb: Type[WordEmbedding] = WordEmbedding, Enc: Type[Union[LstmEncoder, ConvEncoder]] = LstmEncoder, Proj: Type[Projection] = Projection, - reduction: str = 'mean', *, + reduction: Literal['sum', 'mean'] = 'mean', *, word_vocab: Vocab, target_vocab: Vocab) -> None: super(TextClassifier, self).__init__() @@ -106,15 +109,10 @@ def forward(self, word: PackedSequence) -> Tensor: encoding = self.encoding_layer(embedding) return self.projection_layer(encoding) - def fit(self, word: PackedSequence, target: PackedSequence) -> Tensor: + def fit(self, word: PackedSequence, target: Tensor) -> Tensor: + projection = self(word) + return self.criterion(projection, target) + + def inference(self, word: PackedSequence) -> List[int]: projection = self(word) - return self.criterion(projection.data, target.data) - - def inference(self, word: PackedSequence) -> List[List[int]]: - predictions, lengths = pad_packed_sequence(self(word), batch_first=True) - predictions = predictions.detach().cpu().tolist() - lengths = lengths.detach().cpu().tolist() - return [ - predictions[i][:l] - for i, l in enumerate(lengths) - ] + return projection.detach().cpu().argmax(dim=-1).tolist() From 34097b52b06eb1ddb4dd308f8a1305bca0b1fa19 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 14 Apr 2020 17:18:41 +0900 Subject: [PATCH 18/72] Feat: Add optimizer SGD and Adam --- examples/text_classification.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/examples/text_classification.py b/examples/text_classification.py index 85658fb..93d570a 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -1,10 +1,11 @@ from abc import ABCMeta, abstractmethod -from typing import List, Type, Union, Tuple +from typing import Type +from typing import Union, List, Tuple import torch from einops import rearrange from torch import Tensor -from torch import nn +from torch import nn, optim from torch.nn.utils.rnn import PackedSequence from torch.nn.utils.rnn import pad_packed_sequence from torchglyph.vocab import Vocab @@ -116,3 +117,23 @@ def fit(self, word: PackedSequence, target: Tensor) -> Tensor: def inference(self, word: PackedSequence) -> List[int]: projection = self(word) return projection.detach().cpu().argmax(dim=-1).tolist() + + +class SGD(optim.SGD): + def __init__(self, lr: float = 1e-3, momentum: float = 0, dampening: float = 0, + weight_decay: float = 0, nesterov: bool = False, *, module: nn.Module) -> None: + super(SGD, self).__init__( + params=module.parameters(), + lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, + ) + + +class Adam(optim.Adam): + def __init__(self, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, + weight_decay: float = 0, amsgrad: bool = False, *, module: nn.Module) -> None: + super(Adam, self).__init__( + params=module.parameters(), + lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, + ) From 21142bb6202421be79bba39915da917b76f1c8f3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 14 Apr 2020 17:26:56 +0900 Subject: [PATCH 19/72] Feat: Add train_classifier --- examples/text_classification.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/text_classification.py b/examples/text_classification.py index 93d570a..836d25c 100644 --- a/examples/text_classification.py +++ b/examples/text_classification.py @@ -13,6 +13,12 @@ from aku import Literal +class Corpus(object): + @classmethod + def new(cls, batch_size: int) -> None: + raise NotImplementedError + + class WordEmbedding(nn.Embedding): def __init__(self, *, word_vocab: Vocab) -> None: super(WordEmbedding, self).__init__( @@ -137,3 +143,23 @@ def __init__(self, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, ) + + +def train_classifier( + num_epochs: int = 100, + Data: Type[Corpus.new] = Corpus.new, + Cls: Type[TextClassifier] = TextClassifier, + Opt: Type[Union[SGD, Adam]] = Adam, +): + train, dev, test = Data() + classifier = Cls(word_vocab=..., target_vocab=...) + optimizer = Opt(module=classifier) + + for epoch in range(1, num_epochs + 1): + classifier.train() + for batch in train: + loss = classifier(batch) + + optimizer.zero_grad() + loss.backward() + optimizer.step() From 172ad0eb913b7bcbc88954f575f4490a7d2f3544 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 13:47:48 +0900 Subject: [PATCH 20/72] Feat: Support get_origin and get_args for Python 3.7 --- aku/__init__.py | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/aku/__init__.py b/aku/__init__.py index d9bc81e..3d3677e 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -2,9 +2,11 @@ from typing import List, Tuple, Set, FrozenSet, Dict from typing import Optional from typing import Union, Type +from typing import get_type_hints if sys.version_info < (3, 8): - from typing import _SpecialForm + import collections + from typing import _SpecialForm, _GenericAlias, Generic Literal = _SpecialForm('Literal', doc= """Special typing form to define literal types (a.k.a. value types). @@ -27,6 +29,46 @@ def open_helper(file: str, mode: MODE) -> str: is allowed as type argument to Literal[...], but type checkers may impose restrictions. """) + + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final and ClassVar. + Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + """ + if isinstance(tp, _GenericAlias): + return tp.__origin__ + if tp is Generic: + return Generic + return None + + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _GenericAlias): + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () else: from typing import Literal @@ -35,4 +77,5 @@ def open_helper(file: str, mode: MODE) -> str: 'Optional', 'Union', 'Type', 'Literal', + 'get_args', 'get_origin', 'get_type_hints', ] From e6935a956eeb94c38d538d5447afcedc2e2d6339 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:22:04 +0900 Subject: [PATCH 21/72] Feat: Add PrimitiveTp --- aku/parse_fn.py | 27 ++++++++++++++++++++++++++ aku/tp.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 aku/parse_fn.py create mode 100644 aku/tp.py diff --git a/aku/parse_fn.py b/aku/parse_fn.py new file mode 100644 index 0000000..57ddefa --- /dev/null +++ b/aku/parse_fn.py @@ -0,0 +1,27 @@ +import inspect + +registry = {} + + +def register_parse_fn(tp): + ret = inspect.getfullargspec(tp).annotations['return'] + + global registry + assert ret not in registry + + registry[ret] = tp + return tp + + +def get_parse_fn(tp): + return registry.get(tp, tp) + + +@register_parse_fn +def parse_bool(option_string: str) -> bool: + option_string = option_string.strip().lower() + if option_string in ('1', 'y', 'yes', 't', 'true'): + return True + if option_string in ('0', 'n', 'no', 'f', 'false'): + return False + raise ValueError(f'{option_string} is not a boolean value') diff --git a/aku/tp.py b/aku/tp.py new file mode 100644 index 0000000..c576cdf --- /dev/null +++ b/aku/tp.py @@ -0,0 +1,50 @@ +from abc import ABCMeta, abstractmethod +from typing import get_args, get_origin, Any + +from aku.parse_fn import get_parse_fn +from argparse import ArgumentParser + + +class Tp(object, metaclass=ABCMeta): + def __init__(self, origin, *args: 'Tp') -> None: + super(Tp, self).__init__() + self.origin = origin + self.args = args + + def __class_getitem__(cls, tp): + args = get_args(tp) + origin = get_origin(tp) + + if origin is None and args == (): + return PrimitiveTp(origin, *args) + + raise NotImplementedError(f'unsupported annotation {tp}') + + @property + @abstractmethod + def metavar(self) -> str: + raise NotImplementedError + + @abstractmethod + def parse_fn(self, option_string: str) -> Any: + raise NotImplementedError + + @abstractmethod + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + raise NotImplementedError + + +class PrimitiveTp(Tp): + @property + def metavar(self) -> str: + return f'{self.origin.__name__.lower()}' + + def parse_fn(self, option_string: str) -> Any: + fn = get_parse_fn(self.origin) + return fn(option_string.strip()) + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + ) From 4d52e8aed4b3307a9f15824f34c28d7bc03bc2fd Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:26:31 +0900 Subject: [PATCH 22/72] Feat: Add ListTp --- aku/tp.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/aku/tp.py b/aku/tp.py index c576cdf..1353890 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -3,6 +3,9 @@ from aku.parse_fn import get_parse_fn from argparse import ArgumentParser +import re + +COMMA = re.compile(r',\s*') class Tp(object, metaclass=ABCMeta): @@ -17,6 +20,8 @@ def __class_getitem__(cls, tp): if origin is None and args == (): return PrimitiveTp(origin, *args) + if origin is list and len(args) == 1: + return ListTp(origin, *args) raise NotImplementedError(f'unsupported annotation {tp}') @@ -48,3 +53,23 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) f'--{name}', required=True, help=f'{name}', type=self.parse_fn, metavar=self.metavar, default=default, ) + + +class ListTp(Tp): + @property + def metavar(self) -> str: + return f'[{self.args[0].metavar}]' + + def parse_fn(self, option_string: str) -> Any: + option_string = option_string.strip() + if not option_string.startswith('[') or not option_string.endswith(']'): + raise ValueError(f'{option_string} is not a list') + + option_strings = re.split(COMMA, option_string) + return [self.args[0].parse_fn(s) for s in option_strings] + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + ) From d4dfb4dcee55d29d074ffaa4b8219efdd01777cc Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:33:39 +0900 Subject: [PATCH 23/72] Feat: Add HomoTupleTp and HeteroTupleTp --- aku/tp.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 1353890..19612c0 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,9 +1,9 @@ +import re from abc import ABCMeta, abstractmethod +from argparse import ArgumentParser from typing import get_args, get_origin, Any from aku.parse_fn import get_parse_fn -from argparse import ArgumentParser -import re COMMA = re.compile(r',\s*') @@ -22,6 +22,11 @@ def __class_getitem__(cls, tp): return PrimitiveTp(origin, *args) if origin is list and len(args) == 1: return ListTp(origin, *args) + if origin is tuple: + if len(args) == 2 and args[1] is ...: + return HomoTupleTp(origin, *args) + else: + return HeteroTupleTp(origin, *args) raise NotImplementedError(f'unsupported annotation {tp}') @@ -66,7 +71,47 @@ def parse_fn(self, option_string: str) -> Any: raise ValueError(f'{option_string} is not a list') option_strings = re.split(COMMA, option_string) - return [self.args[0].parse_fn(s) for s in option_strings] + return list(self.args[0].parse_fn(s) for s in option_strings) + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + ) + + +class HomoTupleTp(Tp): + @property + def metavar(self) -> str: + return f'({self.args[0].metavar}, ...)' + + def parse_fn(self, option_string: str) -> Any: + option_string = option_string.strip() + if not option_string.startswith('(') or not option_string.endswith(')'): + raise ValueError(f'{option_string} is not a list') + + option_strings = re.split(COMMA, option_string) + return tuple(self.args[0].parse_fn(s) for s in option_strings) + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + ) + + +class HeteroTupleTp(Tp): + @property + def metavar(self) -> str: + return f"({', '.join([f'{a.metavar}' for a in self.args])})" + + def parse_fn(self, option_string: str) -> Any: + option_string = option_string.strip() + if not option_string.startswith('(') or not option_string.endswith(')'): + raise ValueError(f'{option_string} is not a list') + + option_strings = re.split(COMMA, option_string) + return tuple(a.parse_fn(s) for s, a in zip(option_strings, self.args)) def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( From 3147ae1f1292eca236d57b390ca5feadecfa8da6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:39:21 +0900 Subject: [PATCH 24/72] Refactor: Update details --- aku/tp.py | 52 +++++++++++++++++----------------------------------- 1 file changed, 17 insertions(+), 35 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 19612c0..708fd94 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -19,14 +19,14 @@ def __class_getitem__(cls, tp): origin = get_origin(tp) if origin is None and args == (): - return PrimitiveTp(origin, *args) + return PrimitiveTp(origin) if origin is list and len(args) == 1: - return ListTp(origin, *args) + return ListTp(origin, cls[args[0]]) if origin is tuple: if len(args) == 2 and args[1] is ...: - return HomoTupleTp(origin, *args) + return HomoTupleTp(origin, cls[args[0]]) else: - return HeteroTupleTp(origin, *args) + return HeteroTupleTp(origin, *[cls[a] for a in args]) raise NotImplementedError(f'unsupported annotation {tp}') @@ -39,9 +39,11 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: raise NotImplementedError - @abstractmethod def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - raise NotImplementedError + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + ) class PrimitiveTp(Tp): @@ -53,12 +55,6 @@ def parse_fn(self, option_string: str) -> Any: fn = get_parse_fn(self.origin) return fn(option_string.strip()) - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, - ) - class ListTp(Tp): @property @@ -68,16 +64,10 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('[') or not option_string.endswith(']'): - raise ValueError(f'{option_string} is not a list') + raise ValueError(f'{option_string} is not a {self.origin.__name__}') option_strings = re.split(COMMA, option_string) - return list(self.args[0].parse_fn(s) for s in option_strings) - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, - ) + return self.origin(self.args[0].parse_fn(s) for s in option_strings) class HomoTupleTp(Tp): @@ -88,16 +78,10 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a list') + raise ValueError(f'{option_string} is not a {self.origin.__name__}') option_strings = re.split(COMMA, option_string) - return tuple(self.args[0].parse_fn(s) for s in option_strings) - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, - ) + return self.origin(self.args[0].parse_fn(s) for s in option_strings) class HeteroTupleTp(Tp): @@ -108,13 +92,11 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a list') + raise ValueError(f'{option_string} is not a {self.origin.__name__}') option_strings = re.split(COMMA, option_string) - return tuple(a.parse_fn(s) for s, a in zip(option_strings, self.args)) + assert len(option_strings) == len(self.args), \ + f'the number of parameters is not correct, ' \ + f'got {len(option_strings)} instead of {len(self.args)}' - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, - ) + return self.origin(a.parse_fn(s) for s, a in zip(option_strings, self.args)) From 4f3e3c3aa7784a17f157aacaa92350d49da5e0e4 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:44:19 +0900 Subject: [PATCH 25/72] Feat: Add SetTp and FrozenSetTp --- aku/tp.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/aku/tp.py b/aku/tp.py index 708fd94..59abd9d 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -27,6 +27,10 @@ def __class_getitem__(cls, tp): return HomoTupleTp(origin, cls[args[0]]) else: return HeteroTupleTp(origin, *[cls[a] for a in args]) + if origin is set: + return SetTp(origin, cls[args[0]]) + if origin is frozenset: + return FrozenSetTp(origin, cls[args[0]]) raise NotImplementedError(f'unsupported annotation {tp}') @@ -100,3 +104,31 @@ def parse_fn(self, option_string: str) -> Any: f'got {len(option_strings)} instead of {len(self.args)}' return self.origin(a.parse_fn(s) for s, a in zip(option_strings, self.args)) + + +class SetTp(Tp): + @property + def metavar(self) -> str: + return f'{{{self.args[0].metavar}}}' + + def parse_fn(self, option_string: str) -> Any: + option_string = option_string.strip() + if not option_string.startswith('{') or not option_string.endswith('}'): + raise ValueError(f'{option_string} is not a {self.origin.__name__}') + + option_strings = re.split(COMMA, option_string) + return self.origin(self.args[0].parse_fn(s) for s in option_strings) + + +class FrozenSetTp(Tp): + @property + def metavar(self) -> str: + return f'{{{self.args[0].metavar}}}' + + def parse_fn(self, option_string: str) -> Any: + option_string = option_string.strip() + if not option_string.startswith('{') or not option_string.endswith('}'): + raise ValueError(f'{option_string} is not a {self.origin.__name__}') + + option_strings = re.split(COMMA, option_string) + return self.origin(self.args[0].parse_fn(s) for s in option_strings) From b96231d5ae8da4987bdf350ab8b8b4bbcf5f8059 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 14:50:27 +0900 Subject: [PATCH 26/72] Fix: origin of PrimitiveTp --- aku/tp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aku/tp.py b/aku/tp.py index 59abd9d..745044d 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -19,7 +19,7 @@ def __class_getitem__(cls, tp): origin = get_origin(tp) if origin is None and args == (): - return PrimitiveTp(origin) + return PrimitiveTp(tp) if origin is list and len(args) == 1: return ListTp(origin, cls[args[0]]) if origin is tuple: From 06caf2d3615484c2077bbaed92bbc7c33999dc18 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 15:05:38 +0900 Subject: [PATCH 27/72] Feat: Support Literal --- aku/tp.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 745044d..cbb8185 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,18 +1,19 @@ import re from abc import ABCMeta, abstractmethod from argparse import ArgumentParser -from typing import get_args, get_origin, Any +from typing import get_args, get_origin, Any, Union, Literal from aku.parse_fn import get_parse_fn -COMMA = re.compile(r',\s*') +COMMA = re.compile(r'\s*,\s*') class Tp(object, metaclass=ABCMeta): - def __init__(self, origin, *args: 'Tp') -> None: + def __init__(self, origin, *args: Union['Tp', Any], **kwargs: ['Tp', Any]) -> None: super(Tp, self).__init__() self.origin = origin self.args = args + self.kwargs = kwargs def __class_getitem__(cls, tp): args = get_args(tp) @@ -20,15 +21,26 @@ def __class_getitem__(cls, tp): if origin is None and args == (): return PrimitiveTp(tp) + + if origin is Literal: + tp = type(args[0]) + + assert all(isinstance(a, tp) for a in args), \ + f'all parameters should have the same type {tp.__name__}' + return PrimitiveTp(tp, *set(args)) + if origin is list and len(args) == 1: return ListTp(origin, cls[args[0]]) + if origin is tuple: if len(args) == 2 and args[1] is ...: return HomoTupleTp(origin, cls[args[0]]) else: return HeteroTupleTp(origin, *[cls[a] for a in args]) + if origin is set: return SetTp(origin, cls[args[0]]) + if origin is frozenset: return FrozenSetTp(origin, cls[args[0]]) @@ -59,6 +71,13 @@ def parse_fn(self, option_string: str) -> Any: fn = get_parse_fn(self.origin) return fn(option_string.strip()) + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + return argument_parser.add_argument( + f'--{name}', required=True, help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=default, + choices=self.args if len(self.args) > 0 else None, + ) + class ListTp(Tp): @property From b2f054426cd5dbabeb6f8e63b2340b3131128642 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 15:20:57 +0900 Subject: [PATCH 28/72] Feat: Add an example basic.py --- aku/tp.py | 40 +++++++++++++++++++++------------------- examples/basic.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 19 deletions(-) create mode 100644 examples/basic.py diff --git a/aku/tp.py b/aku/tp.py index cbb8185..5bf0aa6 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -26,7 +26,7 @@ def __class_getitem__(cls, tp): tp = type(args[0]) assert all(isinstance(a, tp) for a in args), \ - f'all parameters should have the same type {tp.__name__}' + f'all arguments should have the same type {tp.__name__}' return PrimitiveTp(tp, *set(args)) if origin is list and len(args) == 1: @@ -44,7 +44,7 @@ def __class_getitem__(cls, tp): if origin is frozenset: return FrozenSetTp(origin, cls[args[0]]) - raise NotImplementedError(f'unsupported annotation {tp}') + raise NotImplementedError(f'unsupported {cls.__name__} {tp}') @property @abstractmethod @@ -57,15 +57,17 @@ def parse_fn(self, option_string: str) -> Any: def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, + f'--{name}', help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=repr(default), ) class PrimitiveTp(Tp): @property def metavar(self) -> str: - return f'{self.origin.__name__.lower()}' + if len(self.args) == 0: + return f'{self.origin.__name__.lower()}' + return f"{{{', '.join([f'{repr(a)}' for a in self.args])}}}" def parse_fn(self, option_string: str) -> Any: fn = get_parse_fn(self.origin) @@ -73,8 +75,8 @@ def parse_fn(self, option_string: str) -> Any: def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( - f'--{name}', required=True, help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=default, + f'--{name}', help=f'{name}', + type=self.parse_fn, metavar=self.metavar, default=repr(default), choices=self.args if len(self.args) > 0 else None, ) @@ -87,9 +89,9 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('[') or not option_string.endswith(']'): - raise ValueError(f'{option_string} is not a {self.origin.__name__}') + raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string) + option_strings = re.split(COMMA, option_string[1:-1]) return self.origin(self.args[0].parse_fn(s) for s in option_strings) @@ -101,9 +103,9 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a {self.origin.__name__}') + raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string) + option_strings = re.split(COMMA, option_string[1:-1]) return self.origin(self.args[0].parse_fn(s) for s in option_strings) @@ -115,12 +117,12 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a {self.origin.__name__}') + raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string) + option_strings = re.split(COMMA, option_string[1:-1]) assert len(option_strings) == len(self.args), \ - f'the number of parameters is not correct, ' \ - f'got {len(option_strings)} instead of {len(self.args)}' + f'the number of arguments is not correct, ' \ + f'got {len(option_strings)} but excepted {len(self.args)}' return self.origin(a.parse_fn(s) for s, a in zip(option_strings, self.args)) @@ -133,9 +135,9 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('{') or not option_string.endswith('}'): - raise ValueError(f'{option_string} is not a {self.origin.__name__}') + raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string) + option_strings = re.split(COMMA, option_string[1:-1]) return self.origin(self.args[0].parse_fn(s) for s in option_strings) @@ -147,7 +149,7 @@ def metavar(self) -> str: def parse_fn(self, option_string: str) -> Any: option_string = option_string.strip() if not option_string.startswith('{') or not option_string.endswith('}'): - raise ValueError(f'{option_string} is not a {self.origin.__name__}') + raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string) + option_strings = re.split(COMMA, option_string[1:-1]) return self.origin(self.args[0].parse_fn(s) for s in option_strings) diff --git a/examples/basic.py b/examples/basic.py new file mode 100644 index 0000000..42e3ff0 --- /dev/null +++ b/examples/basic.py @@ -0,0 +1,45 @@ +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +from aku.tp import Tp +from typing import List, Tuple, Set, FrozenSet, Literal + +parser = ArgumentParser( + formatter_class=ArgumentDefaultsHelpFormatter, +) + +Tp[str].add_argument(parser, 'pa', 'a') +Tp[int].add_argument(parser, 'pb', 1) +Tp[bool].add_argument(parser, 'pc', False) +Tp[float].add_argument(parser, 'pd', 2.0) + +Tp[Literal['a', 'b']].add_argument(parser, 'la', 'a') +Tp[Literal[2, 3]].add_argument(parser, 'lb', 3) +Tp[Literal[True, False]].add_argument(parser, 'lc', False) +Tp[Literal[2.0, 3.0]].add_argument(parser, 'ld', 2.0) + +Tp[List[str]].add_argument(parser, 'sa', ['a', 'b']) +Tp[List[int]].add_argument(parser, 'sb', [2, 3]) +Tp[List[bool]].add_argument(parser, 'sc', [True, False]) +Tp[List[float]].add_argument(parser, 'sd', [2.0, 3.0]) + +Tp[Tuple[str, ...]].add_argument(parser, 'oa', ('a', 'b')) +Tp[Tuple[int, ...]].add_argument(parser, 'ob', (2, 3)) +Tp[Tuple[bool, ...]].add_argument(parser, 'oc', (True, False)) +Tp[Tuple[float, ...]].add_argument(parser, 'od', (2.0, 3.0)) + +Tp[Tuple[str, int]].add_argument(parser, 'ea', ('a', 1)) +Tp[Tuple[int, int]].add_argument(parser, 'eb', (2, 1)) +Tp[Tuple[bool, int]].add_argument(parser, 'ec', (True, 1)) +Tp[Tuple[float, int]].add_argument(parser, 'ed', (2.0, 1)) + +Tp[Set[str]].add_argument(parser, 'ta', {'a'}) +Tp[Set[int]].add_argument(parser, 'tb', {2}) +Tp[Set[bool]].add_argument(parser, 'tc', {True}) +Tp[Set[float]].add_argument(parser, 'td', {2.0}) + +Tp[FrozenSet[str]].add_argument(parser, 'fa', frozenset({'a'})) +Tp[FrozenSet[int]].add_argument(parser, 'fb', frozenset({2})) +Tp[FrozenSet[bool]].add_argument(parser, 'fc', frozenset({True})) +Tp[FrozenSet[float]].add_argument(parser, 'fd', frozenset({2.0})) + +print(parser.parse_args()) From 84749ec180223097cac446fe9c11de0d9cd99b5c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 15:25:47 +0900 Subject: [PATCH 29/72] Refactor: Update parsing functions --- aku/parse_fn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aku/parse_fn.py b/aku/parse_fn.py index 57ddefa..3ed1963 100644 --- a/aku/parse_fn.py +++ b/aku/parse_fn.py @@ -5,8 +5,6 @@ def register_parse_fn(tp): ret = inspect.getfullargspec(tp).annotations['return'] - - global registry assert ret not in registry registry[ret] = tp @@ -24,4 +22,4 @@ def parse_bool(option_string: str) -> bool: return True if option_string in ('0', 'n', 'no', 'f', 'false'): return False - raise ValueError(f'{option_string} is not a boolean value') + raise ValueError(f'{option_string} is not a {bool.__name__} value') From 3e17ff324076e4a415728a3b451bb64b197048c0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 15:27:38 +0900 Subject: [PATCH 30/72] Refactor: Check unification of Literal --- examples/basic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/basic.py b/examples/basic.py index 42e3ff0..46388a3 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -12,9 +12,9 @@ Tp[bool].add_argument(parser, 'pc', False) Tp[float].add_argument(parser, 'pd', 2.0) -Tp[Literal['a', 'b']].add_argument(parser, 'la', 'a') -Tp[Literal[2, 3]].add_argument(parser, 'lb', 3) -Tp[Literal[True, False]].add_argument(parser, 'lc', False) +Tp[Literal['a', 'b', 'a']].add_argument(parser, 'la', 'a') +Tp[Literal[2, 3, 2]].add_argument(parser, 'lb', 3) +Tp[Literal[True, False, True]].add_argument(parser, 'lc', False) Tp[Literal[2.0, 3.0]].add_argument(parser, 'ld', 2.0) Tp[List[str]].add_argument(parser, 'sa', ['a', 'b']) @@ -37,9 +37,9 @@ Tp[Set[bool]].add_argument(parser, 'tc', {True}) Tp[Set[float]].add_argument(parser, 'td', {2.0}) -Tp[FrozenSet[str]].add_argument(parser, 'fa', frozenset({'a'})) -Tp[FrozenSet[int]].add_argument(parser, 'fb', frozenset({2})) -Tp[FrozenSet[bool]].add_argument(parser, 'fc', frozenset({True})) -Tp[FrozenSet[float]].add_argument(parser, 'fd', frozenset({2.0})) +# Tp[FrozenSet[str]].add_argument(parser, 'fa', frozenset({'a'})) +# Tp[FrozenSet[int]].add_argument(parser, 'fb', frozenset({2})) +# Tp[FrozenSet[bool]].add_argument(parser, 'fc', frozenset({True})) +# Tp[FrozenSet[float]].add_argument(parser, 'fd', frozenset({2.0})) print(parser.parse_args()) From 7eaed85b75b7ccb7199b157a81f3e855a20a23eb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 25 Apr 2020 16:05:23 +0900 Subject: [PATCH 31/72] Feat: Support SUPPRESS default --- aku/tp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 5bf0aa6..4961840 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,6 +1,6 @@ import re from abc import ABCMeta, abstractmethod -from argparse import ArgumentParser +from argparse import ArgumentParser, SUPPRESS from typing import get_args, get_origin, Any, Union, Literal from aku.parse_fn import get_parse_fn @@ -58,7 +58,8 @@ def parse_fn(self, option_string: str) -> Any: def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( f'--{name}', help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=repr(default), + type=self.parse_fn, metavar=self.metavar, required=default == SUPPRESS, + default=repr(default) if default != SUPPRESS else SUPPRESS, ) @@ -70,14 +71,14 @@ def metavar(self) -> str: return f"{{{', '.join([f'{repr(a)}' for a in self.args])}}}" def parse_fn(self, option_string: str) -> Any: - fn = get_parse_fn(self.origin) - return fn(option_string.strip()) + return get_parse_fn(self.origin)(option_string.strip()) def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( f'--{name}', help=f'{name}', - type=self.parse_fn, metavar=self.metavar, default=repr(default), choices=self.args if len(self.args) > 0 else None, + type=self.parse_fn, metavar=self.metavar, required=default == SUPPRESS, + default=repr(default) if default != SUPPRESS else SUPPRESS, ) From 6f7387ac2cff7009064c25f4d78623a285f6ac9f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 30 Apr 2020 21:16:44 +0900 Subject: [PATCH 32/72] Feat: Add TypeTp --- aku/tp.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/aku/tp.py b/aku/tp.py index 4961840..a3df80a 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,13 +1,28 @@ import re from abc import ABCMeta, abstractmethod from argparse import ArgumentParser, SUPPRESS -from typing import get_args, get_origin, Any, Union, Literal +from inspect import getfullargspec +from itertools import zip_longest +from typing import get_args, get_origin, get_type_hints, Any, Union, Literal from aku.parse_fn import get_parse_fn COMMA = re.compile(r'\s*,\s*') +def get_type_annotations(tp): + tys = get_type_hints(tp) + spec = getfullargspec(tp) + + name_default = zip_longest( + reversed(spec.args), + reversed(spec.defaults or []), + fillvalue=SUPPRESS, + ) + for arg_name, arg_default in reversed(list(name_default)): + yield arg_name, arg_default, tys[arg_name] + + class Tp(object, metaclass=ABCMeta): def __init__(self, origin, *args: Union['Tp', Any], **kwargs: ['Tp', Any]) -> None: super(Tp, self).__init__() @@ -44,6 +59,16 @@ def __class_getitem__(cls, tp): if origin is frozenset: return FrozenSetTp(origin, cls[args[0]]) + if origin is type: + if get_origin(args[0]) is Union: + args = get_args(args[0]) + return UnionTp(str, **{a.__name__: a for a in args}) + return TypeTp(args[0]) + + if origin is Union: + args = [get_args(a)[0] for a in args] + return UnionTp(str, **{a.__name__: a for a in args}) + raise NotImplementedError(f'unsupported {cls.__name__} {tp}') @property @@ -63,6 +88,22 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) ) +class TypeTp(Tp): + @property + def metavar(self) -> str: + raise NotImplementedError + + def parse_fn(self, option_string: str) -> Any: + raise NotImplementedError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + for arg_name, arg_default, arg_tp in get_type_annotations(self.origin): + Tp[arg_tp].add_argument( + argument_parser=argument_parser, + name=arg_name, default=arg_default, + ) + + class PrimitiveTp(Tp): @property def metavar(self) -> str: @@ -154,3 +195,12 @@ def parse_fn(self, option_string: str) -> Any: option_strings = re.split(COMMA, option_string[1:-1]) return self.origin(self.args[0].parse_fn(s) for s in option_strings) + + +class UnionTp(Tp): + @property + def metavar(self) -> str: + return f"{{{', '.join(self.kwargs.keys())}}}" + + def parse_fn(self, option_string: str) -> Any: + return str(option_string.strip()) From 08b4646217ebf1df9528c07ee37cc59f31e24a5c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 30 Apr 2020 22:31:15 +0900 Subject: [PATCH 33/72] Feat: Add UnionTp --- aku/parse_fn.py | 10 +++--- aku/tp.py | 95 ++++++++++++++++++++++++++++--------------------- 2 files changed, 60 insertions(+), 45 deletions(-) diff --git a/aku/parse_fn.py b/aku/parse_fn.py index 3ed1963..39f8e7c 100644 --- a/aku/parse_fn.py +++ b/aku/parse_fn.py @@ -16,10 +16,10 @@ def get_parse_fn(tp): @register_parse_fn -def parse_bool(option_string: str) -> bool: - option_string = option_string.strip().lower() - if option_string in ('1', 'y', 'yes', 't', 'true'): +def parse_bool(string: str) -> bool: + string = string.strip().lower() + if string in ('1', 'y', 'yes', 't', 'true'): return True - if option_string in ('0', 'n', 'no', 'f', 'false'): + if string in ('0', 'n', 'no', 'f', 'false'): return False - raise ValueError(f'{option_string} is not a {bool.__name__} value') + raise ValueError(f'{string} is not a {bool.__name__} value') diff --git a/aku/tp.py b/aku/tp.py index a3df80a..35ff212 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,9 +1,9 @@ import re from abc import ABCMeta, abstractmethod -from argparse import ArgumentParser, SUPPRESS +from argparse import ArgumentParser, SUPPRESS, Action from inspect import getfullargspec from itertools import zip_longest -from typing import get_args, get_origin, get_type_hints, Any, Union, Literal +from typing import get_args, get_origin, get_type_hints, Any, Union, Literal, Type from aku.parse_fn import get_parse_fn @@ -77,7 +77,7 @@ def metavar(self) -> str: raise NotImplementedError @abstractmethod - def parse_fn(self, option_string: str) -> Any: + def parse_fn(self, string: str) -> Any: raise NotImplementedError def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): @@ -93,7 +93,7 @@ class TypeTp(Tp): def metavar(self) -> str: raise NotImplementedError - def parse_fn(self, option_string: str) -> Any: + def parse_fn(self, string: str) -> Any: raise NotImplementedError def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): @@ -111,8 +111,8 @@ def metavar(self) -> str: return f'{self.origin.__name__.lower()}' return f"{{{', '.join([f'{repr(a)}' for a in self.args])}}}" - def parse_fn(self, option_string: str) -> Any: - return get_parse_fn(self.origin)(option_string.strip()) + def parse_fn(self, string: str) -> Any: + return get_parse_fn(self.origin)(string.strip()) def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): return argument_parser.add_argument( @@ -128,13 +128,13 @@ class ListTp(Tp): def metavar(self) -> str: return f'[{self.args[0].metavar}]' - def parse_fn(self, option_string: str) -> Any: - option_string = option_string.strip() - if not option_string.startswith('[') or not option_string.endswith(']'): - raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') + def parse_fn(self, string: str) -> Any: + string = string.strip() + if not string.startswith('[') or not string.endswith(']'): + raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in option_strings) + strings = re.split(COMMA, string[1:-1]) + return self.origin(self.args[0].parse_fn(s) for s in strings) class HomoTupleTp(Tp): @@ -142,13 +142,13 @@ class HomoTupleTp(Tp): def metavar(self) -> str: return f'({self.args[0].metavar}, ...)' - def parse_fn(self, option_string: str) -> Any: - option_string = option_string.strip() - if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') + def parse_fn(self, string: str) -> Any: + string = string.strip() + if not string.startswith('(') or not string.endswith(')'): + raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in option_strings) + strings = re.split(COMMA, string[1:-1]) + return self.origin(self.args[0].parse_fn(s) for s in strings) class HeteroTupleTp(Tp): @@ -156,17 +156,17 @@ class HeteroTupleTp(Tp): def metavar(self) -> str: return f"({', '.join([f'{a.metavar}' for a in self.args])})" - def parse_fn(self, option_string: str) -> Any: - option_string = option_string.strip() - if not option_string.startswith('(') or not option_string.endswith(')'): - raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') + def parse_fn(self, string: str) -> Any: + string = string.strip() + if not string.startswith('(') or not string.endswith(')'): + raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string[1:-1]) - assert len(option_strings) == len(self.args), \ + strings = re.split(COMMA, string[1:-1]) + assert len(strings) == len(self.args), \ f'the number of arguments is not correct, ' \ - f'got {len(option_strings)} but excepted {len(self.args)}' + f'got {len(strings)} but excepted {len(self.args)}' - return self.origin(a.parse_fn(s) for s, a in zip(option_strings, self.args)) + return self.origin(a.parse_fn(s) for s, a in zip(strings, self.args)) class SetTp(Tp): @@ -174,13 +174,13 @@ class SetTp(Tp): def metavar(self) -> str: return f'{{{self.args[0].metavar}}}' - def parse_fn(self, option_string: str) -> Any: - option_string = option_string.strip() - if not option_string.startswith('{') or not option_string.endswith('}'): - raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') + def parse_fn(self, string: str) -> Any: + string = string.strip() + if not string.startswith('{') or not string.endswith('}'): + raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in option_strings) + strings = re.split(COMMA, string[1:-1]) + return self.origin(self.args[0].parse_fn(s) for s in strings) class FrozenSetTp(Tp): @@ -188,13 +188,13 @@ class FrozenSetTp(Tp): def metavar(self) -> str: return f'{{{self.args[0].metavar}}}' - def parse_fn(self, option_string: str) -> Any: - option_string = option_string.strip() - if not option_string.startswith('{') or not option_string.endswith('}'): - raise ValueError(f'{option_string} is not a(n) {self.origin.__name__}') + def parse_fn(self, string: str) -> Any: + string = string.strip() + if not string.startswith('{') or not string.endswith('}'): + raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - option_strings = re.split(COMMA, option_string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in option_strings) + strings = re.split(COMMA, string[1:-1]) + return self.origin(self.args[0].parse_fn(s) for s in strings) class UnionTp(Tp): @@ -202,5 +202,20 @@ class UnionTp(Tp): def metavar(self) -> str: return f"{{{', '.join(self.kwargs.keys())}}}" - def parse_fn(self, option_string: str) -> Any: - return str(option_string.strip()) + def parse_fn(self, string: str) -> Any: + return str(string.strip()) + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): + choices = self.kwargs + + class _UnionTpAction(Action): + def __call__(self, parser: ArgumentParser, namespace, values, option_string=...) -> None: + Tp[Type[choices[values]]].add_argument( + argument_parser=argument_parser, name=name, default=default) + argument_parser.parse_known_args() + + argument_parser.add_argument( + f'--{name}', help=f'{name}', required=default == SUPPRESS, + type=self.parse_fn, metavar=self.metavar, action=_UnionTpAction, + choices=list(choices.keys()), default=default.__name__, + ) From a6172ecf1473407e5840b894881d863af41787ff Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 16 May 2020 17:56:18 +0900 Subject: [PATCH 34/72] Refactor: Rename to parsing_fn --- aku/{parse_fn.py => parsing_fn.py} | 0 aku/tp.py | 3 +-- 2 files changed, 1 insertion(+), 2 deletions(-) rename aku/{parse_fn.py => parsing_fn.py} (100%) diff --git a/aku/parse_fn.py b/aku/parsing_fn.py similarity index 100% rename from aku/parse_fn.py rename to aku/parsing_fn.py diff --git a/aku/tp.py b/aku/tp.py index 35ff212..61d6080 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -5,7 +5,7 @@ from itertools import zip_longest from typing import get_args, get_origin, get_type_hints, Any, Union, Literal, Type -from aku.parse_fn import get_parse_fn +from aku.parsing_fn import get_parse_fn COMMA = re.compile(r'\s*,\s*') @@ -212,7 +212,6 @@ class _UnionTpAction(Action): def __call__(self, parser: ArgumentParser, namespace, values, option_string=...) -> None: Tp[Type[choices[values]]].add_argument( argument_parser=argument_parser, name=name, default=default) - argument_parser.parse_known_args() argument_parser.add_argument( f'--{name}', help=f'{name}', required=default == SUPPRESS, From b57d25304a65213e2354ecc0ed786b703fd0d554 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 16 May 2020 17:56:40 +0900 Subject: [PATCH 35/72] Refactor: Rename to parsing_fn and annotations --- aku/{tp.py => annotations.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename aku/{tp.py => annotations.py} (100%) diff --git a/aku/tp.py b/aku/annotations.py similarity index 100% rename from aku/tp.py rename to aku/annotations.py From 5f22f31158763435cd8fcdbb8163405a8c1e2846 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 16 May 2020 18:09:05 +0900 Subject: [PATCH 36/72] Feat: Add Aku --- aku/abc.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 aku/abc.py diff --git a/aku/abc.py b/aku/abc.py new file mode 100644 index 0000000..a0c6863 --- /dev/null +++ b/aku/abc.py @@ -0,0 +1,15 @@ +from argparse import ArgumentParser, Namespace +from typing import List + + +class Aku(ArgumentParser): + def parse_known_args(self, args: List[str] = None, namespace: Namespace = None): + last_actions_len = -1 + while last_actions_len != len(self._actions): + last_actions_len = len(self._actions) + namespace, args = super(Aku, self).parse_known_args(args=args, namespace=namespace) + return super(Aku, self).parse_known_args(args=args, namespace=namespace) + + def parse_args(self, args: List[str] = None, namespace: Namespace = None): + namespace, args = super(Aku, self).parse_known_args(args=args, namespace=namespace) + return super(Aku, self).parse_args(args=args, namespace=namespace) From f42c53d2885e075db3dcb0e9a0b52a23754f468f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 16 May 2020 18:17:24 +0900 Subject: [PATCH 37/72] Feat: Save name in UnionTp --- aku/annotations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aku/annotations.py b/aku/annotations.py index 61d6080..4a1fe80 100644 --- a/aku/annotations.py +++ b/aku/annotations.py @@ -212,6 +212,7 @@ class _UnionTpAction(Action): def __call__(self, parser: ArgumentParser, namespace, values, option_string=...) -> None: Tp[Type[choices[values]]].add_argument( argument_parser=argument_parser, name=name, default=default) + setattr(namespace, name, values) argument_parser.add_argument( f'--{name}', help=f'{name}', required=default == SUPPRESS, From 1927dedcac8a01b75a48462434922b6b3a67a8db Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 16 May 2020 18:20:18 +0900 Subject: [PATCH 38/72] Refactor: Simplify parse_known_args --- aku/abc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aku/abc.py b/aku/abc.py index a0c6863..0a49847 100644 --- a/aku/abc.py +++ b/aku/abc.py @@ -8,8 +8,8 @@ def parse_known_args(self, args: List[str] = None, namespace: Namespace = None): while last_actions_len != len(self._actions): last_actions_len = len(self._actions) namespace, args = super(Aku, self).parse_known_args(args=args, namespace=namespace) - return super(Aku, self).parse_known_args(args=args, namespace=namespace) + return namespace, args def parse_args(self, args: List[str] = None, namespace: Namespace = None): - namespace, args = super(Aku, self).parse_known_args(args=args, namespace=namespace) + namespace, args = self.parse_known_args(args=args, namespace=namespace) return super(Aku, self).parse_args(args=args, namespace=namespace) From 70c9efa23eb0ccf841a43f45c796eb161b2279c7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 17 May 2020 16:10:00 +0900 Subject: [PATCH 39/72] Feat: Add fetch_actions --- utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 utils.py diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..26d932a --- /dev/null +++ b/utils.py @@ -0,0 +1,9 @@ +from argparse import ArgumentParser + + +def fetch_actions(argument_parser: ArgumentParser) -> str: + msg = ', '.join([ + action.option_strings[-1] + for action in argument_parser._actions + ]) + return f"[{msg}]" From 4e76dde09b96586e2b5b8bbd9f2b81b94eb54f41 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 17 May 2020 16:17:17 +0900 Subject: [PATCH 40/72] Refactor: Move fetch_annotations --- aku/annotations.py | 20 +++----------------- aku/utils.py | 25 +++++++++++++++++++++++++ utils.py | 9 --------- 3 files changed, 28 insertions(+), 26 deletions(-) create mode 100644 aku/utils.py delete mode 100644 utils.py diff --git a/aku/annotations.py b/aku/annotations.py index 4a1fe80..d5c5759 100644 --- a/aku/annotations.py +++ b/aku/annotations.py @@ -1,28 +1,14 @@ import re from abc import ABCMeta, abstractmethod from argparse import ArgumentParser, SUPPRESS, Action -from inspect import getfullargspec -from itertools import zip_longest -from typing import get_args, get_origin, get_type_hints, Any, Union, Literal, Type +from typing import get_args, get_origin, Any, Union, Literal, Type from aku.parsing_fn import get_parse_fn +from aku.utils import fetch_annotations COMMA = re.compile(r'\s*,\s*') -def get_type_annotations(tp): - tys = get_type_hints(tp) - spec = getfullargspec(tp) - - name_default = zip_longest( - reversed(spec.args), - reversed(spec.defaults or []), - fillvalue=SUPPRESS, - ) - for arg_name, arg_default in reversed(list(name_default)): - yield arg_name, arg_default, tys[arg_name] - - class Tp(object, metaclass=ABCMeta): def __init__(self, origin, *args: Union['Tp', Any], **kwargs: ['Tp', Any]) -> None: super(Tp, self).__init__() @@ -97,7 +83,7 @@ def parse_fn(self, string: str) -> Any: raise NotImplementedError def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - for arg_name, arg_default, arg_tp in get_type_annotations(self.origin): + for arg_name, arg_default, arg_tp in fetch_annotations(self.origin): Tp[arg_tp].add_argument( argument_parser=argument_parser, name=arg_name, default=arg_default, diff --git a/aku/utils.py b/aku/utils.py new file mode 100644 index 0000000..23bb1a6 --- /dev/null +++ b/aku/utils.py @@ -0,0 +1,25 @@ +from argparse import SUPPRESS, ArgumentParser +from inspect import getfullargspec +from itertools import zip_longest +from typing import get_type_hints + + +def fetch_annotations(tp): + arg_spec = getfullargspec(tp) + type_hints = get_type_hints(tp) + + name_default = zip_longest( + reversed(arg_spec.args), + reversed(arg_spec.defaults or []), + fillvalue=SUPPRESS, + ) + for name, default in reversed(list(name_default)): + yield name, default, type_hints[name] + + +def fetch_actions(argument_parser: ArgumentParser) -> str: + msg = ', '.join([ + action.option_strings[-1] + for action in argument_parser._actions + ]) + return f"[{msg}]" diff --git a/utils.py b/utils.py deleted file mode 100644 index 26d932a..0000000 --- a/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -from argparse import ArgumentParser - - -def fetch_actions(argument_parser: ArgumentParser) -> str: - msg = ', '.join([ - action.option_strings[-1] - for action in argument_parser._actions - ]) - return f"[{msg}]" From d03440c1458134a9ea6c5d605442ae8d10865034 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 17 Sep 2020 01:27:52 +0900 Subject: [PATCH 41/72] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index caca0de..1df93b1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Aku [![Actions Status](https://github.com/speedcell4/aku/workflows/unit-tests/badge.svg)](https://github.com/speedcell4/aku/actions) +[![Downloads](https://pepy.tech/badge/aku)](https://pepy.tech/project/aku) An Annotation-driven ArgumentParser Generator From ee08a0a14d925afeab0828c35f12c704045be1e3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 14:02:44 +0900 Subject: [PATCH 42/72] Feat: Re-write aku from scratch --- .github/workflows/unit-tests.yml | 6 +- aku/__init__.py | 81 ---------- aku/__main__.py | 35 ++++ aku/abc.py | 15 -- aku/annotations.py | 207 ------------------------ aku/parsing_fn.py | 25 --- aku/tp.py | 264 +++++++++++++++++++++++++++++++ aku/utils.py | 25 --- examples/basic.py | 45 ------ examples/text_classification.py | 165 ------------------- 10 files changed, 302 insertions(+), 566 deletions(-) create mode 100644 aku/__main__.py delete mode 100644 aku/abc.py delete mode 100644 aku/annotations.py delete mode 100644 aku/parsing_fn.py create mode 100644 aku/tp.py delete mode 100644 aku/utils.py delete mode 100644 examples/basic.py delete mode 100644 examples/text_classification.py diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index fbecce8..63384e4 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,4 +1,4 @@ -name: unit-tests +name: Unit Tests on: [push] @@ -8,9 +8,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 - name: Set up Python 3.7 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: 3.7 - name: Install dependencies diff --git a/aku/__init__.py b/aku/__init__.py index 3d3677e..e69de29 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,81 +0,0 @@ -import sys -from typing import List, Tuple, Set, FrozenSet, Dict -from typing import Optional -from typing import Union, Type -from typing import get_type_hints - -if sys.version_info < (3, 8): - import collections - from typing import _SpecialForm, _GenericAlias, Generic - - Literal = _SpecialForm('Literal', doc= - """Special typing form to define literal types (a.k.a. value types). - - This form can be used to indicate to type checkers that the corresponding - variable or function parameter has a value equivalent to the provided - literal (or one of several literals): - - def validate_simple(data: Any) -> Literal[True]: # always returns True - ... - - MODE = Literal['r', 'rb', 'w', 'wb'] - def open_helper(file: str, mode: MODE) -> str: - ... - - open_helper('/some/path', 'r') # Passes type check - open_helper('/other/path', 'typo') # Error in type checker - - Literal[...] cannot be subclassed. At runtime, an arbitrary value - is allowed as type argument to Literal[...], but type checkers may - impose restrictions. - """) - - - def get_origin(tp): - """Get the unsubscripted version of a type. - - This supports generic types, Callable, Tuple, Union, Literal, Final and ClassVar. - Return None for unsupported types. Examples:: - - get_origin(Literal[42]) is Literal - get_origin(int) is None - get_origin(ClassVar[int]) is ClassVar - get_origin(Generic) is Generic - get_origin(Generic[T]) is Generic - get_origin(Union[T, int]) is Union - get_origin(List[Tuple[T, T]][int]) == list - """ - if isinstance(tp, _GenericAlias): - return tp.__origin__ - if tp is Generic: - return Generic - return None - - - def get_args(tp): - """Get type arguments with all substitutions performed. - - For unions, basic simplifications used by Union constructor are performed. - Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) - """ - if isinstance(tp, _GenericAlias): - res = tp.__args__ - if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: - res = (list(res[:-1]), res[-1]) - return res - return () -else: - from typing import Literal - -__all__ = [ - 'List', 'Tuple', 'Set', 'FrozenSet', 'Dict', - 'Optional', - 'Union', 'Type', - 'Literal', - 'get_args', 'get_origin', 'get_type_hints', -] diff --git a/aku/__main__.py b/aku/__main__.py new file mode 100644 index 0000000..693b2e6 --- /dev/null +++ b/aku/__main__.py @@ -0,0 +1,35 @@ +from argparse import ArgumentDefaultsHelpFormatter, SUPPRESS +from typing import Type +from typing import Union, List, Tuple, Literal + +from aku.tp import Aku, _init_argument_parser, AkuTp + +if __name__ == '__main__': + parser = Aku( + prog='oar', usage=None, description=None, + formatter_class=ArgumentDefaultsHelpFormatter, + ) + _init_argument_parser(parser) + + + def foo(f1: int = 3, f2: str = '4'): + print(f'a => {f1}') + print(f'w => {f2}') + + + def bar(b1: Literal[1, 2, 3] = 2, b2: List[int] = [2, 3, 4]): + print(f'c => {b1}') + print(f'd => {b2}') + + + def baz(z1: Tuple[int, str], z2: Tuple[float, ...]): + pass + + + def nice(a: Type[Union[foo, bar, baz]]): + pass + + + AkuTp[Type[nice]].add_argument(parser, 'nice', SUPPRESS) + + print(parser.parse_args()) diff --git a/aku/abc.py b/aku/abc.py deleted file mode 100644 index 0a49847..0000000 --- a/aku/abc.py +++ /dev/null @@ -1,15 +0,0 @@ -from argparse import ArgumentParser, Namespace -from typing import List - - -class Aku(ArgumentParser): - def parse_known_args(self, args: List[str] = None, namespace: Namespace = None): - last_actions_len = -1 - while last_actions_len != len(self._actions): - last_actions_len = len(self._actions) - namespace, args = super(Aku, self).parse_known_args(args=args, namespace=namespace) - return namespace, args - - def parse_args(self, args: List[str] = None, namespace: Namespace = None): - namespace, args = self.parse_known_args(args=args, namespace=namespace) - return super(Aku, self).parse_args(args=args, namespace=namespace) diff --git a/aku/annotations.py b/aku/annotations.py deleted file mode 100644 index d5c5759..0000000 --- a/aku/annotations.py +++ /dev/null @@ -1,207 +0,0 @@ -import re -from abc import ABCMeta, abstractmethod -from argparse import ArgumentParser, SUPPRESS, Action -from typing import get_args, get_origin, Any, Union, Literal, Type - -from aku.parsing_fn import get_parse_fn -from aku.utils import fetch_annotations - -COMMA = re.compile(r'\s*,\s*') - - -class Tp(object, metaclass=ABCMeta): - def __init__(self, origin, *args: Union['Tp', Any], **kwargs: ['Tp', Any]) -> None: - super(Tp, self).__init__() - self.origin = origin - self.args = args - self.kwargs = kwargs - - def __class_getitem__(cls, tp): - args = get_args(tp) - origin = get_origin(tp) - - if origin is None and args == (): - return PrimitiveTp(tp) - - if origin is Literal: - tp = type(args[0]) - - assert all(isinstance(a, tp) for a in args), \ - f'all arguments should have the same type {tp.__name__}' - return PrimitiveTp(tp, *set(args)) - - if origin is list and len(args) == 1: - return ListTp(origin, cls[args[0]]) - - if origin is tuple: - if len(args) == 2 and args[1] is ...: - return HomoTupleTp(origin, cls[args[0]]) - else: - return HeteroTupleTp(origin, *[cls[a] for a in args]) - - if origin is set: - return SetTp(origin, cls[args[0]]) - - if origin is frozenset: - return FrozenSetTp(origin, cls[args[0]]) - - if origin is type: - if get_origin(args[0]) is Union: - args = get_args(args[0]) - return UnionTp(str, **{a.__name__: a for a in args}) - return TypeTp(args[0]) - - if origin is Union: - args = [get_args(a)[0] for a in args] - return UnionTp(str, **{a.__name__: a for a in args}) - - raise NotImplementedError(f'unsupported {cls.__name__} {tp}') - - @property - @abstractmethod - def metavar(self) -> str: - raise NotImplementedError - - @abstractmethod - def parse_fn(self, string: str) -> Any: - raise NotImplementedError - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', help=f'{name}', - type=self.parse_fn, metavar=self.metavar, required=default == SUPPRESS, - default=repr(default) if default != SUPPRESS else SUPPRESS, - ) - - -class TypeTp(Tp): - @property - def metavar(self) -> str: - raise NotImplementedError - - def parse_fn(self, string: str) -> Any: - raise NotImplementedError - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - for arg_name, arg_default, arg_tp in fetch_annotations(self.origin): - Tp[arg_tp].add_argument( - argument_parser=argument_parser, - name=arg_name, default=arg_default, - ) - - -class PrimitiveTp(Tp): - @property - def metavar(self) -> str: - if len(self.args) == 0: - return f'{self.origin.__name__.lower()}' - return f"{{{', '.join([f'{repr(a)}' for a in self.args])}}}" - - def parse_fn(self, string: str) -> Any: - return get_parse_fn(self.origin)(string.strip()) - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - return argument_parser.add_argument( - f'--{name}', help=f'{name}', - choices=self.args if len(self.args) > 0 else None, - type=self.parse_fn, metavar=self.metavar, required=default == SUPPRESS, - default=repr(default) if default != SUPPRESS else SUPPRESS, - ) - - -class ListTp(Tp): - @property - def metavar(self) -> str: - return f'[{self.args[0].metavar}]' - - def parse_fn(self, string: str) -> Any: - string = string.strip() - if not string.startswith('[') or not string.endswith(']'): - raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - - strings = re.split(COMMA, string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in strings) - - -class HomoTupleTp(Tp): - @property - def metavar(self) -> str: - return f'({self.args[0].metavar}, ...)' - - def parse_fn(self, string: str) -> Any: - string = string.strip() - if not string.startswith('(') or not string.endswith(')'): - raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - - strings = re.split(COMMA, string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in strings) - - -class HeteroTupleTp(Tp): - @property - def metavar(self) -> str: - return f"({', '.join([f'{a.metavar}' for a in self.args])})" - - def parse_fn(self, string: str) -> Any: - string = string.strip() - if not string.startswith('(') or not string.endswith(')'): - raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - - strings = re.split(COMMA, string[1:-1]) - assert len(strings) == len(self.args), \ - f'the number of arguments is not correct, ' \ - f'got {len(strings)} but excepted {len(self.args)}' - - return self.origin(a.parse_fn(s) for s, a in zip(strings, self.args)) - - -class SetTp(Tp): - @property - def metavar(self) -> str: - return f'{{{self.args[0].metavar}}}' - - def parse_fn(self, string: str) -> Any: - string = string.strip() - if not string.startswith('{') or not string.endswith('}'): - raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - - strings = re.split(COMMA, string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in strings) - - -class FrozenSetTp(Tp): - @property - def metavar(self) -> str: - return f'{{{self.args[0].metavar}}}' - - def parse_fn(self, string: str) -> Any: - string = string.strip() - if not string.startswith('{') or not string.endswith('}'): - raise ValueError(f'{string} is not a(n) {self.origin.__name__}') - - strings = re.split(COMMA, string[1:-1]) - return self.origin(self.args[0].parse_fn(s) for s in strings) - - -class UnionTp(Tp): - @property - def metavar(self) -> str: - return f"{{{', '.join(self.kwargs.keys())}}}" - - def parse_fn(self, string: str) -> Any: - return str(string.strip()) - - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any): - choices = self.kwargs - - class _UnionTpAction(Action): - def __call__(self, parser: ArgumentParser, namespace, values, option_string=...) -> None: - Tp[Type[choices[values]]].add_argument( - argument_parser=argument_parser, name=name, default=default) - setattr(namespace, name, values) - - argument_parser.add_argument( - f'--{name}', help=f'{name}', required=default == SUPPRESS, - type=self.parse_fn, metavar=self.metavar, action=_UnionTpAction, - choices=list(choices.keys()), default=default.__name__, - ) diff --git a/aku/parsing_fn.py b/aku/parsing_fn.py deleted file mode 100644 index 39f8e7c..0000000 --- a/aku/parsing_fn.py +++ /dev/null @@ -1,25 +0,0 @@ -import inspect - -registry = {} - - -def register_parse_fn(tp): - ret = inspect.getfullargspec(tp).annotations['return'] - assert ret not in registry - - registry[ret] = tp - return tp - - -def get_parse_fn(tp): - return registry.get(tp, tp) - - -@register_parse_fn -def parse_bool(string: str) -> bool: - string = string.strip().lower() - if string in ('1', 'y', 'yes', 't', 'true'): - return True - if string in ('0', 'n', 'no', 'f', 'false'): - return False - raise ValueError(f'{string} is not a {bool.__name__} value') diff --git a/aku/tp.py b/aku/tp.py new file mode 100644 index 0000000..68d8072 --- /dev/null +++ b/aku/tp.py @@ -0,0 +1,264 @@ +import inspect +import re +from argparse import ArgumentParser, Action, Namespace, SUPPRESS +from re import Pattern +from typing import Union, Tuple, Literal, Any +from typing import get_origin, get_args, get_type_hints + +NEW_ACTIONS = '_new_actions' + + +def tp_none(arg_strings: str) -> type(None): + arg_strings = arg_strings.lower().strip() + if arg_strings in ('nil', 'null', 'none'): + return None + raise ValueError + + +def tp_bool(arg_strings: str) -> bool: + arg_strings = arg_strings.lower().strip() + if arg_strings in ('t', 'true', 'y', 'yes', '1'): + return True + if arg_strings in ('f', 'false', 'n', 'no', '0'): + return False + raise ValueError + + +def register_type(fn, argument_parser: ArgumentParser): + tp = get_type_hints(fn)['return'] + registry = argument_parser._registries['type'] + if tp not in registry: + registry.setdefault(tp, fn) + return fn + + +def register_homo_tuple(tp: type, argument_parser: ArgumentParser, + pattern: Pattern = re.compile(r',\s*')) -> None: + def fn(arg_strings: str) -> Tuple[tp, ...]: + nonlocal tp + + tp = argument_parser._registry_get('type', tp, tp) + return tuple(tp(arg) for arg in re.split(pattern, arg_strings.strip())) + + return register_type(fn, argument_parser) + + +def register_hetero_tuple(tps: Tuple[type, ...], argument_parser: ArgumentParser, + pattern: Pattern = re.compile(r',\s*')) -> None: + def fn(arg_strings: str) -> Tuple[tps]: + nonlocal tps + + tps = [argument_parser._registry_get('type', tp, tp) for tp in tps] + return tuple(tp(arg) for tp, arg in zip(tps, re.split(pattern, arg_strings.strip()))) + + return register_type(fn, argument_parser) + + +def _init_argument_parser(argument_parser: ArgumentParser): + register_type(tp_bool, argument_parser) + + +def tp_iter(fn): + is_method = inspect.ismethod(fn) + if inspect.isclass(fn): + fn = fn.__init__ + is_method = True + + tps = get_type_hints(fn) + spec = inspect.getfullargspec(fn) + args = spec.args or [] + defaults = spec.defaults or [] + defaults = {a: d for a, d in zip(args[::-1], defaults[::-1])} + + for index, arg in enumerate(args[1:] if is_method else args): + yield arg, tps[arg], defaults.get(arg, SUPPRESS) + + +class AkuTp(object): + def __init__(self, tp, choices): + super(AkuTp, self).__init__() + self.tp = tp + self.choices = choices + + registry = [] + + def __init_subclass__(cls, **kwargs): + cls.registry.append(cls) + + def __class_getitem__(cls, tp): + origin = get_origin(tp) + args = get_args(tp) + for aku_ty in cls.registry: + try: + return aku_ty[origin, args] + except TypeError: + pass + raise TypeError(f'unsupported annotation {tp}') + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + raise NotImplementedError + + +class StorePrimitiveAction(Action): + def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): + setattr(namespace, self.dest, values) + self.required = False + + +class AkuPrimitive(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is None: + return AkuPrimitive(tp, None) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + argument_parser.add_argument( + f'--{name}', type=self.tp, choices=self.choices, required=True, + action=StorePrimitiveAction, default=default, + ) + + +class AppendListAction(Action): + def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): + flag_name = '_aku_visited' + if not getattr(self, flag_name, False): + setattr(namespace, self.dest, []) + setattr(self, flag_name, True) + getattr(namespace, self.dest).append(values) + + +class AkuList(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is list: + return AkuList(args[0], None) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + argument_parser.add_argument( + f'--{name}', type=self.tp, choices=self.choices, required=True, + action=AppendListAction, default=default, + ) + + +class AkuHomoTuple(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is tuple: + if len(args) == 2 and args[1] is ...: + return AkuHomoTuple(args[0], None) + else: + return AkuHeteroTuple(args, None) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + argument_parser.add_argument( + f'--{name}', type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=True, + action=StorePrimitiveAction, default=default, + ) + + +class AkuHeteroTuple(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is tuple: + if len(args) == 2 and args[1] is ...: + return AkuHomoTuple(args[0], None) + else: + return AkuHeteroTuple(args, None) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + argument_parser.add_argument( + f'--{name}', type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=True, + action=StorePrimitiveAction, default=default, + ) + + +class AkuLiteral(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is Literal: + if len(args) > 0: + tp = type(args[0]) + for arg in args[1:]: + assert isinstance(arg, tp), f'{type(arg)} is not {tp}' + return AkuLiteral(tp, args) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + argument_parser.add_argument( + f'--{name}', type=self.tp, choices=self.choices, required=True, + action=StorePrimitiveAction, default=default, + ) + + +class AkuType(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is type: + if len(args) == 1: + if get_origin(args[0]) == Union: + return AkuUnion(str, get_args(args[0])) + else: + return AkuType(args[0], None) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + for arg, tp, df in tp_iter(self.tp): + tp = AkuTp[tp] + if name.endswith('_'): + arg = f'{name}{arg}' + tp.add_argument(argument_parser=argument_parser, name=arg, default=df) + + +class AkuUnion(AkuTp): + def __class_getitem__(cls, tp): + origin, args = tp + if origin is type: + if len(args) == 1: + if get_origin(args[0]) == Union: + return AkuUnion(str, get_args(args[0])) + else: + return AkuType(args[0], None) + elif origin is Union: + args = [ + get_args(arg)[0] + for arg in get_args(tp) + ] + return AkuUnion(str, args) + raise TypeError + + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + choices = {c.__name__: c for c in self.choices} + + class UnionAction(Action): + def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): + setattr(namespace, self.dest, values) + self.required = False + + num_actions = len(parser._actions) + AkuType(choices[values], None).add_argument(argument_parser=parser, name=name, default=None) + parser._actions, new_actions = parser._actions[:num_actions], parser._actions[num_actions:] + setattr(parser, NEW_ACTIONS, getattr(parser, NEW_ACTIONS, []) + new_actions) + + argument_parser.add_argument( + f'--{name}' if not name.endswith('_') else f'--{name[:-1]}', + type=self.tp, choices=tuple(choices.keys()), required=True, + action=UnionAction, + ) + + +class Aku(ArgumentParser): + def parse_args(self, args=None) -> Namespace: + namespace, args = None, None + while True: + namespace, args = self.parse_known_args(args=args, namespace=namespace) + if hasattr(self, NEW_ACTIONS): + self._actions = self._actions + getattr(self, NEW_ACTIONS) + delattr(self, NEW_ACTIONS) + else: + break + + return namespace diff --git a/aku/utils.py b/aku/utils.py deleted file mode 100644 index 23bb1a6..0000000 --- a/aku/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from argparse import SUPPRESS, ArgumentParser -from inspect import getfullargspec -from itertools import zip_longest -from typing import get_type_hints - - -def fetch_annotations(tp): - arg_spec = getfullargspec(tp) - type_hints = get_type_hints(tp) - - name_default = zip_longest( - reversed(arg_spec.args), - reversed(arg_spec.defaults or []), - fillvalue=SUPPRESS, - ) - for name, default in reversed(list(name_default)): - yield name, default, type_hints[name] - - -def fetch_actions(argument_parser: ArgumentParser) -> str: - msg = ', '.join([ - action.option_strings[-1] - for action in argument_parser._actions - ]) - return f"[{msg}]" diff --git a/examples/basic.py b/examples/basic.py deleted file mode 100644 index 46388a3..0000000 --- a/examples/basic.py +++ /dev/null @@ -1,45 +0,0 @@ -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter - -from aku.tp import Tp -from typing import List, Tuple, Set, FrozenSet, Literal - -parser = ArgumentParser( - formatter_class=ArgumentDefaultsHelpFormatter, -) - -Tp[str].add_argument(parser, 'pa', 'a') -Tp[int].add_argument(parser, 'pb', 1) -Tp[bool].add_argument(parser, 'pc', False) -Tp[float].add_argument(parser, 'pd', 2.0) - -Tp[Literal['a', 'b', 'a']].add_argument(parser, 'la', 'a') -Tp[Literal[2, 3, 2]].add_argument(parser, 'lb', 3) -Tp[Literal[True, False, True]].add_argument(parser, 'lc', False) -Tp[Literal[2.0, 3.0]].add_argument(parser, 'ld', 2.0) - -Tp[List[str]].add_argument(parser, 'sa', ['a', 'b']) -Tp[List[int]].add_argument(parser, 'sb', [2, 3]) -Tp[List[bool]].add_argument(parser, 'sc', [True, False]) -Tp[List[float]].add_argument(parser, 'sd', [2.0, 3.0]) - -Tp[Tuple[str, ...]].add_argument(parser, 'oa', ('a', 'b')) -Tp[Tuple[int, ...]].add_argument(parser, 'ob', (2, 3)) -Tp[Tuple[bool, ...]].add_argument(parser, 'oc', (True, False)) -Tp[Tuple[float, ...]].add_argument(parser, 'od', (2.0, 3.0)) - -Tp[Tuple[str, int]].add_argument(parser, 'ea', ('a', 1)) -Tp[Tuple[int, int]].add_argument(parser, 'eb', (2, 1)) -Tp[Tuple[bool, int]].add_argument(parser, 'ec', (True, 1)) -Tp[Tuple[float, int]].add_argument(parser, 'ed', (2.0, 1)) - -Tp[Set[str]].add_argument(parser, 'ta', {'a'}) -Tp[Set[int]].add_argument(parser, 'tb', {2}) -Tp[Set[bool]].add_argument(parser, 'tc', {True}) -Tp[Set[float]].add_argument(parser, 'td', {2.0}) - -# Tp[FrozenSet[str]].add_argument(parser, 'fa', frozenset({'a'})) -# Tp[FrozenSet[int]].add_argument(parser, 'fb', frozenset({2})) -# Tp[FrozenSet[bool]].add_argument(parser, 'fc', frozenset({True})) -# Tp[FrozenSet[float]].add_argument(parser, 'fd', frozenset({2.0})) - -print(parser.parse_args()) diff --git a/examples/text_classification.py b/examples/text_classification.py deleted file mode 100644 index 836d25c..0000000 --- a/examples/text_classification.py +++ /dev/null @@ -1,165 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import Type -from typing import Union, List, Tuple - -import torch -from einops import rearrange -from torch import Tensor -from torch import nn, optim -from torch.nn.utils.rnn import PackedSequence -from torch.nn.utils.rnn import pad_packed_sequence -from torchglyph.vocab import Vocab - -from aku import Literal - - -class Corpus(object): - @classmethod - def new(cls, batch_size: int) -> None: - raise NotImplementedError - - -class WordEmbedding(nn.Embedding): - def __init__(self, *, word_vocab: Vocab) -> None: - super(WordEmbedding, self).__init__( - num_embeddings=len(word_vocab), - embedding_dim=word_vocab.vec_dim, - padding_idx=word_vocab.pad_idx, - _weight=word_vocab.vectors, - ) - - def forward(self, word: PackedSequence) -> PackedSequence: - data = super(WordEmbedding, self).forward(word.data) - return word._replace(data=data) - - -class Encoder(nn.Module, metaclass=ABCMeta): - encoding_dim: int - - def __init__(self, *, embedding_layer: WordEmbedding) -> None: - super(Encoder, self).__init__() - - @abstractmethod - def forward(self, embedding: PackedSequence) -> Tensor: - raise NotImplementedError - - -class LstmEncoder(Encoder): - def __init__(self, hidden_dim: int = 300, num_layers: int = 1, *, embedding_layer: WordEmbedding) -> None: - super(LstmEncoder, self).__init__(embedding_layer=embedding_layer) - self.rnn = nn.LSTM( - input_size=embedding_layer.embedding_dim, - hidden_size=hidden_dim, num_layers=num_layers, - bias=True, batch_first=True, bidirectional=True, - ) - - self.encoding_dim = self.rnn.hidden_size * (2 if self.rnn.bidirectional else 1) - - def forward(self, embedding: PackedSequence) -> Tensor: - _, (hidden, _) = self.rnn(embedding) - return rearrange(hidden, '(l d) b h -> l b (d h)', l=self.rnn.num_layers)[-1] - - -class ConvEncoder(Encoder): - def __init__(self, kernel_sizes: Tuple[int, ...] = (3, 5, 7), hidden_dim: int = 200, *, - embedding_layer: WordEmbedding) -> None: - super(ConvEncoder, self).__init__(embedding_layer=embedding_layer) - self.convs = nn.ModuleList([ - nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=hidden_dim, - kernel_size=(kernel_size, embedding_layer.embedding_dim), - padding=(kernel_size // 2, 0), bias=True, - ), - nn.AdaptiveMaxPool2d(output_size=(1, 1)), - nn.ReLU(), - ) - for kernel_size in kernel_sizes - ]) - - self.encoding_dim = hidden_dim * len(kernel_sizes) - - def forward(self, embedding: PackedSequence) -> Tensor: - data, _ = pad_packed_sequence(embedding, batch_first=True) - return torch.cat([conv(data[:, None, :, :])[:, :, 0, 0] for conv in self.convs], dim=-1) - - -class Projection(nn.Sequential): - def __init__(self, *, target_vocab: Vocab, encoding_layer: Encoder) -> None: - super(Projection, self).__init__( - nn.Linear(encoding_layer.encoding_dim, encoding_layer.encoding_dim), - nn.ReLU(), - nn.Linear(encoding_layer.encoding_dim, len(target_vocab)), - ) - - -class TextClassifier(nn.Module): - def __init__(self, - Emb: Type[WordEmbedding] = WordEmbedding, - Enc: Type[Union[LstmEncoder, ConvEncoder]] = LstmEncoder, - Proj: Type[Projection] = Projection, - reduction: Literal['sum', 'mean'] = 'mean', *, - word_vocab: Vocab, target_vocab: Vocab) -> None: - super(TextClassifier, self).__init__() - - self.embedding_layer = Emb(word_vocab=word_vocab) - self.encoding_layer = Enc(embedding_layer=self.embedding_layer) - self.projection_layer = Proj(target_vocab=target_vocab, encoding_layer=self.encoding_layer) - - self.criterion = nn.CrossEntropyLoss( - ignore_index=target_vocab.pad_idx, - reduction=reduction, - ) - - def forward(self, word: PackedSequence) -> Tensor: - embedding = self.embedding_layer(word) - encoding = self.encoding_layer(embedding) - return self.projection_layer(encoding) - - def fit(self, word: PackedSequence, target: Tensor) -> Tensor: - projection = self(word) - return self.criterion(projection, target) - - def inference(self, word: PackedSequence) -> List[int]: - projection = self(word) - return projection.detach().cpu().argmax(dim=-1).tolist() - - -class SGD(optim.SGD): - def __init__(self, lr: float = 1e-3, momentum: float = 0, dampening: float = 0, - weight_decay: float = 0, nesterov: bool = False, *, module: nn.Module) -> None: - super(SGD, self).__init__( - params=module.parameters(), - lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, - ) - - -class Adam(optim.Adam): - def __init__(self, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, - weight_decay: float = 0, amsgrad: bool = False, *, module: nn.Module) -> None: - super(Adam, self).__init__( - params=module.parameters(), - lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, - ) - - -def train_classifier( - num_epochs: int = 100, - Data: Type[Corpus.new] = Corpus.new, - Cls: Type[TextClassifier] = TextClassifier, - Opt: Type[Union[SGD, Adam]] = Adam, -): - train, dev, test = Data() - classifier = Cls(word_vocab=..., target_vocab=...) - optimizer = Opt(module=classifier) - - for epoch in range(1, num_epochs + 1): - classifier.train() - for batch in train: - loss = classifier(batch) - - optimizer.zero_grad() - loss.backward() - optimizer.step() From 09301d5a403715aa483ae2c223c2e2151dfeb9a6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 14:17:19 +0900 Subject: [PATCH 43/72] Fix: fix AkuPrimitive bug --- aku/tp.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 68d8072..9b724ac 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -80,6 +80,9 @@ def __init__(self, tp, choices): self.tp = tp self.choices = choices + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.tp.__name__}, {self.choices})' + registry = [] def __init_subclass__(cls, **kwargs): @@ -90,7 +93,7 @@ def __class_getitem__(cls, tp): args = get_args(tp) for aku_ty in cls.registry: try: - return aku_ty[origin, args] + return aku_ty[tp, origin, args] except TypeError: pass raise TypeError(f'unsupported annotation {tp}') @@ -107,7 +110,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ class AkuPrimitive(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is None: return AkuPrimitive(tp, None) raise TypeError @@ -130,7 +133,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ class AkuList(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is list: return AkuList(args[0], None) raise TypeError @@ -144,7 +147,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) class AkuHomoTuple(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is tuple: if len(args) == 2 and args[1] is ...: return AkuHomoTuple(args[0], None) @@ -161,7 +164,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) class AkuHeteroTuple(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is tuple: if len(args) == 2 and args[1] is ...: return AkuHomoTuple(args[0], None) @@ -178,7 +181,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) class AkuLiteral(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is Literal: if len(args) > 0: tp = type(args[0]) @@ -196,7 +199,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) class AkuType(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is type: if len(args) == 1: if get_origin(args[0]) == Union: @@ -215,7 +218,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) class AkuUnion(AkuTp): def __class_getitem__(cls, tp): - origin, args = tp + tp, origin, args = tp if origin is type: if len(args) == 1: if get_origin(args[0]) == Union: From cd3d0ebcc4b2782f6a43dc1a9ee78b61f5b2c6c7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 14:27:34 +0900 Subject: [PATCH 44/72] Refactor: update Actions --- aku/tp.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 9b724ac..544b808 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -102,7 +102,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) raise NotImplementedError -class StorePrimitiveAction(Action): +class StoreAction(Action): def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): setattr(namespace, self.dest, values) self.required = False @@ -118,17 +118,17 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: argument_parser.add_argument( f'--{name}', type=self.tp, choices=self.choices, required=True, - action=StorePrimitiveAction, default=default, + action=StoreAction, default=default, ) class AppendListAction(Action): def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): - flag_name = '_aku_visited' - if not getattr(self, flag_name, False): + if not getattr(self, '_aku_visited', False): + setattr(self, '_aku_visited', True) setattr(namespace, self.dest, []) - setattr(self, flag_name, True) getattr(namespace, self.dest).append(values) + self.required = False class AkuList(AkuTp): @@ -158,7 +158,7 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: argument_parser.add_argument( f'--{name}', type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=True, - action=StorePrimitiveAction, default=default, + action=StoreAction, default=default, ) @@ -175,7 +175,7 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: argument_parser.add_argument( f'--{name}', type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=True, - action=StorePrimitiveAction, default=default, + action=StoreAction, default=default, ) @@ -193,7 +193,7 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: argument_parser.add_argument( f'--{name}', type=self.tp, choices=self.choices, required=True, - action=StorePrimitiveAction, default=default, + action=StoreAction, default=default, ) From ee7193112db0ac78f1345e83a215164d86a261a9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 14:55:39 +0900 Subject: [PATCH 45/72] Feat: support domain --- aku/__main__.py | 27 ++++++++---------- aku/tp.py | 75 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 38 deletions(-) diff --git a/aku/__main__.py b/aku/__main__.py index 693b2e6..0fcdaf2 100644 --- a/aku/__main__.py +++ b/aku/__main__.py @@ -1,6 +1,6 @@ from argparse import ArgumentDefaultsHelpFormatter, SUPPRESS from typing import Type -from typing import Union, List, Tuple, Literal +from typing import Union, List, Literal from aku.tp import Aku, _init_argument_parser, AkuTp @@ -12,24 +12,19 @@ _init_argument_parser(parser) - def foo(f1: int = 3, f2: str = '4'): - print(f'a => {f1}') - print(f'w => {f2}') + def foo(x: int = 3, y: str = '4'): + print(f'x => {x}') + print(f'y => {y}') - def bar(b1: Literal[1, 2, 3] = 2, b2: List[int] = [2, 3, 4]): - print(f'c => {b1}') - print(f'd => {b2}') + def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4]): + print(f'x => {x}') + print(f'y => {y}') - def baz(z1: Tuple[int, str], z2: Tuple[float, ...]): - pass - - - def nice(a: Type[Union[foo, bar, baz]]): - pass - - - AkuTp[Type[nice]].add_argument(parser, 'nice', SUPPRESS) + AkuTp[Type[Union[foo, bar]]].add_argument( + argument_parser=parser, name='fn_', default=SUPPRESS, + prefixes=(), domain=(), + ) print(parser.parse_args()) diff --git a/aku/tp.py b/aku/tp.py index 544b808..2c82779 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -74,6 +74,16 @@ def tp_iter(fn): yield arg, tps[arg], defaults.get(arg, SUPPRESS) +def join_names(prefixes: Tuple[str, ...], name: str) -> str: + if name.endswith('_'): + name = name[:-1] + return '-'.join(prefixes + (name,)).lower() + + +def join_dests(domain: Tuple[str, ...], name: str) -> str: + return '.'.join(domain + (name,)).lower() + + class AkuTp(object): def __init__(self, tp, choices): super(AkuTp, self).__init__() @@ -98,7 +108,8 @@ def __class_getitem__(cls, tp): pass raise TypeError(f'unsupported annotation {tp}') - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: raise NotImplementedError @@ -115,9 +126,11 @@ def __class_getitem__(cls, tp): return AkuPrimitive(tp, None) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( - f'--{name}', type=self.tp, choices=self.choices, required=True, + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + type=self.tp, choices=self.choices, required=True, action=StoreAction, default=default, ) @@ -138,9 +151,11 @@ def __class_getitem__(cls, tp): return AkuList(args[0], None) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( - f'--{name}', type=self.tp, choices=self.choices, required=True, + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + type=self.tp, choices=self.choices, required=True, action=AppendListAction, default=default, ) @@ -155,9 +170,11 @@ def __class_getitem__(cls, tp): return AkuHeteroTuple(args, None) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( - f'--{name}', type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=True, + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=True, action=StoreAction, default=default, ) @@ -172,9 +189,11 @@ def __class_getitem__(cls, tp): return AkuHeteroTuple(args, None) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( - f'--{name}', type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=True, + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=True, action=StoreAction, default=default, ) @@ -190,14 +209,16 @@ def __class_getitem__(cls, tp): return AkuLiteral(tp, args) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( - f'--{name}', type=self.tp, choices=self.choices, required=True, + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + type=self.tp, choices=self.choices, required=True, action=StoreAction, default=default, ) -class AkuType(AkuTp): +class AkuFn(AkuTp): def __class_getitem__(cls, tp): tp, origin, args = tp if origin is type: @@ -205,15 +226,21 @@ def __class_getitem__(cls, tp): if get_origin(args[0]) == Union: return AkuUnion(str, get_args(args[0])) else: - return AkuType(args[0], None) + return AkuFn(args[0], None) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + if name.endswith('_'): + prefixes = prefixes + (name[:-1],) + + domain = domain + (name,) for arg, tp, df in tp_iter(self.tp): tp = AkuTp[tp] - if name.endswith('_'): - arg = f'{name}{arg}' - tp.add_argument(argument_parser=argument_parser, name=arg, default=df) + tp.add_argument( + argument_parser=argument_parser, name=arg, + prefixes=prefixes, domain=domain, default=df, + ) class AkuUnion(AkuTp): @@ -224,7 +251,7 @@ def __class_getitem__(cls, tp): if get_origin(args[0]) == Union: return AkuUnion(str, get_args(args[0])) else: - return AkuType(args[0], None) + return AkuFn(args[0], None) elif origin is Union: args = [ get_args(arg)[0] @@ -233,21 +260,25 @@ def __class_getitem__(cls, tp): return AkuUnion(str, args) raise TypeError - def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any) -> None: + def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, + prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: choices = {c.__name__: c for c in self.choices} class UnionAction(Action): def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): - setattr(namespace, self.dest, values) + setattr(namespace, self.dest, choices[values]) self.required = False num_actions = len(parser._actions) - AkuType(choices[values], None).add_argument(argument_parser=parser, name=name, default=None) + AkuFn(choices[values], None).add_argument( + argument_parser=parser, name=name, + prefixes=prefixes, domain=domain, default=None, + ) parser._actions, new_actions = parser._actions[:num_actions], parser._actions[num_actions:] setattr(parser, NEW_ACTIONS, getattr(parser, NEW_ACTIONS, []) + new_actions) argument_parser.add_argument( - f'--{name}' if not name.endswith('_') else f'--{name[:-1]}', + f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=self.tp, choices=tuple(choices.keys()), required=True, action=UnionAction, ) From 81c6ed9e633c7f60f247914c60b2774c733f18c5 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 15:33:12 +0900 Subject: [PATCH 46/72] Feat: run function --- aku/__main__.py | 22 ++++++++++++---------- aku/tp.py | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/aku/__main__.py b/aku/__main__.py index 0fcdaf2..d813050 100644 --- a/aku/__main__.py +++ b/aku/__main__.py @@ -5,26 +5,28 @@ from aku.tp import Aku, _init_argument_parser, AkuTp if __name__ == '__main__': - parser = Aku( + aku = Aku( prog='oar', usage=None, description=None, formatter_class=ArgumentDefaultsHelpFormatter, ) - _init_argument_parser(parser) + _init_argument_parser(aku) - def foo(x: int = 3, y: str = '4'): - print(f'x => {x}') - print(f'y => {y}') + def foo(x: int = 3, y: str = '4', **kwargs): + for _ in range(x): + print(f'foo => {y}') + print(kwargs['@aku']) - def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4]): - print(f'x => {x}') - print(f'y => {y}') + def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], **kwargs): + print(f'bar.x => {x}') + print(f'bar.y => {y}') + print(kwargs['@aku']) AkuTp[Type[Union[foo, bar]]].add_argument( - argument_parser=parser, name='fn_', default=SUPPRESS, + argument_parser=aku, name='fn_', default=SUPPRESS, prefixes=(), domain=(), ) - print(parser.parse_args()) + print(aku.run()) diff --git a/aku/tp.py b/aku/tp.py index 2c82779..c14fa91 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,3 +1,4 @@ +import functools import inspect import re from argparse import ArgumentParser, Action, Namespace, SUPPRESS @@ -278,7 +279,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ setattr(parser, NEW_ACTIONS, getattr(parser, NEW_ACTIONS, []) + new_actions) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{join_names(prefixes, name)}', dest=join_dests(domain + (name,), '@fn'), type=self.tp, choices=tuple(choices.keys()), required=True, action=UnionAction, ) @@ -296,3 +297,41 @@ def parse_args(self, args=None) -> Namespace: break return namespace + + def run(self, namespace: Namespace = None): + if namespace is None: + namespace = self.parse_args() + if isinstance(namespace, Namespace): + namespace = namespace.__dict__ + + args = {} + for key, value in namespace.items(): + collection = args + *names, key = key.split('.') + for name in names: + collection = collection.setdefault(name, {}) + if key == '@fn': + collection[key] = value + else: + collection.setdefault('@args', {})[key] = value + + def recur(x): + if isinstance(x, dict): + if '@fn' in x: + kwargs = {key: recur(value) for key, value in x['@args'].items()} + return functools.partial(x['@fn'], **kwargs) + else: + return { + key: recur(value) + for key, value in x.items() + } + else: + return x + + ret = recur(args) + assert len(ret) == 1 + for _, fn in ret.items(): + if inspect.getfullargspec(fn).varkw is None: + return fn() + else: + return fn(**{'@aku': args}) From 831c7067908cf083398798316ed5160f49f203cb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 15:40:10 +0900 Subject: [PATCH 47/72] Feat: extend ArgumentParser --- aku/__main__.py | 15 ++++++--------- aku/tp.py | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/aku/__main__.py b/aku/__main__.py index d813050..8c5e65a 100644 --- a/aku/__main__.py +++ b/aku/__main__.py @@ -1,20 +1,17 @@ -from argparse import ArgumentDefaultsHelpFormatter, SUPPRESS +from argparse import SUPPRESS from typing import Type from typing import Union, List, Literal -from aku.tp import Aku, _init_argument_parser, AkuTp +from aku.tp import Aku, AkuTp if __name__ == '__main__': - aku = Aku( - prog='oar', usage=None, description=None, - formatter_class=ArgumentDefaultsHelpFormatter, - ) - _init_argument_parser(aku) + aku = Aku() - def foo(x: int = 3, y: str = '4', **kwargs): + def foo(x: int = 3, y: str = '4', z: bool = True, **kwargs): for _ in range(x): print(f'foo => {y}') + print(f'z => {z}') print(kwargs['@aku']) @@ -25,7 +22,7 @@ def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], **kwargs): AkuTp[Type[Union[foo, bar]]].add_argument( - argument_parser=aku, name='fn_', default=SUPPRESS, + argument_parser=aku, name='fn', default=SUPPRESS, prefixes=(), domain=(), ) diff --git a/aku/tp.py b/aku/tp.py index c14fa91..c2ad937 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,7 +1,7 @@ import functools import inspect import re -from argparse import ArgumentParser, Action, Namespace, SUPPRESS +from argparse import ArgumentParser, Action, Namespace, SUPPRESS, ArgumentDefaultsHelpFormatter from re import Pattern from typing import Union, Tuple, Literal, Any from typing import get_origin, get_args, get_type_hints @@ -9,13 +9,6 @@ NEW_ACTIONS = '_new_actions' -def tp_none(arg_strings: str) -> type(None): - arg_strings = arg_strings.lower().strip() - if arg_strings in ('nil', 'null', 'none'): - return None - raise ValueError - - def tp_bool(arg_strings: str) -> bool: arg_strings = arg_strings.lower().strip() if arg_strings in ('t', 'true', 'y', 'yes', '1'): @@ -131,7 +124,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), - type=self.tp, choices=self.choices, required=True, + type=self.tp, choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, ) @@ -156,7 +149,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), - type=self.tp, choices=self.choices, required=True, + type=self.tp, choices=self.choices, required=default == SUPPRESS, action=AppendListAction, default=default, ) @@ -175,7 +168,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), - type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=True, + type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, ) @@ -194,7 +187,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), - type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=True, + type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, ) @@ -205,7 +198,8 @@ def __class_getitem__(cls, tp): if origin is Literal: if len(args) > 0: tp = type(args[0]) - for arg in args[1:]: + for arg in args: + assert get_origin(arg) is None, f'{arg} is not a primitive type' assert isinstance(arg, tp), f'{type(arg)} is not {tp}' return AkuLiteral(tp, args) raise TypeError @@ -214,7 +208,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), - type=self.tp, choices=self.choices, required=True, + type=self.tp, choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, ) @@ -280,12 +274,32 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain + (name,), '@fn'), - type=self.tp, choices=tuple(choices.keys()), required=True, + type=self.tp, choices=tuple(choices.keys()), required=default == SUPPRESS, action=UnionAction, ) class Aku(ArgumentParser): + def __init__(self, prog=__file__, + usage=None, + description=None, + epilog=None, + parents=(), + formatter_class=ArgumentDefaultsHelpFormatter, + prefix_chars='-', + fromfile_prefix_chars=None, + argument_default=None, + conflict_handler='error', + add_help=True, + allow_abbrev=True, + exit_on_error=True) -> None: + super(Aku, self).__init__( + prog, usage, description, epilog, parents, formatter_class, prefix_chars, + fromfile_prefix_chars, argument_default, conflict_handler, add_help, allow_abbrev, + exit_on_error, + ) + _init_argument_parser(self) + def parse_args(self, args=None) -> Namespace: namespace, args = None, None while True: From 0ff3dc386afb7e63931ef44aed11b9cea1b7f70f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 15:51:51 +0900 Subject: [PATCH 48/72] Feat: metavar --- aku/__main__.py | 11 ++++++++--- aku/tp.py | 13 +++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/aku/__main__.py b/aku/__main__.py index 8c5e65a..f780097 100644 --- a/aku/__main__.py +++ b/aku/__main__.py @@ -1,5 +1,6 @@ from argparse import SUPPRESS -from typing import Type +from pathlib import Path +from typing import Type, Tuple from typing import Union, List, Literal from aku.tp import Aku, AkuTp @@ -8,16 +9,20 @@ aku = Aku() - def foo(x: int = 3, y: str = '4', z: bool = True, **kwargs): + def foo(x: int = 3, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): for _ in range(x): print(f'foo => {y}') print(f'z => {z}') + print(f'w => {w}') print(kwargs['@aku']) - def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], **kwargs): + def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], + z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): print(f'bar.x => {x}') print(f'bar.y => {y}') + print(f'bar.z => {z}') + print(f'bar.w => {w}') print(kwargs['@aku']) diff --git a/aku/tp.py b/aku/tp.py index c2ad937..a197a68 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -125,7 +125,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=self.tp, choices=self.choices, required=default == SUPPRESS, - action=StoreAction, default=default, + action=StoreAction, default=default, metavar=self.tp.__name__.lower(), ) @@ -150,7 +150,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=self.tp, choices=self.choices, required=default == SUPPRESS, - action=AppendListAction, default=default, + action=AppendListAction, default=default, metavar=f'[{self.tp.__name__.lower()}]', ) @@ -169,7 +169,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, - action=StoreAction, default=default, + action=StoreAction, default=default, metavar=f'({self.tp.__name__.lower()}, ...)', ) @@ -185,10 +185,11 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + metavars = ', '.join(t.__name__.lower() for t in self.tp) argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, - action=StoreAction, default=default, + action=StoreAction, default=default, metavar=f'({metavars})', ) @@ -209,7 +210,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), type=self.tp, choices=self.choices, required=default == SUPPRESS, - action=StoreAction, default=default, + action=StoreAction, default=default, metavar=f'{self.tp.__name__.lower()}{set(self.choices)}', ) @@ -275,7 +276,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ argument_parser.add_argument( f'--{join_names(prefixes, name)}', dest=join_dests(domain + (name,), '@fn'), type=self.tp, choices=tuple(choices.keys()), required=default == SUPPRESS, - action=UnionAction, + action=UnionAction, metavar=f'{{{", ".join(choices.keys())}}}[fn]' ) From a63f74763ae57a1013d30da73ca91241f72c0887 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 16:03:29 +0900 Subject: [PATCH 49/72] Feat: default value on help --- aku/__main__.py | 2 +- aku/tp.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/aku/__main__.py b/aku/__main__.py index f780097..f87367f 100644 --- a/aku/__main__.py +++ b/aku/__main__.py @@ -9,7 +9,7 @@ aku = Aku() - def foo(x: int = 3, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): + def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): for _ in range(x): print(f'foo => {y}') print(f'z => {z}') diff --git a/aku/tp.py b/aku/tp.py index a197a68..5c786d5 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -122,8 +122,9 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{prefixes_name}', dest=join_dests(domain, name), help=prefixes_name, type=self.tp, choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, metavar=self.tp.__name__.lower(), ) @@ -147,8 +148,9 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{prefixes_name}', dest=join_dests(domain, name), help=prefixes_name, type=self.tp, choices=self.choices, required=default == SUPPRESS, action=AppendListAction, default=default, metavar=f'[{self.tp.__name__.lower()}]', ) @@ -166,8 +168,9 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{prefixes_name}', dest=join_dests(domain, name), help=prefixes_name, type=register_homo_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, metavar=f'({self.tp.__name__.lower()}, ...)', ) @@ -186,8 +189,9 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: metavars = ', '.join(t.__name__.lower() for t in self.tp) + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{prefixes_name}', dest=join_dests(domain, name), help=prefixes_name, type=register_hetero_tuple(self.tp, argument_parser), choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, metavar=f'({metavars})', ) @@ -207,8 +211,9 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain, name), + f'--{prefixes_name}', dest=join_dests(domain, name), help=prefixes_name, type=self.tp, choices=self.choices, required=default == SUPPRESS, action=StoreAction, default=default, metavar=f'{self.tp.__name__.lower()}{set(self.choices)}', ) @@ -273,9 +278,10 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ parser._actions, new_actions = parser._actions[:num_actions], parser._actions[num_actions:] setattr(parser, NEW_ACTIONS, getattr(parser, NEW_ACTIONS, []) + new_actions) + prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{join_names(prefixes, name)}', dest=join_dests(domain + (name,), '@fn'), - type=self.tp, choices=tuple(choices.keys()), required=default == SUPPRESS, + f'--{prefixes_name}', dest=join_dests(domain + (name,), '@fn'), help=prefixes_name, + type=self.tp, choices=tuple(choices.keys()), required=True, default=SUPPRESS, action=UnionAction, metavar=f'{{{", ".join(choices.keys())}}}[fn]' ) From 90ba1e1a2bbb81ce05c113c87896428e26b62e55 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 16:10:50 +0900 Subject: [PATCH 50/72] Feat: get_max_help_position --- aku/__init__.py | 5 +++++ aku/tp.py | 7 ++++++- aku/utils.py | 8 ++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 aku/utils.py diff --git a/aku/__init__.py b/aku/__init__.py index e69de29..60156f5 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -0,0 +1,5 @@ +from aku.tp import Aku + +__all__ = [ + 'Aku', +] diff --git a/aku/tp.py b/aku/tp.py index 5c786d5..524123d 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -6,6 +6,8 @@ from typing import Union, Tuple, Literal, Any from typing import get_origin, get_args, get_type_hints +from aku.utils import get_max_help_position + NEW_ACTIONS = '_new_actions' @@ -292,7 +294,10 @@ def __init__(self, prog=__file__, description=None, epilog=None, parents=(), - formatter_class=ArgumentDefaultsHelpFormatter, + formatter_class=functools.partial( + ArgumentDefaultsHelpFormatter, + max_help_position=get_max_help_position(), + ), prefix_chars='-', fromfile_prefix_chars=None, argument_default=None, diff --git a/aku/utils.py b/aku/utils.py new file mode 100644 index 0000000..559f552 --- /dev/null +++ b/aku/utils.py @@ -0,0 +1,8 @@ +import os + + +def get_max_help_position() -> int: + try: + return os.get_terminal_size().columns + except OSError: + return 24 \ No newline at end of file From 32f972c4b5a4d1b36bfac64aeeb0388186b64d17 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 16:27:56 +0900 Subject: [PATCH 51/72] Feat: option --- aku/__main__.py | 34 ---------------------------------- aku/tp.py | 21 +++++++++++++++++---- examples/__init__.py | 0 examples/naive.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 38 deletions(-) delete mode 100644 aku/__main__.py create mode 100644 examples/__init__.py create mode 100644 examples/naive.py diff --git a/aku/__main__.py b/aku/__main__.py deleted file mode 100644 index f87367f..0000000 --- a/aku/__main__.py +++ /dev/null @@ -1,34 +0,0 @@ -from argparse import SUPPRESS -from pathlib import Path -from typing import Type, Tuple -from typing import Union, List, Literal - -from aku.tp import Aku, AkuTp - -if __name__ == '__main__': - aku = Aku() - - - def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): - for _ in range(x): - print(f'foo => {y}') - print(f'z => {z}') - print(f'w => {w}') - print(kwargs['@aku']) - - - def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], - z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): - print(f'bar.x => {x}') - print(f'bar.y => {y}') - print(f'bar.z => {z}') - print(f'bar.w => {w}') - print(kwargs['@aku']) - - - AkuTp[Type[Union[foo, bar]]].add_argument( - argument_parser=aku, name='fn', default=SUPPRESS, - prefixes=(), domain=(), - ) - - print(aku.run()) diff --git a/aku/tp.py b/aku/tp.py index 524123d..0e0d941 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -3,7 +3,7 @@ import re from argparse import ArgumentParser, Action, Namespace, SUPPRESS, ArgumentDefaultsHelpFormatter from re import Pattern -from typing import Union, Tuple, Literal, Any +from typing import Union, Tuple, Literal, Any, Type from typing import get_origin, get_args, get_type_hints from aku.utils import get_max_help_position @@ -234,10 +234,12 @@ def __class_getitem__(cls, tp): def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, prefixes: Tuple[str, ...], domain: Tuple[str, ...]) -> None: - if name.endswith('_'): - prefixes = prefixes + (name[:-1],) - domain = domain + (name,) + if name is not None: + domain = domain + (name,) + if name.endswith('_'): + prefixes = prefixes + (name[:-1],) + for arg, tp, df in tp_iter(self.tp): tp = AkuTp[tp] tp.add_argument( @@ -312,7 +314,18 @@ def __init__(self, prog=__file__, ) _init_argument_parser(self) + self._functions = [] + + def option(self, fn): + self._functions.append(Type[fn]) + return fn + def parse_args(self, args=None) -> Namespace: + AkuTp[Union[tuple(self._functions)]].add_argument( + self, name='root', default=SUPPRESS, + prefixes=(), domain=(), + ) + namespace, args = None, None while True: namespace, args = self.parse_known_args(args=args, namespace=namespace) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/naive.py b/examples/naive.py new file mode 100644 index 0000000..5ca4b48 --- /dev/null +++ b/examples/naive.py @@ -0,0 +1,29 @@ +from pathlib import Path +from typing import List, Literal +from typing import Tuple + +from aku.tp import Aku + +aku = Aku() + + +@aku.option +def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): + print(f'foo.x => {x}') + print(f'foo.y => {y}') + print(f'foo.z => {z}') + print(f'foo.w => {w}') + print(kwargs['@aku']) + + +@aku.option +def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], + z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): + print(f'bar.x => {x}') + print(f'bar.y => {y}') + print(f'bar.z => {z}') + print(f'bar.w => {w}') + print(kwargs['@aku']) + + +print(aku.run()) From 3e9b12b8f1d1d1354d4f420ea609a67a6c9f8af2 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 16:47:24 +0900 Subject: [PATCH 52/72] Fix: compatibility with Python 3.8 --- aku/tp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aku/tp.py b/aku/tp.py index 0e0d941..85c6677 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -305,12 +305,12 @@ def __init__(self, prog=__file__, argument_default=None, conflict_handler='error', add_help=True, - allow_abbrev=True, - exit_on_error=True) -> None: + allow_abbrev=True) -> None: super(Aku, self).__init__( - prog, usage, description, epilog, parents, formatter_class, prefix_chars, - fromfile_prefix_chars, argument_default, conflict_handler, add_help, allow_abbrev, - exit_on_error, + prog=prog, usage=usage, description=description, epilog=epilog, + parents=parents, formatter_class=formatter_class, prefix_chars=prefix_chars, + fromfile_prefix_chars=fromfile_prefix_chars, argument_default=argument_default, + conflict_handler=conflict_handler, add_help=add_help, allow_abbrev=allow_abbrev, ) _init_argument_parser(self) From af646ea5cfc8ab5df2b0030512dd1905466f2d2d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 16:59:54 +0900 Subject: [PATCH 53/72] Feat: compatibility with Python 3.7 --- aku/__init__.py | 4 ++ aku/compat.py | 103 ++++++++++++++++++++++++++++++ aku/tp.py | 7 +- aku/utils.py | 8 --- examples/{naive.py => naive37.py} | 5 +- examples/naive39.py | 28 ++++++++ 6 files changed, 140 insertions(+), 15 deletions(-) create mode 100644 aku/compat.py delete mode 100644 aku/utils.py rename examples/{naive.py => naive37.py} (87%) create mode 100644 examples/naive39.py diff --git a/aku/__init__.py b/aku/__init__.py index 60156f5..5849e1c 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,5 +1,9 @@ +from __future__ import annotations + +from aku.compat import Literal, get_origin, get_args from aku.tp import Aku __all__ = [ 'Aku', + 'Literal', 'get_origin', 'get_args', ] diff --git a/aku/compat.py b/aku/compat.py new file mode 100644 index 0000000..6688079 --- /dev/null +++ b/aku/compat.py @@ -0,0 +1,103 @@ +import collections +import sys + +__all__ = [ + 'Literal', + 'get_origin', 'get_args', +] + +if sys.version_info < (3, 8): + from typing import _SpecialForm, _GenericAlias, Generic, _tp_cache, _type_check, _remove_dups_flatten, Union + + + @_tp_cache + def __getitem__(self, parameters): + if self._name in ('ClassVar', 'Final'): + item = _type_check(parameters, f'{self._name} accepts only single type.') + return _GenericAlias(self, (item,)) + if self._name == 'Union': + if parameters == (): + raise TypeError("Cannot take a Union of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + msg = "Union[arg, ...]: each arg must be a type." + parameters = tuple(_type_check(p, msg) for p in parameters) + parameters = _remove_dups_flatten(parameters) + if len(parameters) == 1: + return parameters[0] + return _GenericAlias(self, parameters) + if self._name == 'Optional': + arg = _type_check(parameters, "Optional[t] requires a single type.") + return Union[arg, type(None)] + if self._name == 'Literal': + # There is no '_type_check' call because arguments to Literal[...] are + # values, not types. + return _GenericAlias(self, parameters) + raise TypeError(f"{self} is not subscriptable") + + + _SpecialForm.__getitem__ = __getitem__ + + Literal = _SpecialForm('Literal', doc= + """Special typing form to define literal types (a.k.a. value types). + + This form can be used to indicate to type checkers that the corresponding + variable or function parameter has a value equivalent to the provided + literal (or one of several literals): + + def validate_simple(data: Any) -> Literal[True]: # always returns True + ... + + MODE = Literal['r', 'rb', 'w', 'wb'] + def open_helper(file: str, mode: MODE) -> str: + ... + + open_helper('/some/path', 'r') # Passes type check + open_helper('/other/path', 'typo') # Error in type checker + + Literal[...] cannot be subclassed. At runtime, an arbitrary value + is allowed as type argument to Literal[...], but type checkers may + impose restrictions. + """) + + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final and ClassVar. + Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + """ + if isinstance(tp, _GenericAlias): + return tp.__origin__ + if tp is Generic: + return Generic + return None + + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _GenericAlias): + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () +else: + from typing import Literal, get_origin, get_args diff --git a/aku/tp.py b/aku/tp.py index 85c6677..7f54814 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -3,10 +3,9 @@ import re from argparse import ArgumentParser, Action, Namespace, SUPPRESS, ArgumentDefaultsHelpFormatter from re import Pattern -from typing import Union, Tuple, Literal, Any, Type -from typing import get_origin, get_args, get_type_hints +from typing import Union, Tuple, Any, Type, get_type_hints -from aku.utils import get_max_help_position +from aku.compat import Literal, get_origin, get_args NEW_ACTIONS = '_new_actions' @@ -298,7 +297,7 @@ def __init__(self, prog=__file__, parents=(), formatter_class=functools.partial( ArgumentDefaultsHelpFormatter, - max_help_position=get_max_help_position(), + max_help_position=82, ), prefix_chars='-', fromfile_prefix_chars=None, diff --git a/aku/utils.py b/aku/utils.py deleted file mode 100644 index 559f552..0000000 --- a/aku/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - - -def get_max_help_position() -> int: - try: - return os.get_terminal_size().columns - except OSError: - return 24 \ No newline at end of file diff --git a/examples/naive.py b/examples/naive37.py similarity index 87% rename from examples/naive.py rename to examples/naive37.py index 5ca4b48..3daf77a 100644 --- a/examples/naive.py +++ b/examples/naive37.py @@ -1,8 +1,7 @@ from pathlib import Path -from typing import List, Literal -from typing import Tuple +from typing import List, Tuple -from aku.tp import Aku +from aku import Aku, Literal aku = Aku() diff --git a/examples/naive39.py b/examples/naive39.py new file mode 100644 index 0000000..dac74bf --- /dev/null +++ b/examples/naive39.py @@ -0,0 +1,28 @@ +from pathlib import Path +from typing import Literal + +from aku import Aku + +aku = Aku() + + +@aku.option +def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): + print(f'foo.x => {x}') + print(f'foo.y => {y}') + print(f'foo.z => {z}') + print(f'foo.w => {w}') + print(kwargs['@aku']) + + +@aku.option +def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], + z: tuple[float, ...] = (), w: tuple[float, str, int] = (1., '2', 3), **kwargs): + print(f'bar.x => {x}') + print(f'bar.y => {y}') + print(f'bar.z => {z}') + print(f'bar.w => {w}') + print(kwargs['@aku']) + + +print(aku.run()) From 5acf01912569ea0d26039f63faa706e82b3b4d79 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 20 Oct 2020 17:03:36 +0900 Subject: [PATCH 54/72] Refactor: modularize --- aku/__init__.py | 3 +- aku/actions.py | 16 +++++ aku/aku.py | 93 +++++++++++++++++++++++++ aku/tp.py | 182 ++---------------------------------------------- aku/utils.py | 75 ++++++++++++++++++++ 5 files changed, 189 insertions(+), 180 deletions(-) create mode 100644 aku/actions.py create mode 100644 aku/aku.py create mode 100644 aku/utils.py diff --git a/aku/__init__.py b/aku/__init__.py index 5849e1c..604498a 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,9 +1,8 @@ from __future__ import annotations from aku.compat import Literal, get_origin, get_args -from aku.tp import Aku +from aku.aku import Aku __all__ = [ - 'Aku', 'Literal', 'get_origin', 'get_args', ] diff --git a/aku/actions.py b/aku/actions.py new file mode 100644 index 0000000..4d70246 --- /dev/null +++ b/aku/actions.py @@ -0,0 +1,16 @@ +from argparse import Action, ArgumentParser, Namespace + + +class StoreAction(Action): + def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): + setattr(namespace, self.dest, values) + self.required = False + + +class AppendListAction(Action): + def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): + if not getattr(self, '_aku_visited', False): + setattr(self, '_aku_visited', True) + setattr(namespace, self.dest, []) + getattr(namespace, self.dest).append(values) + self.required = False diff --git a/aku/aku.py b/aku/aku.py new file mode 100644 index 0000000..b689afe --- /dev/null +++ b/aku/aku.py @@ -0,0 +1,93 @@ +import functools +import inspect +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, Namespace, SUPPRESS +from typing import Type, Union + +from aku.tp import AkuTp +from aku.utils import _init_argument_parser, NEW_ACTIONS + + +class Aku(ArgumentParser): + def __init__(self, prog=__file__, + usage=None, + description=None, + epilog=None, + parents=(), + formatter_class=functools.partial( + ArgumentDefaultsHelpFormatter, + max_help_position=82, + ), + prefix_chars='-', + fromfile_prefix_chars=None, + argument_default=None, + conflict_handler='error', + add_help=True, + allow_abbrev=True) -> None: + super(Aku, self).__init__( + prog=prog, usage=usage, description=description, epilog=epilog, + parents=parents, formatter_class=formatter_class, prefix_chars=prefix_chars, + fromfile_prefix_chars=fromfile_prefix_chars, argument_default=argument_default, + conflict_handler=conflict_handler, add_help=add_help, allow_abbrev=allow_abbrev, + ) + _init_argument_parser(self) + + self._functions = [] + + def option(self, fn): + self._functions.append(Type[fn]) + return fn + + def parse_args(self, args=None) -> Namespace: + AkuTp[Union[tuple(self._functions)]].add_argument( + self, name='root', default=SUPPRESS, + prefixes=(), domain=(), + ) + + namespace, args = None, None + while True: + namespace, args = self.parse_known_args(args=args, namespace=namespace) + if hasattr(self, NEW_ACTIONS): + self._actions = self._actions + getattr(self, NEW_ACTIONS) + delattr(self, NEW_ACTIONS) + else: + break + + return namespace + + def run(self, namespace: Namespace = None): + if namespace is None: + namespace = self.parse_args() + if isinstance(namespace, Namespace): + namespace = namespace.__dict__ + + args = {} + for key, value in namespace.items(): + collection = args + *names, key = key.split('.') + for name in names: + collection = collection.setdefault(name, {}) + if key == '@fn': + collection[key] = value + else: + collection.setdefault('@args', {})[key] = value + + def recur(x): + if isinstance(x, dict): + if '@fn' in x: + kwargs = {key: recur(value) for key, value in x['@args'].items()} + return functools.partial(x['@fn'], **kwargs) + else: + return { + key: recur(value) + for key, value in x.items() + } + else: + return x + + ret = recur(args) + assert len(ret) == 1 + for _, fn in ret.items(): + if inspect.getfullargspec(fn).varkw is None: + return fn() + else: + return fn(**{'@aku': args}) diff --git a/aku/tp.py b/aku/tp.py index 7f54814..2f31cf3 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -1,82 +1,9 @@ -import functools -import inspect -import re -from argparse import ArgumentParser, Action, Namespace, SUPPRESS, ArgumentDefaultsHelpFormatter -from re import Pattern -from typing import Union, Tuple, Any, Type, get_type_hints +from argparse import ArgumentParser, Action, Namespace, SUPPRESS +from typing import Union, Tuple, Any +from aku.actions import StoreAction, AppendListAction from aku.compat import Literal, get_origin, get_args - -NEW_ACTIONS = '_new_actions' - - -def tp_bool(arg_strings: str) -> bool: - arg_strings = arg_strings.lower().strip() - if arg_strings in ('t', 'true', 'y', 'yes', '1'): - return True - if arg_strings in ('f', 'false', 'n', 'no', '0'): - return False - raise ValueError - - -def register_type(fn, argument_parser: ArgumentParser): - tp = get_type_hints(fn)['return'] - registry = argument_parser._registries['type'] - if tp not in registry: - registry.setdefault(tp, fn) - return fn - - -def register_homo_tuple(tp: type, argument_parser: ArgumentParser, - pattern: Pattern = re.compile(r',\s*')) -> None: - def fn(arg_strings: str) -> Tuple[tp, ...]: - nonlocal tp - - tp = argument_parser._registry_get('type', tp, tp) - return tuple(tp(arg) for arg in re.split(pattern, arg_strings.strip())) - - return register_type(fn, argument_parser) - - -def register_hetero_tuple(tps: Tuple[type, ...], argument_parser: ArgumentParser, - pattern: Pattern = re.compile(r',\s*')) -> None: - def fn(arg_strings: str) -> Tuple[tps]: - nonlocal tps - - tps = [argument_parser._registry_get('type', tp, tp) for tp in tps] - return tuple(tp(arg) for tp, arg in zip(tps, re.split(pattern, arg_strings.strip()))) - - return register_type(fn, argument_parser) - - -def _init_argument_parser(argument_parser: ArgumentParser): - register_type(tp_bool, argument_parser) - - -def tp_iter(fn): - is_method = inspect.ismethod(fn) - if inspect.isclass(fn): - fn = fn.__init__ - is_method = True - - tps = get_type_hints(fn) - spec = inspect.getfullargspec(fn) - args = spec.args or [] - defaults = spec.defaults or [] - defaults = {a: d for a, d in zip(args[::-1], defaults[::-1])} - - for index, arg in enumerate(args[1:] if is_method else args): - yield arg, tps[arg], defaults.get(arg, SUPPRESS) - - -def join_names(prefixes: Tuple[str, ...], name: str) -> str: - if name.endswith('_'): - name = name[:-1] - return '-'.join(prefixes + (name,)).lower() - - -def join_dests(domain: Tuple[str, ...], name: str) -> str: - return '.'.join(domain + (name,)).lower() +from aku.utils import register_homo_tuple, register_hetero_tuple, tp_iter, join_names, join_dests, NEW_ACTIONS class AkuTp(object): @@ -108,12 +35,6 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, raise NotImplementedError -class StoreAction(Action): - def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): - setattr(namespace, self.dest, values) - self.required = False - - class AkuPrimitive(AkuTp): def __class_getitem__(cls, tp): tp, origin, args = tp @@ -131,15 +52,6 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, ) -class AppendListAction(Action): - def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): - if not getattr(self, '_aku_visited', False): - setattr(self, '_aku_visited', True) - setattr(namespace, self.dest, []) - getattr(namespace, self.dest).append(values) - self.required = False - - class AkuList(AkuTp): def __class_getitem__(cls, tp): tp, origin, args = tp @@ -287,89 +199,3 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ type=self.tp, choices=tuple(choices.keys()), required=True, default=SUPPRESS, action=UnionAction, metavar=f'{{{", ".join(choices.keys())}}}[fn]' ) - - -class Aku(ArgumentParser): - def __init__(self, prog=__file__, - usage=None, - description=None, - epilog=None, - parents=(), - formatter_class=functools.partial( - ArgumentDefaultsHelpFormatter, - max_help_position=82, - ), - prefix_chars='-', - fromfile_prefix_chars=None, - argument_default=None, - conflict_handler='error', - add_help=True, - allow_abbrev=True) -> None: - super(Aku, self).__init__( - prog=prog, usage=usage, description=description, epilog=epilog, - parents=parents, formatter_class=formatter_class, prefix_chars=prefix_chars, - fromfile_prefix_chars=fromfile_prefix_chars, argument_default=argument_default, - conflict_handler=conflict_handler, add_help=add_help, allow_abbrev=allow_abbrev, - ) - _init_argument_parser(self) - - self._functions = [] - - def option(self, fn): - self._functions.append(Type[fn]) - return fn - - def parse_args(self, args=None) -> Namespace: - AkuTp[Union[tuple(self._functions)]].add_argument( - self, name='root', default=SUPPRESS, - prefixes=(), domain=(), - ) - - namespace, args = None, None - while True: - namespace, args = self.parse_known_args(args=args, namespace=namespace) - if hasattr(self, NEW_ACTIONS): - self._actions = self._actions + getattr(self, NEW_ACTIONS) - delattr(self, NEW_ACTIONS) - else: - break - - return namespace - - def run(self, namespace: Namespace = None): - if namespace is None: - namespace = self.parse_args() - if isinstance(namespace, Namespace): - namespace = namespace.__dict__ - - args = {} - for key, value in namespace.items(): - collection = args - *names, key = key.split('.') - for name in names: - collection = collection.setdefault(name, {}) - if key == '@fn': - collection[key] = value - else: - collection.setdefault('@args', {})[key] = value - - def recur(x): - if isinstance(x, dict): - if '@fn' in x: - kwargs = {key: recur(value) for key, value in x['@args'].items()} - return functools.partial(x['@fn'], **kwargs) - else: - return { - key: recur(value) - for key, value in x.items() - } - else: - return x - - ret = recur(args) - assert len(ret) == 1 - for _, fn in ret.items(): - if inspect.getfullargspec(fn).varkw is None: - return fn() - else: - return fn(**{'@aku': args}) diff --git a/aku/utils.py b/aku/utils.py new file mode 100644 index 0000000..29d4155 --- /dev/null +++ b/aku/utils.py @@ -0,0 +1,75 @@ +import inspect +import re +from argparse import ArgumentParser, SUPPRESS +from typing import get_type_hints, Pattern, Tuple + +NEW_ACTIONS = '_new_actions' + + +def tp_bool(arg_strings: str) -> bool: + arg_strings = arg_strings.lower().strip() + if arg_strings in ('t', 'true', 'y', 'yes', '1'): + return True + if arg_strings in ('f', 'false', 'n', 'no', '0'): + return False + raise ValueError + + +def register_type(fn, argument_parser: ArgumentParser): + tp = get_type_hints(fn)['return'] + registry = argument_parser._registries['type'] + if tp not in registry: + registry.setdefault(tp, fn) + return fn + + +def register_homo_tuple(tp: type, argument_parser: ArgumentParser, + pattern: Pattern = re.compile(r',\s*')) -> None: + def fn(arg_strings: str) -> Tuple[tp, ...]: + nonlocal tp + + tp = argument_parser._registry_get('type', tp, tp) + return tuple(tp(arg) for arg in re.split(pattern, arg_strings.strip())) + + return register_type(fn, argument_parser) + + +def register_hetero_tuple(tps: Tuple[type, ...], argument_parser: ArgumentParser, + pattern: Pattern = re.compile(r',\s*')) -> None: + def fn(arg_strings: str) -> Tuple[tps]: + nonlocal tps + + tps = [argument_parser._registry_get('type', tp, tp) for tp in tps] + return tuple(tp(arg) for tp, arg in zip(tps, re.split(pattern, arg_strings.strip()))) + + return register_type(fn, argument_parser) + + +def _init_argument_parser(argument_parser: ArgumentParser): + register_type(tp_bool, argument_parser) + + +def tp_iter(fn): + is_method = inspect.ismethod(fn) + if inspect.isclass(fn): + fn = fn.__init__ + is_method = True + + tps = get_type_hints(fn) + spec = inspect.getfullargspec(fn) + args = spec.args or [] + defaults = spec.defaults or [] + defaults = {a: d for a, d in zip(args[::-1], defaults[::-1])} + + for index, arg in enumerate(args[1:] if is_method else args): + yield arg, tps[arg], defaults.get(arg, SUPPRESS) + + +def join_names(prefixes: Tuple[str, ...], name: str) -> str: + if name.endswith('_'): + name = name[:-1] + return '-'.join(prefixes + (name,)).lower() + + +def join_dests(domain: Tuple[str, ...], name: str) -> str: + return '.'.join(domain + (name,)).lower() From d5695127ed8dd29f9efd1bd092f82519eb6fb690 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 30 Oct 2020 14:44:34 +0900 Subject: [PATCH 55/72] Refactor: Clean them up --- .github/workflows/python-publish.yml | 31 +++++++++++++ .github/workflows/unit-tests.yml | 3 +- README.md | 65 +++++----------------------- aku/__init__.py | 4 +- setup.py | 2 - 5 files changed, 45 insertions(+), 60 deletions(-) create mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..ded27a6 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,31 @@ +# This workflows will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Upload Python Package + +on: + release: + types: [ created ] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.7' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel twine + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* \ No newline at end of file diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 63384e4..063f998 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,6 @@ name: Unit Tests -on: [push] +on: [ push ] jobs: build: @@ -16,7 +16,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install torch python -m pip install -e '.[dev]' - name: Test with pytest run: | diff --git a/README.md b/README.md index 1df93b1..a39167d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ # Aku [![Actions Status](https://github.com/speedcell4/aku/workflows/unit-tests/badge.svg)](https://github.com/speedcell4/aku/actions) +[![PyPI version](https://badge.fury.io/py/aku.svg)](https://badge.fury.io/py/aku) [![Downloads](https://pepy.tech/badge/aku)](https://pepy.tech/project/aku) -An Annotation-driven ArgumentParser Generator +Aku is an interactive annotation-driven `ArgumentParser` generator. ## Requirements -Python 3.7 or higher +* Python 3.7 or higher ## Install @@ -15,16 +16,15 @@ Python 3.7 or higher python -m pip install aku --upgrade ``` -## Quick Start +## Usage ```python -# tests/test_single_function.py from aku import Aku aku = Aku() -@aku.register +@aku.option def add(a: int, b: int = 2): print(f'{a} + {b} => {a + b}') @@ -34,55 +34,12 @@ aku.run() `aku` will automatically add argument options according to your function signature. -``` -~ python tests/test_single_function.py --help -usage: aku [-h] --a A [--b B] +```shell script +python3 foo.py --help +usage: foo.py [-h] --a int [--b int] optional arguments: -h, --help show this help message and exit - --a A a (default: None) - --b B b (default: 2) - -``` - -Registering more than one function will make `aku` add them to sub-parser respectively (and lazily). - -```python -# file test_multi_functions.py -from aku import Aku - -aku = Aku() - - -@aku.register -def add(a: int, b: int = 2): - print(f'{a} + {b} => {a + b}') - - -@aku.register -def say_hello(name: str): - print(f'hello {name}') - - -aku.run() -``` - -Similarly, your command line interface will look like, - -``` -~ python tests/test_multi_functions.py --help -usage: aku [-h] {add,say_hello} ... - -positional arguments: - {add,say_hello} - -optional arguments: - -h, --help show this help message and exit - -~ python tests/test_multi_functions.py say_hello --help -usage: aku say_hello [-h] --name NAME - -optional arguments: - -h, --help show this help message and exit - --name NAME name (default: None) -``` + --a int a + --b int b (default: 2) +``` \ No newline at end of file diff --git a/aku/__init__.py b/aku/__init__.py index 604498a..a57d0da 100644 --- a/aku/__init__.py +++ b/aku/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations -from aku.compat import Literal, get_origin, get_args from aku.aku import Aku +from aku.compat import Literal, get_origin, get_args __all__ = [ - 'Literal', 'get_origin', 'get_args', + 'Aku', 'Literal', 'get_origin', 'get_args', ] diff --git a/setup.py b/setup.py index d27f575..7ef8706 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,6 @@ 'dev': [ 'pytest', 'hypothesis', - 'torchglyph', - 'einops', ], } ) From 83a827df1e16bd19b045ff5ef9d76ba3e5bfdf8b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 30 Oct 2020 14:48:40 +0900 Subject: [PATCH 56/72] Chore: Update setup.py --- setup.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 7ef8706..991ace7 100644 --- a/setup.py +++ b/setup.py @@ -1,18 +1,12 @@ -from setuptools import setup +from setuptools import setup, find_packages + +name = 'aku' setup( - name='aku', - description='An Annotation-driven ArgumentParser Generator', + name=name, + description='An interactive annotation-driven ArgumentParser generator', version='0.2.0', - packages=['aku'], - classifiers=[ - 'Programming Language :: Python :: 3 :: Only', - 'Development Status :: 2 - Pre-Alpha', - 'License :: OSI Approved :: MIT License', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Unix', - 'Topic :: Utilities', - ], + packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/aku', license='MIT', author='speedcell4', From 2833b87feff9a2741e48496dc7ef1949bc69792c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 30 Oct 2020 14:54:07 +0900 Subject: [PATCH 57/72] Chore: Add examples/foo.py --- README.md | 2 +- aku/aku.py | 4 ++-- aku/tp.py | 2 +- examples/foo.py | 11 +++++++++++ 4 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 examples/foo.py diff --git a/README.md b/README.md index a39167d..1b32085 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ aku.run() `aku` will automatically add argument options according to your function signature. ```shell script -python3 foo.py --help +python examples/foo.py --help usage: foo.py [-h] --a int [--b int] optional arguments: diff --git a/aku/aku.py b/aku/aku.py index b689afe..d049863 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -8,7 +8,7 @@ class Aku(ArgumentParser): - def __init__(self, prog=__file__, + def __init__(self, prog=None, usage=None, description=None, epilog=None, @@ -39,7 +39,7 @@ def option(self, fn): def parse_args(self, args=None) -> Namespace: AkuTp[Union[tuple(self._functions)]].add_argument( - self, name='root', default=SUPPRESS, + argument_parser=self, name='root', default=SUPPRESS, prefixes=(), domain=(), ) diff --git a/aku/tp.py b/aku/tp.py index 2f31cf3..cf7d1c0 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -20,7 +20,7 @@ def __repr__(self) -> str: def __init_subclass__(cls, **kwargs): cls.registry.append(cls) - def __class_getitem__(cls, tp): + def __class_getitem__(cls, tp) -> 'AkuTp': origin = get_origin(tp) args = get_args(tp) for aku_ty in cls.registry: diff --git a/examples/foo.py b/examples/foo.py new file mode 100644 index 0000000..d6eedd5 --- /dev/null +++ b/examples/foo.py @@ -0,0 +1,11 @@ +from aku import Aku + +aku = Aku() + + +@aku.option +def add(a: int, b: int = 2): + print(f'{a} + {b} => {a + b}') + + +aku.run() From 3ee8c2e3d49b12883ccaf17316a3e42a09f8d304 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 17:36:29 +0900 Subject: [PATCH 58/72] Feat: Support both single and multiple functions --- aku/aku.py | 40 ++++++++++++++++++++++++++++++---------- examples/naive37.py | 20 ++++++++++---------- examples/naive39.py | 20 ++++++++++---------- 3 files changed, 50 insertions(+), 30 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index d049863..8a1aeeb 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -1,7 +1,8 @@ import functools import inspect +import sys from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, Namespace, SUPPRESS -from typing import Type, Union +from typing import Type from aku.tp import AkuTp from aku.utils import _init_argument_parser, NEW_ACTIONS @@ -34,21 +35,39 @@ def __init__(self, prog=None, self._functions = [] def option(self, fn): - self._functions.append(Type[fn]) + self._functions.append(fn) return fn - def parse_args(self, args=None) -> Namespace: - AkuTp[Union[tuple(self._functions)]].add_argument( - argument_parser=self, name='root', default=SUPPRESS, + @staticmethod + def _add_root_function(argument_parser, fn): + AkuTp[Type[fn]].add_argument( + argument_parser=argument_parser, name='@root', default=SUPPRESS, prefixes=(), domain=(), ) + return Namespace(**{'@root.@fn': fn}) + + def parse_args(self, args=None) -> Namespace: + assert len(self._functions) > 0 + + namespace, args, argument_parser = None, sys.argv, self + if len(self._functions) == 1: + fn = self._functions[0] + namespace = self._add_root_function(argument_parser, fn) + else: + subparsers = self.add_subparsers() + functions = { + fn.__name__: (fn, subparsers.add_parser(name=fn.__name__)) + for fn in self._functions + } + if len(args) > 1 and args[1] in functions: + fn, argument_parser = functions[args[1]] + namespace = self._add_root_function(argument_parser, fn) - namespace, args = None, None while True: - namespace, args = self.parse_known_args(args=args, namespace=namespace) - if hasattr(self, NEW_ACTIONS): - self._actions = self._actions + getattr(self, NEW_ACTIONS) - delattr(self, NEW_ACTIONS) + namespace, args = argument_parser.parse_known_args(args=args, namespace=namespace) + if hasattr(argument_parser, NEW_ACTIONS): + argument_parser._actions = argument_parser._actions + getattr(argument_parser, NEW_ACTIONS) + delattr(argument_parser, NEW_ACTIONS) else: break @@ -85,6 +104,7 @@ def recur(x): return x ret = recur(args) + print(f'ret => {ret}') assert len(ret) == 1 for _, fn in ret.items(): if inspect.getfullargspec(fn).varkw is None: diff --git a/examples/naive37.py b/examples/naive37.py index 3daf77a..03dea58 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -8,21 +8,21 @@ @aku.option def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): - print(f'foo.x => {x}') - print(f'foo.y => {y}') - print(f'foo.z => {z}') - print(f'foo.w => {w}') - print(kwargs['@aku']) + print(f'{foo.__name__}.x => {x}') + print(f'{foo.__name__}.y => {y}') + print(f'{foo.__name__}.z => {z}') + print(f'{foo.__name__}.w => {w}') + print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): - print(f'bar.x => {x}') - print(f'bar.y => {y}') - print(f'bar.z => {z}') - print(f'bar.w => {w}') - print(kwargs['@aku']) + print(f'{bar.__name__}.x => {x}') + print(f'{bar.__name__}.y => {y}') + print(f'{bar.__name__}.z => {z}') + print(f'{bar.__name__}.w => {w}') + print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') print(aku.run()) diff --git a/examples/naive39.py b/examples/naive39.py index dac74bf..8c1015d 100644 --- a/examples/naive39.py +++ b/examples/naive39.py @@ -8,21 +8,21 @@ @aku.option def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): - print(f'foo.x => {x}') - print(f'foo.y => {y}') - print(f'foo.z => {z}') - print(f'foo.w => {w}') - print(kwargs['@aku']) + print(f'{foo.__name__}.x => {x}') + print(f'{foo.__name__}.y => {y}') + print(f'{foo.__name__}.z => {z}') + print(f'{foo.__name__}.w => {w}') + print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], z: tuple[float, ...] = (), w: tuple[float, str, int] = (1., '2', 3), **kwargs): - print(f'bar.x => {x}') - print(f'bar.y => {y}') - print(f'bar.z => {z}') - print(f'bar.w => {w}') - print(kwargs['@aku']) + print(f'{bar.__name__}.x => {x}') + print(f'{bar.__name__}.y => {y}') + print(f'{bar.__name__}.z => {z}') + print(f'{bar.__name__}.w => {w}') + print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') print(aku.run()) From 8470f20344d67b5a3135ae96e896372900cdb8fc Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 18:01:04 +0900 Subject: [PATCH 59/72] Fix: Fix bug of nested function --- aku/aku.py | 37 ++++++++++++++++++++++++------------- aku/tp.py | 2 +- examples/naive37.py | 21 ++++++++++++++------- examples/naive39.py | 8 ++++---- 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index 8a1aeeb..2f90bc1 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -39,12 +39,12 @@ def option(self, fn): return fn @staticmethod - def _add_root_function(argument_parser, fn): + def _add_root_function(argument_parser, fn, name): AkuTp[Type[fn]].add_argument( argument_parser=argument_parser, name='@root', default=SUPPRESS, prefixes=(), domain=(), ) - return Namespace(**{'@root.@fn': fn}) + return Namespace(**{'@root.@fn': (fn, name)}) def parse_args(self, args=None) -> Namespace: assert len(self._functions) > 0 @@ -52,7 +52,7 @@ def parse_args(self, args=None) -> Namespace: namespace, args, argument_parser = None, sys.argv, self if len(self._functions) == 1: fn = self._functions[0] - namespace = self._add_root_function(argument_parser, fn) + namespace = self._add_root_function(argument_parser, fn, fn.__name__) else: subparsers = self.add_subparsers() functions = { @@ -61,16 +61,19 @@ def parse_args(self, args=None) -> Namespace: } if len(args) > 1 and args[1] in functions: fn, argument_parser = functions[args[1]] - namespace = self._add_root_function(argument_parser, fn) + namespace = self._add_root_function(argument_parser, fn, args[1]) while True: namespace, args = argument_parser.parse_known_args(args=args, namespace=namespace) + print(f'namespace => {namespace}') + print(f'args => {args}') if hasattr(argument_parser, NEW_ACTIONS): argument_parser._actions = argument_parser._actions + getattr(argument_parser, NEW_ACTIONS) delattr(argument_parser, NEW_ACTIONS) else: break + print(f'namespace => {namespace}') return namespace def run(self, namespace: Namespace = None): @@ -79,22 +82,29 @@ def run(self, namespace: Namespace = None): if isinstance(namespace, Namespace): namespace = namespace.__dict__ - args = {} + curry, literal = {}, {} for key, value in namespace.items(): - collection = args + curry_co = curry + literal_co = literal *names, key = key.split('.') for name in names: - collection = collection.setdefault(name, {}) + curry_co = curry_co.setdefault(name, {}) + literal_co = literal_co.setdefault(name, {}) if key == '@fn': - collection[key] = value + curry_co[key] = value[0] + literal_co[key] = value[1] else: - collection.setdefault('@args', {})[key] = value + curry_co[key] = literal_co[key] = value + + print(f'curry => {curry}') + print(f'literal => {literal}') def recur(x): if isinstance(x, dict): if '@fn' in x: - kwargs = {key: recur(value) for key, value in x['@args'].items()} - return functools.partial(x['@fn'], **kwargs) + fn = x.pop('@fn') + kwargs = {key: recur(value) for key, value in x.items()} + return functools.partial(fn, **kwargs) else: return { key: recur(value) @@ -103,11 +113,12 @@ def recur(x): else: return x - ret = recur(args) + print(curry) + ret = recur(curry) print(f'ret => {ret}') assert len(ret) == 1 for _, fn in ret.items(): if inspect.getfullargspec(fn).varkw is None: return fn() else: - return fn(**{'@aku': args}) + return fn(**{'@aku': curry}) diff --git a/aku/tp.py b/aku/tp.py index cf7d1c0..3b2f8a0 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -182,7 +182,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, class UnionAction(Action): def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_string=None): - setattr(namespace, self.dest, choices[values]) + setattr(namespace, self.dest, (choices[values], values)) self.required = False num_actions = len(parser._actions) diff --git a/examples/naive37.py b/examples/naive37.py index 03dea58..f4f50c6 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Tuple +from typing import List, Tuple, Union, Type from aku import Aku, Literal @@ -7,22 +7,29 @@ @aku.option -def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): +def foo(x: int, y: str = '4', z: bool = True, k: Path = Path.home(), **kwargs): print(f'{foo.__name__}.x => {x}') print(f'{foo.__name__}.y => {y}') print(f'{foo.__name__}.z => {z}') - print(f'{foo.__name__}.w => {w}') - print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') + print(f'{foo.__name__}.k => {k}') + if '@aku' in kwargs: + print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], - z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): + z: Tuple[float, ...] = (), k: Tuple[float, str, int] = (1., '2', 3), **kwargs): print(f'{bar.__name__}.x => {x}') print(f'{bar.__name__}.y => {y}') print(f'{bar.__name__}.z => {z}') - print(f'{bar.__name__}.w => {w}') - print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') + print(f'{bar.__name__}.k => {k}') + if '@aku' in kwargs: + print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') + + +@aku.option +def both(work: Union[Type[foo], Type[bar]]): + work() print(aku.run()) diff --git a/examples/naive39.py b/examples/naive39.py index 8c1015d..fe66599 100644 --- a/examples/naive39.py +++ b/examples/naive39.py @@ -7,21 +7,21 @@ @aku.option -def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): +def foo(x: int, y: str = '4', z: bool = True, k: Path = Path.home(), **kwargs): print(f'{foo.__name__}.x => {x}') print(f'{foo.__name__}.y => {y}') print(f'{foo.__name__}.z => {z}') - print(f'{foo.__name__}.w => {w}') + print(f'{foo.__name__}.k => {k}') print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], - z: tuple[float, ...] = (), w: tuple[float, str, int] = (1., '2', 3), **kwargs): + z: tuple[float, ...] = (), k: tuple[float, str, int] = (1., '2', 3), **kwargs): print(f'{bar.__name__}.x => {x}') print(f'{bar.__name__}.y => {y}') print(f'{bar.__name__}.z => {z}') - print(f'{bar.__name__}.w => {w}') + print(f'{bar.__name__}.k => {k}') print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') From 7ae1dd58fbcdb1e79c4cb4fc5c479b7b616fcff3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 18:02:18 +0900 Subject: [PATCH 60/72] Fix: allow_abbrev = False by default --- aku/aku.py | 2 +- examples/naive37.py | 8 ++++---- examples/naive39.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index 2f90bc1..7d04c9d 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -23,7 +23,7 @@ def __init__(self, prog=None, argument_default=None, conflict_handler='error', add_help=True, - allow_abbrev=True) -> None: + allow_abbrev=False) -> None: super(Aku, self).__init__( prog=prog, usage=usage, description=description, epilog=epilog, parents=parents, formatter_class=formatter_class, prefix_chars=prefix_chars, diff --git a/examples/naive37.py b/examples/naive37.py index f4f50c6..e311aec 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -7,22 +7,22 @@ @aku.option -def foo(x: int, y: str = '4', z: bool = True, k: Path = Path.home(), **kwargs): +def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): print(f'{foo.__name__}.x => {x}') print(f'{foo.__name__}.y => {y}') print(f'{foo.__name__}.z => {z}') - print(f'{foo.__name__}.k => {k}') + print(f'{foo.__name__}.w => {w}') if '@aku' in kwargs: print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], - z: Tuple[float, ...] = (), k: Tuple[float, str, int] = (1., '2', 3), **kwargs): + z: Tuple[float, ...] = (), w: Tuple[float, str, int] = (1., '2', 3), **kwargs): print(f'{bar.__name__}.x => {x}') print(f'{bar.__name__}.y => {y}') print(f'{bar.__name__}.z => {z}') - print(f'{bar.__name__}.k => {k}') + print(f'{bar.__name__}.w => {w}') if '@aku' in kwargs: print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') diff --git a/examples/naive39.py b/examples/naive39.py index fe66599..8c1015d 100644 --- a/examples/naive39.py +++ b/examples/naive39.py @@ -7,21 +7,21 @@ @aku.option -def foo(x: int, y: str = '4', z: bool = True, k: Path = Path.home(), **kwargs): +def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): print(f'{foo.__name__}.x => {x}') print(f'{foo.__name__}.y => {y}') print(f'{foo.__name__}.z => {z}') - print(f'{foo.__name__}.k => {k}') + print(f'{foo.__name__}.w => {w}') print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], - z: tuple[float, ...] = (), k: tuple[float, str, int] = (1., '2', 3), **kwargs): + z: tuple[float, ...] = (), w: tuple[float, str, int] = (1., '2', 3), **kwargs): print(f'{bar.__name__}.x => {x}') print(f'{bar.__name__}.y => {y}') print(f'{bar.__name__}.z => {z}') - print(f'{bar.__name__}.k => {k}') + print(f'{bar.__name__}.w => {w}') print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') From 17b9ac8f5aa389691b2d881f4f487b302f3919bb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 18:03:47 +0900 Subject: [PATCH 61/72] Style: Remove debugging functions --- aku/aku.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index 7d04c9d..adb8e46 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -65,15 +65,12 @@ def parse_args(self, args=None) -> Namespace: while True: namespace, args = argument_parser.parse_known_args(args=args, namespace=namespace) - print(f'namespace => {namespace}') - print(f'args => {args}') if hasattr(argument_parser, NEW_ACTIONS): argument_parser._actions = argument_parser._actions + getattr(argument_parser, NEW_ACTIONS) delattr(argument_parser, NEW_ACTIONS) else: break - print(f'namespace => {namespace}') return namespace def run(self, namespace: Namespace = None): @@ -96,26 +93,18 @@ def run(self, namespace: Namespace = None): else: curry_co[key] = literal_co[key] = value - print(f'curry => {curry}') - print(f'literal => {literal}') - - def recur(x): - if isinstance(x, dict): - if '@fn' in x: - fn = x.pop('@fn') - kwargs = {key: recur(value) for key, value in x.items()} - return functools.partial(fn, **kwargs) + def recur(item): + if isinstance(item, dict): + if '@fn' in item: + func = item.pop('@fn') + kwargs = {k: recur(v) for k, v in item.items()} + return functools.partial(func, **kwargs) else: - return { - key: recur(value) - for key, value in x.items() - } + return {k: recur(v) for k, v in item.items()} else: - return x + return item - print(curry) ret = recur(curry) - print(f'ret => {ret}') assert len(ret) == 1 for _, fn in ret.items(): if inspect.getfullargspec(fn).varkw is None: From 1d846ee2f5877170fb436ddfb5c3670d573ba255 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 18:08:10 +0900 Subject: [PATCH 62/72] Refactor: Handle the number of arguments for HeteroTuple --- aku/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aku/utils.py b/aku/utils.py index 29d4155..6f6c5c6 100644 --- a/aku/utils.py +++ b/aku/utils.py @@ -40,7 +40,11 @@ def fn(arg_strings: str) -> Tuple[tps]: nonlocal tps tps = [argument_parser._registry_get('type', tp, tp) for tp in tps] - return tuple(tp(arg) for tp, arg in zip(tps, re.split(pattern, arg_strings.strip()))) + args = re.split(pattern, arg_strings.strip()) + + if len(tps) != len(args): + raise ValueError(f'the number of arguments does not match, {len(tps)} != {len(args)}') + return tuple(tp(arg) for tp, arg in zip(tps, args)) return register_type(fn, argument_parser) From 4a286cb83f163c0ba12b022a21475f8b46c1970d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 18:41:08 +0900 Subject: [PATCH 63/72] Feat: Add fetch_name --- aku/aku.py | 17 +++++++++++------ aku/utils.py | 16 ++++++++++++++++ examples/naive37.py | 17 +++++++++++++++++ tests/test_hello.py | 2 -- tests/test_utils.py | 11 +++++++++++ tests/utils.py | 21 +++++++++++++++++++++ 6 files changed, 76 insertions(+), 8 deletions(-) delete mode 100644 tests/test_hello.py create mode 100644 tests/test_utils.py create mode 100644 tests/utils.py diff --git a/aku/aku.py b/aku/aku.py index adb8e46..53c008e 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -5,7 +5,7 @@ from typing import Type from aku.tp import AkuTp -from aku.utils import _init_argument_parser, NEW_ACTIONS +from aku.utils import _init_argument_parser, NEW_ACTIONS, fetch_name class Aku(ArgumentParser): @@ -52,13 +52,18 @@ def parse_args(self, args=None) -> Namespace: namespace, args, argument_parser = None, sys.argv, self if len(self._functions) == 1: fn = self._functions[0] - namespace = self._add_root_function(argument_parser, fn, fn.__name__) + name = fetch_name(fn) + namespace = self._add_root_function(argument_parser, fn, name) else: subparsers = self.add_subparsers() - functions = { - fn.__name__: (fn, subparsers.add_parser(name=fn.__name__)) - for fn in self._functions - } + functions = {} + for fn in self._functions: + name = fetch_name(fn) + if name not in functions: + functions[name] = (fn, subparsers.add_parser(name=name)) + else: + raise ValueError(f'{name} was already registered') + if len(args) > 1 and args[1] in functions: fn, argument_parser = functions[args[1]] namespace = self._add_root_function(argument_parser, fn, args[1]) diff --git a/aku/utils.py b/aku/utils.py index 6f6c5c6..7b3e6ac 100644 --- a/aku/utils.py +++ b/aku/utils.py @@ -69,6 +69,22 @@ def tp_iter(fn): yield arg, tps[arg], defaults.get(arg, SUPPRESS) +def fetch_name(fn) -> str: + if inspect.isfunction(fn): # function, static method + return fn.__name__ + if inspect.isclass(fn): # class + return fn.__name__.lower() + if inspect.ismethod(fn): # class method + __class__ = fn.__self__ + if not inspect.isclass(__class__): + __class__ = __class__.__class__ + + return f'{__class__.__name__.lower()}.{fn.__name__}' + if callable(fn): # __call__ + return f'{fn.__class__.__name__.lower()}' + raise NotImplementedError + + def join_names(prefixes: Tuple[str, ...], name: str) -> str: if name.endswith('_'): name = name[:-1] diff --git a/examples/naive37.py b/examples/naive37.py index e311aec..c2d59bc 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -32,4 +32,21 @@ def both(work: Union[Type[foo], Type[bar]]): work() +class A(object): + @classmethod + def baz(cls, x: int): + print(f'{A.__name__}.x => {x}') + + +aku.option(A.baz) + + +class B(object): + @classmethod + def baz(cls, x: int): + print(f'{B.__name__}.x => {x}') + + +aku.option(B.baz) + print(aku.run()) diff --git a/tests/test_hello.py b/tests/test_hello.py deleted file mode 100644 index 3983ffa..0000000 --- a/tests/test_hello.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_hello(): - assert 1 + 1 == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d99bfd6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,11 @@ +from aku.utils import fetch_name +from tests.utils import func, Class + + +def test_fetch_name(): + assert fetch_name(func) == 'func' + assert fetch_name(Class) == 'class' + assert fetch_name(Class()) == 'class' + assert fetch_name(Class().method) == 'class.method' + assert fetch_name(Class.class_method) == 'class.class_method' + assert fetch_name(Class.static_method) == 'static_method' diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..b0818a8 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,21 @@ +def func(): + pass + + +class Class(object): + def method(self): + pass + + @classmethod + def class_method(cls): + pass + + @staticmethod + def static_method(): + pass + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + pass From fd963fad684572eb9549bd53a53fc131136243b1 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 19:05:18 +0900 Subject: [PATCH 64/72] Feat: AkuFn default value --- aku/aku.py | 19 ++++++++----------- aku/tp.py | 1 + aku/utils.py | 4 ++++ examples/naive37.py | 5 +++++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index 53c008e..e952877 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -38,22 +38,16 @@ def option(self, fn): self._functions.append(fn) return fn - @staticmethod - def _add_root_function(argument_parser, fn, name): - AkuTp[Type[fn]].add_argument( - argument_parser=argument_parser, name='@root', default=SUPPRESS, - prefixes=(), domain=(), - ) - return Namespace(**{'@root.@fn': (fn, name)}) - def parse_args(self, args=None) -> Namespace: assert len(self._functions) > 0 namespace, args, argument_parser = None, sys.argv, self if len(self._functions) == 1: fn = self._functions[0] - name = fetch_name(fn) - namespace = self._add_root_function(argument_parser, fn, name) + AkuTp[Type[fn]].add_argument( + argument_parser=argument_parser, name='@root', default=SUPPRESS, + prefixes=(), domain=(), + ) else: subparsers = self.add_subparsers() functions = {} @@ -66,7 +60,10 @@ def parse_args(self, args=None) -> Namespace: if len(args) > 1 and args[1] in functions: fn, argument_parser = functions[args[1]] - namespace = self._add_root_function(argument_parser, fn, args[1]) + AkuTp[Type[fn]].add_argument( + argument_parser=argument_parser, name='@root', default=SUPPRESS, + prefixes=(), domain=(), + ) while True: namespace, args = argument_parser.parse_known_args(args=args, namespace=namespace) diff --git a/aku/tp.py b/aku/tp.py index 3b2f8a0..a20af18 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -151,6 +151,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, if name.endswith('_'): prefixes = prefixes + (name[:-1],) + argument_parser.set_defaults(**{join_dests(domain, f'@fn'): (self.tp, name)}) for arg, tp, df in tp_iter(self.tp): tp = AkuTp[tp] tp.add_argument( diff --git a/aku/utils.py b/aku/utils.py index 7b3e6ac..93a2546 100644 --- a/aku/utils.py +++ b/aku/utils.py @@ -5,6 +5,10 @@ NEW_ACTIONS = '_new_actions' +AKU = '@aku' +AKU_FN = '@fn' +AKU_ROOT = '@root' + def tp_bool(arg_strings: str) -> bool: arg_strings = arg_strings.lower().strip() diff --git a/examples/naive37.py b/examples/naive37.py index c2d59bc..a78d8ae 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -27,6 +27,11 @@ def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') +@aku.option +def delay(work: Type[foo]): + work() + + @aku.option def both(work: Union[Type[foo], Type[bar]]): work() From 8100c22db3eb2fdf7f5894f664b4e8a88aa93b65 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 19:06:36 +0900 Subject: [PATCH 65/72] Refactor: Update constants --- aku/aku.py | 14 +++++++------- aku/tp.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/aku/aku.py b/aku/aku.py index e952877..af6291a 100644 --- a/aku/aku.py +++ b/aku/aku.py @@ -5,7 +5,7 @@ from typing import Type from aku.tp import AkuTp -from aku.utils import _init_argument_parser, NEW_ACTIONS, fetch_name +from aku.utils import _init_argument_parser, NEW_ACTIONS, fetch_name, AKU, AKU_FN, AKU_ROOT class Aku(ArgumentParser): @@ -45,7 +45,7 @@ def parse_args(self, args=None) -> Namespace: if len(self._functions) == 1: fn = self._functions[0] AkuTp[Type[fn]].add_argument( - argument_parser=argument_parser, name='@root', default=SUPPRESS, + argument_parser=argument_parser, name=AKU_ROOT, default=SUPPRESS, prefixes=(), domain=(), ) else: @@ -61,7 +61,7 @@ def parse_args(self, args=None) -> Namespace: if len(args) > 1 and args[1] in functions: fn, argument_parser = functions[args[1]] AkuTp[Type[fn]].add_argument( - argument_parser=argument_parser, name='@root', default=SUPPRESS, + argument_parser=argument_parser, name=AKU_ROOT, default=SUPPRESS, prefixes=(), domain=(), ) @@ -89,7 +89,7 @@ def run(self, namespace: Namespace = None): for name in names: curry_co = curry_co.setdefault(name, {}) literal_co = literal_co.setdefault(name, {}) - if key == '@fn': + if key == AKU_FN: curry_co[key] = value[0] literal_co[key] = value[1] else: @@ -97,8 +97,8 @@ def run(self, namespace: Namespace = None): def recur(item): if isinstance(item, dict): - if '@fn' in item: - func = item.pop('@fn') + if AKU_FN in item: + func = item.pop(AKU_FN) kwargs = {k: recur(v) for k, v in item.items()} return functools.partial(func, **kwargs) else: @@ -112,4 +112,4 @@ def recur(item): if inspect.getfullargspec(fn).varkw is None: return fn() else: - return fn(**{'@aku': curry}) + return fn(**{AKU: curry}) diff --git a/aku/tp.py b/aku/tp.py index a20af18..5a3d0e6 100644 --- a/aku/tp.py +++ b/aku/tp.py @@ -3,7 +3,7 @@ from aku.actions import StoreAction, AppendListAction from aku.compat import Literal, get_origin, get_args -from aku.utils import register_homo_tuple, register_hetero_tuple, tp_iter, join_names, join_dests, NEW_ACTIONS +from aku.utils import register_homo_tuple, register_hetero_tuple, tp_iter, join_names, join_dests, NEW_ACTIONS, AKU_FN class AkuTp(object): @@ -151,7 +151,7 @@ def add_argument(self, argument_parser: ArgumentParser, name: str, default: Any, if name.endswith('_'): prefixes = prefixes + (name[:-1],) - argument_parser.set_defaults(**{join_dests(domain, f'@fn'): (self.tp, name)}) + argument_parser.set_defaults(**{join_dests(domain, AKU_FN): (self.tp, name)}) for arg, tp, df in tp_iter(self.tp): tp = AkuTp[tp] tp.add_argument( @@ -196,7 +196,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values, option_ prefixes_name = join_names(prefixes, name) argument_parser.add_argument( - f'--{prefixes_name}', dest=join_dests(domain + (name,), '@fn'), help=prefixes_name, + f'--{prefixes_name}', dest=join_dests(domain + (name,), AKU_FN), help=prefixes_name, type=self.tp, choices=tuple(choices.keys()), required=True, default=SUPPRESS, action=UnionAction, metavar=f'{{{", ".join(choices.keys())}}}[fn]' ) From bc969ca1b375d07e14b2edeefdc511a852353fb4 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 19:09:33 +0900 Subject: [PATCH 66/72] Doc: Update examples --- examples/naive37.py | 14 ++++++++++---- examples/naive39.py | 43 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/examples/naive37.py b/examples/naive37.py index a78d8ae..6463972 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -28,13 +28,19 @@ def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], @aku.option -def delay(work: Type[foo]): - work() +def delay(call: Type[foo]): + call() @aku.option -def both(work: Union[Type[foo], Type[bar]]): - work() +def one(call: Union[Type[foo], Type[bar]]): + call() + + +@aku.option +def both(a_: Type[foo], b_: Type[bar]): + a_() + b_() class A(object): diff --git a/examples/naive39.py b/examples/naive39.py index 8c1015d..5782c8b 100644 --- a/examples/naive39.py +++ b/examples/naive39.py @@ -1,7 +1,7 @@ from pathlib import Path -from typing import Literal +from typing import Union -from aku import Aku +from aku import Aku, Literal aku = Aku() @@ -12,7 +12,8 @@ def foo(x: int, y: str = '4', z: bool = True, w: Path = Path.home(), **kwargs): print(f'{foo.__name__}.y => {y}') print(f'{foo.__name__}.z => {z}') print(f'{foo.__name__}.w => {w}') - print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') + if '@aku' in kwargs: + print(f'{foo.__name__}.@aku => {kwargs["@aku"]}') @aku.option @@ -22,7 +23,41 @@ def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], print(f'{bar.__name__}.y => {y}') print(f'{bar.__name__}.z => {z}') print(f'{bar.__name__}.w => {w}') - print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') + if '@aku' in kwargs: + print(f'{bar.__name__}.@aku => {kwargs["@aku"]}') +@aku.option +def delay(call: type[foo]): + call() + + +@aku.option +def one(call: Union[type[foo], type[bar]]): + call() + + +@aku.option +def both(a_: type[foo], b_: type[bar]): + a_() + b_() + + +class A(object): + @classmethod + def baz(cls, x: int): + print(f'{A.__name__}.x => {x}') + + +aku.option(A.baz) + + +class B(object): + @classmethod + def baz(cls, x: int): + print(f'{B.__name__}.x => {x}') + + +aku.option(B.baz) + print(aku.run()) From d90a9d0c1e5c8c209b7015b901c64fe64bf55d2e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 19:44:13 +0900 Subject: [PATCH 67/72] Doc: Update README.md --- README.md | 91 ++++++++++++++++++++++++++++++++++---- examples/bar.py | 22 +++++++++ examples/equivalent_foo.py | 25 +++++++++++ examples/foo.py | 10 ++++- 4 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 examples/bar.py create mode 100644 examples/equivalent_foo.py diff --git a/README.md b/README.md index 1b32085..faea2fd 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![PyPI version](https://badge.fury.io/py/aku.svg)](https://badge.fury.io/py/aku) [![Downloads](https://pepy.tech/badge/aku)](https://pepy.tech/project/aku) -Aku is an interactive annotation-driven `ArgumentParser` generator. +An interactive annotation-driven `ArgumentParser` generator ## Requirements @@ -12,34 +12,109 @@ Aku is an interactive annotation-driven `ArgumentParser` generator. ## Install -```bash +```shell script python -m pip install aku --upgrade ``` ## Usage +The key idea of aku to generate `ArgumentParser` according to the type annotations of functions. For example, to register single function with only primitive types, i.e., `int`, `bool`, `str`, `float`, `Path`, etc. + ```python +from pathlib import Path + from aku import Aku aku = Aku() @aku.option -def add(a: int, b: int = 2): - print(f'{a} + {b} => {a + b}') +def foo(a: int, b: bool = True, c: str = '3', d: float = 4.0, e: Path = Path.home()): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + print(f'e => {e}') aku.run() ``` -`aku` will automatically add argument options according to your function signature. +`aku` will generate a `ArgumentParser` which provides your command line interface looks like below, ```shell script -python examples/foo.py --help -usage: foo.py [-h] --a int [--b int] +~ python examples/foo.py --help +usage: foo.py [-h] --a int [--b bool] [--c str] [--d float] [--e path] optional arguments: -h, --help show this help message and exit --a int a - --b int b (default: 2) + --b bool b (default: True) + --c str c (default: 3) + --d float d (default: 4.0) + --e path e (default: /Users/home) +``` + +Of course you can achieve the same functions by instantiating an `ArgumentParser`, but `aku` certainly makes such steps simple and efficient. + +```python +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, SUPPRESS +from pathlib import Path + + +def tp_bool(arg_strings: str) -> bool: + arg_strings = arg_strings.lower().strip() + if arg_strings in ('t', 'true', 'y', 'yes', '1'): + return True + if arg_strings in ('f', 'false', 'n', 'no', '0'): + return False + raise ValueError + + +argument_parser = ArgumentParser( + formatter_class=ArgumentDefaultsHelpFormatter, +) +argument_parser.add_argument('--a', type=int, metavar='int', default=SUPPRESS, required=True, help='a') +argument_parser.add_argument('--b', type=tp_bool, metavar='bool', default=True, help='b') +argument_parser.add_argument('--c', type=str, metavar='str', default='3', help='c') +argument_parser.add_argument('--d', type=float, metavar='float', default=4.0, help='d') +argument_parser.add_argument('--e', type=Path, metavar='path', default=Path.home(), help='e') + +args = argument_parser.parse_args().__dict__ +for key, value in args.items(): + print(f'{key} => {value}') +``` + +Moreover, if you register more than one functions, e.g., register function `add`, + +```python +@aku.option +def add(x: int, y: int): + print(f'{x} + {y} => {x + y}') +``` + +Then you can choose which one to run by passing its name as the first parameter, + +```shell script +~ python examples/bar.py foo --help +usage: bar.py foo [-h] --a int [--b bool] [--c str] [--d float] [--e path] + +optional arguments: + -h, --help show this help message and exit + --a int a + --b bool b (default: True) + --c str c (default: 3) + --d float d (default: 4.0) + --e path e (default: /Users/home) + +~ python examples/bar.py add --help +usage: bar.py add [-h] --x int --y int + +optional arguments: + -h, --help show this help message and exit + --x int x + --y int y + +~ python examples/bar.py add --x 1 --y 2 +1 + 2 => 3 ``` \ No newline at end of file diff --git a/examples/bar.py b/examples/bar.py new file mode 100644 index 0000000..fc0f2e6 --- /dev/null +++ b/examples/bar.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from aku import Aku + +aku = Aku() + + +@aku.option +def foo(a: int, b: bool = True, c: str = '3', d: float = 4.0, e: Path = Path.home()): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + print(f'e => {e}') + + +@aku.option +def add(x: int, y: int): + print(f'{x} + {y} => {x + y}') + + +aku.run() diff --git a/examples/equivalent_foo.py b/examples/equivalent_foo.py new file mode 100644 index 0000000..a033f22 --- /dev/null +++ b/examples/equivalent_foo.py @@ -0,0 +1,25 @@ +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, SUPPRESS +from pathlib import Path + + +def tp_bool(arg_strings: str) -> bool: + arg_strings = arg_strings.lower().strip() + if arg_strings in ('t', 'true', 'y', 'yes', '1'): + return True + if arg_strings in ('f', 'false', 'n', 'no', '0'): + return False + raise ValueError + + +argument_parser = ArgumentParser( + formatter_class=ArgumentDefaultsHelpFormatter, +) +argument_parser.add_argument('--a', type=int, metavar='int', default=SUPPRESS, required=True, help='a') +argument_parser.add_argument('--b', type=tp_bool, metavar='bool', default=True, help='b') +argument_parser.add_argument('--c', type=str, metavar='str', default='3', help='c') +argument_parser.add_argument('--d', type=float, metavar='float', default=4.0, help='d') +argument_parser.add_argument('--e', type=Path, metavar='path', default=Path.home(), help='e') + +args = argument_parser.parse_args().__dict__ +for key, value in args.items(): + print(f'{key} => {value}') diff --git a/examples/foo.py b/examples/foo.py index d6eedd5..00775d1 100644 --- a/examples/foo.py +++ b/examples/foo.py @@ -1,11 +1,17 @@ +from pathlib import Path + from aku import Aku aku = Aku() @aku.option -def add(a: int, b: int = 2): - print(f'{a} + {b} => {a + b}') +def foo(a: int, b: bool = True, c: str = '3', d: float = 4.0, e: Path = Path.home()): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + print(f'e => {e}') aku.run() From c7ee10fcb476d6a03b79de3f5cc9a84432a2d31e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 19:48:20 +0900 Subject: [PATCH 68/72] Doc: Update README.md (Type Annotations) --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index faea2fd..afe63d8 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,16 @@ An interactive annotation-driven `ArgumentParser` generator python -m pip install aku --upgrade ``` +## Type Annotations + +* primitive types, e.g., `int`, `bool`, `str`, `float`, `Path`, etc. +* list `List[T]` +* homogeneous tuple, e.g., `Tuple[T, ...]` +* heterogeneous tuple, e.g., `Tuple[T1, T2, T3]` +* literal, e.g., `Literal[42, 1905]` +* function, e.g., `Type[]` +* union of functions, e.g., `Union[Type[], Type[], Type[]]` + ## Usage The key idea of aku to generate `ArgumentParser` according to the type annotations of functions. For example, to register single function with only primitive types, i.e., `int`, `bool`, `str`, `float`, `Path`, etc. From f1eceecbc2d971017ddb6d152d71bc4a2259ebe3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 20:00:46 +0900 Subject: [PATCH 69/72] Doc: Update README.md (Nested Types) --- README.md | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ examples/baz.py | 17 ++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 examples/baz.py diff --git a/README.md b/README.md index afe63d8..1813e8d 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ python -m pip install aku --upgrade ## Usage +### Primitive Types + The key idea of aku to generate `ArgumentParser` according to the type annotations of functions. For example, to register single function with only primitive types, i.e., `int`, `bool`, `str`, `float`, `Path`, etc. ```python @@ -127,4 +129,54 @@ optional arguments: ~ python examples/bar.py add --x 1 --y 2 1 + 2 => 3 +``` + +### Nested Types + +```python +from typing import List, Tuple + +from aku import Aku, Literal + +aku = Aku() + + +@aku.option +def baz(a: List[int], b: Tuple[bool, ...], c: Tuple[int, bool, str], d: Literal[42, 1905]): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + + +if __name__ == '__main__': + aku.run() +``` + +* argument `a` is annotated with `List[int]`, thus every `--a` appends one item at the end of existing list +* homogenous tuple holds arbitrary number of elements with the same type, while heterogeneous tuple holds specialized number of elements with specialized type +* literal arguments can be assigned value from the specified ones + +```shell script +~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,true,false,false,true" --c 42,true,"yes" --d 42 +a => [1, 2, 3] +b => (True, True, False, False, True) +c => (42, True, 'yes') +d => 42 + +~ python examples/baz.py --a 1 --a 2 --a nice --b "true,wow" --c 42,true,"yes" --d 42 +usage: baz.py [-h] [--a [int]] --b bool, ...) --c (int, bool, str --d int{1905, 42} +baz.py: error: argument --a: invalid int value: 'nice' + +~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,wow" --c 42,true,"yes" --d 42 +usage: baz.py [-h] [--a [int]] --b bool, ...) --c (int, bool, str --d int{1905, 42} +baz.py: error: argument --b: invalid fn value: 'true,wow' + +~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,true,false,false,true" --c 42,true,"yes,43" --d 42 +usage: baz.py [-h] [--a [int]] [--b bool, ...)] --c (int, bool, str --d int{1905, 42} +baz.py: error: argument --c: invalid fn value: '42,true,yes,43' + +~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,true,false,false,true" --c 42,true,"yes" --d 43 +usage: baz.py [-h] [--a [int]] [--b bool, ...)] [--c (int, bool, str] --d int{1905, 42} +baz.py: error: argument --d: invalid choice: 43 (choose from 42, 1905) ``` \ No newline at end of file diff --git a/examples/baz.py b/examples/baz.py new file mode 100644 index 0000000..5a3de8e --- /dev/null +++ b/examples/baz.py @@ -0,0 +1,17 @@ +from typing import List, Tuple + +from aku import Aku, Literal + +aku = Aku() + + +@aku.option +def baz(a: List[int], b: Tuple[bool, ...], c: Tuple[int, bool, str], d: Literal[42, 1905]): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + + +if __name__ == '__main__': + aku.run() From 4652ef158ed7b935bab845dc531afe25dae1af10 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 20:16:18 +0900 Subject: [PATCH 70/72] Doc: Update README.md (Nested Types) --- README.md | 87 ++++++++++++++++++++++++++++++++++++++++++++- examples/naive37.py | 2 +- examples/naive39.py | 2 +- examples/qux.py | 28 +++++++++++++++ 4 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 examples/qux.py diff --git a/README.md b/README.md index 1813e8d..5d70ebc 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ optional arguments: 1 + 2 => 3 ``` -### Nested Types +### Container Types ```python from typing import List, Tuple @@ -158,6 +158,16 @@ if __name__ == '__main__': * literal arguments can be assigned value from the specified ones ```shell script +~ python examples/baz.py --help +usage: baz.py [-h] --a [int] --b bool, ...) --c (int, bool, str --d int{1905, 42} + +optional arguments: + -h, --help show this help message and exit + --a [int] a + --b (bool, ...) b + --c (int, bool, str) c + --d int{1905, 42} d + ~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,true,false,false,true" --c 42,true,"yes" --d 42 a => [1, 2, 3] b => (True, True, False, False, True) @@ -179,4 +189,79 @@ baz.py: error: argument --c: invalid fn value: '42,true,yes,43' ~ python examples/baz.py --a 1 --a 2 --a 3 --b "true,true,false,false,true" --c 42,true,"yes" --d 43 usage: baz.py [-h] [--a [int]] [--b bool, ...)] [--c (int, bool, str] --d int{1905, 42} baz.py: error: argument --d: invalid choice: 43 (choose from 42, 1905) +``` + +### Nested Types + +Wrap your function in `Type` and then this can be passed as a higher-order type to annotations, then `aku` can recursively analysis them. For `Union` type, you can choose which type to run at command line interface. To avoid name conflicting, you can open a sub-namespace by adding a underline to your argument name. + +```python +from typing import Type, Union +from aku import Aku + + +def add(x: int, y: int): + print(f'{x} + {y} => {x + y}') + + +def sub(x: int, y: int): + print(f'{x} - {y} => {x - y}') + + +aku = Aku() + + +@aku.option +def one(op: Union[Type[add], Type[sub]]): + op() + + +@aku.option +def both(lhs_: Type[add], rhs_: Type[sub]): + lhs_() + rhs_() + + +if __name__ == '__main__': + aku.run() +``` + +```shell script +~ python examples/qux.py one --op add --help +usage: qux.py one [-h] [--op {add, sub}[fn]] + +optional arguments: + -h, --help show this help message and exit + --op {add, sub}[fn] op (default: (, 'op')) + --x int x + --y int y + +~ python examples/qux.py one --op add --x 1 --y 2 +1 + 2 => 3 + +~ python examples/qux.py one --op sub --help +usage: qux.py one [-h] [--op {add, sub}[fn]] + +optional arguments: + -h, --help show this help message and exit + --op {add, sub}[fn] op (default: (, 'op')) + --x int x + --y int y + +~ python examples/qux.py one --op sub --x 1 --y 2 +1 - 2 => -1 + +~ python examples/qux.py both --help +usage: qux.py both [-h] --lhs-x int --lhs-y int --rhs-x int --rhs-y int + +optional arguments: + -h, --help show this help message and exit + --lhs-x int lhs-x + --lhs-y int lhs-y + --rhs-x int rhs-x + --rhs-y int rhs-y + +~ python examples/qux.py both --lhs-x 1 --lhs-y 2 --rhs-x 3 --rhs-y 4 +1 + 2 => 3 +3 - 4 => -1 ``` \ No newline at end of file diff --git a/examples/naive37.py b/examples/naive37.py index 6463972..c2d14eb 100644 --- a/examples/naive37.py +++ b/examples/naive37.py @@ -28,7 +28,7 @@ def bar(x: Literal[1, 2, 3] = 2, y: List[int] = [2, 3, 4], @aku.option -def delay(call: Type[foo]): +def delegate(call: Type[foo]): call() diff --git a/examples/naive39.py b/examples/naive39.py index 5782c8b..9ecdf4c 100644 --- a/examples/naive39.py +++ b/examples/naive39.py @@ -28,7 +28,7 @@ def bar(x: Literal[1, 2, 3] = 2, y: list[int] = [2, 3, 4], @aku.option -def delay(call: type[foo]): +def delegate(call: type[foo]): call() diff --git a/examples/qux.py b/examples/qux.py new file mode 100644 index 0000000..f882370 --- /dev/null +++ b/examples/qux.py @@ -0,0 +1,28 @@ +from typing import Type, Union +from aku import Aku + + +def add(x: int, y: int): + print(f'{x} + {y} => {x + y}') + + +def sub(x: int, y: int): + print(f'{x} - {y} => {x - y}') + + +aku = Aku() + + +@aku.option +def one(op: Union[Type[add], Type[sub]]): + op() + + +@aku.option +def both(lhs_: Type[add], rhs_: Type[sub]): + lhs_() + rhs_() + + +if __name__ == '__main__': + aku.run() From 176343a2125b30739573d66788ce4d39a05cad99 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 20:17:53 +0900 Subject: [PATCH 71/72] Doc: Update README.md --- README.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5d70ebc..d03c7f0 100644 --- a/README.md +++ b/README.md @@ -18,19 +18,22 @@ python -m pip install aku --upgrade ## Type Annotations -* primitive types, e.g., `int`, `bool`, `str`, `float`, `Path`, etc. -* list `List[T]` -* homogeneous tuple, e.g., `Tuple[T, ...]` -* heterogeneous tuple, e.g., `Tuple[T1, T2, T3]` -* literal, e.g., `Literal[42, 1905]` -* function, e.g., `Type[]` -* union of functions, e.g., `Union[Type[], Type[], Type[]]` +* primitive types, + - e.g., `int`, `bool`, `str`, `float`, `Path`, etc. +* container types + - list `List[T]` + - homogeneous tuple, e.g., `Tuple[T, ...]` + - heterogeneous tuple, e.g., `Tuple[T1, T2, T3]` + - literal, e.g., `Literal[42, 1905]` +* nested types + - function, e.g., `Type[]` + - union of functions, e.g., `Union[Type[], Type[], Type[]]` ## Usage ### Primitive Types -The key idea of aku to generate `ArgumentParser` according to the type annotations of functions. For example, to register single function with only primitive types, i.e., `int`, `bool`, `str`, `float`, `Path`, etc. +The key idea of aku to generate `ArgumentParser` according to the type annotations of functions. For example, to register single function with only primitive types, ```python from pathlib import Path From ca2bd2ddb323b974598593772f9bf548b5743a44 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 7 Nov 2020 20:20:00 +0900 Subject: [PATCH 72/72] Doc: Update README.md --- README.md | 13 ++++++++++--- examples/equivalent_foo.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d03c7f0..67ac509 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,14 @@ def tp_bool(arg_strings: str) -> bool: raise ValueError +def foo(a: int, b: bool = True, c: str = '3', d: float = 4.0, e: Path = Path.home()): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + print(f'e => {e}') + + argument_parser = ArgumentParser( formatter_class=ArgumentDefaultsHelpFormatter, ) @@ -95,9 +103,8 @@ argument_parser.add_argument('--c', type=str, metavar='str', default='3', help=' argument_parser.add_argument('--d', type=float, metavar='float', default=4.0, help='d') argument_parser.add_argument('--e', type=Path, metavar='path', default=Path.home(), help='e') -args = argument_parser.parse_args().__dict__ -for key, value in args.items(): - print(f'{key} => {value}') +args = argument_parser.parse_args() +foo(a=args.a, b=args.b, c=args.c, d=args.d, e=args.e) ``` Moreover, if you register more than one functions, e.g., register function `add`, diff --git a/examples/equivalent_foo.py b/examples/equivalent_foo.py index a033f22..e4bb33b 100644 --- a/examples/equivalent_foo.py +++ b/examples/equivalent_foo.py @@ -11,6 +11,14 @@ def tp_bool(arg_strings: str) -> bool: raise ValueError +def foo(a: int, b: bool = True, c: str = '3', d: float = 4.0, e: Path = Path.home()): + print(f'a => {a}') + print(f'b => {b}') + print(f'c => {c}') + print(f'd => {d}') + print(f'e => {e}') + + argument_parser = ArgumentParser( formatter_class=ArgumentDefaultsHelpFormatter, ) @@ -20,6 +28,5 @@ def tp_bool(arg_strings: str) -> bool: argument_parser.add_argument('--d', type=float, metavar='float', default=4.0, help='d') argument_parser.add_argument('--e', type=Path, metavar='path', default=Path.home(), help='e') -args = argument_parser.parse_args().__dict__ -for key, value in args.items(): - print(f'{key} => {value}') +args = argument_parser.parse_args() +foo(a=args.a, b=args.b, c=args.c, d=args.d, e=args.e)