Using Create ML on iOS to auto-complete forms

Saved you a click

Nothing frustrates me more than filling out a long form on my phone. The keyboard is too small and questions always feel redundant. Streamlining UX with helpful suggestions and auto-completion is a great way to reduce friction, increase conversion, and make your app feel more fluid.

Amazon realized this early, resulting in their infamous “1-click” patent. Google reports that over 12% of all emails sent are composed via their single click Smart Reply option. In this tutorial, we’ll use Apple’s Create ML tool to train our own machine learning model to anticipate a user’s intentions, automatically complete a form, and save them a click in the process.

I’ll be using Reddit’s link submission flow as an example, but the same process can be applied to forms in your own app, provided you have some data on past responses by users.

Jump straight to the code on GitHub.

Scoping out the problem

Submitting a link to Reddit currently involves three steps. Tap the icon to create a post and select the post type (link, image, etc.), search for and select the subreddit you’d like to post to, and type out the title for your post. It’s not the most onerous UX in the world, but it’s not perfect. We can do better with that middle step, selecting a subreddit.

Subreddits are topical and many have their own unique conventions and syntax for post titles. In theory, we should be able to guess where a post belongs to based on the title, thus saving users a few clicks. In our improved UX, a user will simply enter a post title, and we’ll suggest a subreddit for their post and fill out the form automatically.

Now that we have our desired UX in mind, we need to frame this problem like a machine learning engineer. We want to predict which subreddit a post belongs to based on its title.

Reddit itself contains a whole bunch of posts grouped by their proper subreddit labels, so we can frame this as a classification problem. We’ll train a model to identify which class (subreddit) an example (post title) belongs to based on observations where the proper class labels are known (historical reddit posts).

This is a fairly common task in natural language processing, and it turns out Apple’s new Create ML tool provides exactly what we’re looking for with their MLTextClassifier model.

What you’ll need

  • Xcode 10.0 or Later. This version of Xcode includes Create ML and Swift Playgrounds.
  • A little bit of Python. While more and more data science tools are coming to Swift, it’s still not quite an end-to-end language for us. We’ll use some basic Python code and the popular Pandas library to gather and clean data. If you’re not familiar with Python, don’t worry, this won’t be too complicated.
  • About 30 minutes.

Getting some data

In order to train our model we need labeled training data. That means a bunch of post titles and the subreddits they belong to. Thankfully, Reddit’s API is fantastic and will supply us with all the data we need.

Our choice of training data will impact how well our model performs and introduce some biases. Reddit provides multiple ways of sorting posts in a subreddit. We can ask the API for new posts, popular posts within a time period, controversial ones, etc. What’s the best data to scrape? Using popular posts from this week might bias our model to more recent events. Using only new submissions might include posts that don’t belong and will never be upvoted.

There are no perfect answers, but top posts within the past year should give us a reasonable cross section of what belongs in each subreddit. Here is what that request looks like for the r/machinelearning subreddit:

https://www.reddit.com/r/machinelearning/top.json?t=year

By default, the API returns the top 25 posts, but we can increase this to 100 with the limit=100 url parameter. Each JSON response includes an after field that’s used to get additional pages of results. For example, this URL gets the next 100 posts after the post with the identifier t3_81050c.

https://www.reddit.com/r/machinelearning/top.json?t=year&limit=100&after=t3_81050c

This gist contains code that takes a list of subreddits and collects the top 1000 post titles from each over the past year. I picked 28 subreddits made up primarily of the most popular communities.

import requests
import json

def get_listing(subreddit, sort='top', after=None, limit=25):
    # Set the user agent on the request so we don't get rate limited
    headers = {'User-agent': 'Subreddit Title Bot'}
    url_fmt = 'https://www.reddit.com/r/{subreddit}/{sort}.json'
    url = url_fmt.format(subreddit=subreddit, sort=sort)
    params = {'after': after, 'limit': limit, 't': 'year'}
    response = requests.get(url, params=params, headers=headers)
    print(response.url)
    return response.json()

def parse_listing(data):
    items = []
    for child in data['data']['children']:
        post = child['data']
        item = (post['subreddit'], post['title'])
        items.append(item)
    after = data['data']['after']
    return items, after

