一道计数类\(DP\)

原题链接

我们可以先计算从左上角到右下角总的路径,再减去经过黑色方格的路径即是答案。

总路径数可以用组合数直接计算:\(C_{H+W-2}^{H-1}\)

因为从左上角到右下角必须走\(H+W-2\)步,而其中必须向右走\(H-1\)步,向下走\(W-1\)步,所以这就相当于是从\(H+W-2\)步中取出\(H-1\)步来向右走,剩下的向下走,这就是一个排列组合问题。

然后先按\(x,y\)递增的顺序对黑色方块进行排序,并设右下角为第\(n+1\)个黑色方块,第\(i\)个方块坐标为\((x_i,y_i)\)。

定义\(f[i]\)表示从左上角走到第\(i\)个黑色方块,且途中不经过其他黑色方块的路径总数。

于是有状态转移方程:

\(\qquad\qquad f[i]=C_{x_i+y_i-2}^{x_i-1}-\sum\limits_{j=1}^{i-1}f[j]\times C_{x_i-x_j+y_i-y_j}^{x_i-x_j},\text{且}x_i\geqslant x_j,y_i\geqslant y_j\)

其中第一个组合数是求从左上角到第\(i\)个黑色方块总的路径数,而这就需要减去这些路径中经过黑色方块的路径数,\(f[j]\)是从左上角到第\(j\)个黑色方块,且途中不经过其他黑色方块的路径数,而后面的组合数即是求从第\(j\)个黑色方块到第\(i\)个黑色方块的总路径数,两者满足乘法原理,乘起来就是从左上角到第\(i\)个黑色方块的路径中经过第\(j\)个黑色方格的路径数,而因为在\(j\)循环的过程中保证了第一个经过的黑色方格不同,所以计数时不会重复,直接累加减去即可。

最后组合数的计算可以先预处理阶乘和对应的逆元来\(O(1)\)计算。

#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
const int M = 2010;
const int mod = 1e9 + 7;
struct dd {
int x, y;
};
dd a[M];
int f[M];
ll inv[N << 1], fac[N << 1];
int re()
{
int x = 0;
char c = getchar();
bool p = 0;
for (; c<'0' || c>'9'; c = getchar())
p = (c == '-' || p) ? 1 : 0;
for (; c >= '0'&&c <= '9'; c = getchar())
x = x * 10 + (c - '0');
return p ? -x : x;
}
int comp(dd x, dd y)
{
if (x.x == y.x)
return x.y < y.y;
return x.x < y.x;
}
int ksm(int x, int y)
{
int s = 1;
for (; y; y >>= 1, x = 1LL * x*x%mod)
if (y & 1)
s = 1LL * s*x%mod;
return s;
}
int C(int x, int y)
{
return fac[y] * inv[x] % mod*inv[y - x] % mod;
}
int main()
{
int i, j, h, w, n, o;
h = re();
w = re();
n = re();
for (i = 1; i <= n; i++)
{
a[i].x = re();
a[i].y = re();
}
sort(a + 1, a + n + 1, comp);
for (fac[0] = i = 1, o = h + w; i <= o; i++)
fac[i] = fac[i - 1] * i%mod;
inv[o] = ksm(fac[o], mod - 2);
for (i = o - 1; i >= 0; i--)
inv[i] = inv[i + 1] * (i + 1) % mod;
a[n + 1].x = h;
a[n + 1].y = w;
for (i = 1; i <= n + 1; i++)
{
f[i] = C(a[i].x - 1, a[i].x + a[i].y - 2);
for (j = 1; j < i; j++)
if (a[j].x <= a[i].x&&a[j].y <= a[i].y)
f[i] = (f[i] - 1LL * f[j] * C(a[i].x - a[j].x, a[i].x + a[i].y - a[j].x - a[j].y) % mod) % mod;
}
printf("%d", (f[n + 1] + mod) % mod);
return 0;
}
05-11 19:50