diff --git a/.gitignore b/.gitignore index b31e3e1..e5479ac 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ **/.DS_Store # User generated files +**/temp/ **/design-*.json **/toy_*.json diff --git a/citrination_api_examples/clients_sequence/1_data_client_api_tutorial.ipynb b/citrination_api_examples/clients_sequence/1_data_client_api_tutorial.ipynb index feb8e27..aec4c45 100644 --- a/citrination_api_examples/clients_sequence/1_data_client_api_tutorial.ipynb +++ b/citrination_api_examples/clients_sequence/1_data_client_api_tutorial.ipynb @@ -15,7 +15,7 @@ "\n", "*Authors: Enze Chen, Max Gallant*\n", "\n", - "In this notebook, we will cover how to use the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/) to upload and manage datasets on Citrination. Getting your data on Citrination will allow you to keep your data organized in one place and enable you to perform machine learning (ML) on the data. The application program interface (API) aims to facilitate the process for those who prefer writing Python scripts and wish to avoid the web user interface (UI). As a sanity check, however, it might be helpful for you to keep the UI open and follow along with the tutorial to verify the changes are what you expect." + "In this notebook, we will cover how to use the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/) to upload and manage datasets on Citrination. Getting your data on Citrination will allow you to keep your data organized in one place and enable you to perform machine learning (ML) on the data. The application program interface (API) aims to facilitate the process for those who prefer writing Python scripts. " ] }, { @@ -78,9 +78,9 @@ "outputs": [], "source": [ "# Standard packages\n", - "import os\n", - "import time\n", - "import uuid # generating random IDs\n", + "from os import environ # get environment variables\n", + "from time import sleep # wait time\n", + "from uuid import uuid4 # generating random IDs\n", "\n", "# Third-party packages\n", "from citrination_client import *" @@ -94,11 +94,11 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "Assuming that this is the very first time you're interacting with the Citrination API, we will first go over how to properly initialize the client that handles all communication. Most APIs require a key for access, and the PyCC is no exception. You can find your API key by navigating to [Citrination](https://citrination.com), clicking your username in the top-right corner, clicking \"Account Settings,\" and then looking under your Email. Copy this key to your clipboard.\n", + "Assuming that this is the very first time you're interacting with the Citrination API, we will first go over how to properly initialize the client that handles all communication. Most APIs require a key for access, and the PyCC is no exception. You can find your API key by navigating to [Citrination](https://citrination.com), clicking your username in the top-right corner, clicking \"Account Settings,\" and then looking under your Email. Copy this key to your clipboard (`Ctrl+C`).\n", "\n", "Since the key is linked to your specific user profile, *you should never hard-code or expose your API key in your code.* Instead, first store the API key in your [environment variables](https://medium.com/@himanshuagarwal1395/setting-up-environment-variables-in-macos-sierra-f5978369b255) like so (for Macs):\n", "* In Terminal, type `vim ~/.bash_profile` (or use an editor of your choice).\n", - "* In that file, press `i` (edit mode) and add the line `export CITRINATION_API_KEY=\"your_api_key\"`.\n", + "* In that file, press `i` (edit mode) and add the line `export CITRINATION_API_KEY=\"paste_your_api_key\"`.\n", "* Save and exit (`Esc`, `:wq`, `Enter`).\n", "* Open up a new Terminal and load this notebook one more time.\n", "\n", @@ -113,8 +113,8 @@ "metadata": {}, "outputs": [], "source": [ - "site = \"https://citrination.com\" # site you want to access; we'll use the public site\n", - "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'), \n", + "site = \"https://citrination.com\" # site you want to access; we'll use the public site\n", + "client = CitrinationClient(api_key=environ.get('CITRINATION_API_KEY'), \n", " site=site)\n", "client # reveal the attributes" ] @@ -136,7 +136,7 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "Once the base client is initialized, the [`DataClient`](http://citrineinformatics.github.io/python-citrination-client/tutorial/data_examples.html) can be easily accessed as an attribute." + "Once the base client is initialized, the [`DataClient`](http://citrineinformatics.github.io/python-citrination-client/tutorial/data_examples.html) can be easily accessed using the `.data` attribute." ] }, { @@ -146,7 +146,7 @@ "outputs": [], "source": [ "data_client = client.data\n", - "data_client # reveal the methods" + "data_client # reveal the methods" ] }, { @@ -154,10 +154,10 @@ "metadata": {}, "source": [ "### Create a dataset\n", - "Before you can upload data, you have to create an empty dataset to store the files in. The `create_dataset()` method does exactly this and returns a [`Dataset`](http://citrineinformatics.github.io/python-citrination-client/modules/data/datasets.html) object. The method has the following inputs:\n", - "* **name**: A string for the name of the dataset. It cannot be the same as that of an existing dataset that you own.\n", - "* **description**: A string for the description of the dataset.\n", - "* **public**: A Boolean indicating whether to make the dataset public (`default=False`)." + "Before you can upload data, you have to create an empty dataset to store the files in. The `create_dataset()` method of the `DataClient` does exactly this and returns a [`Dataset`](http://citrineinformatics.github.io/python-citrination-client/modules/data/datasets.html) object. The method has the following inputs:\n", + "* `name`: A string for the name of the dataset. It cannot be the same as that of an existing dataset that you own.\n", + "* `description`: A string for the description of the dataset.\n", + "* `public`: A Boolean indicating whether to make the dataset public (`default=False`)." ] }, { @@ -168,16 +168,17 @@ }, "outputs": [], "source": [ - "data_name = 'PyCC Dataset ' + str(uuid.uuid4())[:6]\n", + "data_name = 'PyCC Dataset ' + str(uuid4())[:6]\n", "data_desc = 'This dataset was created by the PyCC API tutorial.'\n", - "dataset = data_client.create_dataset(name=data_name, description=data_desc)" + "dataset = data_client.create_dataset(name=data_name, \n", + " description=data_desc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Once you've created the `Dataset` object, you can obtain from its attributes the dataset ID, which you will need for subsequent operations." + "Once you've created the `Dataset` object, you can obtain the dataset ID from the `.id` attribute of a `Dataset`. You will need this ID for subsequent operations." ] }, { @@ -192,15 +193,22 @@ "print('It can be accessed at {}/datasets/{}'.format(site, dataset_id))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you click on the above URL, it will take you to the dataset on Citrination, which at this point should be empty." + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload data to a dataset\n", - "The `upload()` method allows you to upload a file or a directory to a dataset. The method has the following inputs:\n", - "* **dataset_id**: The integer value of the ID of the dataset to which you will be uploading data.\n", - "* **source_path**: The path to the file or directory you want to upload.\n", - "* **dest_path**: The name of the file or directory as it should appear on Citrination (`default=None`).\n", + "The `upload()` method of the `DataClient` allows you to upload a file or a directory to a dataset. The method has the following inputs:\n", + "* `dataset_id`: The integer value of the ID of the dataset to which you will be uploading data.\n", + "* `source_path`: The path to the file or directory you want to upload.\n", + "* `dest_path`: The name of the file or directory as it should appear on Citrination (`default=None`).\n", "\n", "The returned [`UploadResult`](http://citrineinformatics.github.io/python-citrination-client/modules/data/data_client.html#citrination_client.data.upload_result.UploadResult) object tracks the number of successful and failed uploads. You can also use the function `get_ingest_status()` to check the status of ingest.\n", "\n", @@ -214,22 +222,32 @@ "outputs": [], "source": [ "# Upload a single file\n", - "upload_result = data_client.upload(dataset_id=dataset_id, source_path='test_pif.json')\n", + "upload_result = data_client.upload(dataset_id=dataset_id, \n", + " source_path='test_pif.json')\n", "print('Successful upload? {}'.format(upload_result.successful())) # Boolean; True if none fail\n", "\n", - "# Upload a directory; each file is recursively added and has the folder name as prefix\n", - "upload_result = data_client.upload(dataset_id=dataset_id, source_path='test_pif_dir/')\n", + "# Upload a directory; each file is recursively added and has the folder name as a prefix\n", + "upload_result = data_client.upload(dataset_id=dataset_id, \n", + " source_path='test_pif_dir/')\n", "print('Number of successful uploads: {}'.format(len(upload_result.successes))) # list of successful files\n", "\n", "# Check ingest status with loop\n", - "while (True):\n", - " ingest_status = data_client.get_ingest_status(dataset_id)\n", + "while True:\n", + " ingest_status = data_client.get_ingest_status(dataset_id=dataset_id)\n", " if (ingest_status == 'Finished'):\n", " print('Ingestion complete!')\n", + " print('Dataset URL: {}/datasets/{}'.format(site, dataset_id))\n", " break\n", " else:\n", " print('Waiting for data ingest...')\n", - " time.sleep(10)" + " sleep(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Verify**: If you go back to the dataset in the UI and refresh the page, you should find it populated with PIF records!" ] }, { @@ -238,10 +256,10 @@ "source": [ "### Retrieving data: File download URLs\n", "The more common way to retrieve data from datasets on Citrination is to request download URLs. The `get_dataset_files()` function can be used to get a list of [`DatasetFile`](http://citrineinformatics.github.io/python-citrination-client/modules/data/datasets.html#citrination_client.data.dataset_file.DatasetFile) objects from a dataset. The method has the following inputs:\n", - "* **dataset_id**: The integer value of the ID of the dataset that you're retrieving data from.\n", - "* **glob**: A [regex](https://ryanstutorials.net/regular-expressions-tutorial/) used to select one or more files in the dataset (`default='.'`).\n", - "* **is_dir**: A Boolean indicating whether or not the supplied pattern should be treated as a directory to search in (`default=False`).\n", - "* **version_number**: The integer value of the version number of the dataset to retrieve files from (`default=None`)." + "* `dataset_id`: The integer value of the ID of the dataset that you're retrieving data from.\n", + "* `glob`: A [regex](https://ryanstutorials.net/regular-expressions-tutorial/) used to select one or more files in the dataset (`default='.'`).\n", + "* `is_dir`: A Boolean indicating whether or not the supplied pattern should be treated as a directory to search in (`default=False`).\n", + "* `version_number`: The integer value of the version number of the dataset to retrieve files from (`default=None`)." ] }, { @@ -250,9 +268,12 @@ "metadata": {}, "outputs": [], "source": [ - "regex = 'pif' # matches files with 'pif' in the name\n", - "dataset_files = data_client.get_dataset_files(dataset_id, glob=regex)\n", - "print('The regex \\'{}\\' matched {} files in dataset {}.'.format(regex, len(dataset_files), dataset_id))" + "regex = 'pif' # matches files with 'pif' in the name\n", + "dataset_files = data_client.get_dataset_files(dataset_id=dataset_id, \n", + " glob=regex)\n", + "print('The regex \\'{}\\' matched {} files in dataset {}.'.format(regex, \n", + " len(dataset_files), \n", + " dataset_id))" ] }, { @@ -260,8 +281,8 @@ "metadata": {}, "source": [ "[`DatasetFile`](http://citrineinformatics.github.io/python-citrination-client/modules/data/datasets.html#citrination_client.data.dataset_file.DatasetFile) objects have `path` and `url` attributes that can then be accessed. There is also a `download_files()` method with the following parameters:\n", - "* **dataset_files**: A list of `DatasetFile` objects.\n", - "* **destination**: The path to the desired local download destination (`default='.'`)." + "* `dataset_files`: A list of `DatasetFile` objects.\n", + "* `destination`: The path to the desired local download destination (`default='.'`)." ] }, { @@ -273,7 +294,8 @@ "print('The first file in the dataset is \"{}\"'.format(dataset_files[0].path))\n", "\n", "# Download all files, preserving the same file organization\n", - "data_client.download_files(dataset_files, destination='./downloads/')" + "data_client.download_files(dataset_files=dataset_files, \n", + " destination='./downloads/')" ] }, { @@ -281,10 +303,10 @@ "metadata": {}, "source": [ "### Retrieving data: PIF retrieval\n", - "Another way to retrieve data is to request the contents of a single PIF record in JSON format. The `get_pif()` method takes in the following parameters and returns a [pypif](https://github.com/CitrineInformatics/pypif) [PIF](http://citrineinformatics.github.io/pif-documentation/schema_definition/index.html) object.\n", - "* **dataset_id**: The integer value of the ID of the dataset that you're retrieving data from.\n", - "* **uid**: A string representing the uid of the PIF to retrieve.\n", - "* **dataset_version**: The integer value of the version number of the dataset to retrieve files from (`default=None`).\n", + "Another way to retrieve data is to request the contents of a single PIF record in JSON format. The `get_pif()` method takes in the following parameters and returns a [pypif](https://github.com/CitrineInformatics/pypif) [`pif`](http://citrineinformatics.github.io/pif-documentation/schema_definition/index.html) object.\n", + "* `dataset_id`: The integer value of the ID of the dataset that you're retrieving data from.\n", + "* `uid`: A string representing the uid of the PIF to retrieve.\n", + "* `dataset_version`: The integer value of the version number of the dataset to retrieve files from (`default=None`).\n", "\n", "*Note*: Because the `uid` is only revealed through the web UI and [`SearchClient`](http://citrineinformatics.github.io/python-citrination-client/tutorial/search_examples.html), `get_pif()` is not commonly used when working solely with the `DataClient`." ] @@ -295,8 +317,9 @@ "metadata": {}, "outputs": [], "source": [ - "pif_uid = 'test_uid' # this was set in the PIF\n", - "my_pif = data_client.get_pif(dataset_id, pif_uid)\n", + "pif_uid = 'test_uid' # this UID was set in the PIF\n", + "my_pif = data_client.get_pif(dataset_id=dataset_id, \n", + " uid=pif_uid)\n", "print('The chemical formula of this PIF is {}.'.format(my_pif.chemical_formula))" ] }, @@ -306,10 +329,10 @@ "source": [ "### Modify a dataset\n", "You can easily modify datasets on Citrination with the `update_dataset()` function. It takes as inputs:\n", - "* **dataset_id**: The integer value of the ID of the dataset that you're retrieving data from.\n", - "* **name**: A string for the new name of the dataset (`default=None`).\n", - "* **description**: A string for the new description of the dataset (`default=None`).\n", - "* **public**: A Boolean indicating whether the dataset should be public (`default=None`)." + "* `dataset_id`: The integer value of the ID of the dataset that you're retrieving data from.\n", + "* `name`: A string for the new name of the dataset (`default=None`).\n", + "* `description`: A string for the new description of the dataset (`default=None`).\n", + "* `public`: A Boolean indicating whether the dataset should be public (`default=None`)." ] }, { @@ -318,10 +341,12 @@ "metadata": {}, "outputs": [], "source": [ - "new_name = 'PyCC Dataset New Name ' + str(uuid.uuid4())[:6]\n", + "new_name = 'PyCC Dataset New Name ' + str(uuid4())[:6]\n", "public_flag = False\n", - "new_dataset = data_client.update_dataset(dataset_id, name=new_name, public=public_flag)\n", - "print('Dataset {} is now named \"{}.'.format(dataset_id, new_dataset.name))" + "new_dataset = data_client.update_dataset(dataset_id=dataset_id, \n", + " name=new_name, \n", + " public=public_flag)\n", + "print('Dataset {} is now named \"{}.\"'.format(dataset_id, new_dataset.name))" ] }, { @@ -337,14 +362,15 @@ "metadata": {}, "outputs": [], "source": [ - "print('Files list: {0}.'.format(data_client.list_files(dataset_id, glob='.')))" + "print('Files list: {0}.'.format(data_client.list_files(dataset_id=dataset_id, \n", + " glob='.')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The `create_dataset_version()` function creates a new version of a data set. Note that creating a new version deletes all records from the old version, so handle with care!" + "The `create_dataset_version()` method of the `DataClient` creates a new version of a data set. Note that creating a new version deletes all records from the old version, so handle with care!" ] }, { @@ -353,10 +379,28 @@ "metadata": {}, "outputs": [], "source": [ - "dataset_version = data_client.create_dataset_version(dataset_id)\n", + "dataset_version = data_client.create_dataset_version(dataset_id=dataset_id)\n", "print('Dataset {} is now version {}.'.format(dataset_id, dataset_version.number))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Delete a dataset\n", + "\n", + "Finally, if you wish to delete a dataset that you own, you can use the `delete_dataset()` method of the `DataClient`. As this is a permanent deletion, please handle with care!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# data_client.delete_dataset(dataset_id=dataset_id)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -370,7 +414,8 @@ "* How to create a new dataset.\n", "* How to upload data to the dataset.\n", "* How to retrieve data from the dataset.\n", - "* How to modify the properties of the dataset." + "* How to modify the properties of the dataset.\n", + "* How to delete a dataset." ] }, { diff --git a/citrination_api_examples/clients_sequence/2_data_views_client_api_tutorial.ipynb b/citrination_api_examples/clients_sequence/2_data_views_client_api_tutorial.ipynb index d8773ce..5037593 100644 --- a/citrination_api_examples/clients_sequence/2_data_views_client_api_tutorial.ipynb +++ b/citrination_api_examples/clients_sequence/2_data_views_client_api_tutorial.ipynb @@ -15,7 +15,7 @@ "\n", "*Authors: Enze Chen, Eric Lundberg*\n", "\n", - "In this notebook, we will cover how to *create* a data view using the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/). Data views provide the configuration necessary in order to perform machine learning and identify relationships in your data. We will demonstrate this functionality using the [Band gaps from Strehlow and Cook](https://citrination.com/datasets/1160/show_search?searchMatchOption=fuzzyMatch) dataset, where we will create a view mapping: \n", + "In this notebook, we will cover how to *create* a data view using the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/). Data views provide the configuration necessary in order to perform machine learning and data analysis. We will demonstrate this functionality using the [Band gaps from Strehlow and Cook](https://citrination.com/datasets/1160/show_search?searchMatchOption=fuzzyMatch) dataset, where we will create a view mapping: \n", "\n", "$$\\text{Chemical formula (inorganic) + Crystallinity (categorical)} \\longrightarrow \\boxed{\\text{ML model}} \\longrightarrow \\text{Band gap (real)}$$" ] @@ -79,10 +79,9 @@ "outputs": [], "source": [ "# Standard packages\n", - "import json\n", - "import os\n", - "import time\n", - "import uuid # generating random IDs\n", + "from os import environ # get environment variables\n", + "from time import sleep # wait time\n", + "from uuid import uuid4 # generating random IDs\n", "\n", "# Third-party packages\n", "from citrination_client import *\n", @@ -97,12 +96,18 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "The [`DataViewBuilder`](http://citrineinformatics.github.io/python-citrination-client/modules/views/ml_config_builder.html) class handles the configuration for data views and returns a **configuration** object that is an input for the `DataViewsClient`. The configuration specifies the datasets, model, and descriptors. Some of the important parameters to note are:\n", - "* **dataset_ids**: An array of strings, one for each dataset ID that should be included in the view.\n", - "* **descriptors**: A descriptor instance, which could be `{RealDescriptor, InorganicDescriptor, OrganicDescriptor, CategoricalDescriptor,` or `AlloyCompositionDescriptor}`.\n", - " * **Note 1**: Chemical formulas for the API take the key `formula`.\n", - " * **Note 2**: Properties take the key `Property `.\n", - "* **roles**: A role for each descriptor, as a string, which could be `{input, output, latentVariable, ignored}`." + "The [`DataViewBuilder`](http://citrineinformatics.github.io/python-citrination-client/modules/views/ml_config_builder.html) class handles the configuration for data views and returns a **configuration** object that is an input for the `DataViewsClient`. The configuration specifies:\n", + "* The datasets you want to include.\n", + "* The ML model you want to use.\n", + "* Which properties you want to use as descriptors. \n", + "\n", + "Some of the important parameters to note are:\n", + "* `dataset_ids`: An array of strings, one for each dataset ID that should be included in the view.\n", + "* `descriptors`: A descriptor instance, which is one of `{RealDescriptor, InorganicDescriptor, OrganicDescriptor, CategoricalDescriptor,` or `AlloyCompositionDescriptor}`.\n", + " * *Note 1*: Chemical formulas for the API take the key `\"formula\"`.\n", + " * *Note 2*: Properties take the key `\"Property [property name]\"`.\n", + " * *Note 3*: Strings are **Case-sensitive!**\n", + "* `roles`: A role for each descriptor, as a string, which is one of `{'input', 'output', 'latentVariable',` or `'ignored'}`." ] }, { @@ -115,16 +120,26 @@ "dv_builder = DataViewBuilder()\n", "dv_builder.dataset_ids(['172242']) # ID number for band gaps dataset\n", "\n", - "# Define descriptors\n", + "# Define crystallinity descriptor\n", "crystallinity = ['Single crystalline', 'Polycrystalline', 'Amorphous'] # Obtained from dataset\n", - "desc_crystal = CategoricalDescriptor(key='Property Crystallinity', categories=crystallinity)\n", - "dv_builder.add_descriptor(descriptor=desc_crystal, role='input')\n", + "desc_crystal = CategoricalDescriptor(key='Property Crystallinity', \n", + " categories=crystallinity)\n", + "dv_builder.add_descriptor(descriptor=desc_crystal, \n", + " role='input')\n", "\n", - "desc_formula = InorganicDescriptor(key='formula', threshold=1.0) # threshold <= 1.0; default in future releases\n", - "dv_builder.add_descriptor(descriptor=desc_formula, role='input')\n", + "# Define chemical formula descriptor\n", + "desc_formula = InorganicDescriptor(key='formula', \n", + " threshold=1.0)\n", + "dv_builder.add_descriptor(descriptor=desc_formula, \n", + " role='input')\n", "\n", - "desc_bandgap = RealDescriptor(key='Property Band gap', lower_bound=0.0, upper_bound=1e2, units='eV')\n", - "dv_builder.add_descriptor(descriptor=desc_bandgap, role='output')\n", + "# Define band gap descriptor\n", + "desc_bandgap = RealDescriptor(key='Property Band gap', \n", + " lower_bound=0.0, \n", + " upper_bound=1e3, \n", + " units='eV')\n", + "dv_builder.add_descriptor(descriptor=desc_bandgap, \n", + " role='output')\n", "\n", "# Build the configuration once all the pieces are in place\n", "view_config = dv_builder.build()" @@ -138,7 +153,7 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "After obtaining your customized configuration, you have to initialize a [`DataViewsClient`](http://citrineinformatics.github.io/python-citrination-client/modules/views/data_views_client.html) instance in order to create a data view from the configuration you built. The `create()` method returns the ID for the data view, which you will need for subsequent analysis and retraining." + "After obtaining your customized configuration, you have to initialize a [`DataViewsClient`](http://citrineinformatics.github.io/python-citrination-client/modules/views/data_views_client.html) instance in order to create a data view from the configuration you built." ] }, { @@ -148,14 +163,27 @@ "outputs": [], "source": [ "# Instantiate the base CitrinationClient\n", - "site = 'https://citrination.com' # site you want to access; we'll use the public site\n", - "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'), site=site)\n", + "site = 'https://citrination.com' # site you want to access; we'll use the public site\n", + "client = CitrinationClient(api_key=environ.get('CITRINATION_API_KEY'), \n", + " site=site)\n", "\n", "# Instantiate the DataViewsClient\n", "views_client = client.data_views\n", "views_client # reveal the methods" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `create()` method for the `DataViewClient` takes as input:\n", + "* `configuration`: A view configuration, like the template you created above.\n", + "* `name`: A name for the data view (must be unique among your data views).\n", + "* `description`: A description for the data view.\n", + "\n", + "and returns the ID for the data view, which you will need for subsequent analysis and retraining." + ] + }, { "cell_type": "code", "execution_count": null, @@ -163,11 +191,20 @@ "outputs": [], "source": [ "# Create a data view using the above configuration and store the ID\n", - "view_name = 'PyCC View ' + str(uuid.uuid4()) # random name to avoid clashes\n", + "view_name = 'PyCC View ' + str(uuid4())[:6] # random name to avoid clashes\n", "view_desc = 'This view was created by the PyCC API tutorial.'\n", - "view_id = views_client.create(configuration=view_config, name=view_name, description=view_desc)\n", + "view_id = views_client.create(configuration=view_config, \n", + " name=view_name, \n", + " description=view_desc)\n", "print('Data view {} was successfully created.'.format(view_id))\n", - "print('It can be accessed at {}/data_views/{}.'.format(site, view_id))" + "print('It can be accessed at {}/data_views/{}'.format(site, view_id))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clicking the above URL will take you to the data view you just created on your deployment of Citrination." ] }, { @@ -198,7 +235,7 @@ "metadata": {}, "source": [ "### Check status of services\n", - "If there's a lot of data, training might take some time, and you might want to check when `predict` services are ready. Other possible services include `experimental_design`, `data_reports`, and `model_reports`." + "If there's a lot of data, training might take some time, and you might want to check when certain services are ready. Possible services enabled by data views include `predict`, `experimental_design`, `data_reports`, and `model_reports`." ] }, { @@ -207,13 +244,19 @@ "metadata": {}, "outputs": [], "source": [ - "# Use a loop to monitor status\n", + "# Use a loop to monitor view status\n", "while True:\n", - " predict_state = views_client.get_data_view_service_status(view_id).predict.reason\n", - " print(predict_state)\n", - " if predict_state == 'Predict services are ready.':\n", + " view_status = views_client.get_data_view_service_status(data_view_id=view_id)\n", + " \n", + " # Design and Predict are most important endpoints to check\n", + " if (view_status.experimental_design.ready and\n", + " view_status.predict.event.normalized_progress == 1.0):\n", + " print(\"Data view ready!\")\n", + " print(\"Data view URL: {}/data_views/{}\".format(site, view_id))\n", " break\n", - " time.sleep(10)" + " else:\n", + " print(\"Waiting for data view services...\")\n", + " sleep(10)" ] }, { @@ -271,7 +314,7 @@ "To recap, this notebook went through the steps for creating a data view using the API.\n", "1. First, we used the `DataViewBuilder` object to specify the configuration.\n", "2. Then, we trained the model, which is simple as long as the configuration is correct.\n", - "3. Lastly, we explored some of the post-processing capabilities, such as retraining and submitting predictions." + "3. We showed how to monitor the status of various endpoints enabled by data views." ] }, { diff --git a/citrination_api_examples/clients_sequence/3_models_client_api_tutorial.ipynb b/citrination_api_examples/clients_sequence/3_models_client_api_tutorial.ipynb index 0291c76..cdcc50d 100644 --- a/citrination_api_examples/clients_sequence/3_models_client_api_tutorial.ipynb +++ b/citrination_api_examples/clients_sequence/3_models_client_api_tutorial.ipynb @@ -15,8 +15,6 @@ "\n", "*Authors: Enze Chen, Eddie Kim*\n", "\n", - "**Note**: The [`ModelsClient`](http://citrineinformatics.github.io/python-citrination-client/modules/models/models_client.html) is now linked as an attribute of the [`DataViewsClient`](http://citrineinformatics.github.io/python-citrination-client/modules/views/data_views_client.html). Since this sub-client has many capabilties, this tutorial will still exist as a stanadalone reference.\n", - "\n", "In this notebook, we will cover how to use the `ModelsClient` to interface with *existing* data views and ML models through the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/). We will demonstrate how to analyze ML models and leverage them for prediction and design using the [Band gaps from Strehlow and Cook](https://citrination.com/datasets/1160/show_search?searchMatchOption=fuzzyMatch) dataset, where we will have created a model mapping:\n", "\n", "$$\\text{Chemical formula (inorganic) + Crystallinity (categorical)} \\longrightarrow \\boxed{\\text{ML model}} \\longrightarrow \\text{Band gap (real)}$$" @@ -82,15 +80,16 @@ "outputs": [], "source": [ "# Standard packages\n", - "import os\n", - "import time\n", - "import uuid # generating random IDs\n", + "from os import environ # get environment variables\n", + "from time import sleep # wait time\n", + "from uuid import uuid4 # generating random IDs\n", "\n", "# Third-party packages\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", + "# Citrine packages\n", "from citrination_client import *\n", "from citrination_client.models.design import Target" ] @@ -103,13 +102,7 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "We will start by initializing the `ModelsClient` from the `CitrinationClient` and look at some basic properties of the view using `get_data_view()`. The returned `DataView` object has the following properties:\n", - "* `id`: The view ID.\n", - "* `name`: The name of the view.\n", - "* `description`: The description of the view.\n", - "* `datasets`: A list of datasets used in the view.\n", - "* `column_names`: A list of column names in the view.\n", - "* `columns`: A list of columns in the view (objects extend [`BaseColumn`](https://github.com/CitrineInformatics/python-citrination-client/tree/master/citrination_client/models/columns))." + "We will start by initializing the `ModelsClient` from the `CitrinationClient`." ] }, { @@ -119,14 +112,28 @@ "outputs": [], "source": [ "# Instantiate the base CitrinationClient\n", - "site = 'https://citrination.com' # site you want to access; we'll use the public site\n", - "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'), site=site)\n", + "site = 'https://citrination.com' # site you want to access; we'll use the public site\n", + "client = CitrinationClient(api_key=environ.get('CITRINATION_API_KEY'), \n", + " site=site)\n", "\n", "# Instantiate the ModelsClient\n", "models_client = client.models\n", "models_client # reveal some methods" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can look at some basic properties of the view using the `get_data_view()` method. The returned `DataView` object has the following properties:\n", + "* `id`: The view ID.\n", + "* `name`: The name of the view.\n", + "* `description`: The description of the view.\n", + "* `datasets`: A list of datasets used in the view.\n", + "* `column_names`: A list of column names in the view.\n", + "* `columns`: A list of columns in the view (objects extend [`BaseColumn`](https://github.com/CitrineInformatics/python-citrination-client/tree/master/citrination_client/models/columns))." + ] + }, { "cell_type": "code", "execution_count": null, @@ -134,12 +141,12 @@ "outputs": [], "source": [ "# Look up the data view ID and analyze the view\n", - "view_id = 7753 # Band gaps model with only 100 data points for faster demonstration\n", - "data_view = models_client.get_data_view(view_id)\n", + "view_id = 7753 # Band gaps model with only 100 data points for faster demonstration\n", + "data_view = models_client.get_data_view(data_view_id=view_id)\n", "print('Data view name: {}.'.format(data_view.name))\n", "print('Data view description: {}'.format(data_view.description))\n", "print('Names of included datasets: {}.'.format([data_view.datasets[i].name for i in range(len(data_view.datasets))]))\n", - "print('Data view URL: {}/data_views/{}.'.format(site, view_id))" + "print('Data view URL: {}/data_views/{}'.format(site, view_id))" ] }, { @@ -165,7 +172,7 @@ "metadata": {}, "source": [ "### Check status of services\n", - "You can check on the various services in your view, which includes `predict`, `experimental_design`, `data_reports`, `model_reports`, using `get_data_view_service_status()`. A `ServiceStatus` object has the following properties:\n", + "You can check on the various services in your view, which includes `predict`, `experimental_design`, `data_reports`, `model_reports`, using `get_data_view_service_status()`. A [`ServiceStatus`](https://github.com/CitrineInformatics/python-citrination-client/blob/master/citrination_client/models/service_status.py) object has the following properties:\n", "* `ready`: A Boolean indicating whether or not the service can be used.\n", "* `context`: A contextual description of the current status: `notice`, `success`, `error`.\n", "* `reason`: A full sentence explanation of the service's status.\n", @@ -179,13 +186,13 @@ "outputs": [], "source": [ "# Check status of services in a loop\n", - "time.sleep(5)\n", - "while (True):\n", - " view_status = models_client.get_data_view_service_status(view_id)\n", + "sleep(5)\n", + "while True:\n", + " view_status = models_client.get_data_view_service_status(data_view_id=view_id)\n", " model_report_progress = view_status.model_reports.event.normalized_progress\n", " print('Model reports are still being generated, progress: {0:.1f}%.'.format(100 * model_report_progress))\n", " if (model_report_progress < 0.99):\n", - " time.sleep(15)\n", + " sleep(15)\n", " else:\n", " print('Model reports generated!')\n", " break" @@ -218,13 +225,13 @@ "outputs": [], "source": [ "# Get the Tsne object\n", - "tsne = models_client.tsne(view_id)\n", + "tsne = models_client.tsne(data_view_id=view_id)\n", "\n", "# Get first output Property in dict_keys object\n", "projection_key = list(tsne.projections())[0]\n", "\n", "# Get the t-SNE projection from the key\n", - "projection = tsne.get_projection(projection_key)\n", + "projection = tsne.get_projection(key=projection_key)\n", "max_index, max_value = (np.argmax(projection.responses), max(projection.responses))\n", "print('Highest band gap material: \\t{0}.'.format(projection.tags[max_index]))\n", "print('It has projected coordinates: \\t({0:.3f}, {1:.3f}).'.format(\n", @@ -267,11 +274,12 @@ "candidates = [{'formula':'MgO'}, {'formula':'GaN'}]\n", "\n", "# Predict endpoint\n", - "prediction_results = models_client.predict(view_id, candidates)\n", + "prediction_results = models_client.predict(data_view_id=view_id, \n", + " candidates=candidates)\n", "target_prop = projection_key\n", "\n", "# Get predicted value for first candidate\n", - "prediction_value = prediction_results[0].get_value(target_prop)\n", + "prediction_value = prediction_results[0].get_value(key=target_prop)\n", "print('{0} has a predicted {1} value of {2:.3f} +/- {3:.3f}.'.format(\n", " prediction_results[0].get_value('formula').value,\n", " prediction_value.key,\n", @@ -287,10 +295,12 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", + "*Note*: In order to submit design runs on Public Citrination, you will need an Admin account.\n", + "\n", "Once ML models have been trained, you can generate a list of candidate materials designed to achieve your target objectives. We can submit a new experimental design run using `submit_design_run()`, which takes as inputs:\n", "* `data_view_id`: The view ID.\n", "* `num_candidates`: The number of candidates to return.\n", - "* `effort`: A value $\\le 30$ indicating how much resource (time) to allocate towards design.\n", + "* `effort`: A value $\\le 30$ indicating how much resource to allocate towards design.\n", "* `target`: A [`Target`](https://github.com/CitrineInformatics/python-citrination-client/blob/master/citrination_client/models/design/target.py) instance, which consists of the name of the output column and the objective (`Max` or `Min`).\n", "* `constraints`: A list of [design constraints](https://github.com/CitrineInformatics/python-citrination-client/tree/master/citrination_client/models/design/constraints) that extend the [`BaseConstraint`](https://github.com/CitrineInformatics/python-citrination-client/blob/master/citrination_client/models/design/constraints/base.py) class.\n", "* `sampler`: The name of the sampler to use as a string, either `Default` or `This view`.\n", @@ -304,7 +314,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Submit the design run and obtain design run uuid\n", + "# Submit the design run and obtain a design run uuid\n", "design_run = models_client.submit_design_run(\n", " data_view_id=view_id,\n", " num_candidates=10,\n", @@ -328,7 +338,7 @@ "* `status`: The status string of the process, which can be `Accepted`, `Finished`, or `Killed`.\n", "* `messages`: A list of messages representing the steps the process has already progressed through.\n", "\n", - "If a design run is taking too long, you can end it with `kill_design_run()`." + "If a design run is taking too long, you can end it with the `kill_design_run()` method." ] }, { @@ -338,19 +348,20 @@ "outputs": [], "source": [ "# Check status of design in a loop\n", - "design_running = True\n", - "while (design_running):\n", - " process_status = models_client.get_design_run_status(view_id, design_id)\n", + "while True:\n", + " process_status = models_client.get_design_run_status(data_view_id=view_id, \n", + " run_uuid=design_id)\n", " design_status = process_status.status\n", " design_progress = process_status.progress\n", " print('Design is running, progress: {0:.1f}%.'.format(design_progress))\n", " if (design_status != 'Finished'):\n", - " time.sleep(15)\n", + " sleep(15)\n", " else:\n", " print('Design complete!')\n", - " design_running = False\n", + " break\n", " \n", - "# models_client.kill_design_run(view_id, design_id)" + "# models_client.kill_design_run(data_view_id=view_id, \n", + "# run_uuid=design_id)" ] }, { @@ -371,7 +382,8 @@ "metadata": {}, "outputs": [], "source": [ - "design_results = models_client.get_design_run_results(view_id, design_id)\n", + "design_results = models_client.get_design_run_results(data_view_id=view_id, \n", + " run_uuid=design_id)\n", "best_material = design_results.best_materials[0]\n", "print('The best material is {0} with a predicted target value of {1}.'.format(\n", " best_material['descriptor_values']['formula'], \n", @@ -388,7 +400,8 @@ "\n", "To recap, this notebook demonstrated the functionalities enabled by the `ModelsClient`, which means you can use the API to:\n", "* Interface with an existing data view that already has ML configured.\n", - "* Query t-SNE and Predict endpoints for data visualization and making predictions on new materials.\n", + "* Query the t-SNE endpoint for data visualization.\n", + "* Make predictions on new materials.\n", "* Submit design runs and generate optimized material candidates." ] }, diff --git a/citrination_api_examples/clients_sequence/4_search_client_api_tutorial.ipynb b/citrination_api_examples/clients_sequence/4_search_client_api_tutorial.ipynb index 63025e9..8ae4e3e 100644 --- a/citrination_api_examples/clients_sequence/4_search_client_api_tutorial.ipynb +++ b/citrination_api_examples/clients_sequence/4_search_client_api_tutorial.ipynb @@ -105,9 +105,10 @@ "source": [ "# Initialize the base CitrinationClient\n", "site = \"https://citrination.com\" # site you want to access; we'll use the public site\n", - "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'), site=site)\n", + "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'), \n", + " site=site)\n", "\n", - "# Access the SearchClient as an attribute\n", + "# Access the SearchClient from the attribute\n", "search_client = client.search\n", "search_client # reveal the methods" ] @@ -129,12 +130,12 @@ "\n", "Before we discuss the specifics of each method, we'll provide a high-level discussion about the structure of [`Query`](https://github.com/CitrineInformatics/python-citrination-client/tree/64aab061500811fae4767491e5b069bb4a4af068/citrination_client/search/core/query) objects. There are two generic types of queries used by the `SearchClient`:\n", "\n", - "1. `ReturningQuery` objects that actually returns specific objects (e.g. PIFs, datasets).\n", + "1. `ReturningQuery` objects that actually return specific objects with data (e.g. PIFs, datasets).\n", " * These are inputs to the search methods listed above.\n", "\n", "\n", - "1. Other `Query` objects that just match for specific fields (e.g. datasets, formulas).\n", - " * There is approximately a `Query` object for each PIF object ([see here](http://citrineinformatics.github.io/python-citrination-client/modules/search/pif_query_core.html))." + "2. Other `Query` objects that just match for specific fields (e.g. datasets, formulas).\n", + " * Roughly speaking, there is a `Query` object corresponding to each PIF object ([see here](http://citrineinformatics.github.io/python-citrination-client/modules/search/pif_query_core.html))." ] }, { @@ -170,7 +171,9 @@ "source": [ "### `extract_as`\n", "\n", - "`extract_as` is a powerful keyword that facilitates the aggregation of data from multiple sources. It takes a `string` with the alias to save a field under, and is useful when different datasets use slightly different names to describe the same Property. It will return the PIF records and relevant field all under the same `extract_as` name. [See here](../tutorial_sequence/3_IntroQueries.ipynb) for an example and discussion." + "`extract_as` is a powerful keyword that facilitates the aggregation of data from multiple sources. It takes a `string` with the alias to save a field under, and is useful when different datasets use slightly different names to describe the same Property. \n", + "\n", + "It will return the PIF records and relevant field all under the same `extract_as` name. This flattens the data from the hierarchical PIF format to facilitate analysis. [See here](../tutorial_sequence/3_IntroQueries.ipynb) for an example and discussion." ] }, { @@ -199,11 +202,12 @@ "print(\"The dataset URL is: {}/datasets/{}\".format(site, dataset_id))\n", "\n", "system_query = PifSystemReturningQuery(\n", - " size=5,\n", + " size=500, # Returns the total number of matching hits without retrieving any data.\n", " query=DataQuery(\n", " dataset=DatasetQuery(\n", " id=Filter(\n", " equal=str(dataset_id)))))\n", + "\n", "search_result = search_client.pif_search(system_query)\n", "print(\"Found {} PIFs in dataset {}.\".format(search_result.total_num_hits, dataset_id))" ] @@ -225,6 +229,73 @@ "print(pif.dumps(search_result.hits[0].system, indent=4))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example: Filter a range of values\n", + "Whereas the `.system` attribute above returned the entire PIF, we can use the `.extracted` attribute to return only the fields of interest specified in the query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "system_query = PifSystemReturningQuery(\n", + " size=500,\n", + " query=DataQuery(\n", + " dataset=DatasetQuery(id=Filter(equal=dataset_id)),\n", + " system=PifSystemQuery(\n", + " chemical_formula=ChemicalFieldQuery(\n", + " extract_as='Chemical formula',\n", + " filter=ChemicalFilter(equal='?x?y')),\n", + " properties=PropertyQuery(\n", + " name=FieldQuery(\n", + " filter=Filter(equal='Band gap')),\n", + " value=FieldQuery(\n", + " filter=Filter(min=3.0, max=6.0),\n", + " extract_as='Band gap')))))\n", + " \n", + "\n", + "search_result = search_client.pif_search(system_query)\n", + "print(\"Found {} PIFs in dataset {}.\".format(search_result.total_num_hits, dataset_id))\n", + "print([x.extracted for x in search_result.hits][:2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example: Logic\n", + "We can search for materials that `SHOULD` be oxides but `MUST NOT` have only 1 oxygen atom." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_size = 5\n", + "query_logical = PifSystemReturningQuery(\n", + " size=query_size,\n", + " query=DataQuery(\n", + " dataset=DatasetQuery(\n", + " id=Filter(equal=str(dataset_id))),\n", + " system=PifSystemQuery(\n", + " chemical_formula=ChemicalFieldQuery(\n", + " extract_as='formula',\n", + " filter=[ChemicalFilter(equal='?xOy', logic=\"SHOULD\"),\n", + " ChemicalFilter(equal='?xO1', logic=\"MUST_NOT\")]))))\n", + "\n", + "search_result = search_client.pif_search(query_logical)\n", + "print(\"{} total hits, the first {} of which are:\".format(search_result.total_num_hits, query_size))\n", + "for i in range(query_size):\n", + " print(pif.dumps(search_result.hits[i].extracted))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -255,6 +326,7 @@ " chemical_formula=ChemicalFieldQuery(\n", " filter=ChemicalFilter(\n", " equal='As2S3')))))\n", + "\n", "search_result = search_client.dataset_search(dataset_query)\n", "print('{} datasets matched this query.'.format(search_result.total_num_hits))" ] @@ -274,7 +346,7 @@ "source": [ "first = search_result.hits[0]\n", "print('A matching dataset is \"{}\" with ID {}.\\nIt was made by {} at {}.'.format(\n", - " first.name, first.id, first.owner, first.updated_at, first.num_pifs))" + " first.name, first.id, first.owner, first.updated_at))" ] }, { @@ -328,9 +400,9 @@ "[Back to ToC](#Table-of-contents)\n", "\n", "Some other topics that might interest you include:\n", + "* Other examples on [learn-citrination](https://github.com/CitrineInformatics/learn-citrination), including [Intro](../tutorial_sequence/3_IntroQueries.ipynb) and [Advanced](../tutorial_sequence/AdvancedQueries.ipynb) queries.\n", "* [DataClient](http://citrineinformatics.github.io/python-citrination-client/tutorial/data_examples.html) - This allows you to create datasets and upload PIF data (only) using the API.\n", - " * There is also a corresponding [tutorial](1_data_client_api_tutorial.ipynb).\n", - "* Other examples on [learn-citrination](https://github.com/CitrineInformatics/learn-citrination), including [Intro](../tutorial_sequence/3_IntroQueries.ipynb) and [Advanced](../tutorial_sequence/AdvancedQueries.ipynb) queries." + " * There is also a corresponding [tutorial](1_data_client_api_tutorial.ipynb)." ] } ], diff --git a/citrination_api_examples/clients_sequence/5_sequential_learning_api_tutorial.ipynb b/citrination_api_examples/clients_sequence/5_sequential_learning_api_tutorial.ipynb index 6bb3f1f..072e107 100644 --- a/citrination_api_examples/clients_sequence/5_sequential_learning_api_tutorial.ipynb +++ b/citrination_api_examples/clients_sequence/5_sequential_learning_api_tutorial.ipynb @@ -17,6 +17,10 @@ "\n", "In this notebook, we will cover how to perform **sequential learning** (SL) using the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/). [Sequential learning](https://citrine.io/platform/sequential-learning/) is the key workflow which allows machine learning algorithms and in-lab experiments to iteratively inform each other.\n", "\n", + "**Objective**: We want to optimize the radius of CdSe nanoparticles to achieve a target bandgap.\n", + "\n", + "![Band gap graphic](https://raw.githubusercontent.com/CitrineInformatics/community-tools/master/templates/fig/bandgap_graphic.png)\n", + "\n", "To replace the need for an actual laboratory, this notebook uses a simple *toy function* that allows for \"measurements\" on the data.\n", "\n", "**NOTE**: If you want to run the sequential learning code in the final part of this tutorial on the public version of Citrination (https://citrination.com), then you will need an Admin account to run design jobs." @@ -81,17 +85,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Standard packages\n", - "import os\n", - "import uuid\n", + "from os import environ # get environment variables\n", + "from time import sleep # wait time\n", + "from uuid import uuid4 # generate random strings\n", "\n", "# Third-party packages\n", "from sequential_learning_wrappers import * # Helper functions to wrap several API endpoints together\n", - "%matplotlib inline" + "\n", + "# Magic settings\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -107,12 +116,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "site = \"https://citrination.com\" # site you want to access; we'll use the public site\n", - "client = CitrinationClient(api_key=os.environ.get(\"CITRINATION_API_KEY\"), site=site)" + "site = \"https://citrination.com\" # site you want to access; we'll use the public site\n", + "client = CitrinationClient(api_key=environ.get(\"CITRINATION_API_KEY\"), \n", + " site=site)" ] }, { @@ -123,38 +133,40 @@ "\n", "[Back to ToC](#Table-of-contents)\n", "\n", - "Since we aren't using a real laboratory, we need access to a quick way to generate \"correct\" measurements. A simple placeholder here is to use a function that sums the squares of its inputs. The goal, in this case, will be to find the global minimum located at the origin. \n", + "Since we aren't using a real laboratory, we need access to a quick way to generate \"correct\" measurements. \n", "\n", - "In a real example, we could minimize or maximize any output: compressive strengths, conductivities, and so on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### `toy_func()`\n", - "This function takes a list of values and sums the squares of the values." + "In this example, we are modelling the band gap of ellipsoidal CdSe nanoparticles, where the band gap can be tuned by nanoparticle size (due to quantum confinement)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "def toy_func(inputs):\n", - " return np.sum(np.square(inputs))" + "# A toy function to generate \"ground truth measurements\"\n", + "def bandgap_diff(radii, target=1.9, bulk_bg=1.76): # Close-to-actual values for CdSe nano band gaps\n", + " nano_bg = bulk_bg + sum(10/(r**2) for r in radii) # the last term here comes from particle-in-a-box-like energy\n", + " return abs(target - nano_bg) # get how close you are to the target" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best (lowest) value in initial training set: 0.1276917166095386\n" + ] + } + ], "source": [ - "# Generate random inputs and outputs\n", - "toy_x = [np.random.normal(loc=3.0, scale=1.0, size=(1, 2))[0] for x in range(20)]\n", - "toy_y = [toy_func(x) for x in toy_x]\n", + "# Generate random initial inputs and outputs\n", + "toy_x = [np.random.normal(loc=100.0, scale=30.0, size=(1, 2))[0] for x in range(16)]\n", + "toy_y = [bandgap_diff(x) for x in toy_x]\n", "\n", "initial_best_measured_value = min(toy_y)\n", "\n", @@ -165,72 +177,141 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "### Plot the data\n", "Now we can plot the initial training set, and color it by the function value." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "# Plot initial training set, colored by toy function values\n", "plt.rcParams.update({'figure.figsize':(8, 6), 'font.size':18, 'lines.markersize':8})\n", "plt.scatter(np.array(toy_x)[:,0], np.array(toy_x)[:,1],\n", " c=toy_y, cmap=plt.cm.plasma)\n", - "plt.colorbar(label='toy function value')\n", + "plt.colorbar(label='Band gap absolute difference (eV)')\n", "plt.xlabel(r'$x_1$')\n", "plt.ylabel(r'$x_2$')\n", - "plt.xlim(-5,5)\n", - "plt.ylim(-5,5)\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save the dataset to a PIF" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\"toy_initial_dataset_017021.json\" file successfully created.\n" + ] + } + ], "source": [ "# Write a PIF JSON dataset file\n", - "random_string = str(uuid.uuid4())[:6]\n", + "random_string = str(uuid4())[:6]\n", "output_file = 'toy_initial_dataset_{}.json'.format(random_string)\n", - "write_dataset_from_func(toy_func, output_file, toy_x)" + "write_dataset_from_func(test_function=bandgap_diff,\n", + " filename=output_file,\n", + " input_vals=toy_x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Upload the data to Citrination" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset created: 181513\n", + "Dataset URL: https://citrination.com/datasets/181513\n" + ] + } + ], "source": [ "# Make a dataset, upload to Citrination, return/print the ID\n", "dataset_name = output_file.split('.')[0]\n", "dataset_id = upload_data_and_get_id(\n", - " client,\n", - " dataset_name,\n", - " output_file,\n", - " create_new_version=True,\n", - ")\n", + " client=client,\n", + " dataset_name=dataset_name,\n", + " dataset_local_fpath=output_file,\n", + " create_new_version=True)\n", + "\n", "print(\"Dataset created: {}\".format(dataset_id))\n", "print(\"Dataset URL: {}/datasets/{}\".format(site, dataset_id))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Make a data view" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data view created: 10787\n", + "Data view URL: https://citrination.com/data_views/10787\n" + ] + } + ], "source": [ "# Make a data view on Citrination, return/print the ID\n", "view_name = 'toy_view_{}'.format(random_string)\n", - "view_id = build_view_and_get_id(client, dataset_id, \n", - " input_keys=[\"Property x1\", \"Property x2\"], output_keys=[\"Property y\"],\n", - " view_name=view_name, view_desc=\"toy test view\")\n", + "view_desc = 'toy test view'\n", + "input_keys = [\"Property x1\", \"Property x2\"]\n", + "target_property = 'Property Band gap difference'\n", + "\n", + "view_id = build_view_and_get_id(\n", + " client=client, \n", + " dataset_id=dataset_id, \n", + " input_keys=input_keys, \n", + " output_keys=[target_property],\n", + " view_name=view_name, \n", + " view_desc=view_desc)\n", "\n", "print(\"Data view created: {}\".format(view_id))\n", "print(\"Data view URL: {}/data_views/{}\".format(site, view_id))" @@ -245,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -277,11 +358,87 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "---STARTING SL ITERATION #1---\n", + "Design ready\n", + "Created design run with ID 4a5516d1-ec4c-4fc8-b73e-c288753f0029\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Finished\n", + "SL iter #1, best predicted (value, uncertainty) = ('0.1313', '0.0072')\n", + "\"design-4a5516d1-ec4c-4fc8-b73e-c288753f0029.json\" file successfully created.\n", + "Dataset updated: 10 candidates added.\n", + "New dataset contains 26 PIFs.\n", + "Design ready\n", + "\n", + "---STARTING SL ITERATION #2---\n", + "Design ready\n", + "Created design run with ID 0b8fe17a-ec86-4a48-a997-6d860edf6fac\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Finished\n", + "SL iter #2, best predicted (value, uncertainty) = ('0.1211', '0.0025')\n", + "\"design-0b8fe17a-ec86-4a48-a997-6d860edf6fac.json\" file successfully created.\n", + "Dataset updated: 10 candidates added.\n", + "New dataset contains 36 PIFs.\n", + "Design ready\n", + "\n", + "---STARTING SL ITERATION #3---\n", + "Design ready\n", + "Created design run with ID 4d8817d8-1332-4625-8945-472086c1e5b6\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Finished\n", + "SL iter #3, best predicted (value, uncertainty) = ('0.113', '0.017')\n", + "\"design-4d8817d8-1332-4625-8945-472086c1e5b6.json\" file successfully created.\n", + "Dataset updated: 10 candidates added.\n", + "New dataset contains 46 PIFs.\n", + "Design ready\n", + "\n", + "---STARTING SL ITERATION #4---\n", + "Design ready\n", + "Created design run with ID 8d9b069a-aa22-4ef5-a5c8-508539793ff1\n", + "Design run status: Accepted\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Finished\n", + "SL iter #4, best predicted (value, uncertainty) = ('0.072', '0.073')\n", + "\"design-8d9b069a-aa22-4ef5-a5c8-508539793ff1.json\" file successfully created.\n", + "Dataset updated: 10 candidates added.\n", + "New dataset contains 56 PIFs.\n", + "Design ready\n", + "\n", + "---STARTING SL ITERATION #5---\n", + "Design ready\n", + "Created design run with ID 39c78154-02d6-489d-8f66-d826ff04b746\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Processing\n", + "Design run status: Finished\n", + "SL iter #5, best predicted (value, uncertainty) = ('0.039', '0.015')\n", + "\"design-39c78154-02d6-489d-8f66-d826ff04b746.json\" file successfully created.\n", + "Dataset updated: 10 candidates added.\n", + "New dataset contains 66 PIFs.\n", + "Design ready\n", + "SL finished!\n", + "\n" + ] + } + ], "source": [ "best_sl_pred_vals, best_sl_measured_vals = run_sequential_learning(\n", " client=client,\n", @@ -291,12 +448,11 @@ " design_effort=10,\n", " wait_time=10,\n", " num_sl_iterations=5,\n", - " input_properties=[\"Property x1\", \"Property x2\"],\n", - " target=[\"Property y\", \"Min\"],\n", + " input_properties=input_keys,\n", + " target=[target_property, \"Min\"],\n", " print_output=True,\n", - " true_function=toy_func,\n", - " score_type=\"MLI\"\n", - ")" + " true_function=bandgap_diff,\n", + " score_type=\"MLI\")" ] }, { @@ -316,9 +472,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "plot_sl_results(best_sl_measured_vals, best_sl_pred_vals, initial_best_measured_value)" ] diff --git a/citrination_api_examples/clients_sequence/6_sequential_learning_steel_fatigue.ipynb b/citrination_api_examples/clients_sequence/6_sequential_learning_steel_fatigue.ipynb new file mode 100644 index 0000000..7d27112 --- /dev/null +++ b/citrination_api_examples/clients_sequence/6_sequential_learning_steel_fatigue.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![logo](https://github.com/CitrineInformatics/community-tools/blob/master/templates/fig/citrine_banner_2.png?raw=true)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sequential Learning Workshop\n", + "*Authors: Edward Kim, Enze Chen, Nils Persson*\n", + "\n", + "In this notebook, we will cover how to perform **sequential learning** (SL) using the [Citrination API](http://citrineinformatics.github.io/python-citrination-client/). [Sequential learning](https://citrine.io/platform/sequential-learning/) is the key workflow which allows machine learning algorithms and in-lab experiments to iteratively inform each other.\n", + "\n", + "To replace the need for an actual laboratory or simulation, this demo uses an existing dataset from the Open Citrination platform, with measurements of *steel fatigue strength across 437 experiments spanning 23 processing and formulation variables*. \n", + "\n", + "To simulate this experiment, we will redact the output measurement (Fatigue Strength) from all but 25 random experiments from the bottom quartile of performance. Each new experiment will be selected from the list of 412 other \"unmeasured\" points using the Citrination platform's design algorithm, with the goal of *maximizing Fatigue Strength*." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of contents\n", + "1. [Setup](#1)\n", + "1. [Get Training Data](#2)\n", + "1. [Initial Measurements](#3)\n", + "1. [Run Sequential Learning](#4)\n", + " 1. [Design](#4.1)\n", + " 1. [Measure (and re-train)](#4.2)\n", + " 1. [Repeat](#4.3)\n", + "1. [Conclusion](#5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup\n", + "---\n", + "[Back to TOC](#toc)\n", + "\n", + "This notebook uses some convenience functions to wrap several API endpoints. These are contained in the file `sequential_learning_wrappers_class.py` and imported below. Review the docstrings and code in that file to learn more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# IPython magic settings\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# Third-party packages\n", + "from steel_fatigue_wrapper_class import * # Helper functions to wrap several API endpoints together" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the CitrinationClient\n", + "\n", + "Initializing a `CitrinationClient` requires two arguments, `api_key` and `site`.\n", + "\n", + "If the following cell runs successfully, you will see `Client created successfully!`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the CitrinationClient with your API key and deployment\n", + "site = \"https://citrination.com\"\n", + "client = CitrinationClient(api_key=os.environ.get('CITRINATION_API_KEY'),\n", + " site=site)\n", + "verify_client(client)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Get Training Data\n", + "---\n", + "[Back to TOC](#toc)\n", + "\n", + "Since we don't have access to an actual experiment or simulation, we will use data from an existing public dataset on steel fatigue." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.set_option('display.max_columns', 500)\n", + "orig_dataset_id = 150670 \n", + "df_steel = get_steel_dataset(client, orig_dataset_id)\n", + "ordered_cols = df_steel.columns.to_list()\n", + "print(\"{} entries spanning {} dimensions.\".format(df_steel.shape[0], df_steel.shape[1]-5))\n", + "df_steel.sample(4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot histogram of Fatigue Strength values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'figure.figsize':(8, 7), 'font.size':14, 'lines.markersize':8})\n", + "df_steel['Fatigue Strength'].hist(bins=20)\n", + "plt.xlabel('Fatigue Strength (MPa)')\n", + "plt.ylabel('Number of Entries')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analyze simple statistics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_steel['Fatigue Strength'].describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate Training Set\n", + "\n", + "We will select 25 points from the bottom 25% of the dataset (in terms of Fatigue Strength) to simulate an initial experimental design space. We will have access to the Fatigue Strength of these 25 training points, but it will be redacted from the remaining 412. Thus, our initial model will be constructed on these below-average candidates.\n", + "\n", + "To \"measure\" a new candidate, we simply look up its Fatigue Strength from the original dataset. This process (the splitting and the \"measurement\") uses the functionality of our `SearchClient` under the hood." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set a cutoff value for Fatigue Strength\n", + "target_col = 'Fatigue Strength'\n", + "target_max = np.percentile(df_steel['Fatigue Strength'], 25) # 50th percentile of fatigue strength\n", + "\n", + "# Split and redact original dataset\n", + "all_pifs = split_dataset(client,\n", + " orig_dataset_id,\n", + " target_col,\n", + " target_max,\n", + " num_train=25)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Initial Measurements\n", + "---\n", + "[Back to TOC](#toc)\n", + "\n", + "We'll now write our initial training data to a JSON file and upload it to Citrination using our `client`. This involves creating a new Dataset, then defining a DataView to run predict and design services.\n", + "\n", + "**Start from here if you want to start-over an SL run**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_string = str(uuid4())[:6]\n", + "meas_dataset_name = \"SL_demo_dataset_{}\".format(random_string)\n", + "\n", + "# Write to file\n", + "if not os.path.exists('temp'):\n", + " os.makedirs('temp')\n", + "dataset_file = os.path.join(\"temp\", meas_dataset_name+\".json\")\n", + "with open(dataset_file, \"w\") as f:\n", + " f.write(pif.dumps(all_pifs, indent=4))\n", + "\n", + "# Upload to Citrination\n", + "dataset_id = upload_data_and_get_id(client,\n", + " meas_dataset_name,\n", + " dataset_file,\n", + " create_new_version=True)\n", + "\n", + "print(\"Dataset created: {}/datasets/{}\".format(site, dataset_id))\n", + "print('The name is \"{}.\"'.format(meas_dataset_name))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a DataView\n", + "\n", + "We now create a DataView to model our initial training data and run design services. We will select the `chemical formula` and the processing variables as inputs, and set the `Fatigue Strength` as an output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "search_template_client = client.data_views.search_template_client\n", + "avail_cols = search_template_client.get_available_columns(dataset_id)\n", + "\n", + "excluded_cols = \\\n", + " [col for col in avail_cols if ('Area Proportion' in col\n", + " or 'Reduction Ratio' in col\n", + " or 'composition' in col\n", + " or 'Fatigue Strength' in col\n", + " or 'Sample Number' in col)]\n", + "\n", + "input_cols = [col for col in avail_cols if col not in excluded_cols]\n", + "\n", + "print('Inputs:\\n{}'.format(input_cols))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make a data view on Citrination, return/print the ID\n", + "\n", + "view_name = \"SL_demo_view_{}\".format(random_string)\n", + "\n", + "view_id = build_view_and_get_id(client,\n", + " dataset_id,\n", + " view_name,\n", + " view_desc='DataView for SL demo for Fatigue Strength.',\n", + " input_keys=input_cols,\n", + " output_keys=['Property Fatigue Strength',\n", + " 'Property Sample Number'],\n", + " model_type='default')\n", + "\n", + "print(\"Data view created: {}/data_views/{}\".format(site, view_id))\n", + "print('The name is \"{}.\"'.format(view_name))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While model training proceeds, we can explore the DataView we just created." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Run Sequential Learning\n", + "---\n", + "[Back to TOC](#toc)\n", + "\n", + "We're now ready to run sequential learning (SL). SL consists of three main phases:\n", + "\n", + "- **design**: generate new candidates to test in the lab (or *in silico*)\n", + "- **measure**: test those new candidates and add the results to your dataset\n", + "- **retrain**: re-train the machine learning model using the new measurements\n", + "- **repeat**\n", + "\n", + "That's really all there is to it! We will manage the entire sequential learning process through an object called an `SL_run`, which has methods (`.design()` and `.measure()`) to run each of these steps. Let's instantiate one of those right now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "meas_cols = ordered_cols+['iter']\n", + "SLR = SL_run(client=client,\n", + " view_id=str(view_id),\n", + " dataset_id=str(dataset_id),\n", + " orig_dataset_id=orig_dataset_id,\n", + " all_dataset_cols=meas_cols,\n", + " target=[\"Property Fatigue Strength\", \"Max\"],\n", + " score_type=\"MLI\",\n", + " sampler='This view')\n", + "\n", + "SLR.measurements[meas_cols].head(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.1 Design" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "design_effort = 10 # An integer between 1 and 30\n", + "SLR.design(design_effort=design_effort)\n", + "cand_cols = input_cols + ['Property Fatigue Strength', \n", + " 'Uncertainty in Property Fatigue Strength', \n", + " 'citrine_score']\n", + "SLR.candidates[cand_cols]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.2 Measure (and re-train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SLR.measure()\n", + "SLR.measurements[meas_cols].tail(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.3 Repeat\n", + "---\n", + "[Back to TOC](#toc)\n", + "\n", + "From here on out, we can repeat this process as much as we want (and ultimately run until in converges to a specified tolerance). This basically consists of `.design()` and `.measure()` cycles. To wrap things up, we'll run a loop of a few more iterations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "repeat_iters = 6\n", + "for i in range(repeat_iters):\n", + " SLR.design(design_effort=design_effort)\n", + " SLR.measure()\n", + " SLR.plot_sl_results();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inspect the new data in a DataFrame\n", + "SLR.measurements[meas_cols].tail(repeat_iters+3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "---\n", + "[Back to ToC](#toc)\n", + "\n", + "After running this demo, you should have a sense for the steps involved in a sequential learning cycle, namely that it consists of **design** and **measure** phases. Design is run on the Citrination platform by training a model to fit your existing data, and returns candidates to measure. Measure is the phase where you (the experimentalist or computationalist) go and run your experiment!\n", + "\n", + "A few key takeaways from this demo:\n", + "* Building a model on Citrination is as easy as defining inputs, outputs, and latent variables.\n", + "* Design runs return candidates based on predicted output *and* uncertainty.\n", + "* Well-calibrated prediction uncertainties are vital to this process." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/citrination_api_examples/clients_sequence/sequential_learning_wrappers.py b/citrination_api_examples/clients_sequence/sequential_learning_wrappers.py index 723e9d3..dabcb01 100644 --- a/citrination_api_examples/clients_sequence/sequential_learning_wrappers.py +++ b/citrination_api_examples/clients_sequence/sequential_learning_wrappers.py @@ -1,17 +1,20 @@ ''' Authors: Eddie Kim, Enze Chen -This file contains wrapper functions that are used in the sequential learning API tutorial notebook. Detailed docstrings with method fuctions and parameters are given below. +This file contains wrapper functions that are used in the sequential learning +API tutorial notebook. Detailed docstrings with method fuctions and parameters +are given below. ''' -import json -from collections import OrderedDict +# Standard packages +import os from time import sleep +# Third-party packages import numpy as np import matplotlib.pyplot as plt -from citrination_client import (CitrinationClient, DataQuery, DatasetQuery, - Filter, PifSystemReturningQuery, - RealDescriptor) + +# Citrine packages +from citrination_client import * from citrination_client.models.design import Target from citrination_client.views.data_view_builder import DataViewBuilder from pypif import pif @@ -19,16 +22,13 @@ def write_dataset_from_func(test_function, filename, input_vals): - '''Given a function, write a dataset evaluated on given input values - - :param test_function: Function for generating dataset - :type test_function: Callable[[np.ndarray], float] - :param filename: Name of file for saving CSV dataset - :type filename: str - :param input_vals: List of input values to eval function over - :type input_vals: np.ndarray - :return: Doesn't return anything - :rtype: None + ''' + Given a function, write a dataset evaluated on given input values. + + :param test_function: Function for generating dataset. + :param filename: Name of file as a string for saving CSV dataset. + :param input_vals: List of input values as numpy array to evaluate function over. + :return: None. ''' pif_systems = [] @@ -44,31 +44,30 @@ def write_dataset_from_func(test_function, filename, input_vals): system.properties.append(func_input) func_output = Property() - func_output.name = 'y' + func_output.name = 'Band gap difference' func_output.scalars = test_function(val_row) system.properties.append(func_output) pif_systems.append(system) - with open(filename, "w") as f: + if not os.path.exists('temp'): + os.makedirs('temp') + + with open(os.path.join('temp', filename), "w") as f: f.write(pif.dumps(pif_systems, indent=4)) + print('"{}" file successfully created.'.format(filename)) def upload_data_and_get_id(client, dataset_name, dataset_local_fpath, - create_new_version = False, given_dataset_id = None): - '''Uploads data to a new/given dataset and returns its ID - - :param client: Client API object to pass in - :type client: CitrinationClient - :param dataset_name: Name of dataset - :type dataset_name: str - :param dataset_local_fpath: Local data filepath - :type dataset_local_fpath: str - :param create_new_version: Whether or not to make a new version - :param create_new_version: bool - :param given_dataset_id: ID if using existing dataset, defaults to None - :param given_dataset_id: int - :return: ID of the dataset - :rtype: int + create_new_version = False, given_dataset_id = None): + ''' + Uploads data to a new/given dataset and returns its ID. + + :param client: CitrinationClient API object to pass in. + :param dataset_name: Name of dataset as a string. + :param dataset_local_fpath: Local data filepath as a string. + :param create_new_version: Boolean flag for whether or not to make a new version. + :param given_dataset_id: Integer ID if using existing dataset; default = None. + :return dataset_id: Integer ID of the dataset. ''' if given_dataset_id is None: @@ -79,46 +78,37 @@ def upload_data_and_get_id(client, dataset_name, dataset_local_fpath, if create_new_version: client.data.create_dataset_version(dataset_id) - client.data.upload(dataset_id, dataset_local_fpath) + client.data.upload(dataset_id, os.path.join('temp', dataset_local_fpath)) assert (client.data.matched_file_count(dataset_id) >= 1), "Upload failed." return dataset_id -def build_view_and_get_id(client, dataset_id, input_keys, output_keys, view_name, view_desc = "", - wait_time = 2, print_output = False): - '''Builds a new data view and returns the view ID - - :param client: Client object - :type client: CitrinationClient - :param dataset_id: Dataset to build view from - :type dataset_id: int - :param view_name: Name of the new view - :type view_name: str - :param input_keys: Input key names - :type input_keys: List[str] - :param output_keys: Output key names - :type output_keys: List[str] - :param view_desc: Description for the view, defaults to "" - :param view_desc: str, optional - :param wait_time: Wait time in seconds before polling API - :type wait_time: int - :param print_output: Whether or not to print outputs - :type print_output: bool - :return: ID of the view - :rtype: int +def build_view_and_get_id(client, dataset_id, input_keys, output_keys, view_name, + view_desc = '', wait_time = 2, print_output = False): + ''' + Builds a new data view and returns the view ID. + + :param client: CitrinationClient object. + :param dataset_id: Integer ID of the dataset to build data view from. + :param view_name: Name of the new data view as a string. + :param input_keys: List of string representing input key names. + :param output_keys: List of string representing output key names. + :param view_desc: String description for the data view. + :param wait_time: Wait time in seconds (int) before polling API. + :param print_output: Boolean flag for whether or not to print outputs. + :return dv_id: Integer ID of the data view. ''' dv_builder = DataViewBuilder() dv_builder.dataset_ids([str(dataset_id)]) - dv_builder.model_type('default') for key_name in input_keys: - desc_x = RealDescriptor(key=key_name, lower_bound=-1e6, upper_bound=1e6) - dv_builder.add_descriptor(desc_x, role='input') + desc_x = RealDescriptor(key=key_name, lower_bound=-1e3, upper_bound=1e3) + dv_builder.add_descriptor(descriptor=desc_x, role='input') for key_name in output_keys: - desc_y = RealDescriptor(key=key_name, lower_bound=-1e6, upper_bound=1e6) - dv_builder.add_descriptor(desc_y, role='output') + desc_y = RealDescriptor(key=key_name, lower_bound=0, upper_bound=1e2) + dv_builder.add_descriptor(descriptor=desc_y, role='output') dv_config = dv_builder.build() @@ -127,49 +117,33 @@ def build_view_and_get_id(client, dataset_id, input_keys, output_keys, view_name dv_id = client.data_views.create( configuration=dv_config, name=view_name, - description=view_desc - ) + description=view_desc) + return dv_id -def run_sequential_learning(client, view_id, dataset_id, - num_candidates_per_iter, - design_effort, wait_time, - num_sl_iterations, input_properties, - target, print_output, - true_function, - score_type): - '''Runs SL design - - :param client: Client object - :type client: CitrinationClient - :param view_id: View ID - :type view_id: int - :param dataset_id: Dataset ID - :type dataset_id: int - :param num_candidates_per_iter: Candidates in a batch - :type num_candidates_per_iter: int - :param design_effort: Effort from 1-30 - :type design_effort: int - :param wait_time: Wait time in seconds before polling API - :type wait_time: int - :param num_sl_iterations: SL iterations to run - :type num_sl_iterations: int - :param input_properties: Inputs - :type input_properties: List[str] - :param target: ("Output property", {"Min", "Max"}) - :type target: List[str] - :param print_output: Whether or not to print outputs - :type print_output: bool - :param true_function: Actual function for evaluating measured/true values - :type true_function: Callable[[np.ndarray], float] - :param score_type: MLI or MEI - :type score_type: str - :return: 2-tuple: list of predicted scores/uncertainties; list of measured scores/uncertainties - :rtype: Tuple[List[float], List[float]] +def run_sequential_learning(client, view_id, dataset_id, num_candidates_per_iter, + design_effort, wait_time, num_sl_iterations, + input_properties, target, print_output, true_function, + score_type): + ''' + Runs SL design. + + :param client: CitrinationClient object. + :param view_id: Integer ID for the data view. + :param dataset_id: Integer ID for the data set. + :param num_candidates_per_iter: Integer number of candidates in a batch. + :param design_effort: Integer from 1 to 30 representing design effort. + :param wait_time: Wait time in seconds (int) before polling API. + :param num_sl_iterations: Integer number of SL iterations to run. + :param input_properties: List of strings representing input property keys. + :param target: List of strings for target property key and optimization goal. + :param print_output: Boolean flag for whether or not to print outputs. + :param true_function: Actual function for evaluating measured/true values. + :param score_type: String for candidate selection strategy: 'MLI' or 'MEI'. + :return: 2-tuple: (List of floats for predicted scores/uncertainties, + List of floats for measured scores/uncertainties) ''' - - best_sl_pred_vals = [] best_sl_measured_vals = [] @@ -181,7 +155,7 @@ def run_sequential_learning(client, view_id, dataset_id, print("\n---STARTING SL ITERATION #{}---".format(i+1)) _wait_on_ingest(client, dataset_id, wait_time, print_output) - _wait_on_data_view(client, dataset_id, view_id, wait_time, print_output) + _wait_on_data_view(client, view_id, wait_time, print_output) # Submit a design run design_id = client.models.submit_design_run( @@ -189,9 +163,7 @@ def run_sequential_learning(client, view_id, dataset_id, num_candidates=num_candidates_per_iter, effort=design_effort, target=Target(*target), - constraints=[], - sampler="Default" - ).uuid + constraints=[]).uuid if print_output: print("Created design run with ID {}".format(design_id)) @@ -266,7 +238,7 @@ def run_sequential_learning(client, view_id, dataset_id, # Retrain model w/ wait times client.models.retrain(view_id) - _wait_on_data_view(client, dataset_id, view_id, wait_time, print_output) + _wait_on_data_view(client, view_id, wait_time, print_output) if print_output: print("SL finished!\n") @@ -275,7 +247,16 @@ def run_sequential_learning(client, view_id, dataset_id, def _wait_on_ingest(client, dataset_id, wait_time, print_output = True): - # Wait for ingest to finish + ''' + Utility function to check for data ingest completion. + + :param client: CitrinationClient API object. + :param dataset_id: Integer ID for the dataset to check. + :param wait_time: Wait time in seconds (int) before polling API again. + :param print_output: Boolean flag for whether to display status messages. + :return: None. + ''' + sleep(wait_time) while (client.data.get_ingest_status(dataset_id) != "Finished"): if print_output: @@ -283,37 +264,63 @@ def _wait_on_ingest(client, dataset_id, wait_time, print_output = True): sleep(wait_time) -def _wait_on_data_view(client, dataset_id, view_id, wait_time, print_output = True): - is_view_ready = False +def _wait_on_data_view(client, view_id, wait_time, print_output = True): + ''' + Utility function to check for data view creation completion. + + :param client: CitrinationClient API object. + :param view_id: Integer ID for the data view to check. + :param wait_time: Wait time in seconds (int) before polling API again. + :param print_output: Boolean flag for whether to display status messages. + :return: None. + ''' + sleep(wait_time) - while (not is_view_ready): + while True: sleep(wait_time) design_status = client.data_views.get_data_view_service_status(view_id) if (design_status.experimental_design.ready and - design_status.predict.event.normalized_progress == 1.0): - is_view_ready = True + design_status.predict.event.normalized_progress == 1.0): if print_output: - print("Design ready") + print("Design ready.") + break else: print("Waiting for design services...") + sleep(2) def _wait_on_design_run(client, design_id, view_id, wait_time, print_output = True): - design_processing = True + ''' + Utility function to check for design run completion. + + :param client: CitrinationClient API object. + :param design_id: Integer ID of the submitted design run. + :param view_id: Integer ID for the data view to check. + :param wait_time: Wait time in seconds (int) before polling API again. + :param print_output: Boolean flag for whether to display status messages. + :return: None. + ''' + sleep(wait_time) - while design_processing: + while True: status = client.models.get_design_run_status(view_id, design_id).status if print_output: - print("Design run status: {}".format(status)) - + print("Design run status: {}.".format(status)) if status != "Finished": sleep(wait_time) else: - design_processing = False - + break + def plot_sl_results(measured, predicted, init_best): -# plt.rcParams.update({'figure.figsize':(8, 6), 'font.size':18}) + ''' + Helper function to plot the SL results for each iteration. + + :param measured: True/measured values of a property (float). + :param predicted: Predicted values of a property (float). + :param init_best: The best value of the property from the training set (float). + :return: None. + ''' # Measured results plt.plot( diff --git a/citrination_api_examples/clients_sequence/steel_fatigue_wrapper_class.py b/citrination_api_examples/clients_sequence/steel_fatigue_wrapper_class.py new file mode 100644 index 0000000..777bee6 --- /dev/null +++ b/citrination_api_examples/clients_sequence/steel_fatigue_wrapper_class.py @@ -0,0 +1,706 @@ +''' +Authors: Eddie Kim, Enze Chen, Nils Persson +This file contains wrapper functions that are used in the sequential learning +API tutorial notebook. Detailed docstrings with method fuctions and parameters +are given below. +''' + +# Standard packages +import os +from uuid import uuid4 +from time import time, sleep + +# Third party packages +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + +# Citrine packages +from citrination_client import * +from citrination_client.models.design import Target +from citrination_client.views.data_view_builder import DataViewBuilder +from pypif import pif +from pypif.obj import * + + +def verify_client(client): + ''' + Verify that the Client is able to create datasets. + + :param client: An instance of CitrinationClient. + :return: None. + ''' + dataset_id = None + try: + dataset = client.data.create_dataset( + name='Test valid API key '+str(uuid4()), + description='This empty dataset was created to test the Client connection.') + dataset_id = dataset.id + except: + print("The Client could not connect.\nPlease double check the deployment name and API key.") + return + print("Client created successfully!") + client.data.delete_dataset(dataset_id=dataset_id) + return + + +def get_steel_dataset(client, orig_dataset_id): + ''' + Retrieve and format the steel fatigue dataset as a DataFrame. + + :param client: An instance of CitrinationClient. + :param orig_dataset_id: An int representing the dataset ID. + :return df_steel: A pandas DataFrame for the dataset. + ''' + dataset_json = client.data.get_dataset_files(orig_dataset_id)[0] + df = pd.read_json(dataset_json._url) + + # Create Composition and Property columns and combine into one DataFrame + df_form = df['composition'].apply(lambda dl: pd.Series({d['element']:d['actualWeightPercent']['value'] for d in dl})) + df_proc = df['properties'].apply(lambda dl: pd.Series({d['name']:d['scalars'][0]['value'] for d in dl})) + df_steel = pd.concat([df_form, df_proc], axis=1) + + return df_steel + + +def split_dataset(client, dataset_id, target_col, target_max, + num_train = 20, random_seed = 1): + ''' + Splits an existing dataset such that num_train entries with + target_col below target_max have their value of target_col + retained, while the rest have it redacted. + + :param client: Client API object to pass in. + :param dataset_id: An int representing the ID of dataset to split. + :param target_col: A string for column name to filter and split on. + :param target_max: Max float value of target_col to allow in training set. + :param num_train: An int for the number of training points to keep. + :param random_seed: A random seed (int) to fix things for testing. + :return all_pifs: A list of PIF Systems. + ''' + + # Get PIFs with Fatigue Strength above cutoff + system_query_high = PifSystemReturningQuery( + size=9999, + query=DataQuery( + dataset=DatasetQuery( + id=Filter( + equal=dataset_id)), + system=PifSystemQuery( + properties=PropertyQuery( + name=FieldQuery(filter=Filter(equal=target_col)), + value=FieldQuery(filter=Filter(min=target_max+1e-6)) + ) + ) + ) + ) + + query_result_high = client.search.pif_search(system_query_high) + print('Entries in top split:', query_result_high.total_num_hits) + + # Get PIFs with Fatigue Strength below cutoff + system_query_low = PifSystemReturningQuery( + size=9999, + query=DataQuery( + dataset=DatasetQuery( + id=Filter( + equal=dataset_id)), + system=PifSystemQuery( + properties=PropertyQuery( + name=FieldQuery(filter=Filter(equal=target_col)), + value=FieldQuery(filter=Filter(max=target_max)) + ) + ) + ) + ) + + query_result_low = client.search.pif_search(system_query_low) + print('Entries in bottom split:', query_result_low.total_num_hits) + + # Choose some number of the low values to use as training points + # np.random.seed(random_seed) + low_hits = query_result_low._hits + np.random.shuffle(low_hits) + low_hits_split = np.split(low_hits, [num_train]) + train_pifs = [h._system for h in low_hits_split[0]] + unmeasured_pifs = [h._system for h in low_hits_split[1]] \ + + [h._system for h in query_result_high._hits] + print("{} train PIFs and {} possible candidate PIFs.".format( + len(train_pifs), len(unmeasured_pifs))) + + # Redact the target_col values from the "unmeasured" candidates + for system in unmeasured_pifs: + fatigue = [p for p in system.properties if p._name == 'Fatigue Strength'][0] + fatigue.scalars[0]._value = None + + all_pifs = train_pifs + unmeasured_pifs + return all_pifs + + +def upload_data_and_get_id(client, dataset_name, dataset_local_fpath, + create_new_version = False, given_dataset_id = None): + ''' + Uploads data to a new/given dataset and returns its ID. + + :param client: CitrinationClient API object to pass in. + :param dataset_name: Name of dataset as a string. + :param dataset_local_fpath: Local data filepath as a string. + :param create_new_version: Boolean flag for whether or not to make a new version. + :param given_dataset_id: ID (int) if using existing dataset, defaults to None. + :return dataset_id: Integer ID of the dataset. + ''' + + if given_dataset_id is None: + dataset = client.data.create_dataset(dataset_name) + dataset_id = dataset.id + else: + dataset_id = given_dataset_id + if create_new_version: + client.data.create_dataset_version(dataset_id) + + # Guard against AWS timeout + start = time() + timeout = 240 + while time() - start < timeout: + sleep(1) + try: + print('Uploading data...') + client.data.upload(dataset_id, dataset_local_fpath) + break + except: + if time() - start >= timeout: + raise RuntimeError("Possible AWS timeout, try re-running.") + continue + + _wait_on_ingest(client, dataset_id, wait_time=15, print_output=True) + assert (client.data.matched_file_count(dataset_id) >= 1), "Upload failed." + return dataset_id + + +def build_view_and_get_id(client, dataset_id, view_name, input_keys, output_keys, + ignore_keys = [], view_desc = '', wait_time = 2, + print_output = False, model_type = 'default'): + ''' + Builds a new data view and returns the view ID. + + :param client: CitrinationClient object. + :param dataset_id: Integer ID of the dataset to build data view from. + :param view_name: Name of the new view as a string. + :param input_keys: Input key names as a list of strings. + :param output_keys: Output key names as a list of strings. + :param ignore_keys: Ignore (dummy) key names in this list of strings. + :param view_desc: String description for the view, defaults to ''. + :param wait_time: Wait time in seconds (int) before polling API. + :param print_output: Boolean flag for whether or not to print outputs. + :param model_type: 'default' or 'linear' ML model. + :return dv_id: Integer ID of the data view. + ''' + + dv_builder = DataViewBuilder() + dv_builder.dataset_ids([str(dataset_id)]) + dv_builder.model_type(model_type) + + for key_name in input_keys: + if 'formula' in key_name: + desc_x = InorganicDescriptor(key=key_name, + threshold=1) + dv_builder.add_descriptor(descriptor=desc_x, + role='input') + else: + desc_x = RealDescriptor(key=key_name, + lower_bound=-9999.0, + upper_bound=9999.0) + dv_builder.add_descriptor(desc_x, role='input') + + + for key_name in output_keys: + desc_y = RealDescriptor(key=key_name, + lower_bound=-9999.0, + upper_bound=9999.0) + dv_builder.add_descriptor(desc_y, role='output') + + for key_name in ignore_keys: + desc_i = RealDescriptor(key=key_name, + lower_bound=-9999.0, + upper_bound=9999.0) + dv_builder.add_descriptor(desc_i, role='ignore') + + dv_config = dv_builder.build() + + _wait_on_ingest(client, dataset_id, wait_time, print_output) + + dv_id = client.data_views.create( + configuration=dv_config, + name=view_name, + description=view_desc) + + return dv_id + + +def _wait_on_ingest(client, dataset_id, wait_time, print_output = True): + ''' + Utility function to check for data ingest completion. + + :param client: CitrinationClient API object. + :param dataset_id: Integer ID for the dataset to check. + :param wait_time: Wait time in seconds (int) before polling API again. + :param print_output: Boolean flag for whether to display status messages. + :return: None. + ''' + + sleep(wait_time) + while (client.data.get_ingest_status(dataset_id) != "Finished"): + if print_output: + print("Waiting for data ingest to complete...") + sleep(wait_time) + sleep(2) + + +def _wait_on_data_view(client, view_id, wait_time, print_output = True): + ''' + Utility function to check for data view creation completion. + + :param client: CitrinationClient API object. + :param view_id: Integer ID for the data view to check. + :param wait_time: Wait time in seconds (int) before polling API again. + :param print_output: Boolean flag for whether to display status messages. + :return: None. + ''' + + sleep(wait_time) + while True: + sleep(wait_time) + design_status = client.data_views.get_data_view_service_status(view_id) + if (design_status.experimental_design.ready and + design_status.predict.event.normalized_progress == 1.0): + if print_output: + print("Design ready.") + break + else: + print("Waiting for design services...") + sleep(2) + + +def candidates_to_df(candidates): + ''' + This function turns design candidates into a DataFrame. + + :param candidates: A list of materials candidates output from design runs. + :return df_cand: A pandas DataFrame of the candidates. + ''' + + df_cand = pd.DataFrame(candidates) + df_cand[list(df_cand['descriptor_values'].iloc[0].keys())] = \ + df_cand['descriptor_values'].apply(pd.Series) + df_cand = df_cand.drop(['descriptor_values', 'constraint_likelihoods'], axis=1) + for col in df_cand.columns: + try: + df_cand[col] = df_cand[col].astype(float) + except: + pass + return df_cand + + +def query_results_to_df(query_result): + ''' + This function puts query results from the SearchClient into a pandas DataFrame. + + :param query_result: Query results from the SearchClient. + :return df_query: A pandas DataFrame with data from the query results. + ''' + try: + query_result_list = query_result._hits + except: + query_result_list = query_result + + result_list = [{p._name:p._scalars[0]._value for p in h._system.properties} + for h in query_result_list] + formula_list = [{c._element:c._actual_weight_percent._value for c in h._system._composition} + for h in query_result_list] + full_results_dict = [{**d1, **d2} for d1,d2 in zip(formula_list, result_list)] + df_query = pd.DataFrame(full_results_dict).astype(float) + return df_query + + +class SL_run: + ''' + This class wraps the various steps of sequential learning to facilitate + running multiple SL iterations. + ''' + def __init__(self, client, view_id, dataset_id, orig_dataset_id, + all_dataset_cols, target, score_type, + design_effort = 25, wait_time = 10, + sampler = 'Default', print_output = True): + ''' + Constructor. + + :param client: CitinationClient object. + :param view_id: Integer ID of data view. + :param dataset_id: Integer ID of dataset. + :param orig_dataset_id: Integer ID of original dataset with all + measurements filled in. + :param all_dataset_cols: Full list of string column names expected for + measurements. + :param target: A list of string for the target property and optimization + objective ('Min' or 'Max'). + :param score_type: String for candidate selection strategy. 'MLI' or 'MEI' + :param wait_time: Wait time in seconds (int) before polling API. + :param sampler: What type of sampling to use for design. + ['Default' or 'This view'] + :param print_output: Boolean flag for whether or not to print outputs. + :return: An SL_run object. + ''' + + # Attributes from arguments + self.client = client + self.view_id = view_id + self.dataset_id = dataset_id + self.orig_dataset_id = orig_dataset_id + self.target = target + self.score_type = score_type + self.wait_time = wait_time + self.sampler = sampler + self.print_output = print_output + self.y_col = self.target[0].replace('Property ','') + + # Empty attributes to add onto later + self.curr_iter = 0 + self.curr_design_id = None + self.measurements = pd.DataFrame(columns=all_dataset_cols) + self.candidates = pd.DataFrame() + + # Dataset should have training data (iteration "0") + # Stick this in the measurements DataFrame as iter 0 + query_dataset = \ + PifSystemReturningQuery(size=9999, + query=DataQuery( + dataset=DatasetQuery( + id=Filter(equal=str(self.dataset_id))), + system=PifSystemQuery( + properties=PropertyQuery( + name=FieldQuery(filter=Filter(equal=self.y_col)), + value=FieldQuery(filter=Filter(exists=True)) + ) + ) + ) + ) + query_result = self.client.search.pif_search(query_dataset) + training_measurements = query_results_to_df(query_result) + training_measurements['iter'] = 0 + self.measurements = pd.concat([self.measurements, training_measurements], + sort=True) + self.measurements['iter'] = self.measurements['iter'].astype(int) + self.last_op = 'measure' + + + def _wait_on_ingest(self): + ''' + Utility function to check for data ingest completion. + + :return: None. + ''' + + sleep(self.wait_time) + while (self.client.data.get_ingest_status(self.dataset_id) != "Finished"): + if self.print_output: + print("Waiting for data ingest to complete...") + sleep(self.wait_time) + print("Ingest finished.") + sleep(2) + + + def _wait_on_data_view(self): + ''' + Utility function to check for data view creation completion. + + :return: None. + ''' + + sleep(self.wait_time) + while True: + design_status = self.client.data_views.get_data_view_service_status(self.view_id) + if (design_status.experimental_design.ready and + design_status.predict.event.normalized_progress == 1.0): + if self.print_output: + print("Design ready.") + break + else: + print("Waiting for design services...") + sleep(2) + + + def _wait_on_design_run(self, design_id): + ''' + Utility function to check for design run completion. + + :param design_id: Integer ID of the submitted design run. + :return: None. + ''' + + sleep(self.wait_time) + while True: + status = self.client.models.get_design_run_status(self.view_id, design_id).status + if self.print_output: + print("Design run status: {}.".format(status)) + if status != "Finished": + sleep(self.wait_time) + else: + break + sleep(2) + + + def _get_valid_candidates(self, num_candidates = 1, num_seeds = 20, + design_effort = 5): + ''' + Wrapper function for design runs that ensures we get back valid + candidates with non-zero uncertainty. + + :param num_candidates: The number of candidates to return. + :param num_seeds: The number of candidates to request from Design. + :param design_effort: The effort for Design runs. + :return valid_candidates: A list of material candidates from Design. + ''' + + valid_candidates = [] + while len(valid_candidates)1e-6] + valid_candidates.extend(candidates_filtered) + [vc.update(design_id=design_id) for vc in valid_candidates] + + if self.print_output: + print("{} candidates obtained.".format(num_candidates)) + + return valid_candidates + + + def design(self, num_candidates = 1, design_effort = 5): + ''' + Submit a design run, get candidates, and add them to self.candidates. + + :param num_candidates: Integer number of candidates to return for this run. + :param design_effort: Effort as an integer from 1 to 30. + :return: None. + ''' + + if self.last_op=='design': + raise Exception('Design was already run for this iteration') + self.curr_iter += 1 + + # Ensure ingest and view creation are complete + self._wait_on_ingest() + self._wait_on_data_view() + + # Get candidates (wrapper for submit_design_run) + candidates = self._get_valid_candidates(num_candidates=num_candidates, + design_effort=design_effort) + + # Candidate DataFrame + df_cand = candidates_to_df(candidates) + df_cand['iter'] = self.curr_iter + df_cand['design_effort'] = design_effort + df_cand = df_cand.sort_values('citrine_score', ascending=False).iloc[:num_candidates] + self.candidates = pd.concat([self.candidates, df_cand.copy()]) + + if self.print_output: + best_val_w_uncertainty = \ + df_cand[[self.target[0], + 'Uncertainty in '+self.target[0]]].iloc[0].values + print("SL iter #{}, best predicted (value, uncertainty) = {}".format( + self.curr_iter, best_val_w_uncertainty)) + + self.last_op = 'design' + + + def measure(self): + ''' + Measure the true_function for the most recent batch of candidates + and add them to self.measurements and the online dataset. + + :return: None. + ''' + + if self.last_op == 'measure': + raise Exception('Candidates were already measured for this iteration.') + if len(self.candidates) == 0: + raise Exception('No candidates to measure, please run design.') + + # Get candidates DataFrame for current iteration + curr_candidates_df = self.candidates.query("iter==@self.curr_iter") + search_results = [] + + # Search original dataset for the chosen samples + if self.print_output: + print("Measuring new candidates...") + for ii,cand in curr_candidates_df.iterrows(): + cand_prop_query = [PropertyQuery(name=FieldQuery(filter=Filter(equal='Sample Number')), + logic='MUST', + value=FieldQuery(filter= + Filter(min=cand['Property Sample Number']-0.1, + max=cand['Property Sample Number']+0.1)))] + system_query_cand = PifSystemReturningQuery( + size=9999, + query=DataQuery( + dataset=DatasetQuery( + id=Filter( + equal=self.orig_dataset_id)), + system=PifSystemQuery( + properties=cand_prop_query, + ) + ) + ) + query_result_cand = self.client.search.pif_search(system_query_cand) + search_results.append(query_result_cand.hits[0]) + + + # Store new measurements + curr_measurements = query_results_to_df(search_results) + curr_measurements['iter'] = self.curr_iter + self.measurements = pd.concat([self.measurements, + curr_measurements], + sort=True) + + + # Write measurements to dataset + self.curr_design_id = curr_candidates_df["design_id"].iloc[0] + temp_dataset_fpath = os.path.join('temp', + "design-{}.json".format(self.curr_design_id)) + + with open(temp_dataset_fpath, "w") as f: + f.write(pif.dumps([h._system for h in search_results], + indent=4)) + + + # Upload results and re-train model + upload_data_and_get_id( + self.client, + "", # No name needed for updating a dataset + temp_dataset_fpath, + given_dataset_id=self.dataset_id + ) + self._wait_on_ingest() + + if self.print_output: + print("Dataset updated: {} candidates added.".format(len(search_results))) + print("New dataset contains {} PIFs.".format(len(self.measurements))) + print('Retraining model...') + + # Re-train the model with the new data + self.client.data_views.models.retrain(self.view_id) + self._wait_on_data_view() + + self.last_op = 'measure' + + + def plot_sl_results(self, figsize = (8,7)): + ''' + Helper function to plot the SL results for each iteration. + + :param figsize: How large to make each rendered plot. + :return fig: A matplotlib figure object. + ''' + + # Get best point in initial training set + if self.target[1] == 'Min': + init_best = self.measurements.query("iter==0")[self.y_col].min() + else: + init_best = self.measurements.query("iter==0")[self.y_col].max() + + df_meas = self.measurements.reset_index().copy() + df_pred = self.candidates.reset_index().copy() + + # Data aggregation + if self.target[1]=='Min': + df_meas['best'] = df_meas[self.y_col].cummin() + df_best_cum = \ + df_meas.loc[df_meas.groupby('iter')['best'].idxmin()] + df_best_meas = \ + df_meas.loc[df_meas.groupby('iter')[self.y_col].idxmin()] + df_best_pred = \ + df_pred.loc[df_pred.groupby('iter')[self.target[0]].idxmin()] + else: + df_meas['best'] = df_meas[self.y_col].cummax() + df_best_cum = \ + df_meas.loc[df_meas.groupby('iter')['best'].idxmax()] + df_best_meas = \ + df_meas.loc[df_meas.groupby('iter')[self.y_col].idxmax()] + df_best_pred = \ + df_pred.loc[df_pred.groupby('iter')[self.target[0]].idxmax()] + + # Create Figure + fig, ax = plt.subplots(1, 1, figsize=figsize, tight_layout=True) + + # Cumulative Best Measurements + plt.sca(ax) + plt.plot('iter', + 'best', + data=df_best_cum, + color='xkcd:steel blue', + linewidth=3, + linestyle='-', + label="Best Measured Candidate (Cumulative)") + + # Best candidate in training set + plt.plot(np.arange(0, len(df_best_pred)+1), + [init_best] * (len(df_best_pred)+1), + color='xkcd:black', + linestyle='--', + linewidth=3, + label="Best Initial Point", + alpha=0.7) + + # Candidate Predictions with Error Bars per iteration + ax.errorbar(x='iter', + y=self.target[0], + fmt='o', + yerr="Uncertainty in "+self.target[0], + data=df_best_pred, + linewidth=3, + color="xkcd:orange", + label="Candidate Predictions w/ Uncertainty") + + # Candidate Measurements per iteration + plt.plot('iter', + self.y_col, + 's', + data=df_best_meas, + color='xkcd:maroon', + label="Candidate Measurements") + + plt.xlabel("SL iteration #") + plt.xticks(df_best_meas['iter']) + plt.ylabel("Fatigue Strength (MPa)") + plt.title("Optimizing using MLI") + plt.legend(loc='best') + plt.grid(b=False, axis='x') + plt.show() + + return fig diff --git a/citrination_api_examples/tutorial_sequence/1_ImportVASP.ipynb b/citrination_api_examples/tutorial_sequence/1_ImportVASP.ipynb index 356d927..feb3ddf 100644 --- a/citrination_api_examples/tutorial_sequence/1_ImportVASP.ipynb +++ b/citrination_api_examples/tutorial_sequence/1_ImportVASP.ipynb @@ -177,7 +177,8 @@ "outputs": [], "source": [ "site = 'https://citrination.com' # public site\n", - "client = CitrinationClient(api_key=os.environ['CITRINATION_API_KEY'], site=site)" + "client = CitrinationClient(api_key=os.environ['CITRINATION_API_KEY'], \n", + " site=site)" ] }, { @@ -224,7 +225,9 @@ "source": [ "# Comment this cell if you have an ID from a dataset you created via the website\n", "dataset_name = \"Tutorial dataset \" + str(uuid.uuid4())[:6]\n", - "dataset = client.data.create_dataset(name=dataset_name, description=\"Dataset for VASP tutorial.\", public=False)\n", + "dataset = client.data.create_dataset(name=dataset_name, \n", + " description=\"Dataset for VASP tutorial.\", \n", + " public=False)\n", "dataset_id = dataset.id\n", "print('Dataset created! {}/datasets/{}'.format(site, dataset_id))" ] diff --git a/citrination_api_examples/tutorial_sequence/2_WorkingWithPIFs.ipynb b/citrination_api_examples/tutorial_sequence/2_WorkingWithPIFs.ipynb index a53abd5..7783f7a 100644 --- a/citrination_api_examples/tutorial_sequence/2_WorkingWithPIFs.ipynb +++ b/citrination_api_examples/tutorial_sequence/2_WorkingWithPIFs.ipynb @@ -342,7 +342,7 @@ }, "outputs": [], "source": [ - "plt.rcParams.update({'font.size': 18, 'figure.figsize':(8, 6), 'lines.markersize':100})\n", + "plt.rcParams.update({'font.size': 18, 'figure.figsize':(8, 6), 'lines.markersize':10})\n", "plt.scatter(*zip(*points))\n", "plt.xlim(0, 1)\n", "plt.xlabel(\"Cu fraction\")\n", diff --git a/citrination_api_examples/tutorial_sequence/3_IntroQueries.ipynb b/citrination_api_examples/tutorial_sequence/3_IntroQueries.ipynb index c71d668..0998467 100644 --- a/citrination_api_examples/tutorial_sequence/3_IntroQueries.ipynb +++ b/citrination_api_examples/tutorial_sequence/3_IntroQueries.ipynb @@ -526,7 +526,8 @@ " filter=ChemicalFilter(equal='AlxCuy')),\n", " properties=PropertyQuery(\n", " name=FieldQuery(\n", - " filter=[Filter(equal=\"Formation energy\"), Filter(equal=\"Enthalpy of Formation\")]),\n", + " filter=[Filter(equal=\"Formation energy\"), \n", + " Filter(equal=\"Enthalpy of Formation\")]),\n", " value=FieldQuery(\n", " extract_as=\"formation_enthalpy\")))))\n", "\n", diff --git a/citrination_api_examples/tutorial_sequence/4_MLonCitrination.ipynb b/citrination_api_examples/tutorial_sequence/4_MLonCitrination.ipynb index fd3dd09..fe3d946 100644 --- a/citrination_api_examples/tutorial_sequence/4_MLonCitrination.ipynb +++ b/citrination_api_examples/tutorial_sequence/4_MLonCitrination.ipynb @@ -174,7 +174,7 @@ "output_type": "stream", "text": [ "We found 500 records.\n", - "[{'density': ['2.849276907145639'], 'formula': 'LiFeSiO4'}, {'density': ['2.6366293129465217'], 'formula': 'K3NiO2'}]\n" + "[{'density': ['4.032147167706144'], 'formula': 'Li2ZrO3'}, {'density': ['5.067462110078542'], 'formula': 'KSr2Cd2Sb3'}]\n" ] } ], @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "slideshow": { "slide_type": "fragment" @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "slideshow": { "slide_type": "fragment" @@ -310,13 +310,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "slideshow": { "slide_type": "fragment" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "We found 5000 records.\n", + "[{'density': ['4.032147167706144'], 'formula': 'Li2ZrO3'}, {'density': ['5.067462110078542'], 'formula': 'KSr2Cd2Sb3'}]\n" + ] + } + ], "source": [ "dataset_id = 150675\n", "query_size = 5000\n", @@ -352,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "slideshow": { "slide_type": "fragment" @@ -363,7 +372,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "We predict the density of AlCu to be 8.1988 +/- 1.2069.\n" + "We predict the density of AlCu to be 8.0285 +/- 1.2208.\n" ] } ], @@ -427,14 +436,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Highest density compound is CuAu3Al96 with rho = 14.1870 +/- 0.9619.\n" + "Highest density compound is PtRe2RhAl96 with rho = 14.3593 +/- 1.3423.\n" ] } ], @@ -451,6 +460,14 @@ " print(\"Highest density compound is {0} with rho = {1:.4f} +/- {2:.4f}.\".format(\n", " best['formula'], best['value'], best['loss']))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "This concludes the tutorial sequence on working with DFT data." + ] } ], "metadata": {