15-859(B) Machine Learning Theory 04/09/02 Learning over the uniform distribution: monotone boolean functions * Lower bound * Upper bound =========================================================================== Preliminaries ============= A while back, in talking about SQ algorithms, we mentioned that you can *weak-learn* any monotone function over {0,1}^n under the *uniform distribution*. Idea: First of all, we can assume target function f is roughly 50/50 or else we're done. Now, think of f as labeling each point in hypercube as + or - . If f is roughly 50/50, then (whether monotone or not) we showed there must be Omega(2^n) edges in the hypercube with one endpoint labeled + and one endpoint labeled -. (We showed this using a "canonical path" argument, but there are other proofs too.) This means there must be some coordinate direction i with Omega((2^n)/n) of these edges. Finally, since f is monotone, this means x_i has correlation Omega(1/n). What about hardness results? Think of a function that is negative on any x with < n/2 ones in it, positive on any x with > n/2 ones in it, and assigns *random* values to the "slice" of x's with exactly n/2 ones. Since that middle slice has roughly 1/sqrt(n) of the probability mass, then it's clear that no algorithm will be able to get correlation better than 1 - 1/sqrt(n) from looking at only a polynomial number of examples. Today: can we bring these bounds closer together? (yes) Results for today ================= -> There's a simple algorithm that uses random examples only and gives correlation Omega(1/sqrt(n)). -> No algorithm (even with membership queries) can get correlation omega(log(n)/sqrt(n)). In fact, the algorithm is really simple: first, see if f is mostly postive or mostly negative. If not, then the claim is the majority function MAJ(x1,...,xn) must have correlation Omega(1/sqrt(n)). We'll get to proving that later... [BTW, I'm going to be dropping the Omegas and omegas for simplicity.] Big open problem: can you learn monotone functions (w/o MQs) in time polynomial in their DNF-size? Remember that with MQs, this is easy, and in fact wrt the uniform distribution, with MQs you can learn *arbitrary* functions in time poly in their DNF-size (this is what Shuchi and Nikhil presented last time). Lower bound (i.e., hardness result) =================================== - Let U = unif dist on {0,1}^n. The lower bound will hold even for algorithms that ask membership queries. So in fact, can assume that the algorithm *only* makes queries (since you can simulate random examples over U with MQs). - We'll prove the following theorem: For any # of queries s, there exists a distrib P_s over monotone functions such that for any alg A asking s queries, Pr_{c <- P_s, x <- U} (A makes mistake on x) >= 1/2 - O(log(sn)/sqrt(n)) - Here's the distribution over target functions. First of all, Let t = lg(3sn). We'll put each conjunction of size t into the function (which we will think of as a monotone DNF) independently with probability p, where p is defined so that an example with n/2 ones has a 50/50 chance of being positive. I.e., (1-p)^{n/2 choose t} = 1/2. - let |x| be the number of 1s in x. - To prove our theorem, consider an augmented MQ oracle that, when you give it some x, looks through all the {|x| choose t} possible conjuctions that x might satisfy in lexicographic order, and outputs the first one actually appears in the target, or else outputs "0" if none are in the target. Notice that this is *more* information than a standard MQ oracle that just outputs "1" if x is positive. This MQ oracle actually spits out one of the terms (the smallest one lexicographically that x satisfies). Can anyone think why this might be useful for analysis? - The nice thing about this MQ oracle is that we have a clean description of the current probability distribution over target functions, conditioned on the results of queries so far. We can write it as a vector V with one entry for each possible conjunct of size t. Initially, every entry is set to p. When a query is made, we can think of the adversary as flipping the coins for the terms in lexicographic order, and telling us the result of each flip, until it either gets a heads or else it runs out of coins. We can then set the correspoding entries in our vector to 0 or 1. - On any indiv query, expected number of coins flipped <= 1/p. Why? - in s queries, the expected number flipped is <= s/p. # of heads is <= s for sure. We can also argue that whp, the total number of coins flipped is <= 2s/p. This is easier to see by thinking in the other direction: if we flipped s/p coins of bias p, we'd expect s heads. If we flipped 2s/p, we'd expect 2s heads. By Chernoff, Pr(s or fewer heads) < e^{-(2s/8)}. So, this means that in our actual experiment, Pr(we make 2s/p or more flips) is also < e^{-(s/4)}. - The rest of the proof is to show that (A) with zero queries, we couldn't get much advantage, and (B) with s queries, our advantage hasn't changed by much. - define: V(x) = Pr_{c <- V}(x is positive). - Given results of queries so far, what is Bayes-optimal prediction? If V(x) > 1/2 then predict positive, else predict negative. - Claim: For any V with <= 2s/p zeroes, <= s ones, rest equal to p, for a random x, Pr_{x <- U}[|V(x) - 1/2| > (c+1)t/sqrt(n)] <= 1/n + 1/sqrt(n) + 2e^{-c^2}. t = lg(3sn) - This claim immediately gives us a slightly weaker version of the theorem. Just plug in c = sqrt(lg n). We get that there's at most a chance 2/n + 1/sqrt(n) that |V(x) - 1/2| > [lg(3sn)]^{3/2}/sqrt(n). I.e., there's a small chance that we'll be sure about x, but for most x we'll have only a small bias, and overall, our probability of getting x correct is O(log(sn)^{3/2} / sqrt(n)). To get rid of the extra power of log(n)^{1/2} you can just do a telescoping argument over the c's. So, we're almost done. Just need to prove the claim. - Proof of claim: Pick a random x. What is the chance that x satisfies one of the terms we know about? For each one, we get a probability 1/2^t = 1/(3sn). Adding over all of them, we get chance at most 1/3n < 1/n. What about the 0 entries? The expected number of 0-entries that are relevant to x (i.e., contained in x) is at most (2s/p)(1/3sn). By Markov's inequality, there's at most a 1/sqrt(n) chance that the actual number is >= sqrt(n) times larger. Finally, by Hoeffding bounds, there's at most a 2e^{-c^2} chance that |x| is not in [n/2 - c*sqrt(n/2), n/2 + c*sqrt(n/2)] Putting these together, we just need to do a calculation. Here's the point: ideally (from the hardness point of view) the number of coins we would have to flip for x would be exactly {n/2 choose t}, since then V(x) = 1/2 exactly. But it might be less or might be more. It might be *less* because (A) x might have fewer than n/2 ones in it, and (B) because some of those entries might already be 0. Let's look at (B): how much did we lose due to those missing coin flips? The chance they would all have been tails anyway is (1-p)^{sqrt(n)*(2s/p)(1/3sn)} < (1/e)^{1/sqrt(n)} ~ 1 - 1/sqrt(n). The other part is due to (A) - it's just a standard (but annoying) calculation. See paper.... ================================ Upper bound =========== Under the assumption that f is monotone and roughly 50/50, we want to show that f must have correlation Omega(1/sqrt(n)) with the majority function. For simplicity, let's assume n is odd and that f is *exactly* 50/50. This will help get rid of a little clutter. [Assuming n is odd makes the majority function 50/50 also.] Let's define P_f(k) = Pr(f(x) = 1 | x has exactly k 1's in it) [the paper calls this p_k] Since f is a monotone function, it seems reasonable that P_f should be monotone too (i.e., non-decreasing with k). It's true too. Here's a simple proof: think of starting with the all-zeros example and randomly putting 1s into it until you get to the all-ones example. In this process, the example starts as negative and eventually becomes positive, but can't go from positive to negative. Therefore, Pr(example is positive after k steps) <= Pr(example is positive after k+1 steps). But notice that after k steps we are at a random location on the kth level (same for k+1). So, this means P_f(k) <= P_f(k+1). This fact is called the "local LYM inequality" in combinatorics. --> draw out graph with P_f(k), P_MAJ(k), and binomial distribution. Our fact so far means that the MAJ function at least has >= 0 correlation with f (i.e., it can't be negative). In particular, one way to write the correlation between f and MAJ is: n SUM 2*Pr(x has k ones)*[P_f(k) - P_f(n-k)] k=(n+1)/2 To get the result we want, all we need to show is that not only is P_f(k) non-decreasing, but that it actually has to increase at a reasonable rate. In particular, we can use the fact that a constant fraction of the probability mass in the hypercube occurs in the range k in [n/2 + sqrt(n), n/2 + 2*sqrt(n)]. So, to get our result, all we need is that: P_f(n/2 + sqrt(n)) >= P_f(n/2 - sqrt(n)) + c/sqrt(n). For example, consider the case f(x)=x_i. What is P_f(k)? It turns out that analyzing the P_f's for monotone functions is a well-studied problem in combinatorics. Usually the results are stated in terms of the "upper shadow" or "lower shadow" of a set of points at some kth level on the hypercube. In particular, there's a theorem called the Kruskal-Katona theorem, from the 1960s, that gives us what we want. [[Here are some specifics. Probably skip. Corollary of Kruskal-Katona says that for any monotone f, and any i= P_f(i) * (3/2)^{4/sqrt(n)}, and now use the fact that e^x is approximately 1+x for small x. The proof of Kruskal-Katona looks a bit like the proof of Sauer's lemma but messier. ]] Here is a different argument: remember we used canonical paths to argue that if f is 50/50, then at least a 1/n fraction of all edges have one endpoint labeled - and one labeled +. If we can show that this is true for the middle [n/2 - sqrt(n), n/2 + sqrt(n)] region too, then we are done. In particular (since all levels in this range have roughly the same number of edges) it would mean that if you picked a random point at the bottom of the range, and took a random path up to the top of the range, the expected number of (-,+) edges traversed would be c/sqrt(n). This is the same as the chance of traversing one of them since you can't possibly traverse more than 1. We can do this by arguing that if we pick a random pair of nodes (u,v) and take the canonical path, there's a reasonable probability that u is a - point in the middle region, v is a + point in the middle region, and the entire canonical path stays in this region. [Give argument based on flipping 2n coins, that whp you never have >> sqrt(n) more heads than tails.] [Might need to slightly widen the middle region for this argument.]