Last active
December 25, 2021 06:02
-
-
Save johnchen902/44d9c5be53154aec4acf685c41c88a81 to your computer and use it in GitHub Desktop.
$O(n\log n)$ Matrix Chain Multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Description | |
--- | |
Just matrix chain multiplication. | |
Input format | |
--- | |
There may be multiple test cases. | |
Each test case consists of two lines. | |
On the first line is an integer, $n$, the number of matrices. | |
On the second line are $n + 1$ integers, the dimension of the matrices. | |
Output format | |
--- | |
For each test case, output an integer, the minimum number of multiplications needed. | |
Constraints | |
--- | |
Time limit: 1 second | |
Memory limit: 128 Mib | |
There are at most $200$ test cases. | |
For each test case, the *maximum* number of multiplication needed does not exceed $2^63-1$. | |
Subtask 1 (25%): | |
--- | |
There are at most $500$ matrices in total. | |
Subtask 2 (45%): | |
--- | |
There are at most $20000$ matrices in total. | |
Subtask 3 (30%): | |
--- | |
There are no additional constraints. | |
There are at most $200000$ matrices in total. | |
Solutions: | |
--- | |
Subtask 1 can be solved by simple dynamic programming. | |
For subtask 2 and 3, please read the papers | |
[Computation of Matrix Chain Products. part I](https://epubs.siam.org/doi/abs/10.1137/0211028) and | |
[Computation of Matrix Chain Products. Part II](https://epubs.siam.org/doi/abs/10.1137/0213017). | |
With a mergable priority queue, the algorithm can be implemented in $O(n\log n)$. | |
Otherwise, it can be implemented in $O(n^2)$. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cstdio> | |
#include <limits> | |
#include <vector> | |
long solve(const std::vector<long> &a) { | |
size_t n = a.size() - 1; | |
std::vector<long> dp(n * n); | |
for (size_t i = n - 2; i != (size_t) -1; i--) { | |
for (size_t j = i + 1; j < n; j++) { | |
long ans = std::numeric_limits<long>::max(); | |
for (size_t k = i; k < j; k++) { | |
ans = std::min(ans, | |
dp[i * n + k] + | |
dp[j * n + k + 1] + | |
a[i] * a[k + 1] * a[j + 1]); | |
} | |
dp[i * n + j] = dp[j * n + i] = ans; | |
} | |
} | |
return dp[0 * n + n - 1]; | |
} | |
int main() { | |
size_t n; | |
while (std::scanf("%zu", &n) == 1) { | |
n++; | |
std::vector<long> a(n); | |
for (size_t i = 0; i < n; i++) | |
scanf("%ld", &a[i]); | |
std::printf("%ld\n", solve(a)); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cstdio> | |
#include <list> | |
#include <vector> | |
struct Fraction { | |
long a, b; | |
Fraction (long aa, long bb = 1) : a(aa), b(bb) {} | |
bool operator < (Fraction rhs) const { | |
return (__int128) a * rhs.b < (__int128) b * rhs.a; | |
} | |
bool operator > (Fraction rhs) const { return rhs < *this; } | |
bool operator <= (Fraction rhs) const { return !(rhs < *this); } | |
bool operator >= (Fraction rhs) const { return !(*this < rhs); } | |
}; | |
struct Arc { | |
size_t first, second; | |
std::list<Arc *> ceiling; | |
std::list<Arc *>::iterator hm_it; | |
long numerator, denominator; | |
Fraction support() const { | |
return Fraction(numerator, denominator); | |
} | |
void init() { | |
hm_it = std::max_element(ceiling.begin(), ceiling.end(), | |
[](const Arc *lhs, const Arc *rhs) { | |
return lhs->support() < rhs->support(); | |
}); | |
} | |
const Arc *get_hm() const { | |
return *hm_it; | |
} | |
Arc *get_hm() { | |
return const_cast<Arc *>(const_cast<const Arc *>(this)->get_hm()); | |
} | |
void merge_hm() { | |
auto it = hm_it; | |
Arc *hm = *it; | |
it = ceiling.erase(it); | |
ceiling.splice(it, hm->ceiling); | |
hm_it = std::max_element(ceiling.begin(), ceiling.end(), | |
[](const Arc *lhs, const Arc *rhs) { | |
return lhs->support() < rhs->support(); | |
}); | |
} | |
}; | |
void get_arcs(const std::vector<long> &a, std::vector<Arc> &arcs) { | |
size_t n = a.size() - 1; | |
std::vector<long> v; | |
std::vector<Arc *> w; | |
for (size_t i = 0, j = 0; i <= n && j < n - 3; i++) { | |
while (j < n - 3 && v.size() >= 2 && a[i] <= a[v.back()]) { | |
arcs[j].first = v[v.size() - 2]; | |
arcs[j].second = i; | |
while (!w.empty() && | |
arcs[j].first <= w.back()->first && | |
w.back()->second <= arcs[j].second) { | |
arcs[j].ceiling.push_front(w.back()); | |
w.pop_back(); | |
} | |
w.push_back(&arcs[j]); | |
j++; | |
v.pop_back(); | |
} | |
v.push_back(i); | |
} | |
arcs[n - 3].first = 0; | |
arcs[n - 3].second = n; | |
arcs[n - 3].ceiling.assign(w.begin(), w.end()); | |
} | |
long solve(std::vector<long> a) { | |
size_t n = a.size(); | |
if (n <= 2) | |
return 0; | |
if (n == 3) | |
return a[0] * a[1] * a[2]; | |
std::rotate(a.begin(), std::min_element(a.begin(), a.end()), a.end()); | |
a.push_back(a[0]); | |
std::vector<long> accum(n + 1); | |
for (size_t i = 1; i <= n; i++) | |
accum[i] = accum[i - 1] + a[i] * a[i - 1]; | |
std::vector<Arc> arcs(n - 2); | |
get_arcs(a, arcs); | |
long ans = 0; | |
for (Arc &arc : arcs) { | |
if (arc.first + 2 == arc.second) { | |
// leaf nodes | |
arc.numerator = a[arc.first] * a[arc.first + 1] * a[arc.second]; | |
arc.denominator = accum[arc.second] - accum[arc.first] - | |
a[arc.first] * a[arc.second]; | |
ans += arc.numerator; | |
continue; | |
} | |
arc.init(); | |
arc.denominator = accum[arc.second] - accum[arc.first] - | |
a[arc.first] * a[arc.second]; | |
for (Arc *jp : arc.ceiling) { | |
size_t jf = jp->first, js = jp->second; | |
arc.denominator -= accum[js] - accum[jf] - a[jf] * a[js]; | |
} | |
// step 1 | |
while (!arc.ceiling.empty() && arc.get_hm()->support() >= | |
std::min(a[arc.first], a[arc.second])) { | |
Arc &hm = *arc.get_hm(); | |
ans -= hm.numerator; | |
arc.denominator += hm.denominator; | |
arc.merge_hm(); | |
} | |
// calculate support | |
size_t c1 = a[arc.first] <= a[arc.second] ? arc.first : arc.second; | |
size_t c2 = c1 == 0 ? n : c1; | |
arc.numerator = arc.denominator + a[arc.first] * a[arc.second]; | |
if (arc.first == c1) | |
arc.numerator -= a[c1] * a[c1 + 1]; | |
if (arc.second == c2) | |
arc.numerator -= a[c2] * a[c2 - 1]; | |
if (!arc.ceiling.empty()) { | |
Arc *jp = arc.ceiling.front(); | |
if (jp->first == c1) { | |
arc.numerator += a[c1] * a[c1 + 1]; | |
arc.numerator -= a[jp->first] * a[jp->second]; | |
} | |
Arc *jq = arc.ceiling.back(); | |
if (jq->second == c2) { | |
arc.numerator += a[c2] * a[c2 - 1]; | |
arc.numerator -= a[jq->first] * a[jq->second]; | |
} | |
} | |
arc.numerator *= a[c1]; | |
ans += arc.numerator; | |
// step 2 | |
while (!arc.ceiling.empty() && | |
arc.support() <= arc.get_hm()->support()) { | |
Arc &hm = *arc.get_hm(); | |
arc.numerator += hm.numerator; | |
arc.denominator += hm.denominator; | |
arc.merge_hm(); | |
} | |
} | |
return ans; | |
} | |
template<typename T> | |
int strict_read_int(T &ref) { | |
int c = std::getchar(); | |
if (c == EOF) | |
return EOF; | |
T t = c - '0'; | |
while ((c = std::getchar()) >= '0' && c <= '9') | |
t = t * 10 + c - '0'; | |
ref = t; | |
return 1; | |
} | |
int main() { | |
for (size_t n; strict_read_int(n) == 1; ) { | |
n++; | |
std::vector<long> a(n); | |
for (size_t i = 0; i < n; i++) | |
strict_read_int(a[i]); | |
std::printf("%ld\n", solve(a)); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- quadratic.cpp 2018-05-17 15:11:14.723852886 +0800 | |
+++ linearithmic.cpp 2018-05-17 15:10:39.373853655 +0800 | |
@@ -3,6 +3,9 @@ | |
#include <list> | |
#include <vector> | |
+#include <ext/pool_allocator.h> | |
+#include <ext/pb_ds/priority_queue.hpp> | |
+ | |
struct Fraction { | |
long a, b; | |
Fraction (long aa, long bb = 1) : a(aa), b(bb) {} | |
@@ -14,38 +17,51 @@ | |
bool operator >= (Fraction rhs) const { return !(*this < rhs); } | |
}; | |
+struct Arc; | |
+ | |
+struct Arc_ite_cmp { | |
+ bool operator () (const std::list<Arc *>::iterator lhs, | |
+ const std::list<Arc *>::iterator rhs) const; | |
+}; | |
+ | |
+template<typename T> | |
+using Alloc = __gnu_cxx::__pool_alloc<T>; | |
struct Arc { | |
size_t first, second; | |
- std::list<Arc *> ceiling; | |
- std::list<Arc *>::iterator hm_it; | |
+ std::list<Arc *, Alloc<Arc *>> ceiling; | |
+ using iterator = decltype(ceiling)::iterator; | |
+ __gnu_pbds::priority_queue<iterator, Arc_ite_cmp, | |
+ __gnu_pbds::pairing_heap_tag, Alloc<iterator>> pq; | |
long numerator, denominator; | |
Fraction support() const { | |
return Fraction(numerator, denominator); | |
} | |
void init() { | |
- hm_it = std::max_element(ceiling.begin(), ceiling.end(), | |
- [](const Arc *lhs, const Arc *rhs) { | |
- return lhs->support() < rhs->support(); | |
- }); | |
+ for (auto it = ceiling.begin(); it != ceiling.end(); ++it) | |
+ pq.push(it); | |
} | |
const Arc *get_hm() const { | |
- return *hm_it; | |
+ return *pq.top(); | |
} | |
Arc *get_hm() { | |
return const_cast<Arc *>(const_cast<const Arc *>(this)->get_hm()); | |
} | |
void merge_hm() { | |
- auto it = hm_it; | |
+ auto it = pq.top(); | |
Arc *hm = *it; | |
+ pq.pop(); | |
+ pq.join(hm->pq); | |
it = ceiling.erase(it); | |
ceiling.splice(it, hm->ceiling); | |
- hm_it = std::max_element(ceiling.begin(), ceiling.end(), | |
- [](const Arc *lhs, const Arc *rhs) { | |
- return lhs->support() < rhs->support(); | |
- }); | |
} | |
}; | |
+inline bool Arc_ite_cmp::operator () ( | |
+ const std::list<Arc *>::iterator lhs, | |
+ const std::list<Arc *>::iterator rhs) const { | |
+ return (*lhs)->support() < (*rhs)->support(); | |
+} | |
+ | |
void get_arcs(const std::vector<long> &a, std::vector<Arc> &arcs) { | |
size_t n = a.size() - 1; | |
std::vector<long> v; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse, random | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("-t", default=200, type=int, | |
help="number of test cases") | |
parser.add_argument("-n", default=1000, type=int, | |
help="number of matrices") | |
parser.add_argument("-c", default=4000, type=int, | |
help="max columns/rows") | |
parser.add_argument("-s", default=None, type=int, | |
help="random seed") | |
args = parser.parse_args() | |
if args.s is not None: | |
random.seed(args.s, version=2) | |
for _ in range(args.t): | |
print(args.n) | |
print(*(random.randint(1, args.c) for _ in range(args.n + 1))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse, itertools | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("-n", default=200000, type=int, | |
help="number of matrices") | |
args = parser.parse_args() | |
l1 = [1, 2] | |
l2 = [2] | |
for i in range(args.n - 2): | |
num = 3 + i // 2 | |
if i % 4 == 0: | |
l1.insert(-1, num * 2) | |
elif i % 4 == 1: | |
l2.append(num) | |
elif i % 4 == 2: | |
l2.insert(-1, num * 2) | |
else: | |
l1.append(num) | |
print(args.n) | |
print(*itertools.chain(l1, reversed(l2))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment