Skip to content

Instantly share code, notes, and snippets.

@johnchen902
Last active December 25, 2021 06:02
Show Gist options
  • Save johnchen902/44d9c5be53154aec4acf685c41c88a81 to your computer and use it in GitHub Desktop.
Save johnchen902/44d9c5be53154aec4acf685c41c88a81 to your computer and use it in GitHub Desktop.
$O(n\log n)$ Matrix Chain Multiplication
Description
---
Just matrix chain multiplication.
Input format
---
There may be multiple test cases.
Each test case consists of two lines.
On the first line is an integer, $n$, the number of matrices.
On the second line are $n + 1$ integers, the dimension of the matrices.
Output format
---
For each test case, output an integer, the minimum number of multiplications needed.
Constraints
---
Time limit: 1 second
Memory limit: 128 Mib
There are at most $200$ test cases.
For each test case, the *maximum* number of multiplication needed does not exceed $2^63-1$.
Subtask 1 (25%):
---
There are at most $500$ matrices in total.
Subtask 2 (45%):
---
There are at most $20000$ matrices in total.
Subtask 3 (30%):
---
There are no additional constraints.
There are at most $200000$ matrices in total.
Solutions:
---
Subtask 1 can be solved by simple dynamic programming.
For subtask 2 and 3, please read the papers
[Computation of Matrix Chain Products. part I](https://epubs.siam.org/doi/abs/10.1137/0211028) and
[Computation of Matrix Chain Products. Part II](https://epubs.siam.org/doi/abs/10.1137/0213017).
With a mergable priority queue, the algorithm can be implemented in $O(n\log n)$.
Otherwise, it can be implemented in $O(n^2)$.
#include <algorithm>
#include <cstdio>
#include <limits>
#include <vector>
long solve(const std::vector<long> &a) {
size_t n = a.size() - 1;
std::vector<long> dp(n * n);
for (size_t i = n - 2; i != (size_t) -1; i--) {
for (size_t j = i + 1; j < n; j++) {
long ans = std::numeric_limits<long>::max();
for (size_t k = i; k < j; k++) {
ans = std::min(ans,
dp[i * n + k] +
dp[j * n + k + 1] +
a[i] * a[k + 1] * a[j + 1]);
}
dp[i * n + j] = dp[j * n + i] = ans;
}
}
return dp[0 * n + n - 1];
}
int main() {
size_t n;
while (std::scanf("%zu", &n) == 1) {
n++;
std::vector<long> a(n);
for (size_t i = 0; i < n; i++)
scanf("%ld", &a[i]);
std::printf("%ld\n", solve(a));
}
}
#include <algorithm>
#include <cstdio>
#include <list>
#include <vector>
struct Fraction {
long a, b;
Fraction (long aa, long bb = 1) : a(aa), b(bb) {}
bool operator < (Fraction rhs) const {
return (__int128) a * rhs.b < (__int128) b * rhs.a;
}
bool operator > (Fraction rhs) const { return rhs < *this; }
bool operator <= (Fraction rhs) const { return !(rhs < *this); }
bool operator >= (Fraction rhs) const { return !(*this < rhs); }
};
struct Arc {
size_t first, second;
std::list<Arc *> ceiling;
std::list<Arc *>::iterator hm_it;
long numerator, denominator;
Fraction support() const {
return Fraction(numerator, denominator);
}
void init() {
hm_it = std::max_element(ceiling.begin(), ceiling.end(),
[](const Arc *lhs, const Arc *rhs) {
return lhs->support() < rhs->support();
});
}
const Arc *get_hm() const {
return *hm_it;
}
Arc *get_hm() {
return const_cast<Arc *>(const_cast<const Arc *>(this)->get_hm());
}
void merge_hm() {
auto it = hm_it;
Arc *hm = *it;
it = ceiling.erase(it);
ceiling.splice(it, hm->ceiling);
hm_it = std::max_element(ceiling.begin(), ceiling.end(),
[](const Arc *lhs, const Arc *rhs) {
return lhs->support() < rhs->support();
});
}
};
void get_arcs(const std::vector<long> &a, std::vector<Arc> &arcs) {
size_t n = a.size() - 1;
std::vector<long> v;
std::vector<Arc *> w;
for (size_t i = 0, j = 0; i <= n && j < n - 3; i++) {
while (j < n - 3 && v.size() >= 2 && a[i] <= a[v.back()]) {
arcs[j].first = v[v.size() - 2];
arcs[j].second = i;
while (!w.empty() &&
arcs[j].first <= w.back()->first &&
w.back()->second <= arcs[j].second) {
arcs[j].ceiling.push_front(w.back());
w.pop_back();
}
w.push_back(&arcs[j]);
j++;
v.pop_back();
}
v.push_back(i);
}
arcs[n - 3].first = 0;
arcs[n - 3].second = n;
arcs[n - 3].ceiling.assign(w.begin(), w.end());
}
long solve(std::vector<long> a) {
size_t n = a.size();
if (n <= 2)
return 0;
if (n == 3)
return a[0] * a[1] * a[2];
std::rotate(a.begin(), std::min_element(a.begin(), a.end()), a.end());
a.push_back(a[0]);
std::vector<long> accum(n + 1);
for (size_t i = 1; i <= n; i++)
accum[i] = accum[i - 1] + a[i] * a[i - 1];
std::vector<Arc> arcs(n - 2);
get_arcs(a, arcs);
long ans = 0;
for (Arc &arc : arcs) {
if (arc.first + 2 == arc.second) {
// leaf nodes
arc.numerator = a[arc.first] * a[arc.first + 1] * a[arc.second];
arc.denominator = accum[arc.second] - accum[arc.first] -
a[arc.first] * a[arc.second];
ans += arc.numerator;
continue;
}
arc.init();
arc.denominator = accum[arc.second] - accum[arc.first] -
a[arc.first] * a[arc.second];
for (Arc *jp : arc.ceiling) {
size_t jf = jp->first, js = jp->second;
arc.denominator -= accum[js] - accum[jf] - a[jf] * a[js];
}
// step 1
while (!arc.ceiling.empty() && arc.get_hm()->support() >=
std::min(a[arc.first], a[arc.second])) {
Arc &hm = *arc.get_hm();
ans -= hm.numerator;
arc.denominator += hm.denominator;
arc.merge_hm();
}
// calculate support
size_t c1 = a[arc.first] <= a[arc.second] ? arc.first : arc.second;
size_t c2 = c1 == 0 ? n : c1;
arc.numerator = arc.denominator + a[arc.first] * a[arc.second];
if (arc.first == c1)
arc.numerator -= a[c1] * a[c1 + 1];
if (arc.second == c2)
arc.numerator -= a[c2] * a[c2 - 1];
if (!arc.ceiling.empty()) {
Arc *jp = arc.ceiling.front();
if (jp->first == c1) {
arc.numerator += a[c1] * a[c1 + 1];
arc.numerator -= a[jp->first] * a[jp->second];
}
Arc *jq = arc.ceiling.back();
if (jq->second == c2) {
arc.numerator += a[c2] * a[c2 - 1];
arc.numerator -= a[jq->first] * a[jq->second];
}
}
arc.numerator *= a[c1];
ans += arc.numerator;
// step 2
while (!arc.ceiling.empty() &&
arc.support() <= arc.get_hm()->support()) {
Arc &hm = *arc.get_hm();
arc.numerator += hm.numerator;
arc.denominator += hm.denominator;
arc.merge_hm();
}
}
return ans;
}
template<typename T>
int strict_read_int(T &ref) {
int c = std::getchar();
if (c == EOF)
return EOF;
T t = c - '0';
while ((c = std::getchar()) >= '0' && c <= '9')
t = t * 10 + c - '0';
ref = t;
return 1;
}
int main() {
for (size_t n; strict_read_int(n) == 1; ) {
n++;
std::vector<long> a(n);
for (size_t i = 0; i < n; i++)
strict_read_int(a[i]);
std::printf("%ld\n", solve(a));
}
}
--- quadratic.cpp 2018-05-17 15:11:14.723852886 +0800
+++ linearithmic.cpp 2018-05-17 15:10:39.373853655 +0800
@@ -3,6 +3,9 @@
#include <list>
#include <vector>
+#include <ext/pool_allocator.h>
+#include <ext/pb_ds/priority_queue.hpp>
+
struct Fraction {
long a, b;
Fraction (long aa, long bb = 1) : a(aa), b(bb) {}
@@ -14,38 +17,51 @@
bool operator >= (Fraction rhs) const { return !(*this < rhs); }
};
+struct Arc;
+
+struct Arc_ite_cmp {
+ bool operator () (const std::list<Arc *>::iterator lhs,
+ const std::list<Arc *>::iterator rhs) const;
+};
+
+template<typename T>
+using Alloc = __gnu_cxx::__pool_alloc<T>;
struct Arc {
size_t first, second;
- std::list<Arc *> ceiling;
- std::list<Arc *>::iterator hm_it;
+ std::list<Arc *, Alloc<Arc *>> ceiling;
+ using iterator = decltype(ceiling)::iterator;
+ __gnu_pbds::priority_queue<iterator, Arc_ite_cmp,
+ __gnu_pbds::pairing_heap_tag, Alloc<iterator>> pq;
long numerator, denominator;
Fraction support() const {
return Fraction(numerator, denominator);
}
void init() {
- hm_it = std::max_element(ceiling.begin(), ceiling.end(),
- [](const Arc *lhs, const Arc *rhs) {
- return lhs->support() < rhs->support();
- });
+ for (auto it = ceiling.begin(); it != ceiling.end(); ++it)
+ pq.push(it);
}
const Arc *get_hm() const {
- return *hm_it;
+ return *pq.top();
}
Arc *get_hm() {
return const_cast<Arc *>(const_cast<const Arc *>(this)->get_hm());
}
void merge_hm() {
- auto it = hm_it;
+ auto it = pq.top();
Arc *hm = *it;
+ pq.pop();
+ pq.join(hm->pq);
it = ceiling.erase(it);
ceiling.splice(it, hm->ceiling);
- hm_it = std::max_element(ceiling.begin(), ceiling.end(),
- [](const Arc *lhs, const Arc *rhs) {
- return lhs->support() < rhs->support();
- });
}
};
+inline bool Arc_ite_cmp::operator () (
+ const std::list<Arc *>::iterator lhs,
+ const std::list<Arc *>::iterator rhs) const {
+ return (*lhs)->support() < (*rhs)->support();
+}
+
void get_arcs(const std::vector<long> &a, std::vector<Arc> &arcs) {
size_t n = a.size() - 1;
std::vector<long> v;
#!/usr/bin/env python3
import argparse, random
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-t", default=200, type=int,
help="number of test cases")
parser.add_argument("-n", default=1000, type=int,
help="number of matrices")
parser.add_argument("-c", default=4000, type=int,
help="max columns/rows")
parser.add_argument("-s", default=None, type=int,
help="random seed")
args = parser.parse_args()
if args.s is not None:
random.seed(args.s, version=2)
for _ in range(args.t):
print(args.n)
print(*(random.randint(1, args.c) for _ in range(args.n + 1)))
#!/usr/bin/env python3
import argparse, itertools
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-n", default=200000, type=int,
help="number of matrices")
args = parser.parse_args()
l1 = [1, 2]
l2 = [2]
for i in range(args.n - 2):
num = 3 + i // 2
if i % 4 == 0:
l1.insert(-1, num * 2)
elif i % 4 == 1:
l2.append(num)
elif i % 4 == 2:
l2.insert(-1, num * 2)
else:
l1.append(num)
print(args.n)
print(*itertools.chain(l1, reversed(l2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment