From 09f824ae74af6a95e6eded6dbffb6d24cfc5b377 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Wed, 19 Jun 2024 14:03:05 +0200 Subject: [PATCH] custom model checkbox --- src/autotrain/__init__.py | 2 +- .../scripts/fetch_data_and_update_models.js | 7 ++++- src/autotrain/app/static/scripts/utils.js | 6 +++- src/autotrain/app/templates/index.html | 31 +++++++++++++++++-- 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index cd7e891619..1135e90539 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -41,7 +41,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub") logger = Logger().get_logger() -__version__ = "0.7.125.dev0" +__version__ = "0.7.126.dev0" def is_colab(): diff --git a/src/autotrain/app/static/scripts/fetch_data_and_update_models.js b/src/autotrain/app/static/scripts/fetch_data_and_update_models.js index 357d7c704e..20359f4f5f 100644 --- a/src/autotrain/app/static/scripts/fetch_data_and_update_models.js +++ b/src/autotrain/app/static/scripts/fetch_data_and_update_models.js @@ -4,6 +4,8 @@ document.addEventListener('DOMContentLoaded', function () { const baseModelSelect = document.getElementById('base_model'); const queryParams = new URLSearchParams(window.location.search); const customModelsValue = queryParams.get('custom_models'); + const baseModelInput = document.getElementById('base_model_input'); + const baseModelCheckbox = document.getElementById('base_model_checkbox'); let fetchURL = `/ui/model_choices/${taskValue}`; if (customModelsValue) { @@ -14,6 +16,9 @@ document.addEventListener('DOMContentLoaded', function () { .then(response => response.json()) .then(data => { const baseModelSelect = document.getElementById('base_model'); + baseModelCheckbox.checked = false; + baseModelSelect.classList.remove('hidden'); + baseModelInput.classList.add('hidden'); baseModelSelect.innerHTML = ''; // Clear existing options data.forEach(model => { let option = document.createElement('option'); @@ -21,7 +26,7 @@ document.addEventListener('DOMContentLoaded', function () { option.textContent = model.name; // Assuming each model has a 'name' baseModelSelect.appendChild(option); }); - }) + })gi .catch(error => console.error('Error:', error)); } document.getElementById('task').addEventListener('change', fetchDataAndUpdateModels); diff --git a/src/autotrain/app/static/scripts/utils.js b/src/autotrain/app/static/scripts/utils.js index c5f7cb6b51..a81267b807 100644 --- a/src/autotrain/app/static/scripts/utils.js +++ b/src/autotrain/app/static/scripts/utils.js @@ -83,9 +83,13 @@ document.addEventListener('DOMContentLoaded', function () { } else { params = paramsJsonElement.value; } + const baseModelValue = document.getElementById('base_model_checkbox').checked + ? document.getElementById('base_model_input').value + : document.getElementById('base_model').value; + + formData.append('base_model', baseModelValue); formData.append('project_name', document.getElementById('project_name').value); formData.append('task', document.getElementById('task').value); - formData.append('base_model', document.getElementById('base_model').value); formData.append('hardware', document.getElementById('hardware').value); formData.append('params', params); formData.append('autotrain_user', document.getElementById('autotrain_user').value); diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 397df5a410..85eddde444 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -371,9 +371,19 @@

Base Model

- +
+ + +
+ + +
+
@@ -624,6 +634,21 @@

}); }); + \ No newline at end of file