Asymmetric Topic Models

Josua Krause, PhD
11 min readJun 29, 2023

If you are searching for information in the internet these days there is a high probability that your search, at least in parts, is fulfilled using Topic Models to provide you with semantically aligned results related to your query. Topic Modeling is a natural language processing technique to group together semantically similar documents. This is often done by embedding documents into a high dimensional space. This latent space has the property that elements (e.g., documents, sentences, queries) are close to each other in the space if their content is semantically similar. We can look up information in this manifold by computing the embedding corresponding to our search query and determining which documents are in its neighborhood. For example, if we query for “highest mountain in the world” we will find documents related to Mount Everest and even the embedding of the name “Mount Everest” will be close to the embedding of the query. This approach is fully symmetric. That is, each query can also be a result and there is no distinction between a question and an answer (i.e., in our example, the query and “Mount Everest” are close to each other so querying for “Mount Everest” would return “highest mountain in the world” if this phrase is included in the candidate set). The reason input queries are typically not returned as results is that they are simply not indexed and made available.

On the contrary, in a conversation, questions and answers or sequences of back-and-forths are typically not interchangeable and their order is important. Take for example this conversation from the book Mostly Harmless [1] by Douglas Adams. The character Random just found a new “AI enabled” version of “The Hitchhiker’s Guide to the Galaxy” (H2G2 for short) which prompts H2G2 to calibrate itself. Part of the calibration is finding out the order of causality, resulting in a conversation that illustrates why it is important that a question comes before its answer for a dialog to make sense:

“Well, you’re sort of…” Random gestured helplessly off into the distance.
“I see, still infinite in extent, but at least we’re homing in on the right dimensional matrix. Good. No, the answer is an orange and two lemons.”
“Lemons?”
“If I have three lemons and three oranges and I lose two oranges and a lemon what do I have left?”
“Huh?”
“OK, so you think that time flows that way, do you? Interesting. Am I still infinite?” it asked, ballooning this way and that in space. “Am I infinite now? How yellow am I?”

As H2G2 is not bound by the same rules regarding causality as Random, it needs to understand that questions have to strictly come before corresponding answers and not the other way around. Since we live in a similar reality, this rule applies to us as well.

If we want to add functionality to a dialog model that allows the model to semantically query arbitrary information to be accessed while formulating its output, Topic Models seem to be a reasonable choice. However, given the sequential nature of dialogs, as demonstrated above, we would want a Topic Model that obeys this sequentiality as well.

Created using MidJourney

For this, we need to overcome the symmetry of Topic Models by creating an “Asymmetric Topic Model”. Instead of one embedding for both queries and answers we choose different embeddings depending on the context a given text is in. So, a parent or question is embedded in a way that lines up with the embedding of a corresponding child or answer. And as quickly as that, we just created (Cross-)Attention. Let me explain. “Attention” [2], the “technology” that makes Transformer Models work, has three components: Q, the query, K, the keys, and V, the result values. In the case of, for example, Self-Attention the keys are the embeddings (let’s call this set Eₖ) of each input token. The model now formulates a query. In the case of Self-Attention the query is exactly one linear transformation away from the original embeddings Eₖ resulting in secondary embeddings Eₚ. The Attention Head then finds the key that is closest to the query and returns V, which in the case of Self-Attention is just the embedding of the input token in Eₖ again. In our case, Eₖ and Eₚ can be arbitrary embeddings and the set of keys K are embeddings from all the documents (or snippets of information) we want the model to have access to. From the perspective of the model the key set K might as well be infinitely large. Keep that in mind for later. Also, for now, let’s not worry too much about what V is going to be, we need to get Q to look up reasonable keys first. Making this concrete, giving a language model the ability to query arbitrary information to generate its output would result in an architecture looking something like this:

An architecture to enable a model to look up information.

Here Mₚ and Mₖ are two separate transformer models that each output their own embedding. Those embeddings now take on the roles of query and key. We use Q to find the closest K in our knowledge database. Again, how the final output of the model is generated is not too concerning for us at the moment. This architecture looks reasonable but do we need all of this complexity? How similar can Mₚ and Mₖ be? Can we even use the same embedding for each? If that was the case, we could just use a regular Topic Model without worrying about making it asymmetric. The sequentiality argument from the beginning suggests we would need two separate embeddings but we cannot know for sure unless we try it out.

In order to verify that an asymmetric lookup embedding pair is actually needed we can test the different approaches on a condensed architecture:

An architecture to train embedding pairings by contrasting real to fake comment pairs.

The task for this model is to distinguish two example pairs of questions and answers. One of them is a real pair taken from the data set that contains a valid answer from our knowledge base to a given question. The other is a random pair of questions and answers. Which side the real example is on is not known by the model. The task for the model to learn is to decide which side contained the real example. Since there are no trainable parameters after the similarity score function this task will in turn train our embeddings to produce close similarity for real examples while keeping invalid examples apart.

We will use BERT [3] as our base model and explore four different settings: 1) the symmetric Topic Model baseline, where Mₚ and Mₖ are the same, 2) the Attention baseline where Mₚ and Mₖ are the same except that Mₚ has a single additional dense layer at the end, 3) the asymmetric Topic Model, where Mₚ and Mₖ are unrestricted, and 4) which is the same as 3) except that the embeddings are the mean of the BERT output embeddings instead of just the [CLS] token output embedding. (The full matrix of experiments with mean aggregation vs. [CLS] only aggregation, dot product vs. cosine similarity, and other hyper parameters was performed but including the outcomes would not offer further insights beyond the above mentioned cases).

Before jumping to analyze the results of the experiments, let’s have a look at the data first. The data for the experiments were obtained from reddit [4] early in 2023 (several months before the proposed API access changes that would limit access to conversation data). To that extent, the python library praw [5] enabled easily downloading comment graphs from reddit. As our training data we will be using the subreddit r/all [6] as source and a combination of r/askscience, r/askreddit, r/explainlikeimfive, and r/todayilearned [7] downloaded during a different time range will serve as test data. This ensures a) that there is no accidental overlap of comment sections, and b) that the learned embedding transfers well to different conversation topics. We only consider comments with at least 2 karma points as valid parent / child pairs (in its training, GPT-2 used comments with at least 3 karma points [8] instead. maybe this is something to investigate in the future but for now this additional refinement is not necessary). Each comment gets 1 karma point by default, which means that comments with only 1 karma point have no external quality verification. Additionally, negative karma can mean different things: the comment could be stating an unpopular opinion, the comment could be factually wrong, or the comment could be nonsense / unrelated to conversation. Some of those categories might be worth including but since there is no way of distinguishing the cases it is safer to just ignore those comment pairings. Note, however, that a comment with negative karma can still show up as parent comment and thus rebuttals to a bad comment are still likely to be included.

Overall, the training data set contains 601,703 comments (605,713 links and 275,808 strong links) and the test data set contains 488,837 comments (490,724 links and 212,256 strong links). We obtain valid conversation pairs in two ways: 1) taking a randomly selected comment as child and its parent as parent (the other way around does not work as comments might not have a child comment or it might have multiple child comments which would force us to randomly select one). 2) randomly walking the comment graph and using the comment pair when stopping the walk. The 2) variety of valid links ensures comments that are closer to the original topic of a post and results in comments with generally higher karma. For generating contrasting invalid comment pairs, we can select two comments at random or pick a valid pair and flip parent and child around. We randomly select combinations of those approaches (one valid and one invalid each) for every row fed to the model. The training epochs consist of 20,000 such rows. For analyzing the performance of the models we use a fixed 10,000 row training validation set from the training data and a 10,000 row test set from the test data. Over 120 epochs the performance of the models evolved as follows:

Train and test performance of the four experiments.

As expected, the symmetric Topic Model baseline (1) lags significantly behind the other approaches. This confirms our hypothesis and shows that we are on the right track. But there is more to see. Choosing the mean aggregation of the BERT output embeddings (4) over the [CLS] token output embedding (3) has essentially no impact on performance and it appears the mean aggregation model (4) is overall less stable (see, e.g., epochs 70 to 80). This is good news, since computing only the [CLS] token output embedding is much cheaper than computing all output embeddings and aggregating them. The Attention baseline (2) lags only minimally behind the unrestricted embeddings suggesting that a single dense layer might be enough in certain cases (the lag is more easily visible in the training performance but it exists in the test performance as well when looking at rolling averages).

