Binary Exponentiation
Written by Raouf Ould Ali
Problem: Modular Exponentiation 💡
Consider the following problem:
- Write a program that outputs the remainder of the Euclidean division of by , where and .
Straightforward Algorithm 📏
The straightforward algorithm would be the following:
#include <bits/stdc++.h>
using namespace std;
int main(){
long long N;
int M;
cin >> N >> M;
long long result = 1;
for(int i = 0; i < N; i++){
result = result * 3;
}
cout << result % M << endl;
}After trying , for example, we may think that we're done with the problem. But when trying with some higher values, we start getting some nonsense values. For instance, with , we get , which doesn't make sense at all! This has a very simple explanation: as becomes larger and larger, it exceeds the maximum integer C++ can store in memory. Do you have any idea about how to solve this issue?
Modular Arithmetic to the Rescue 🦸
Remember that we're not searching for the exact value, but only its remainder when divided by . Hence, we can use the fact that:
This allows us to perform the computations by replacing at each step the variable result by its division remainder by :
#include <bits/stdc++.h>
using namespace std;
int main(){
long long N;
int M;
cin >> N >> M;
long long result = 1;
for(int i = 0; i < N; i++){
result = result * 3;
result = result % M;
}
cout << result << endl;
}Retrying gives us the intended solution (2001). Let's goooo, we solved it! 🎉
Optimizing for Large 🚀
Remember that the constraints say that 's upper bound is , but our algorithm obviously runs in , which is too slow (You can notice it yourself trying ). Can you figure out a solution?
Instead of multiplying times the result variable by , you can notice the following facts:
This allows us to perform our calculation in instead of , as to compute , we only need to compute once and just square it, as shown in the algorithm below:
#include <bits/stdc++.h>
using namespace std;
long long ThreePow(long long k, int M){
if(k == 0) return 1;
long long subRes = ThreePow(k / 2, M); // If k is odd, k/2 computes its floor, aka (2j+1)/2 = j
if(k % 2) // k is odd
return (((subRes * subRes) % M) * 3) % M;
else // k is even
return (subRes * subRes) % M;
}
int main(){
long long N;
int M;
cin >> N >> M;
cout << ThreePow(N, M) << endl;
}As we're dividing each time by , it is obvious that the number of calls to our function is equal to , giving a time complexity of . (You can try some test cases by hand on a piece of paper to figure it out.)
Your Favorite Section: Problems 📚
- UVA 374 - Big Mod
- LeetCode 1922 - Count Good Numbers
- Codeforces 630I - Again Twenty Five!
- Math in Istanbul (made by me 😉)
If You Wanna Read More 📖
These articles give further explanations but also more advanced examples:

