Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
vertefra committed Sep 25, 2024
1 parent 87bc5c2 commit a191f1a
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions examples/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from qcog_python_client import AsyncQcogClient, QcogClient
from qcog_python_client.schema import GradOptimizationParameters, GradStateParameters
from qcog_python_client.schema.generated_schema.models import AnalyticOptimizationParameters, LOBPCGFastStateParameters

API_TOKEN = os.environ["API_TOKEN"]

Expand Down Expand Up @@ -111,16 +112,15 @@ async def big_data_test() -> None:
)

if DATASET_ID is None:
dataset_id = os.environ["DATASET_NAME"]

dataset_name = os.environ["DATASET_NAME"]
big_df = _get_test_df(100)
size = big_df.memory_usage(deep=True).sum() / 1024**2
print("Testing Size of big_df MB: ", size)

print("Testing upload_data")

start = time.time()
await client.upload_data(big_df, dataset_id)
await client.upload_data(big_df, dataset_name)
end = time.time()
print(f"`upload_data` Time taken to upload {size} MB of data: ", end - start)
else:
Expand Down Expand Up @@ -158,6 +158,41 @@ async def check_status() -> None:
await client.preloaded_model(MODEL_ID)
status = await client.progress()
print(status)
loss = await client.get_loss()
print(loss)


async def case_ensemble() -> None:
"""Test case ensemble."""
client = await AsyncQcogClient.create(
token=API_TOKEN,
hostname=HOST,
port=PORT,
)

dataset_id = "ab1aae7c-28d7-37eb-a251-1479f61818ab"

client.ensemble(
operators=["X", "Y", "Z"],
dim=4,
num_axes=4,
seed=1
)

await client.preloaded_data(dataset_id)

await client.train(
batch_size=1,
num_passes=1,
weight_optimization=AnalyticOptimizationParameters(),
get_states_extra=LOBPCGFastStateParameters(
iterations=1,
tol=0.05
)
)

print(client.trained_model)



if __name__ == "__main__":
Expand All @@ -172,5 +207,7 @@ async def check_status() -> None:
asyncio.run(big_data_test())
elif cmd == "status":
asyncio.run(check_status())
elif cmd == "case_ensemble":
asyncio.run(case_ensemble())
else:
raise ValueError(f"Invalid command: {cmd}")

0 comments on commit a191f1a

Please sign in to comment.