def get_n_posts(num_posts, subreddit, sort='top', limit=100):
    posts = []
    after = None
    while len(posts) < num_posts:
        data = get_listing(subreddit, sort=sort, after=after, limit=limit)
        items, after = parse_listing(data)
        # Only keep as many posts as we need
        keep = min(num_posts - len(posts), len(items))
        posts.extend(items[:keep])
        if not after:
            break
    return posts

# Top 25 subreddits plus a few others
subreddits = [
    'machinelearning',
    'androiddev',
    'iosprogramming',
    'learnmachinelearning',
    'datascience',
    'funny',
    'AskReddit',
    'todayilearned',
    'science',
    'worldnews',
    'pics',
    'IAmA',
    'gaming',
    'videos',
    'movies',
    'aww',
    'Music',
    'blog',
    'gifs',
    'explainlikeimfive',
    'askscience',
    'EarthPorn',
    'books',
    'television',
    'mildlyinteresting',
    'LifeProTips',
    'Showerthoughts',
    'space'
]

# Get the top 1000 post titles from each subreddit
max_posts = 1000
posts = []
for subreddit in subreddits:
    posts.extend(get_n_posts(max_posts, subreddit, sort='top'))

The last step before training is cleaning and preprocessing. Many subreddits have their own posting conventions that use patterns like [tag] to flag certain topics. We want to make sure the model is making suggestions based on the content of titles, not these syntactic markers, so we’ll remove these tags along with all other punctuation. This requires a bit of Pandas and regex fu. Finally, we’ll save the data in the JSON format expected by Create ML.

import pandas
import re
import json

# Use pandas and regex to clean up the post titles.
df = pandas.DataFrame(posts, columns=['subreddit', 'title'])
# Remove any [tag] markers in a post title
df.title = df.title.apply(lambda x: re.sub(r'[.*]', '', x))
# Remove all other punctuation except spaces
df.title = df.title.apply(lambda x: re.sub(r'W(?<![ ])', '', x))

# Save the data in the exact format CreateML expects
output = []
for idx, row in df.iterrows():
    output.append({'text': row.title, 'label': row.subreddit})

filename = 'PATH/TO/data.json'
with open(filename, 'w') as fid:
    fid.write(json.dumps(output))

Training our model

It’s time to leave Python behind and switch over to Swift, where we’ll use Apple’s Create ML to train a mobile-ready subreddit suggester with our data. Create ML provides high-level APIs that make it easy to train machine learning models for common tasks like text classification. While there isn’t much flexibility (yet), in about 30 lines of code, we can train a decent model.

To get started, create a new Swift Playground in Xcode. Make sure you select the macOS option for the Playground, as that’s the platform where Create ML is installed.

Here’s a gist with all the code you’ll need to train and export your model. I’ve included some instructions and sample output in the comments. You can learn more about MLTextClassifier and other Create ML models here.

import CreateML
import Foundation

// Load our data into an MLDataTable object.
let dataFilename = "PATH/TO/data.json"
let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataFilename))
print(data.description)
/*
Columns:
    label	string
    text	string
Rows: 26985
Data:
+----------------+----------------+
| label          | text           |
+----------------+----------------+
| MachineLearn...|  Realtime mu...|
| MachineLearn...|  Keras Imple...|
| MachineLearn...|  Generative ...|
| MachineLearn...|  Landing the...|
| MachineLearn...|  If you had ...|
| MachineLearn...|  Realtime Ma...|
| MachineLearn...|  Dedicated t...|
| MachineLearn...|  StarGAN Uni...|
| MachineLearn...|  Deep Image ...|
| MachineLearn...|  Overview of...|
+----------------+----------------+
[26985 rows x 2 columns]
*/

// Split the dataset into two parts, training and testing.
// We make sure to hold out some data for testing so we can
// identify overfitting.
let (trainingData, testingData) = data.randomSplit(by: 0.8, seed: 5)

// Train the model itself.
let subredditClassifier = try MLTextClassifier(trainingData: trainingData,
                                               textColumn: "text",
                                               labelColumn: "label")

// Training accuracy as a percentage
let trainingAccuracy = (1.0 - subredditClassifier.trainingMetrics.classificationError) * 100
print("Training Accuracy: (trainingAccuracy)")
/*
Training Accuracy: 99.37782530501143
This is really high and suggests the model is over fitting.
*/

// Evaluate the model on the testing data we kept secret from th emodel.
let evaluationMetrics = subredditClassifier.evaluation(on: testingData)

// Evaluation accuracy as a percentage
let evaluationAccuracy = (1.0 - evaluationMetrics.classificationError) * 100
print("Evaluation Accuracy: (evaluationAccuracy)")
/*
Evaluation Accuracy: 63.9894419306184
Much lower than our training accuracy, but not bad considering there are 28
potential subreddits to choose from.
*/

// Test the model on a single example
let title = "Saw this good boy at the park today with TensorFlow."
let predictedSubreddit = try subredditClassifier.prediction(from: title)
print("Suggested subreddit: r/(predictedSubreddit)")
/*
Suggested subreddit: r/aww
*/

// Add some metadata
let metadata = MLModelMetadata(author: "Jameson Toole",
                               shortDescription: "Predict which subreddit a post should go in based on a title.",
                               version: "1.0")
// Save the model
try subredditClassifier.write(to: URL(fileURLWithPath: "PATH/TO/subredditClassifier.mlmodel"),
                              metadata: metadata)

One quick aside for those just getting started with machine learning. Notice that we break out dataset into two parts so that 80% of our data belongs to a training set and 20% to a testing set. It’s extremely important that the model never sees the testing data during training. This way, we can guard against overfitting and make sure that our model is going to generalize to new data in the future.

We can see clear evidence of overfitting by comparing accuracy on the training data (97%) to the testing data (63%). To improve things in the future, we should probably gather more data, but for now, it’s a decent start.

A single accuracy number alone, though, shouldn’t make us feel comfortable enough to release the model. More testing is needed. To get a better understanding of what the experience is going to be like for users, let’s go back to Python and dig into this accuracy a bit more.

Testing it out

Imagine we rolled out this autosuggestion feature today, but in a parallel universe. In this parallel universe, the exact same set of users submits the exact same set of posts, but they submit them to the subreddit suggested to them by our model. How many of the posts in this parallel universe end up in the same subreddit as their counterparts in the real world? This is a good measure of how many users we would have saved a click.

Using the same API scraper as before, grab another 100 posts from each subreddit, but this time, they are the 100 most recent submissions. Unless a top 1000 post was submitted within the last 24 hours (unlikely), there shouldn’t be any overlap with the training data. We’ll predict the subreddit each should be submitted to and see how accurate our model is.

import coremltools

# Scrape the 100 newest posts from each subreddit
max_posts = 100
posts = []
for subreddit in subreddits:
    posts.extend(get_n_posts(max_posts, subreddit, sort='new'))

# Apply the same preprocessing to the data
new_df = pandas.DataFrame(posts, columns=['subreddit', 'title'])
new_df.title = new_df.title.apply(lambda x: re.sub(r'[.*]', '', x))
new_df.title = new_df.title.apply(lambda x: re.sub(r'W(?<![ ])', '', x))

# Load the mlmodel with coremltools for Python. Note you need to be
# on macOS for the predict function to work.
mlmodel = coremltools.models.MLModel('PATH/TO/subredditClassifier.mlmodel')

# Predict a subreddit for each title.
new_df['predicted'] = new_df.title.apply(lambda x: mlmodel.predict({'text': x})['label'])

# Mark the model correct if the predicted matches the actual subreddit
new_df['correct'] = new_df.predicted == new_df.subreddit

# Compute the fraction correct.
new_df.correct.sum() / new_df.shape[0]
# Output: 0.55
# We got 55% correct.

It looks like we correctly suggested the subreddit for 55% of posts! Because subreddits receive posts at different rates, it’s tough to say exactly what fraction of all posts in a day we’d save users a click on, but it’s probably pretty high. Digging even further, let’s take a look at a confusion matrix to figure out where our model is doing well and where it’s going wrong.

Each row denotes the actual subreddit a post ended up in, while columns are subreddits suggested by our model. The color of the square tells us what fraction of posts in row X were predicted to end up in column Y.

Based on our chart, the model does really well with subreddits that have more uniform syntactic structures (e.g. AskReddit titles have question words) as well as subreddits with clue words (e.g. TIL or IAmA). It gets confused with posts going to generic places like r/funny or r/videos and has a tough time distinguishing between subreddits with similar content like r/machinelearning and r/machineslearn.

Finally, I came up with some titles of my own to do some anecdotal testing. I tried to write these with a specific subreddit in mind based on my knowledge of the site. Overall, the model gave me a great suggestion most of the time.

The only mistake was the last example, where I think it saw the word Swift and assumed I was talking about the programming language instead of the pop star.

# Test the model on some titles I wrote with specific subreddits in mind.
sample_titles = [
    'Saw this good boy at the park today.',
    'Latest NIPS submission from OpenAI',
    'TIL you can use Core ML to suggest subreddits to users',
    'I made a tutorial using CreateML AMA',
    'Westworld and Game of Thrones coming to netflix',
    'We park in driveways, but drive on parkways',
    'From the top of Mt. Fuji, Japan',
    "What's the first thing you do every morning?",
    'Taylor Swift announces additional Reputation tour dates'
]

for title in sample_titles:
    result = mlmodel.predict({'text': title})
    print(title, ' | ', result['label'])

# Saw this good boy at the park today.  |  aww
# Latest NIPS submission from OpenAI  |  MachineLearning
# TIL you can use Core ML to suggest subreddits to users  |  todayilearned
# I made a tutorial using CreateML AMA  |  IAmA
# Westworld and Game of Thrones coming to netflix  |  television
# We park in driveways, but drive on parkways  |  Showerthoughts
# From the top of Mt. Fuji, Japan  |  EarthPorn
# What's the first thing you do every morning?  |  AskReddit
# Taylor Swift announces additional Reputation tour dates  |  iOSProgramming

Based on these tests, we can be confident that our model is going to save a large fraction of users a click when submitting posts. We also have some good ideas on where we can make improvements in future versions, like looking for ways to improve performance in generic subreddits like r/funny.

It’d also be nice if the model could output a confidence score to tell us how sure it was of its suggestion. We could use that to decide if the suggestion was worth showing to a user or not. That’s a limitation of Create ML, and it’s something we could include with a little more custom work.

Adding the model to your app

Adding the model to an app requires that you drag the .mlmodel into the navigator and use the Swift classes that Xcode generates to incorporate it into our UX. The following code can be used to call the model.

// Drag your subredditClassifier.mlmodel to the Xcode project navigator.
// Use the model with the following code.
import NaturalLanguage

let subredditPredictor = try NLModel(mlModel: subredditClassifier().model)
subredditPredictor.predictedLabel(for: "TIL you can use Core ML to suggest subreddits to users.")

Final thoughts

Saving users a click here and there can seem small, but these UX improvements can have big impacts on conversion and engagement. Thanks to tools like Apple’s Create ML, you don’t need to be an expert in machine learning save your users some time. I hope more mobile developers feel empowered to make these features a part of their apps.

All of the code can be found in this GitHub repository.

Finally, as a machine learning engineer, it’s a fun exercise to try out tools made for a different developer audience. Swift is a promising language that I plan on learning more of. That said, it’s not quite ready for data science primetime. Here are a few items that are now on my wishlist:

  1. More flexible JSON tools. I initially tried to write my data scraper in Swift, not Python, but the Swift JSONDecorder requires building a complete data model of the payload, which for the Reddit API would have been extremely complicated. There is no option to just create a dictionary. In the end, I found myself back in Python land.
  2. Access to lower-level metrics. Create ML is a great high-level abstraction, but it goes a little bit too far. There aren’t enough tools to troubleshoot a model or test it to the point that I’d feel comfortable putting it into production. For example, we measured training, validation, and testing accuracy directly in Swift, but there’s no easy way to compute something like a confusion matrix without writing a lot of code yourself.
  3. More descriptive output. I wish the model offered by Create ML also provided a prediction confidence. If we knew when the model was unsure, we could build some better fallbacks into the UX.

Discuss this post on Hacker News and Reddit.

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square