youtubebeat/vendor/github.com/gocolly/colly/http_backend.go

228 lines
5.5 KiB
Go
Raw Normal View History

2018-11-18 15:32:28 +01:00
// Copyright 2018 Adam Tauber
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package colly
import (
"crypto/sha1"
"encoding/gob"
"encoding/hex"
"io"
"io/ioutil"
"math/rand"
"net/http"
"os"
"path"
"regexp"
"sync"
"time"
"compress/gzip"
"github.com/gobwas/glob"
)
type httpBackend struct {
LimitRules []*LimitRule
Client *http.Client
lock *sync.RWMutex
}
// LimitRule provides connection restrictions for domains.
// Both DomainRegexp and DomainGlob can be used to specify
// the included domains patterns, but at least one is required.
// There can be two kind of limitations:
// - Parallelism: Set limit for the number of concurrent requests to matching domains
// - Delay: Wait specified amount of time between requests (parallelism is 1 in this case)
type LimitRule struct {
// DomainRegexp is a regular expression to match against domains
DomainRegexp string
// DomainRegexp is a glob pattern to match against domains
DomainGlob string
// Delay is the duration to wait before creating a new request to the matching domains
Delay time.Duration
// RandomDelay is the extra randomized duration to wait added to Delay before creating a new request
RandomDelay time.Duration
// Parallelism is the number of the maximum allowed concurrent requests of the matching domains
Parallelism int
waitChan chan bool
compiledRegexp *regexp.Regexp
compiledGlob glob.Glob
}
// Init initializes the private members of LimitRule
func (r *LimitRule) Init() error {
waitChanSize := 1
if r.Parallelism > 1 {
waitChanSize = r.Parallelism
}
r.waitChan = make(chan bool, waitChanSize)
hasPattern := false
if r.DomainRegexp != "" {
c, err := regexp.Compile(r.DomainRegexp)
if err != nil {
return err
}
r.compiledRegexp = c
hasPattern = true
}
if r.DomainGlob != "" {
c, err := glob.Compile(r.DomainGlob)
if err != nil {
return err
}
r.compiledGlob = c
hasPattern = true
}
if !hasPattern {
return ErrNoPattern
}
return nil
}
func (h *httpBackend) Init(jar http.CookieJar) {
rand.Seed(time.Now().UnixNano())
h.Client = &http.Client{
Jar: jar,
Timeout: 10 * time.Second,
}
h.lock = &sync.RWMutex{}
}
// Match checks that the domain parameter triggers the rule
func (r *LimitRule) Match(domain string) bool {
match := false
if r.compiledRegexp != nil && r.compiledRegexp.MatchString(domain) {
match = true
}
if r.compiledGlob != nil && r.compiledGlob.Match(domain) {
match = true
}
return match
}
func (h *httpBackend) GetMatchingRule(domain string) *LimitRule {
if h.LimitRules == nil {
return nil
}
h.lock.RLock()
defer h.lock.RUnlock()
for _, r := range h.LimitRules {
if r.Match(domain) {
return r
}
}
return nil
}
func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string) (*Response, error) {
if cacheDir == "" || request.Method != "GET" {
return h.Do(request, bodySize)
}
sum := sha1.Sum([]byte(request.URL.String()))
hash := hex.EncodeToString(sum[:])
dir := path.Join(cacheDir, hash[:2])
filename := path.Join(dir, hash)
if file, err := os.Open(filename); err == nil {
resp := new(Response)
err := gob.NewDecoder(file).Decode(resp)
file.Close()
if resp.StatusCode < 500 {
return resp, err
}
}
resp, err := h.Do(request, bodySize)
if err != nil || resp.StatusCode >= 500 {
return resp, err
}
if _, err := os.Stat(dir); err != nil {
if err := os.MkdirAll(dir, 0750); err != nil {
return resp, err
}
}
file, err := os.Create(filename + "~")
if err != nil {
return resp, err
}
if err := gob.NewEncoder(file).Encode(resp); err != nil {
file.Close()
return resp, err
}
file.Close()
return resp, os.Rename(filename+"~", filename)
}
func (h *httpBackend) Do(request *http.Request, bodySize int) (*Response, error) {
r := h.GetMatchingRule(request.URL.Host)
if r != nil {
r.waitChan <- true
defer func(r *LimitRule) {
randomDelay := time.Duration(0)
if r.RandomDelay != 0 {
randomDelay = time.Duration(rand.Int63n(int64(r.RandomDelay)))
}
time.Sleep(r.Delay + randomDelay)
<-r.waitChan
}(r)
}
res, err := h.Client.Do(request)
if err != nil {
return nil, err
}
if res.Request != nil {
*request = *res.Request
}
var bodyReader io.Reader = res.Body
if bodySize > 0 {
bodyReader = io.LimitReader(bodyReader, int64(bodySize))
}
if !res.Uncompressed && res.Header.Get("Content-Encoding") == "gzip" {
bodyReader, err = gzip.NewReader(bodyReader)
if err != nil {
return nil, err
}
}
body, err := ioutil.ReadAll(bodyReader)
defer res.Body.Close()
if err != nil {
return nil, err
}
return &Response{
StatusCode: res.StatusCode,
Body: body,
Headers: &res.Header,
}, nil
}
func (h *httpBackend) Limit(rule *LimitRule) error {
h.lock.Lock()
if h.LimitRules == nil {
h.LimitRules = make([]*LimitRule, 0, 8)
}
h.LimitRules = append(h.LimitRules, rule)
h.lock.Unlock()
return rule.Init()
}
func (h *httpBackend) Limits(rules []*LimitRule) error {
for _, r := range rules {
if err := h.Limit(r); err != nil {
return err
}
}
return nil
}