So far, we have focused on comparing the different experiments to each other. This helped us confirm our main assumption: (Cross-)Attention performs better than a singular embedding. However, we have not yet discussed the actual performance of the models. Most noticeable, the models appear to overfit slightly on the training data but this can also be explained by the different nature of training and test data which are selected from different subreddits and time-ranges (remember also that the training data used for evaluating does not overlap with the training data used for the actual training).

Furthermore, the overall performance of the good models hovers around 65% accuracy on the test data. This, on its own, is not exactly an impressive performance. However, there are several potential reasons for this, which could be improved in future experiments. As mentioned earlier, we included comments with a karma of 2 which might not be a very strong external validation of the comment quality. In future experiments, increasing the lower limit to 3, as it was the case for training GPT-2, might improve the overall quality of the data.

Additionally, the model is looking at comment pairs out of context. That is, the conversation preceding the comments is missing. Much like predicting the next word of a sentence using only the previous word is hard to do well, so is predicting (or confirming) a child comment given only its parent comment. Solutions to this could come via including further comments up the conversation chain (equivalent to using n-grams for word predictions) or via back-referencing previous comments in the conversation using a similar (Self-)Attention mechanism (equivalent to using Self-Attention for word predictions). I particularly like the second approach for its meta-ness and it is definitely something that I think might be worth exploring further in the future.

However, even with context, comment pairings are sometimes ambiguous. Consider the following pair:

“Chuck Norris has entered the thread”

“Chuck Norris caught COVID and the prognosis is not good. Anyone wanting to say goodbye to the virus should visit the hospital tonight.”

There is no way of telling which of those comments is the parent and which one is the child. As there is no reasonable way of detecting and removing those ambiguous pairs, the performance of the models will always be affected.

Lastly, the architecture that we used for our experiments is condensed down with the goal of only verifying our hypothesis. Now that the hypothesis was confirmed, proceeding with a more complete architecture (e.g., the first one shown above) will very likely yield better overall results by learning the full dialog modeling task end to end.

Bonus

As a short bonus at the end, following are visualizations of the computed embeddings of the final model (using the [CLS] token based configuration). The embeddings are projected to two dimensions via t-SNE [9]. Due to technical limitations only 25,000 randomly selected comments are shown instead of the full comment data set. To keep the spatial relations of the embedded comments equivalent to how the model perceives them, instead of using the L2-Norm as distance function, we are adjusting the dot-product to serve as a distance function:

Formula to convert a dot product into a distance function. `distance(a, b) = e^(-a*b)`.

Both parent and child embeddings are projected together and are distinguished in the visualization using color.

As a secondary visualization, we show how parent and child embeddings are connected. Here, a line is drawn between each parent and child. In order to manage the overplotting and keep connections visible, the color hue is chosen from the angle formed by the lines (from parent to child: e.g., red is towards the left horizontally; green is towards the top vertically; cyan is towards the right horizontally; blue is towards the bottom vertically).

Interestingly, parent and child embeddings form separate distinct regions with no overlaps. Most comment embeddings stay close to each other in their respective region forming what looks like tangled up strings. There are a few notable (round) clusters in the child embeddings, however, those represent just a small amount of data. As those clusters are far away from the rest of the embeddings and collapse into very close embeddings on the parent side (e.g., the top-center child cluster in the training data has almost all of its parent embeddings in a very small region at the left side) it might be variations of semantically equivalent phrases which evoke similar responses in the comment chain.

Visualizations of the embeddings of the training data. In the left visualization, each dot represents one comment positioned by either the child (blue) or the parent (orange) embedding. In the right visualization, lines connect positions of the same comment embedded in both spaces.
Visualizations of the embeddings of the training data.
Visualizations of the embeddings of the test data. In the left visualization, each dot represents one comment positioned by either the child (blue) or the parent (orange) embedding. In the right visualization, lines connect positions of the same comment embedded in both spaces.
Visualizations of the embeddings of the test data.

Links

[1] https://en.wikipedia.org/wiki/Mostly_Harmless
[2] https://arxiv.org/abs/1706.03762
[3] https://arxiv.org/abs/1810.04805
[4] https://www.reddit.com/
[5] https://praw.readthedocs.io/
[6] https://www.reddit.com/r/all/
[7] https://www.reddit.com/r/AskReddit+askscience+explainlikeimfive+todayilearned/
[8] https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
[9] https://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf

--

--

Josua has led Data Science teams focused on deep representation learning, natural language processing, and adaptive learning. His PhD focused on explainable AI.