Skip to content

Binary Exponentiation

Written by Raouf Ould Ali

Problem: Modular Exponentiation 💡

Consider the following problem:

  • Write a program that outputs the remainder of the Euclidean division of by , where and .

Straightforward Algorithm 📏

The straightforward algorithm would be the following:

cpp
#include <bits/stdc++.h>
using namespace std;
int main(){
    long long N;
    int M;
    cin >> N >> M;

    long long result = 1;
    for(int i = 0; i < N; i++){
        result = result * 3;
    }
    
    cout << result % M << endl;
}

After trying , for example, we may think that we're done with the problem. But when trying with some higher values, we start getting some nonsense values. For instance, with , we get , which doesn't make sense at all! This has a very simple explanation: as becomes larger and larger, it exceeds the maximum integer C++ can store in memory. Do you have any idea about how to solve this issue?

Modular Arithmetic to the Rescue 🦸

Remember that we're not searching for the exact value, but only its remainder when divided by . Hence, we can use the fact that:

This allows us to perform the computations by replacing at each step the variable result by its division remainder by :

cpp
#include <bits/stdc++.h>
using namespace std;
int main(){
    long long N;
    int M;
    cin >> N >> M;

    long long result = 1;
    for(int i = 0; i < N; i++){
        result = result * 3;
        result = result % M;
    }
    
    cout << result << endl;
}

Retrying gives us the intended solution (2001). Let's goooo, we solved it! 🎉

Optimizing for Large 🚀

Remember that the constraints say that 's upper bound is , but our algorithm obviously runs in , which is too slow (You can notice it yourself trying ). Can you figure out a solution?

Instead of multiplying times the result variable by , you can notice the following facts:

This allows us to perform our calculation in instead of , as to compute , we only need to compute once and just square it, as shown in the algorithm below:

cpp
#include <bits/stdc++.h>
using namespace std;

long long ThreePow(long long k, int M){
    if(k == 0) return 1;
    long long subRes = ThreePow(k / 2, M); // If k is odd, k/2 computes its floor, aka (2j+1)/2 = j
    if(k % 2) // k is odd
        return (((subRes * subRes) % M) * 3) % M;
    else // k is even
        return (subRes * subRes) % M;
}


int main(){
    long long N;
    int M;
    cin >> N >> M;
    cout << ThreePow(N, M) << endl;
}

As we're dividing each time by , it is obvious that the number of calls to our function is equal to , giving a time complexity of . (You can try some test cases by hand on a piece of paper to figure it out.)

Your Favorite Section: Problems 📚

If You Wanna Read More 📖

These articles give further explanations but also more advanced examples: