Large Language Models (LLM) are all the buzz right now. They are used for a variety of tasks, including text classification, question answering, and text generation. In this tutorial, we will show how to conformalize a transformer language model for text classification using [ConformalPrediction.jl](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/)
.
👀 At a Glance
In particular, we are interested in the task of intent classification as illustrated in the sketch below. Firstly, we feed a customer query into an LLM to generate embeddings. Next, we train a classifier to match these embeddings to possible intents. Of course, for this supervised learning problem we need training data consisting of inputs – queries – and outputs – labels indicating the true intent. Finally, we apply Conformal Predition to quantify the predictive uncertainty of our classifier.
Conformal Prediction (CP) is a rapidly emerging methodology for Predictive Uncertainty Quantification. If you’re unfamiliar with CP, you may want to first check out my 3-part introductory series on the topic starting with this post.
🤗 HuggingFace
We will use the Banking77 dataset (Casanueva et al., 2020), which consists of 13,083 queries from 77 intents related to banking. On the model side, we will use the DistilRoBERTa model, which is a distilled version of RoBERTa (Liu et al., 2019) fine-tuned on the Banking77 dataset.
The model can be loaded from HF straight into our running Julia session using the [Transformers.jl](https://github.com/chengchingwen/Transformers.jl/tree/master)
package.
This package makes working with HF models remarkably easy in Julia. Kudos to the devs! 🙏
Below we load the tokenizer tkr
and the model mod
. The tokenizer is used to convert the text into a sequence of integers, which is then fed into the model. The model outputs a hidden state, which is then fed into a classifier to get the logits for each class. Finally, the logits are then passed through a softmax function to get the corresponding predicted probabilities. Below we run a few queries through the model to see how it performs.
# Load model from HF 🤗:
tkr = hgf"mrm8488/distilroberta-finetuned-banking77:tokenizer"
mod = hgf"mrm8488/distilroberta-finetuned-banking77:ForSequenceClassification"
# Test model:
query = [
"What is the base of the exchange rates?",
"Why is my card not working?",
"My Apple Pay is not working, what should I do?",
]
a = encode(tkr, query)
b = mod.model(a)
c = mod.cls(b.hidden_state)
d = softmax(c.logit)
[labels[i] for i in Flux.onecold(d)]
3-element Vector{String}:
"exchange_rate"
"card_not_working"
"apple_pay_or_google_pay"
🔁 MLJ
Interface
Since our package is interfaced to [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/)
, we need to define a wrapper model that conforms to the MLJ
interface. In order to add the model for general use, we would probably go through [MLJFlux.jl](https://github.com/FluxML/MLJFlux.jl)
, but for this tutorial, we will make our life easy and simply overload the MLJBase.fit
and MLJBase.predict
methods.
Since the model from HF is already pre-trained and we are not interested in further fine-tuning, we will simply return the model object in the MLJBase.fit
method. The MLJBase.predict
method will then take the model object and the query and return the predicted probabilities. We also need to define the MLJBase.target_scitype
and MLJBase.predict_mode
methods. The former tells MLJ
what the output type of the model is, and the latter can be used to retrieve the label with the highest predicted probability.
struct IntentClassifier <: MLJBase.Probabilistic
tkr::TextEncoders.AbstractTransformerTextEncoder
mod::HuggingFace.HGFRobertaForSequenceClassification
end
function IntentClassifier(;
tkr::TextEncoders.AbstractTransformerTextEncoder,
mod::HuggingFace.HGFRobertaForSequenceClassification,
)
IntentClassifier(tkr, mod)
end
function get_hidden_state(clf::IntentClassifier, query::Union{AbstractString, Vector{<:AbstractString}})
token = encode(clf.tkr, query)
hidden_state = clf.mod.model(token).hidden_state
return hidden_state
end
# This doesn't actually retrain the model, but it retrieves the classifier object
function MLJBase.fit(clf::IntentClassifier, verbosity, X, y)
cache=nothing
report=nothing
fitresult = (clf = clf.mod.cls, labels = levels(y))
return fitresult, cache, report
end
function MLJBase.predict(clf::IntentClassifier, fitresult, Xnew)
output = fitresult.clf(get_hidden_state(clf, Xnew))
p̂ = UnivariateFinite(fitresult.labels,softmax(output.logit)',pool=missing)
return p̂
end
MLJBase.target_scitype(clf::IntentClassifier) = AbstractVector{<:Finite}
MLJBase.predict_mode(clf::IntentClassifier, fitresult, Xnew) = mode.(MLJBase.predict(clf, fitresult, Xnew))
To test that everything is working as expected, we fit the model and generated predictions for a subset of the test data:
clf = IntentClassifier(tkr, mod)
top_n = 10
fitresult, _, _ = MLJBase.fit(clf, 1, nothing, y_test[1:top_n])
@time ŷ = MLJBase.predict(clf, fitresult, queries_test[1:top_n]);
6.818024 seconds (11.29 M allocations: 799.165 MiB, 2.47% gc time, 91.04% compilation time)
Note that even though the LLM we’re using here isn’t really that large at all, even a simple forward pass does take considerable time.
🤖 Conformal Chatbot
To turn the wrapped, pre-trained model into a conformal intent classifier, we can now rely on standard API calls. We first wrap our atomic model where we also specify the desired coverage rate and method. Since even simple forward passes are computationally expensive for our (small) LLM, we rely on Simple Inductive Conformal Classification.
conf_model = conformal_model(clf; coverage=0.95, method=:simple_inductive, train_ratio=train_ratio)
mach = machine(conf_model, queries, y)
Finally, we use our conformal LLM to build a simple yet powerful chatbot that runs directly in the Julia REPL. Without dwelling on the details too much, the conformal_chatbot
works as follows:
- Prompt user to explain their intent.
- Feed user input through conformal LLM and present the output to the user.
- If the Conformal Prediction set includes more than one label, prompt the user to either refine their input or choose one of the options included in the set.
The following code implements these ideas:
function prediction_set(mach, query::String)
p̂ = MLJBase.predict(mach, query)[1]
probs = pdf.(p̂, collect(1:77))
in_set = findall(probs .!= 0)
labels_in_set = labels[in_set]
probs_in_set = probs[in_set]
_order = sortperm(-probs_in_set)
plt = UnicodePlots.barplot(labels_in_set[_order], probs_in_set[_order], title="Possible Intents")
return labels_in_set, plt
end
function conformal_chatbot()
println("👋 Hi, I'm a Julia, your conformal chatbot. I'm here to help you with your banking query. Ask me anything or type 'exit' to exit ...n")
completed = false
queries = ""
while !completed
query = readline()
queries = queries * "," * query
labels, plt = prediction_set(mach, queries)
if length(labels) > 1
println("🤔 Hmmm ... I can think of several options here. If any of these applies, simply type the corresponding number (e.g. '1' for the first option). Otherwise, can you refine your question, please?n")
println(plt)
else
println("🥳 I think you mean $(labels[1]). Correct?")
end
# Exit:
if query == "exit"
println("👋 Bye!")
break
end
if query ∈ string.(collect(1:77))
println("👍 Great! You've chosen '$(labels[parse(Int64, query)])'. I'm glad I could help you. Have a nice day!")
completed = true
end
end
end
Below we show the output for two example queries. The first one is very ambiguous (and misspelled as I just realised): "transfer mondey?". As expected, the size of the prediction set is therefore large.
ambiguous_query = "transfer mondey?"
prediction_set(mach, ambiguous_query)[2]
Possible Intents
┌ ┐
beneficiary_not_allowed ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.150517
balance_not_updated_after_bank_transfer ┤■■■■■■■■■■■■■■■■■■■■■■ 0.111409
transfer_into_account ┤■■■■■■■■■■■■■■■■■■■ 0.0939535
transfer_not_received_by_recipient ┤■■■■■■■■■■■■■■■■■■ 0.091163
top_up_by_bank_transfer_charge ┤■■■■■■■■■■■■■■■■■■ 0.089306
failed_transfer ┤■■■■■■■■■■■■■■■■■■ 0.0888322
transfer_timing ┤■■■■■■■■■■■■■ 0.0641952
transfer_fee_charged ┤■■■■■■■ 0.0361131
pending_transfer ┤■■■■■ 0.0270795
receiving_money ┤■■■■■ 0.0252126
declined_transfer ┤■■■ 0.0164443
cancel_transfer ┤■■■ 0.0150444
└ ┘
The following is a more refined version of the prompt: "I tried to transfer money to my friend, but it failed". It yields a smaller prediction set, since less ambiguous prompts result in lower predictive uncertainty.
refined_query = "I tried to transfer money to my friend, but it failed."
prediction_set(mach, refined_query)[2]
Possible Intents
┌ ┐
failed_transfer ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.59042
beneficiary_not_allowed ┤■■■■■■■ 0.139806
transfer_not_received_by_recipient ┤■■ 0.0449783
balance_not_updated_after_bank_transfer ┤■■ 0.037894
declined_transfer ┤■ 0.0232856
transfer_into_account ┤■ 0.0108771
cancel_transfer ┤ 0.00876369
└ ┘
The video below shows the REPL-based chatbot in action. You can recreate this yourself and run the bot right from you terminal. To do so, just check out the original post on my blog to find the full source code.
🌯 Wrapping Up
This work was done in collaboration with colleagues at ING as part of the ING Analytics 2023 Experiment Week. Our team demonstrated that Conformal Prediction provides a powerful and principled alternative to top-K intent classification. We won the first prize by popular vote.
Of course, there are a lot of things that can be improved here. As far as Large LMs are concerned, we have used a small one. In terms of Conformal Prediction, we have only looked at simple inductive conformal classification. This is a good starting point, but there are more advanced methods available, which are implemented in the package and were investigated during the competition. Another thing we did not take into consideration here is that we have many outcome classes and may in practice be interested in achieving class-conditional coverage. Stay tuned for more on this in future posts.
If you’re interested in finding out more about Conformal Prediction in Julia, go ahead and check out the repo and docs.
🎉 JuliaCon 2023 is around the corner and this year I will be giving a talk about ConformalPrediction.jl. Check out the details of my talk here and have a look at the full jam-packed conference schedule.
🎓 References
Casanueva, Iñigo, Tadas Temčinas, Daniela Gerz, Matthew Henderson, and Ivan Vulić. 2020. "Efficient Intent Detection with Dual Sentence Encoders." In Proceedings of the 2nd Workshop on Natural Language Processing for Conversational AI , 38–45. Online: Association for Computational Linguistics. https://doi.org/10.18653/v1/2020.nlp4convai-1.5.
Liu, Yinhan, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. 2019. "RoBERTa: A Robustly Optimized BERT Pretraining Approach." arXiv. https://doi.org/10.48550/arXiv.1907.11692.
💾 Data and Model
The Banking77 dataset was retrieved from HuggingFace. It is published under the Creative Commons Attribution 4.0 International license (CC BY 4.0) and curated by PolyAI and was originally published by Casanueva et al. (2020). With thanks also to Manuel Romero who contributed the fine-tuned DistilRoBERTa to HuggingFace.
Originally published at https://www.paltmeyer.com on July 5, 2023.